[
  {
    "path": ".gitattributes",
    "content": "# Normalize Python files to LF line endings\n*.py text eol=lf\n"
  },
  {
    "path": ".github/CODEOWNERS",
    "content": "# Inspired from https://github.com/vllm-project/vllm/blob/main/.github/CODEOWNERS\n\n/unsloth/models/loader.py @danielhanchen @mmathew23\n/unsloth/models/llama.py @Datta0 @danielhanchen @mmathew23\n/unsloth/models/rl.py @Datta0 @pluesclues @danielhanchen\n/unsloth/models/rl_replacements.py @Datta0 @pluesclues @danielhanchen\n/unsloth/trainer.py @danielhanchen\n/unsloth/models/sentence_transformer.py @Etherll @danielhanchen\n/unsloth/save.py @rolandtannous @danielhanchen\n/unsloth/tokenizer_utils.py @mmathew23 @danielhanchen\n/unsloth/chat_templates.py @rolandtannous @danielhanchen\n/unsloth/ollama_template_mappers.py @rolandtannous @danielhanchen\n/unsloth/kernels/moe/*.py @Datta0\n/unsloth/import_fixes.py @danielhanchen\n/unsloth/device_type.py @danielhanchen\n/unsloth/_auto_install.py @danielhanchen\n/unsloth/dataprep/*.py @danielhanchen\n/unsloth/kernels/cross_entropy_loss.py @danielhanchen\n/unsloth/kernels/fast_lora.py @danielhanchen\n/unsloth/kernels/flex_attention.py @danielhanchen\n/unsloth/kernels/fp8.py @Datta0\n/unsloth/kernels/geglu.py @danielhanchen\n/unsloth/kernels/layernorm.py @danielhanchen\n/unsloth/kernels/rms_layernorm.py @danielhanchen\n/unsloth/kernels/rope_embedding.py @danielhanchen\n/unsloth/kernels/swiglu.py @danielhanchen\n/unsloth/kernels/utils.py @danielhanchen @Datta0\n/unsloth/models/_utils.py @danielhanchen @mmathew23\n/unsloth/models/cohere.py @danielhanchen\n/unsloth/models/dpo.py @danielhanchen\n/unsloth/models/falcon_h1.py @danielhanchen\n/unsloth/models/gemma.py @danielhanchen\n/unsloth/models/gemma2.py @danielhanchen\n/unsloth/models/glm4_moe.py @Datta0\n/unsloth/models/granite.py @danielhanchen\n/unsloth/models/llama4.py @danielhanchen\n/unsloth/models/loader_utils.py @Datta0 @danielhanchen\n/unsloth/models/mapper.py @danielhanchen\n/unsloth/models/mistral.py @danielhanchen\n/unsloth/models/qwen2.py @danielhanchen\n/unsloth/models/qwen3.py @Datta0\n/unsloth/models/qwen3_moe.py @Datta0\n/unsloth/models/vision.py @mmathew23 @danielhanchen\n/unsloth/utils/attention_dispatch.py @mmathew23\n/unsloth/utils/hf_hub.py @mmathew23\n/unsloth/utils/packing.py @mmathew23\n\n/cli/ @rolandtannous @Manan17\n/studio/frontend/ @Shine1i @rolandtannous @Manan17\n/studio/frontend/public/ @Shine1i\n/studio/backend/ @rolandtannous\n/studio/backend/core/data_recipe/ @rolandtannous\n/studio/backend/tests/ @rolandtannous @danielhanchen\n/tests/ @rolandtannous @danielhanchen\n/scripts/ @rolandtannous @danielhanchen\n"
  },
  {
    "path": ".github/FUNDING.yml",
    "content": "# These are supported funding model platforms\n\ngithub: unslothai\npatreon: # Replace with a single Patreon username\nopen_collective: # Replace with a single Open Collective username\nko_fi: # unsloth\ntidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel\ncommunity_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry\nliberapay: # Replace with a single Liberapay username\nissuehunt: # Replace with a single IssueHunt username\notechie: # Replace with a single Otechie username\nlfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry\ncustom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/bug---issue.md",
    "content": "---\nname: Bug / Issue\nabout: Bug / Issue\ntitle: \"[Bug] Please fill in your issue title here.\"\nlabels: bug\nassignees: ''\n\n---\n\n1. Did you update? `pip install --upgrade unsloth unsloth_zoo`\n2. `Colab` or `Kaggle` or local / cloud\n3. Number GPUs used, use `nvidia-smi`\n4. Which notebook? Please link!\n5. Which Unsloth version, TRL version, transformers version, PyTorch version?\n6. Which trainer? `SFTTrainer`, `GRPOTrainer` etc\n\n```python\nPut Minimal code to reproduce error here ###Remove Hugging Face token###\n```\n\n🦥 You can also ask via our Reddit page: https://reddit.com/r/unsloth/\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/feature-request.md",
    "content": "---\nname: Feature Request\nabout: New features, model support, ideas\ntitle: \"[Feature]\"\nlabels: feature request\nassignees: ''\n\n---\n\nFor new models, have you tried:\n```python\nfrom unsloth import FastModel\nmodel, tokenizer = FastModel.from_pretrained(\n    \"microsoft/Phi-4-multimodal-instruct\",\n    trust_remote_code = True,\n)\nfrom transformers import AutoModelForSequenceClassification\nmodel, tokenizer = FastModel.from_pretrained(\n    auto_model = AutoModelForSequenceClassification,\n)\n```\n"
  },
  {
    "path": ".github/workflows/stale.yml",
    "content": "name: 'Inactive Issue Pinger'\n\non:\n  schedule:\n    - cron: '30 5 * * *' # Runs at 5:30 UTC every day\n\njobs:\n  stale:\n    runs-on: ubuntu-latest\n    permissions:\n      issues: write\n\n    steps:\n      - uses: actions/stale@v10\n        with:\n          # The message to post on stale issues.\n          # This message will ping the issue author.\n          # Note: The stale bot action does not currently support a direct placeholder for the last commenter.\n          # As a workaround, this message encourages any participant to reply.\n          stale-issue-message: >\n            Is this issue still important to you?\n            Apologies in advance we might have missed this issue as well.\n            For faster response times, please post on our Reddit server - https://www.reddit.com/r/unsloth or our Discord - https://discord.com/invite/unsloth \n\n          # The number of days of inactivity before an issue is considered stale.\n          days-before-issue-stale: 9999\n\n          # Set to -1 to never close stale issues.\n          days-before-issue-close: -1\n\n          # A label to apply to stale issues.\n          stale-issue-label: 'inactive'\n\n          # The number of operations to perform per run to avoid rate limiting.\n          operations-per-run: 500\n\n          enable-statistics: false\n"
  },
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*.class\nunsloth_compiled_cache/\n# ML artifacts (large files)\nfeature/\noutputs/\nexports/\n/datasets/\nstudio/backend/assets/datasets/\nunsloth_training_checkpoints/\n*.gguf\n*.safetensors\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\ncover/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\n.pybuilder/\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n#   For a library or package, you might want to ignore these files since the code is\n#   intended to run in multiple environments; otherwise, check them in:\n# .python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# UV\n#   Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.\n#   This is especially recommended for binary packages to ensure reproducibility, and is more\n#   commonly ignored for libraries.\n#uv.lock\n\n# poetry\n#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.\n#   This is especially recommended for binary packages to ensure reproducibility, and is more\n#   commonly ignored for libraries.\n#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control\n#poetry.lock\n\n# pdm\n#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.\n#pdm.lock\n#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it\n#   in version control.\n#   https://pdm.fming.dev/latest/usage/project/#working-with-version-control\n.pdm.toml\n.pdm-python\n.pdm-build/\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n.venv_overlay/\n.venv_t5/\nenvironment.yaml\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# pytype static type analyzer\n.pytype/\n\n# Cython debug symbols\ncython_debug/\n\n# PyCharm\n#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can\n#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore\n#  and can be added to the global gitignore or merged into this file.  For a more nuclear\n#  option (not recommended) you can uncomment the following to ignore the entire idea folder.\n#.idea/\n\n# Ruff stuff:\n.ruff_cache/\n.pre-commit-cache/\n\n# PyPI configuration file and IDE/Editors\n.pypirc\n.vscode\n.idea/\n.claude/\n*.swp\n*.swo\n\n# oh-my-codex\n.omx/\n\n# Firebase\nfirebase-debug.log\n\n# Other\nresources/\ntmp/\n**/node_modules/\nauth.db\n\n# Local working docs\n**/CLAUDE.md\n**/claude.md\n**/AGENT.md\n**/agent.md\ndocs/canvas-lab-architecture.md\nlog_rtx.txt\nlog.txt\nsetup_leo.sh\nserver.pid\n*.log\npackage-lock.json\n"
  },
  {
    "path": ".pre-commit-ci.yaml",
    "content": "ci:\n  autofix_prs: true\n  autofix_prs_limit: 5\n  autoupdate_schedule: monthly\n  autoupdate_commit_msg: \"chore: pre-commit autoupdate\"\n  skip: []\n"
  },
  {
    "path": ".pre-commit-config.yaml",
    "content": "repos:\n  - repo: https://github.com/astral-sh/ruff-pre-commit\n    rev: v0.15.6\n    hooks:\n      - id: ruff\n        args:\n          - --fix\n          - --exit-non-zero-on-fix\n  - repo: local\n    hooks:\n      - id: ruff-format-with-kwargs\n        name: Ruff format with kwarg spacing\n        entry: scripts/run_ruff_format.py\n        language: python\n        types: [python]\n        additional_dependencies:\n          - ruff==0.6.9\n"
  },
  {
    "path": "CODE_OF_CONDUCT.md",
    "content": "\n# Contributor Covenant Code of Conduct\n\n## Our Pledge\n\nWe as members, contributors, and leaders pledge to make participation in our\ncommunity a harassment-free experience for everyone, regardless of age, body\nsize, visible or invisible disability, ethnicity, sex characteristics, gender\nidentity and expression, level of experience, education, socio-economic status,\nnationality, personal appearance, race, caste, color, religion, or sexual\nidentity and orientation.\n\nWe pledge to act and interact in ways that contribute to an open, welcoming,\ndiverse, inclusive, and healthy community.\n\n## Our Standards\n\nExamples of behavior that contributes to a positive environment for our\ncommunity include:\n\n* Demonstrating empathy and kindness toward other people\n* Being respectful of differing opinions, viewpoints, and experiences\n* Giving and gracefully accepting constructive feedback\n* Accepting responsibility and apologizing to those affected by our mistakes,\n  and learning from the experience\n* Focusing on what is best not just for us as individuals, but for the overall\n  community\n\nExamples of unacceptable behavior include:\n\n* The use of sexualized language or imagery, and sexual attention or advances of\n  any kind\n* Trolling, insulting or derogatory comments, and personal or political attacks\n* Public or private harassment\n* Publishing others' private information, such as a physical or email address,\n  without their explicit permission\n* Other conduct which could reasonably be considered inappropriate in a\n  professional setting\n\n## Enforcement Responsibilities\n\nCommunity leaders are responsible for clarifying and enforcing our standards of\nacceptable behavior and will take appropriate and fair corrective action in\nresponse to any behavior that they deem inappropriate, threatening, offensive,\nor harmful.\n\nCommunity leaders have the right and responsibility to remove, edit, or reject\ncomments, commits, code, wiki edits, issues, and other contributions that are\nnot aligned to this Code of Conduct, and will communicate reasons for moderation\ndecisions when appropriate.\n\n## Scope\n\nThis Code of Conduct applies within all community spaces, and also applies when\nan individual is officially representing the community in public spaces.\nExamples of representing our community include using an official e-mail address,\nposting via an official social media account, or acting as an appointed\nrepresentative at an online or offline event.\n\n## Enforcement\n\nInstances of abusive, harassing, or otherwise unacceptable behavior may be\nreported to the community leaders responsible for enforcement at support@unsloth.ai.\nAll complaints will be reviewed and investigated promptly and fairly.\n\nAll community leaders are obligated to respect the privacy and security of the\nreporter of any incident.\n\n## Enforcement Guidelines\n\nCommunity leaders will follow these Community Impact Guidelines in determining\nthe consequences for any action they deem in violation of this Code of Conduct:\n\n### 1. Correction\n\n**Community Impact**: Use of inappropriate language or other behavior deemed\nunprofessional or unwelcome in the community.\n\n**Consequence**: A private, written warning from community leaders, providing\nclarity around the nature of the violation and an explanation of why the\nbehavior was inappropriate. A public apology may be requested.\n\n### 2. Warning\n\n**Community Impact**: A violation through a single incident or series of\nactions.\n\n**Consequence**: A warning with consequences for continued behavior. No\ninteraction with the people involved, including unsolicited interaction with\nthose enforcing the Code of Conduct, for a specified period of time. This\nincludes avoiding interactions in community spaces as well as external channels\nlike social media. Violating these terms may lead to a temporary or permanent\nban.\n\n### 3. Temporary Ban\n\n**Community Impact**: A serious violation of community standards, including\nsustained inappropriate behavior.\n\n**Consequence**: A temporary ban from any sort of interaction or public\ncommunication with the community for a specified period of time. No public or\nprivate interaction with the people involved, including unsolicited interaction\nwith those enforcing the Code of Conduct, is allowed during this period.\nViolating these terms may lead to a permanent ban.\n\n### 4. Permanent Ban\n\n**Community Impact**: Demonstrating a pattern of violation of community\nstandards, including sustained inappropriate behavior, harassment of an\nindividual, or aggression toward or disparagement of classes of individuals.\n\n**Consequence**: A permanent ban from any sort of public interaction within the\ncommunity.\n\n## Attribution\n\nThis Code of Conduct is adapted from the [Contributor Covenant][homepage],\nversion 2.1, available at\n[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1].\n\nCommunity Impact Guidelines were inspired by\n[Mozilla's code of conduct enforcement ladder][Mozilla CoC].\n\nFor answers to common questions about this code of conduct, see the FAQ at\n[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at\n[https://www.contributor-covenant.org/translations][translations].\n\n[homepage]: https://www.contributor-covenant.org\n[v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html\n[Mozilla CoC]: https://github.com/mozilla/diversity\n[FAQ]: https://www.contributor-covenant.org/faq\n[translations]: https://www.contributor-covenant.org/translations\n"
  },
  {
    "path": "CONTRIBUTING.md",
    "content": "# 🦥 Contributing to Unsloth\n\nThank you for not only using Unsloth but also for being interested in helping out! We value all contributions, whether they come in the form of code, ideas, support for others or just by simply spreading the word of Unsloth! 💕\n\n- **[Support the Community](https://github.com/unslothai/unsloth/issues)**: Answer questions, review pull requests, or assist others in discussions.\n- **Fix Bugs**: Identify and resolve issues with the existing codebase.\n- **Submit Ideas**: Request new features or share enhancements you'd like to see.\n- **Develop Features**: Implement new functionality or improve existing tools which can be done via PRs.\n- **[Improve Documentation](https://docs.unsloth.ai/)**: Help by creating guides, FAQs, or enhancing clarity.\n\nOne of the best ways to support us is by spreading the word about Unsloth! Share how it’s powering your amazing projects in blog posts or social media, and inspire others to explore its potential. Even a simple star on our repo goes a long way in showing your support and helping the community grow. 🌟\n\n## Submitting Issues\nIf you find a bug or have a feature idea, we’d love to hear from you! Here’s how to make your submission stand out:\n\n### Reporting Bugs\n1. **Search First**: Check if the issue has already been reported using GitHub’s search bar under Issues.\n2. **Details Matter**: Is this on Google Colab, Kaggle, or on another platform service? Are you using Unsloth's official notebook? Include your OS, Python version, and other relevant details. For bugs, a concise code snippet that reproduces the issue is incredibly helpful.\n3. **Be Thorough**: Attach screenshots, traceback logs, or any additional information that might speed up resolution.\n\n## Spread the Word\nYour support extends beyond code:\n- Spread the word by writing about Unsloth in blogs or social media.\n- Share how Unsloth powers your projects.\n- Star our repository to show your appreciation.\n\nFinally, please be mindful of our [Code of Conduct](https://github.com/unslothai/unsloth/blob/main/CODE_OF_CONDUCT.md) to ensure a welcoming and inclusive environment for everyone.\n\nThank you so much for reading and we hope you have lots of fun using Unsloth! 🦥\n"
  },
  {
    "path": "COPYING",
    "content": "                    GNU AFFERO GENERAL PUBLIC LICENSE\n                       Version 3, 19 November 2007\n\n Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>\n Everyone is permitted to copy and distribute verbatim copies\n of this license document, but changing it is not allowed.\n\n                            Preamble\n\n  The GNU Affero General Public License is a free, copyleft license for\nsoftware and other kinds of works, specifically designed to ensure\ncooperation with the community in the case of network server software.\n\n  The licenses for most software and other practical works are designed\nto take away your freedom to share and change the works.  By contrast,\nour General Public Licenses are intended to guarantee your freedom to\nshare and change all versions of a program--to make sure it remains free\nsoftware for all its users.\n\n  When we speak of free software, we are referring to freedom, not\nprice.  Our General Public Licenses are designed to make sure that you\nhave the freedom to distribute copies of free software (and charge for\nthem if you wish), that you receive source code or can get it if you\nwant it, that you can change the software or use pieces of it in new\nfree programs, and that you know you can do these things.\n\n  Developers that use our General Public Licenses protect your rights\nwith two steps: (1) assert copyright on the software, and (2) offer\nyou this License which gives you legal permission to copy, distribute\nand/or modify the software.\n\n  A secondary benefit of defending all users' freedom is that\nimprovements made in alternate versions of the program, if they\nreceive widespread use, become available for other developers to\nincorporate.  Many developers of free software are heartened and\nencouraged by the resulting cooperation.  However, in the case of\nsoftware used on network servers, this result may fail to come about.\nThe GNU General Public License permits making a modified version and\nletting the public access it on a server without ever releasing its\nsource code to the public.\n\n  The GNU Affero General Public License is designed specifically to\nensure that, in such cases, the modified source code becomes available\nto the community.  It requires the operator of a network server to\nprovide the source code of the modified version running there to the\nusers of that server.  Therefore, public use of a modified version, on\na publicly accessible server, gives the public access to the source\ncode of the modified version.\n\n  An older license, called the Affero General Public License and\npublished by Affero, was designed to accomplish similar goals.  This is\na different license, not a version of the Affero GPL, but Affero has\nreleased a new version of the Affero GPL which permits relicensing under\nthis license.\n\n  The precise terms and conditions for copying, distribution and\nmodification follow.\n\n                       TERMS AND CONDITIONS\n\n  0. Definitions.\n\n  \"This License\" refers to version 3 of the GNU Affero General Public License.\n\n  \"Copyright\" also means copyright-like laws that apply to other kinds of\nworks, such as semiconductor masks.\n\n  \"The Program\" refers to any copyrightable work licensed under this\nLicense.  Each licensee is addressed as \"you\".  \"Licensees\" and\n\"recipients\" may be individuals or organizations.\n\n  To \"modify\" a work means to copy from or adapt all or part of the work\nin a fashion requiring copyright permission, other than the making of an\nexact copy.  The resulting work is called a \"modified version\" of the\nearlier work or a work \"based on\" the earlier work.\n\n  A \"covered work\" means either the unmodified Program or a work based\non the Program.\n\n  To \"propagate\" a work means to do anything with it that, without\npermission, would make you directly or secondarily liable for\ninfringement under applicable copyright law, except executing it on a\ncomputer or modifying a private copy.  Propagation includes copying,\ndistribution (with or without modification), making available to the\npublic, and in some countries other activities as well.\n\n  To \"convey\" a work means any kind of propagation that enables other\nparties to make or receive copies.  Mere interaction with a user through\na computer network, with no transfer of a copy, is not conveying.\n\n  An interactive user interface displays \"Appropriate Legal Notices\"\nto the extent that it includes a convenient and prominently visible\nfeature that (1) displays an appropriate copyright notice, and (2)\ntells the user that there is no warranty for the work (except to the\nextent that warranties are provided), that licensees may convey the\nwork under this License, and how to view a copy of this License.  If\nthe interface presents a list of user commands or options, such as a\nmenu, a prominent item in the list meets this criterion.\n\n  1. Source Code.\n\n  The \"source code\" for a work means the preferred form of the work\nfor making modifications to it.  \"Object code\" means any non-source\nform of a work.\n\n  A \"Standard Interface\" means an interface that either is an official\nstandard defined by a recognized standards body, or, in the case of\ninterfaces specified for a particular programming language, one that\nis widely used among developers working in that language.\n\n  The \"System Libraries\" of an executable work include anything, other\nthan the work as a whole, that (a) is included in the normal form of\npackaging a Major Component, but which is not part of that Major\nComponent, and (b) serves only to enable use of the work with that\nMajor Component, or to implement a Standard Interface for which an\nimplementation is available to the public in source code form.  A\n\"Major Component\", in this context, means a major essential component\n(kernel, window system, and so on) of the specific operating system\n(if any) on which the executable work runs, or a compiler used to\nproduce the work, or an object code interpreter used to run it.\n\n  The \"Corresponding Source\" for a work in object code form means all\nthe source code needed to generate, install, and (for an executable\nwork) run the object code and to modify the work, including scripts to\ncontrol those activities.  However, it does not include the work's\nSystem Libraries, or general-purpose tools or generally available free\nprograms which are used unmodified in performing those activities but\nwhich are not part of the work.  For example, Corresponding Source\nincludes interface definition files associated with source files for\nthe work, and the source code for shared libraries and dynamically\nlinked subprograms that the work is specifically designed to require,\nsuch as by intimate data communication or control flow between those\nsubprograms and other parts of the work.\n\n  The Corresponding Source need not include anything that users\ncan regenerate automatically from other parts of the Corresponding\nSource.\n\n  The Corresponding Source for a work in source code form is that\nsame work.\n\n  2. Basic Permissions.\n\n  All rights granted under this License are granted for the term of\ncopyright on the Program, and are irrevocable provided the stated\nconditions are met.  This License explicitly affirms your unlimited\npermission to run the unmodified Program.  The output from running a\ncovered work is covered by this License only if the output, given its\ncontent, constitutes a covered work.  This License acknowledges your\nrights of fair use or other equivalent, as provided by copyright law.\n\n  You may make, run and propagate covered works that you do not\nconvey, without conditions so long as your license otherwise remains\nin force.  You may convey covered works to others for the sole purpose\nof having them make modifications exclusively for you, or provide you\nwith facilities for running those works, provided that you comply with\nthe terms of this License in conveying all material for which you do\nnot control copyright.  Those thus making or running the covered works\nfor you must do so exclusively on your behalf, under your direction\nand control, on terms that prohibit them from making any copies of\nyour copyrighted material outside their relationship with you.\n\n  Conveying under any other circumstances is permitted solely under\nthe conditions stated below.  Sublicensing is not allowed; section 10\nmakes it unnecessary.\n\n  3. Protecting Users' Legal Rights From Anti-Circumvention Law.\n\n  No covered work shall be deemed part of an effective technological\nmeasure under any applicable law fulfilling obligations under article\n11 of the WIPO copyright treaty adopted on 20 December 1996, or\nsimilar laws prohibiting or restricting circumvention of such\nmeasures.\n\n  When you convey a covered work, you waive any legal power to forbid\ncircumvention of technological measures to the extent such circumvention\nis effected by exercising rights under this License with respect to\nthe covered work, and you disclaim any intention to limit operation or\nmodification of the work as a means of enforcing, against the work's\nusers, your or third parties' legal rights to forbid circumvention of\ntechnological measures.\n\n  4. Conveying Verbatim Copies.\n\n  You may convey verbatim copies of the Program's source code as you\nreceive it, in any medium, provided that you conspicuously and\nappropriately publish on each copy an appropriate copyright notice;\nkeep intact all notices stating that this License and any\nnon-permissive terms added in accord with section 7 apply to the code;\nkeep intact all notices of the absence of any warranty; and give all\nrecipients a copy of this License along with the Program.\n\n  You may charge any price or no price for each copy that you convey,\nand you may offer support or warranty protection for a fee.\n\n  5. Conveying Modified Source Versions.\n\n  You may convey a work based on the Program, or the modifications to\nproduce it from the Program, in the form of source code under the\nterms of section 4, provided that you also meet all of these conditions:\n\n    a) The work must carry prominent notices stating that you modified\n    it, and giving a relevant date.\n\n    b) The work must carry prominent notices stating that it is\n    released under this License and any conditions added under section\n    7.  This requirement modifies the requirement in section 4 to\n    \"keep intact all notices\".\n\n    c) You must license the entire work, as a whole, under this\n    License to anyone who comes into possession of a copy.  This\n    License will therefore apply, along with any applicable section 7\n    additional terms, to the whole of the work, and all its parts,\n    regardless of how they are packaged.  This License gives no\n    permission to license the work in any other way, but it does not\n    invalidate such permission if you have separately received it.\n\n    d) If the work has interactive user interfaces, each must display\n    Appropriate Legal Notices; however, if the Program has interactive\n    interfaces that do not display Appropriate Legal Notices, your\n    work need not make them do so.\n\n  A compilation of a covered work with other separate and independent\nworks, which are not by their nature extensions of the covered work,\nand which are not combined with it such as to form a larger program,\nin or on a volume of a storage or distribution medium, is called an\n\"aggregate\" if the compilation and its resulting copyright are not\nused to limit the access or legal rights of the compilation's users\nbeyond what the individual works permit.  Inclusion of a covered work\nin an aggregate does not cause this License to apply to the other\nparts of the aggregate.\n\n  6. Conveying Non-Source Forms.\n\n  You may convey a covered work in object code form under the terms\nof sections 4 and 5, provided that you also convey the\nmachine-readable Corresponding Source under the terms of this License,\nin one of these ways:\n\n    a) Convey the object code in, or embodied in, a physical product\n    (including a physical distribution medium), accompanied by the\n    Corresponding Source fixed on a durable physical medium\n    customarily used for software interchange.\n\n    b) Convey the object code in, or embodied in, a physical product\n    (including a physical distribution medium), accompanied by a\n    written offer, valid for at least three years and valid for as\n    long as you offer spare parts or customer support for that product\n    model, to give anyone who possesses the object code either (1) a\n    copy of the Corresponding Source for all the software in the\n    product that is covered by this License, on a durable physical\n    medium customarily used for software interchange, for a price no\n    more than your reasonable cost of physically performing this\n    conveying of source, or (2) access to copy the\n    Corresponding Source from a network server at no charge.\n\n    c) Convey individual copies of the object code with a copy of the\n    written offer to provide the Corresponding Source.  This\n    alternative is allowed only occasionally and noncommercially, and\n    only if you received the object code with such an offer, in accord\n    with subsection 6b.\n\n    d) Convey the object code by offering access from a designated\n    place (gratis or for a charge), and offer equivalent access to the\n    Corresponding Source in the same way through the same place at no\n    further charge.  You need not require recipients to copy the\n    Corresponding Source along with the object code.  If the place to\n    copy the object code is a network server, the Corresponding Source\n    may be on a different server (operated by you or a third party)\n    that supports equivalent copying facilities, provided you maintain\n    clear directions next to the object code saying where to find the\n    Corresponding Source.  Regardless of what server hosts the\n    Corresponding Source, you remain obligated to ensure that it is\n    available for as long as needed to satisfy these requirements.\n\n    e) Convey the object code using peer-to-peer transmission, provided\n    you inform other peers where the object code and Corresponding\n    Source of the work are being offered to the general public at no\n    charge under subsection 6d.\n\n  A separable portion of the object code, whose source code is excluded\nfrom the Corresponding Source as a System Library, need not be\nincluded in conveying the object code work.\n\n  A \"User Product\" is either (1) a \"consumer product\", which means any\ntangible personal property which is normally used for personal, family,\nor household purposes, or (2) anything designed or sold for incorporation\ninto a dwelling.  In determining whether a product is a consumer product,\ndoubtful cases shall be resolved in favor of coverage.  For a particular\nproduct received by a particular user, \"normally used\" refers to a\ntypical or common use of that class of product, regardless of the status\nof the particular user or of the way in which the particular user\nactually uses, or expects or is expected to use, the product.  A product\nis a consumer product regardless of whether the product has substantial\ncommercial, industrial or non-consumer uses, unless such uses represent\nthe only significant mode of use of the product.\n\n  \"Installation Information\" for a User Product means any methods,\nprocedures, authorization keys, or other information required to install\nand execute modified versions of a covered work in that User Product from\na modified version of its Corresponding Source.  The information must\nsuffice to ensure that the continued functioning of the modified object\ncode is in no case prevented or interfered with solely because\nmodification has been made.\n\n  If you convey an object code work under this section in, or with, or\nspecifically for use in, a User Product, and the conveying occurs as\npart of a transaction in which the right of possession and use of the\nUser Product is transferred to the recipient in perpetuity or for a\nfixed term (regardless of how the transaction is characterized), the\nCorresponding Source conveyed under this section must be accompanied\nby the Installation Information.  But this requirement does not apply\nif neither you nor any third party retains the ability to install\nmodified object code on the User Product (for example, the work has\nbeen installed in ROM).\n\n  The requirement to provide Installation Information does not include a\nrequirement to continue to provide support service, warranty, or updates\nfor a work that has been modified or installed by the recipient, or for\nthe User Product in which it has been modified or installed.  Access to a\nnetwork may be denied when the modification itself materially and\nadversely affects the operation of the network or violates the rules and\nprotocols for communication across the network.\n\n  Corresponding Source conveyed, and Installation Information provided,\nin accord with this section must be in a format that is publicly\ndocumented (and with an implementation available to the public in\nsource code form), and must require no special password or key for\nunpacking, reading or copying.\n\n  7. Additional Terms.\n\n  \"Additional permissions\" are terms that supplement the terms of this\nLicense by making exceptions from one or more of its conditions.\nAdditional permissions that are applicable to the entire Program shall\nbe treated as though they were included in this License, to the extent\nthat they are valid under applicable law.  If additional permissions\napply only to part of the Program, that part may be used separately\nunder those permissions, but the entire Program remains governed by\nthis License without regard to the additional permissions.\n\n  When you convey a copy of a covered work, you may at your option\nremove any additional permissions from that copy, or from any part of\nit.  (Additional permissions may be written to require their own\nremoval in certain cases when you modify the work.)  You may place\nadditional permissions on material, added by you to a covered work,\nfor which you have or can give appropriate copyright permission.\n\n  Notwithstanding any other provision of this License, for material you\nadd to a covered work, you may (if authorized by the copyright holders of\nthat material) supplement the terms of this License with terms:\n\n    a) Disclaiming warranty or limiting liability differently from the\n    terms of sections 15 and 16 of this License; or\n\n    b) Requiring preservation of specified reasonable legal notices or\n    author attributions in that material or in the Appropriate Legal\n    Notices displayed by works containing it; or\n\n    c) Prohibiting misrepresentation of the origin of that material, or\n    requiring that modified versions of such material be marked in\n    reasonable ways as different from the original version; or\n\n    d) Limiting the use for publicity purposes of names of licensors or\n    authors of the material; or\n\n    e) Declining to grant rights under trademark law for use of some\n    trade names, trademarks, or service marks; or\n\n    f) Requiring indemnification of licensors and authors of that\n    material by anyone who conveys the material (or modified versions of\n    it) with contractual assumptions of liability to the recipient, for\n    any liability that these contractual assumptions directly impose on\n    those licensors and authors.\n\n  All other non-permissive additional terms are considered \"further\nrestrictions\" within the meaning of section 10.  If the Program as you\nreceived it, or any part of it, contains a notice stating that it is\ngoverned by this License along with a term that is a further\nrestriction, you may remove that term.  If a license document contains\na further restriction but permits relicensing or conveying under this\nLicense, you may add to a covered work material governed by the terms\nof that license document, provided that the further restriction does\nnot survive such relicensing or conveying.\n\n  If you add terms to a covered work in accord with this section, you\nmust place, in the relevant source files, a statement of the\nadditional terms that apply to those files, or a notice indicating\nwhere to find the applicable terms.\n\n  Additional terms, permissive or non-permissive, may be stated in the\nform of a separately written license, or stated as exceptions;\nthe above requirements apply either way.\n\n  8. Termination.\n\n  You may not propagate or modify a covered work except as expressly\nprovided under this License.  Any attempt otherwise to propagate or\nmodify it is void, and will automatically terminate your rights under\nthis License (including any patent licenses granted under the third\nparagraph of section 11).\n\n  However, if you cease all violation of this License, then your\nlicense from a particular copyright holder is reinstated (a)\nprovisionally, unless and until the copyright holder explicitly and\nfinally terminates your license, and (b) permanently, if the copyright\nholder fails to notify you of the violation by some reasonable means\nprior to 60 days after the cessation.\n\n  Moreover, your license from a particular copyright holder is\nreinstated permanently if the copyright holder notifies you of the\nviolation by some reasonable means, this is the first time you have\nreceived notice of violation of this License (for any work) from that\ncopyright holder, and you cure the violation prior to 30 days after\nyour receipt of the notice.\n\n  Termination of your rights under this section does not terminate the\nlicenses of parties who have received copies or rights from you under\nthis License.  If your rights have been terminated and not permanently\nreinstated, you do not qualify to receive new licenses for the same\nmaterial under section 10.\n\n  9. Acceptance Not Required for Having Copies.\n\n  You are not required to accept this License in order to receive or\nrun a copy of the Program.  Ancillary propagation of a covered work\noccurring solely as a consequence of using peer-to-peer transmission\nto receive a copy likewise does not require acceptance.  However,\nnothing other than this License grants you permission to propagate or\nmodify any covered work.  These actions infringe copyright if you do\nnot accept this License.  Therefore, by modifying or propagating a\ncovered work, you indicate your acceptance of this License to do so.\n\n  10. Automatic Licensing of Downstream Recipients.\n\n  Each time you convey a covered work, the recipient automatically\nreceives a license from the original licensors, to run, modify and\npropagate that work, subject to this License.  You are not responsible\nfor enforcing compliance by third parties with this License.\n\n  An \"entity transaction\" is a transaction transferring control of an\norganization, or substantially all assets of one, or subdividing an\norganization, or merging organizations.  If propagation of a covered\nwork results from an entity transaction, each party to that\ntransaction who receives a copy of the work also receives whatever\nlicenses to the work the party's predecessor in interest had or could\ngive under the previous paragraph, plus a right to possession of the\nCorresponding Source of the work from the predecessor in interest, if\nthe predecessor has it or can get it with reasonable efforts.\n\n  You may not impose any further restrictions on the exercise of the\nrights granted or affirmed under this License.  For example, you may\nnot impose a license fee, royalty, or other charge for exercise of\nrights granted under this License, and you may not initiate litigation\n(including a cross-claim or counterclaim in a lawsuit) alleging that\nany patent claim is infringed by making, using, selling, offering for\nsale, or importing the Program or any portion of it.\n\n  11. Patents.\n\n  A \"contributor\" is a copyright holder who authorizes use under this\nLicense of the Program or a work on which the Program is based.  The\nwork thus licensed is called the contributor's \"contributor version\".\n\n  A contributor's \"essential patent claims\" are all patent claims\nowned or controlled by the contributor, whether already acquired or\nhereafter acquired, that would be infringed by some manner, permitted\nby this License, of making, using, or selling its contributor version,\nbut do not include claims that would be infringed only as a\nconsequence of further modification of the contributor version.  For\npurposes of this definition, \"control\" includes the right to grant\npatent sublicenses in a manner consistent with the requirements of\nthis License.\n\n  Each contributor grants you a non-exclusive, worldwide, royalty-free\npatent license under the contributor's essential patent claims, to\nmake, use, sell, offer for sale, import and otherwise run, modify and\npropagate the contents of its contributor version.\n\n  In the following three paragraphs, a \"patent license\" is any express\nagreement or commitment, however denominated, not to enforce a patent\n(such as an express permission to practice a patent or covenant not to\nsue for patent infringement).  To \"grant\" such a patent license to a\nparty means to make such an agreement or commitment not to enforce a\npatent against the party.\n\n  If you convey a covered work, knowingly relying on a patent license,\nand the Corresponding Source of the work is not available for anyone\nto copy, free of charge and under the terms of this License, through a\npublicly available network server or other readily accessible means,\nthen you must either (1) cause the Corresponding Source to be so\navailable, or (2) arrange to deprive yourself of the benefit of the\npatent license for this particular work, or (3) arrange, in a manner\nconsistent with the requirements of this License, to extend the patent\nlicense to downstream recipients.  \"Knowingly relying\" means you have\nactual knowledge that, but for the patent license, your conveying the\ncovered work in a country, or your recipient's use of the covered work\nin a country, would infringe one or more identifiable patents in that\ncountry that you have reason to believe are valid.\n\n  If, pursuant to or in connection with a single transaction or\narrangement, you convey, or propagate by procuring conveyance of, a\ncovered work, and grant a patent license to some of the parties\nreceiving the covered work authorizing them to use, propagate, modify\nor convey a specific copy of the covered work, then the patent license\nyou grant is automatically extended to all recipients of the covered\nwork and works based on it.\n\n  A patent license is \"discriminatory\" if it does not include within\nthe scope of its coverage, prohibits the exercise of, or is\nconditioned on the non-exercise of one or more of the rights that are\nspecifically granted under this License.  You may not convey a covered\nwork if you are a party to an arrangement with a third party that is\nin the business of distributing software, under which you make payment\nto the third party based on the extent of your activity of conveying\nthe work, and under which the third party grants, to any of the\nparties who would receive the covered work from you, a discriminatory\npatent license (a) in connection with copies of the covered work\nconveyed by you (or copies made from those copies), or (b) primarily\nfor and in connection with specific products or compilations that\ncontain the covered work, unless you entered into that arrangement,\nor that patent license was granted, prior to 28 March 2007.\n\n  Nothing in this License shall be construed as excluding or limiting\nany implied license or other defenses to infringement that may\notherwise be available to you under applicable patent law.\n\n  12. No Surrender of Others' Freedom.\n\n  If conditions are imposed on you (whether by court order, agreement or\notherwise) that contradict the conditions of this License, they do not\nexcuse you from the conditions of this License.  If you cannot convey a\ncovered work so as to satisfy simultaneously your obligations under this\nLicense and any other pertinent obligations, then as a consequence you may\nnot convey it at all.  For example, if you agree to terms that obligate you\nto collect a royalty for further conveying from those to whom you convey\nthe Program, the only way you could satisfy both those terms and this\nLicense would be to refrain entirely from conveying the Program.\n\n  13. Remote Network Interaction; Use with the GNU General Public License.\n\n  Notwithstanding any other provision of this License, if you modify the\nProgram, your modified version must prominently offer all users\ninteracting with it remotely through a computer network (if your version\nsupports such interaction) an opportunity to receive the Corresponding\nSource of your version by providing access to the Corresponding Source\nfrom a network server at no charge, through some standard or customary\nmeans of facilitating copying of software.  This Corresponding Source\nshall include the Corresponding Source for any work covered by version 3\nof the GNU General Public License that is incorporated pursuant to the\nfollowing paragraph.\n\n  Notwithstanding any other provision of this License, you have\npermission to link or combine any covered work with a work licensed\nunder version 3 of the GNU General Public License into a single\ncombined work, and to convey the resulting work.  The terms of this\nLicense will continue to apply to the part which is the covered work,\nbut the work with which it is combined will remain governed by version\n3 of the GNU General Public License.\n\n  14. Revised Versions of this License.\n\n  The Free Software Foundation may publish revised and/or new versions of\nthe GNU Affero General Public License from time to time.  Such new versions\nwill be similar in spirit to the present version, but may differ in detail to\naddress new problems or concerns.\n\n  Each version is given a distinguishing version number.  If the\nProgram specifies that a certain numbered version of the GNU Affero General\nPublic License \"or any later version\" applies to it, you have the\noption of following the terms and conditions either of that numbered\nversion or of any later version published by the Free Software\nFoundation.  If the Program does not specify a version number of the\nGNU Affero General Public License, you may choose any version ever published\nby the Free Software Foundation.\n\n  If the Program specifies that a proxy can decide which future\nversions of the GNU Affero General Public License can be used, that proxy's\npublic statement of acceptance of a version permanently authorizes you\nto choose that version for the Program.\n\n  Later license versions may give you additional or different\npermissions.  However, no additional obligations are imposed on any\nauthor or copyright holder as a result of your choosing to follow a\nlater version.\n\n  15. Disclaimer of Warranty.\n\n  THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY\nAPPLICABLE LAW.  EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT\nHOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM \"AS IS\" WITHOUT WARRANTY\nOF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,\nTHE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR\nPURPOSE.  THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM\nIS WITH YOU.  SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF\nALL NECESSARY SERVICING, REPAIR OR CORRECTION.\n\n  16. Limitation of Liability.\n\n  IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING\nWILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS\nTHE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY\nGENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE\nUSE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF\nDATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD\nPARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),\nEVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF\nSUCH DAMAGES.\n\n  17. Interpretation of Sections 15 and 16.\n\n  If the disclaimer of warranty and limitation of liability provided\nabove cannot be given local legal effect according to their terms,\nreviewing courts shall apply local law that most closely approximates\nan absolute waiver of all civil liability in connection with the\nProgram, unless a warranty or assumption of liability accompanies a\ncopy of the Program in return for a fee.\n\n                     END OF TERMS AND CONDITIONS\n\n            How to Apply These Terms to Your New Programs\n\n  If you develop a new program, and you want it to be of the greatest\npossible use to the public, the best way to achieve this is to make it\nfree software which everyone can redistribute and change under these terms.\n\n  To do so, attach the following notices to the program.  It is safest\nto attach them to the start of each source file to most effectively\nstate the exclusion of warranty; and each file should have at least\nthe \"copyright\" line and a pointer to where the full notice is found.\n\n    <one line to give the program's name and a brief idea of what it does.>\n    Copyright (C) <year>  <name of author>\n\n    This program is free software: you can redistribute it and/or modify\n    it under the terms of the GNU Affero General Public License as published\n    by the Free Software Foundation, either version 3 of the License, or\n    (at your option) any later version.\n\n    This program is distributed in the hope that it will be useful,\n    but WITHOUT ANY WARRANTY; without even the implied warranty of\n    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the\n    GNU Affero General Public License for more details.\n\n    You should have received a copy of the GNU Affero General Public License\n    along with this program.  If not, see <https://www.gnu.org/licenses/>.\n\nAlso add information on how to contact you by electronic and paper mail.\n\n  If your software can interact with users remotely through a computer\nnetwork, you should also make sure that it provides a way for users to\nget its source.  For example, if your program is a web application, its\ninterface could display a \"Source\" link that leads users to an archive\nof the code.  There are many ways you could offer source, and different\nsolutions will be better for different programs; see section 13 for the\nspecific requirements.\n\n  You should also get your employer (if you work as a programmer) or school,\nif any, to sign a \"copyright disclaimer\" for the program, if necessary.\nFor more information on this, and how to apply and follow the GNU AGPL, see\n<https://www.gnu.org/licenses/>.\n\nFiles under unsloth/*, tests/*, scripts/* are Apache 2.0 licensed.\nFiles under studio/*, unsloth_cli/* which is optional to install are AGPLv3 licensed."
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [2024-] [Unsloth AI. Inc team, Daniel Han-Chen & Michael Han-Chen]\n   Files under unsloth/*, tests/*, scripts/* are Apache 2.0 licensed.\n   Files under studio/*, unsloth_cli/* which is optional to install are AGPLv3 licensed.\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "<h1 align=\"center\" style=\"margin:0;\">\n  <a href=\"https://unsloth.ai/docs\"><picture>\n    <source media=\"(prefers-color-scheme: dark)\" srcset=\"https://raw.githubusercontent.com/unslothai/unsloth/main/images/STUDIO%20WHITE%20LOGO.png\">\n    <source media=\"(prefers-color-scheme: light)\" srcset=\"https://raw.githubusercontent.com/unslothai/unsloth/main/images/STUDIO%20BLACK%20LOGO.png\">\n    <img alt=\"Unsloth logo\" src=\"https://raw.githubusercontent.com/unslothai/unsloth/main/images/STUDIO%20BLACK%20LOGO.png\" height=\"60\" style=\"max-width:100%;\">\n  </picture></a>\n</h1>\n<h3 align=\"center\" style=\"margin: 0; margin-top: 0;\">\nRun and train AI models with a unified local interface.\n</h3>\n\n<p align=\"center\">\n  <a href=\"#-features\">Features</a> •\n  <a href=\"#-quickstart\">Quickstart</a> •\n  <a href=\"#-free-notebooks\">Notebooks</a> •\n  <a href=\"https://unsloth.ai/docs\">Documentation</a> •\n  <a href=\"https://discord.com/invite/unsloth\">Discord</a>\n</p>\n <a href=\"https://unsloth.ai/docs/new/studio\">\n<img alt=\"unsloth studio ui homepage\" src=\"https://raw.githubusercontent.com/unslothai/unsloth/main/studio/frontend/public/studio%20github%20landscape%20colab%20display.png\" style=\"max-width: 100%; margin-bottom: 0;\"></a>\n\nUnsloth Studio (Beta) lets you run and train text, [audio](https://unsloth.ai/docs/basics/text-to-speech-tts-fine-tuning), [embedding](https://unsloth.ai/docs/new/embedding-finetuning), [vision](https://unsloth.ai/docs/basics/vision-fine-tuning) models on Windows, Linux and macOS.\n\n## ⭐ Features\nUnsloth provides several key features for both inference and training:\n### Inference\n* **Search + download + run models** including GGUF, LoRA adapters, safetensors\n* **Export models**: [Save or export](https://unsloth.ai/docs/new/studio/export) models to GGUF, 16-bit safetensors and other formats.\n* **Tool calling**: Support for [self-healing tool calling](https://unsloth.ai/docs/new/studio/chat#auto-healing-tool-calling) and web search\n* **[Code execution](https://unsloth.ai/docs/new/studio/chat#code-execution)**: lets LLMs test code in Claude artifacts and sandbox environments\n* [Auto-tune inference parameters](https://unsloth.ai/docs/new/studio/chat#auto-parameter-tuning) and customize chat templates.\n* Upload images, audio, PDFs, code, DOCX and more file types to chat with.\n### Training\n* Train **500+ models** up to **2x faster** with up to **70% less VRAM**, with no accuracy loss.\n* Supports full fine-tuning, pretraining, 4-bit, 16-bit and, FP8 training.\n* **Observability**: Monitor training live, track loss and GPU usage and customize graphs.\n* **Data Recipes**: [Auto-create datasets](https://unsloth.ai/docs/new/studio/data-recipe) from **PDF, CSV, DOCX** etc. Edit data in a visual-node workflow.\n* **Reinforcement Learning**: The most efficient [RL](https://unsloth.ai/docs/get-started/reinforcement-learning-rl-guide) library, using **80% less VRAM** for GRPO, [FP8](https://unsloth.ai/docs/get-started/reinforcement-learning-rl-guide/fp8-reinforcement-learning) etc.\n* [Multi-GPU](https://unsloth.ai/docs/basics/multi-gpu-training-with-unsloth) training is supported, with major improvements coming soon.\n\n## ⚡ Quickstart\nUnsloth can be used in two ways: through **[Unsloth Studio](https://unsloth.ai/docs/new/studio/)**, the web UI, or through **Unsloth Core**, the code-based version. Each has different requirements.\n\n### Unsloth Studio (web UI)\nUnsloth Studio (Beta) works on **Windows, Linux, WSL** and **macOS**.\n\n* **CPU:** Supported for Chat and Data Recipes currently\n* **NVIDIA:** Training works on RTX 30/40/50, Blackwell, DGX Spark, Station and more\n* **macOS:** Currently supports chat and Data Recipes. **MLX training** is coming very soon\n* **AMD:** Chat works. Train with [Unsloth Core](#unsloth-core-code-based). Studio support is coming soon.\n* **Coming soon:** Training support for Apple MLX, AMD, and Intel.\n* **Multi-GPU:** Available now, with a major upgrade on the way\n\n#### MacOS, Linux, WSL Setup:\n```bash\ncurl -fsSL https://raw.githubusercontent.com/unslothai/unsloth/main/install.sh | sh\n```\nIf you don't have `curl`, use `wget`. Then to launch after setup:\n```bash\nsource unsloth_studio/bin/activate\nunsloth studio -H 0.0.0.0 -p 8888\n```\n\n#### Windows PowerShell Setup:\n```powershell\nirm https://raw.githubusercontent.com/unslothai/unsloth/main/install.ps1 | iex\n```\nThen to launch after setup:\n```powershell\n& .\\unsloth_studio\\Scripts\\unsloth.exe studio -H 0.0.0.0 -p 8888\n```\n\n#### MacOS, Linux, WSL developer installs:\n```bash\ncurl -LsSf https://astral.sh/uv/install.sh | sh\nuv venv unsloth_studio --python 3.13\nsource unsloth_studio/bin/activate\nuv pip install unsloth --torch-backend=auto\nunsloth studio setup\nunsloth studio -H 0.0.0.0 -p 8888\n```\n\n#### Windows PowerShell developer installs:\n```powershell\nwinget install -e --id Python.Python.3.13\nwinget install --id=astral-sh.uv  -e\nuv venv unsloth_studio --python 3.13\n.\\unsloth_studio\\Scripts\\activate\nuv pip install unsloth --torch-backend=auto\nunsloth studio setup\nunsloth studio -H 0.0.0.0 -p 8888\n```\n\n#### Docker\nUse our [Docker image](https://hub.docker.com/r/unsloth/unsloth) ```unsloth/unsloth``` container. Run:\n```bash\ndocker run -d -e JUPYTER_PASSWORD=\"mypassword\" \\\n  -p 8888:8888 -p 8000:8000 -p 2222:22 \\\n  -v $(pwd)/work:/workspace/work \\\n  --gpus all \\\n  unsloth/unsloth\n  ```\n\n#### Nightly Install - MacOS, Linux, WSL:\n```bash\ncurl -LsSf https://astral.sh/uv/install.sh | sh\ngit clone --filter=blob:none https://github.com/unslothai/unsloth.git unsloth_studio\ncd unsloth_studio\nuv venv --python 3.13\nsource .venv/bin/activate\nuv pip install -e . --torch-backend=auto\nunsloth studio setup\nunsloth studio -H 0.0.0.0 -p 8888\n```\nThen to launch every time:\n```bash\ncd unsloth_studio\nsource .venv/bin/activate\nunsloth studio -H 0.0.0.0 -p 8888\n```\n\n#### Nightly Install - Windows:\nRun in Windows Powershell:\n```bash\nwinget install -e --id Python.Python.3.13\nwinget install --id=astral-sh.uv  -e\ngit clone --filter=blob:none https://github.com/unslothai/unsloth.git unsloth_studio\ncd unsloth_studio\nuv venv --python 3.13\n.\\.venv\\Scripts\\activate\nuv pip install -e . --torch-backend=auto\nunsloth studio setup\nunsloth studio -H 0.0.0.0 -p 8888\n```\nThen to launch every time:\n```bash\ncd unsloth_studio\n.\\.venv\\Scripts\\activate\nunsloth studio -H 0.0.0.0 -p 8888\n```\n\n### Unsloth Core (code-based)\n#### Linux, WSL\n```bash\ncurl -LsSf https://astral.sh/uv/install.sh | sh\nuv venv unsloth_env --python 3.13\nsource unsloth_env/bin/activate\nuv pip install unsloth --torch-backend=auto\n```\n#### Windows Powershell\n```bash\nwinget install -e --id Python.Python.3.13\nwinget install --id=astral-sh.uv  -e\nuv venv unsloth_env --python 3.13\n.\\unsloth_env\\Scripts\\activate\nuv pip install unsloth --torch-backend=auto\n```\nFor Windows, `pip install unsloth` works only if you have Pytorch installed. Read our [Windows Guide](https://unsloth.ai/docs/get-started/install/windows-installation).\nYou can use the same Docker image as Unsloth Studio.\n\n#### AMD, Intel\nFor RTX 50x, B200, 6000 GPUs: `uv pip install unsloth --torch-backend=auto`. Read our guides for: [Blackwell](https://unsloth.ai/docs/blog/fine-tuning-llms-with-blackwell-rtx-50-series-and-unsloth) and [DGX Spark](https://unsloth.ai/docs/blog/fine-tuning-llms-with-nvidia-dgx-spark-and-unsloth). <br>\nTo install Unsloth on **AMD** and **Intel** GPUs, follow our [AMD Guide](https://unsloth.ai/docs/get-started/install/amd) and [Intel Guide](https://unsloth.ai/docs/get-started/install/intel).\n\n## ✨ Free Notebooks\n\nTrain for free with our notebooks. Read our [guide](https://unsloth.ai/docs/get-started/fine-tuning-llms-guide). Add dataset, run, then deploy your trained model.\n\n| Model | Free Notebooks | Performance | Memory use |\n|-----------|---------|--------|----------|\n| **Qwen3.5 (4B)**      | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen3_5_(4B)_Vision.ipynb)               | 1.5x faster | 60% less |\n| **gpt-oss (20B)**      | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/gpt-oss-(20B)-Fine-tuning.ipynb)               | 2x faster | 70% less |\n| **gpt-oss (20B): GRPO**      | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/gpt-oss-(20B)-GRPO.ipynb)               | 2x faster | 80% less |\n| **Qwen3: Advanced GRPO**      | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen3_(4B)-GRPO.ipynb)               | 2x faster | 50% less |\n| **Gemma 3 (4B) Vision** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Gemma3_(4B)-Vision.ipynb)               | 1.7x faster | 60% less |\n| **embeddinggemma (300M)**    | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/EmbeddingGemma_(300M).ipynb)               | 2x faster | 20% less |\n| **Mistral Ministral 3 (3B)**      | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Ministral_3_VL_(3B)_Vision.ipynb)               | 1.5x faster | 60% less |\n| **Llama 3.1 (8B) Alpaca**      | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-Alpaca.ipynb)               | 2x faster | 70% less |\n| **Llama 3.2 Conversational**      | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(1B_and_3B)-Conversational.ipynb)               | 2x faster | 70% less |\n| **Orpheus-TTS (3B)**     | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Orpheus_(3B)-TTS.ipynb)               | 1.5x faster | 50% less |\n\n- See all our notebooks for: [Kaggle](https://github.com/unslothai/notebooks?tab=readme-ov-file#-kaggle-notebooks), [GRPO](https://unsloth.ai/docs/get-started/unsloth-notebooks#grpo-reasoning-rl-notebooks), [TTS](https://unsloth.ai/docs/get-started/unsloth-notebooks#text-to-speech-tts-notebooks), [embedding](https://unsloth.ai/docs/new/embedding-finetuning) & [Vision](https://unsloth.ai/docs/get-started/unsloth-notebooks#vision-multimodal-notebooks)\n- See [all our models](https://unsloth.ai/docs/get-started/unsloth-model-catalog) and [all our notebooks](https://unsloth.ai/docs/get-started/unsloth-notebooks)\n- See detailed documentation for Unsloth [here](https://unsloth.ai/docs)\n\n## 🦥 Unsloth News\n- **Introducing Unsloth Studio**: our new web UI for running and training LLMs. [Blog](https://unsloth.ai/docs/new/studio)\n- **Qwen3.5** - 0.8B, 2B, 4B, 9B, 27B, 35-A3B, 112B-A10B are now supported. [Guide + notebooks](https://unsloth.ai/docs/models/qwen3.5/fine-tune)\n- Train **MoE LLMs 12x faster** with 35% less VRAM - DeepSeek, GLM, Qwen and gpt-oss. [Blog](https://unsloth.ai/docs/new/faster-moe)\n- **Embedding models**: Unsloth now supports ~1.8-3.3x faster embedding fine-tuning. [Blog](https://unsloth.ai/docs/new/embedding-finetuning) • [Notebooks](https://unsloth.ai/docs/get-started/unsloth-notebooks#embedding-models)\n- New **7x longer context RL** vs. all other setups, via our new batching algorithms. [Blog](https://unsloth.ai/docs/new/grpo-long-context)\n- New RoPE & MLP **Triton Kernels** & **Padding Free + Packing**: 3x faster training & 30% less VRAM. [Blog](https://unsloth.ai/docs/new/3x-faster-training-packing)\n- **500K Context**: Training a 20B model with >500K context is now possible on an 80GB GPU. [Blog](https://unsloth.ai/docs/blog/500k-context-length-fine-tuning)\n- **FP8 & Vision RL**: You can now do FP8 & VLM GRPO on consumer GPUs. [FP8 Blog](https://unsloth.ai/docs/get-started/reinforcement-learning-rl-guide/fp8-reinforcement-learning) • [Vision RL](https://unsloth.ai/docs/get-started/reinforcement-learning-rl-guide/vision-reinforcement-learning-vlm-rl)\n- **gpt-oss** by OpenAI: Read our [RL blog](https://unsloth.ai/docs/models/gpt-oss-how-to-run-and-fine-tune/gpt-oss-reinforcement-learning), [Flex Attention](https://unsloth.ai/docs/models/gpt-oss-how-to-run-and-fine-tune/long-context-gpt-oss-training) blog and [Guide](https://unsloth.ai/docs/models/gpt-oss-how-to-run-and-fine-tune).\n\n## 🔗 Links and Resources\n| Type                                                                                                                                      | Links                                                                          |\n| ----------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------ |\n| <img width=\"15\" src=\"https://redditinc.com/hs-fs/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png\" />  **r/unsloth Reddit**                       | [Join Reddit community](https://reddit.com/r/unsloth)                          |\n| 📚 **Documentation & Wiki**                                                                                                               | [Read Our Docs](https://unsloth.ai/docs)                                       |\n| <img width=\"13\" src=\"https://upload.wikimedia.org/wikipedia/commons/0/09/X_(formerly_Twitter)_logo_late_2025.svg\" />  **Twitter (aka X)** | [Follow us on X](https://twitter.com/unslothai)                                |\n| 💾 **Installation**                                                                                                                       | [Pip & Docker Install](https://unsloth.ai/docs/get-started/install) |\n| 🔮 **Our Models**                                                                                                                         | [Unsloth Catalog](https://unsloth.ai/docs/get-started/unsloth-model-catalog)   |\n| ✍️ **Blog**                                                                                                                               | [Read our Blogs](https://unsloth.ai/blog)                                      |\n\n### Citation\n\nYou can cite the Unsloth repo as follows:\n```bibtex\n@software{unsloth,\n  author = {Daniel Han, Michael Han and Unsloth team},\n  title = {Unsloth},\n  url = {https://github.com/unslothai/unsloth},\n  year = {2023}\n}\n```\nIf you trained a model with 🦥Unsloth, you can use this cool sticker!   <img src=\"https://raw.githubusercontent.com/unslothai/unsloth/main/images/made with unsloth.png\" width=\"200\" align=\"center\" />\n\n### License\nUnsloth uses a dual-licensing model of Apache 2.0 and AGPL-3.0. The core Unsloth package remains licensed under **[Apache 2.0](https://github.com/unslothai/unsloth?tab=Apache-2.0-1-ov-file)**, while certain optional components, such as the Unsloth Studio UI are licensed under the open-source license **[AGPL-3.0](https://github.com/unslothai/unsloth?tab=AGPL-3.0-2-ov-file)**.\n\nThis structure helps support ongoing Unsloth development while keeping the project open source and enabling the broader ecosystem to continue growing.\n\n### Thank You to\n- The [llama.cpp library](https://github.com/ggml-org/llama.cpp) that lets users run and save models with Unsloth\n- The Hugging Face team and their libraries: [transformers](https://github.com/huggingface/transformers) and [TRL](https://github.com/huggingface/trl)\n- The Pytorch and [Torch AO](https://github.com/unslothai/unsloth/pull/3391) team for their contributions\n- And of course for every single person who has contributed or has used Unsloth!\n"
  },
  {
    "path": "build.sh",
    "content": "#!/usr/bin/env bash\n\nset -euo pipefail\n\n# 1. Build frontend (Vite outputs to dist/)\ncd studio/frontend\n\n# Clean stale dist to force a full rebuild\nrm -rf dist\n\n# Tailwind v4's oxide scanner respects .gitignore in parent directories.\n# Python venvs create a .gitignore with \"*\" (ignore everything), which\n# prevents Tailwind from scanning .tsx source files for class names.\n# Temporarily hide any such .gitignore during the build, then restore it.\n_HIDDEN_GITIGNORES=()\n_dir=\"$(pwd)\"\nwhile [ \"$_dir\" != \"/\" ]; do\n    _dir=\"$(dirname \"$_dir\")\"\n    if [ -f \"$_dir/.gitignore\" ] && grep -qx '\\*' \"$_dir/.gitignore\" 2>/dev/null; then\n        mv \"$_dir/.gitignore\" \"$_dir/.gitignore._twbuild\"\n        _HIDDEN_GITIGNORES+=(\"$_dir/.gitignore\")\n    fi\ndone\n\n_restore_gitignores() {\n    for _gi in \"${_HIDDEN_GITIGNORES[@]+\"${_HIDDEN_GITIGNORES[@]}\"}\"; do\n        mv \"${_gi}._twbuild\" \"$_gi\" 2>/dev/null || true\n    done\n}\ntrap _restore_gitignores EXIT\n\nnpm install\nnpm run build       # outputs to studio/frontend/dist/\n\n_restore_gitignores\ntrap - EXIT\n\n# Validate CSS output -- catch truncated Tailwind builds before packaging\nMAX_CSS_SIZE=$(find dist/assets -name '*.css' -exec wc -c {} + 2>/dev/null | sort -n | tail -1 | awk '{print $1}')\nif [ -z \"$MAX_CSS_SIZE\" ]; then\n    echo \"❌ ERROR: No CSS files were emitted into dist/assets.\"\n    echo \"   The frontend build may have failed silently.\"\n    exit 1\nfi\nif [ \"$MAX_CSS_SIZE\" -lt 100000 ]; then\n    echo \"❌ ERROR: Largest CSS file is only $((MAX_CSS_SIZE / 1024))KB (expected >100KB).\"\n    echo \"   Tailwind may not have scanned all source files.\"\n    echo \"   Check for .gitignore files blocking the Tailwind oxide scanner.\"\n    exit 1\nfi\necho \"✅ Frontend CSS validated (${MAX_CSS_SIZE} bytes)\"\n\ncd ../..\n\n# 2. Clean old artifacts\nrm -rf build dist *.egg-info\n\n# 3. Build wheel\npython -m build\n\n# 4. Optionally publish\nif [ \"${1:-}\" = \"publish\" ]; then\n    python -m twine upload dist/*\nfi\n"
  },
  {
    "path": "cli.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nfrom unsloth_cli import app\n\nif __name__ == \"__main__\":\n    app()\n"
  },
  {
    "path": "install.ps1",
    "content": "# Unsloth Studio Installer for Windows PowerShell\n# Usage:  irm https://raw.githubusercontent.com/unslothai/unsloth/main/install.ps1 | iex\n# Local:  Set-ExecutionPolicy -Scope Process -ExecutionPolicy Bypass; .\\install.ps1\n\nfunction Install-UnslothStudio {\n    $ErrorActionPreference = \"Stop\"\n\n    $VenvName = \"unsloth_studio\"\n    $PythonVersion = \"3.13\"\n\n    Write-Host \"\"\n    Write-Host \"=========================================\"\n    Write-Host \"   Unsloth Studio Installer (Windows)\"\n    Write-Host \"=========================================\"\n    Write-Host \"\"\n\n    # ── Helper: refresh PATH from registry (preserving current session entries) ──\n    function Refresh-SessionPath {\n        $machine = [System.Environment]::GetEnvironmentVariable(\"Path\", \"Machine\")\n        $user    = [System.Environment]::GetEnvironmentVariable(\"Path\", \"User\")\n        $env:Path = \"$machine;$user;$env:Path\"\n    }\n\n    # ── Check winget ──\n    if (-not (Get-Command winget -ErrorAction SilentlyContinue)) {\n        Write-Host \"Error: winget is not available.\" -ForegroundColor Red\n        Write-Host \"       Install it from https://aka.ms/getwinget\" -ForegroundColor Yellow\n        Write-Host \"       or install Python $PythonVersion and uv manually, then re-run.\" -ForegroundColor Yellow\n        return\n    }\n\n    # ── Install Python if no compatible version (3.11-3.13) found ──\n    $DetectedPythonVersion = \"\"\n    if (Get-Command python -ErrorAction SilentlyContinue) {\n        $pyVer = python --version 2>&1\n        if ($pyVer -match \"Python (3\\.1[1-3])\\.\\d+\") {\n            Write-Host \"==> Python already installed: $pyVer\"\n            $DetectedPythonVersion = $Matches[1]\n        }\n    }\n    if (-not $DetectedPythonVersion) {\n        Write-Host \"==> Installing Python ${PythonVersion}...\"\n        winget install -e --id Python.Python.3.13 --accept-package-agreements --accept-source-agreements\n        Refresh-SessionPath\n        if ($LASTEXITCODE -ne 0) {\n            # winget returns non-zero for \"already installed\" -- only fail if python is truly missing\n            if (-not (Get-Command python -ErrorAction SilentlyContinue)) {\n                Write-Host \"[ERROR] Python installation failed (exit code $LASTEXITCODE)\" -ForegroundColor Red\n                return\n            }\n        }\n        $DetectedPythonVersion = $PythonVersion\n    }\n\n    # ── Install uv if not present ──\n    if (-not (Get-Command uv -ErrorAction SilentlyContinue)) {\n        Write-Host \"==> Installing uv package manager...\"\n        winget install --id=astral-sh.uv -e --accept-package-agreements --accept-source-agreements\n        Refresh-SessionPath\n        # Fallback: if winget didn't put uv on PATH, try the PowerShell installer\n        if (-not (Get-Command uv -ErrorAction SilentlyContinue)) {\n            Write-Host \"    Trying alternative uv installer...\"\n            powershell -ExecutionPolicy ByPass -c \"irm https://astral.sh/uv/install.ps1 | iex\"\n            Refresh-SessionPath\n        }\n    }\n\n    if (-not (Get-Command uv -ErrorAction SilentlyContinue)) {\n        Write-Host \"Error: uv could not be installed.\" -ForegroundColor Red\n        Write-Host \"       Install it from https://docs.astral.sh/uv/\" -ForegroundColor Yellow\n        return\n    }\n\n    # ── Create venv (skip if it already exists and has a valid interpreter) ──\n    $VenvPython = Join-Path $VenvName \"Scripts\\python.exe\"\n    if (-not (Test-Path $VenvPython)) {\n        if (Test-Path $VenvName) { Remove-Item -Recurse -Force $VenvName }\n        Write-Host \"==> Creating Python ${DetectedPythonVersion} virtual environment (${VenvName})...\"\n        uv venv $VenvName --python $DetectedPythonVersion\n        if ($LASTEXITCODE -ne 0) {\n            Write-Host \"[ERROR] Failed to create virtual environment (exit code $LASTEXITCODE)\" -ForegroundColor Red\n            return\n        }\n    } else {\n        Write-Host \"==> Virtual environment ${VenvName} already exists, skipping creation.\"\n    }\n\n    # ── Install unsloth directly into the venv (no activation needed) ──\n    Write-Host \"==> Installing unsloth (this may take a few minutes)...\"\n    uv pip install --python $VenvPython unsloth --torch-backend=auto\n    if ($LASTEXITCODE -ne 0) {\n        Write-Host \"[ERROR] Failed to install unsloth (exit code $LASTEXITCODE)\" -ForegroundColor Red\n        return\n    }\n\n    # ── Run studio setup ──\n    # setup.ps1 will handle installing Git, CMake, Visual Studio Build Tools,\n    # CUDA Toolkit, Node.js, and other dependencies automatically via winget.\n    Write-Host \"==> Running unsloth studio setup...\"\n    $UnslothExe = Join-Path $VenvName \"Scripts\\unsloth.exe\"\n    & $UnslothExe studio setup\n    if ($LASTEXITCODE -ne 0) {\n        Write-Host \"[ERROR] unsloth studio setup failed (exit code $LASTEXITCODE)\" -ForegroundColor Red\n        return\n    }\n\n    Write-Host \"\"\n    Write-Host \"=========================================\"\n    Write-Host \"   Unsloth Studio installed!\"\n    Write-Host \"=========================================\"\n    Write-Host \"\"\n    Write-Host \"  To launch, run:\"\n    Write-Host \"\"\n    Write-Host \"    .\\${VenvName}\\Scripts\\activate\"\n    Write-Host \"    unsloth studio -H 0.0.0.0 -p 8888\"\n    Write-Host \"\"\n}\n\nInstall-UnslothStudio\n"
  },
  {
    "path": "install.sh",
    "content": "#!/bin/sh\n# Unsloth Studio Installer\n# Usage (curl): curl -fsSL https://raw.githubusercontent.com/unslothai/unsloth/main/install.sh | sh\n# Usage (wget): wget -qO- https://raw.githubusercontent.com/unslothai/unsloth/main/install.sh | sh\nset -e\n\nVENV_NAME=\"unsloth_studio\"\nPYTHON_VERSION=\"3.13\"\n\n# ── Helper: download a URL to a file (supports curl and wget) ──\ndownload() {\n    if command -v curl >/dev/null 2>&1; then\n        curl -LsSf \"$1\" -o \"$2\"\n    elif command -v wget >/dev/null 2>&1; then\n        wget -qO \"$2\" \"$1\"\n    else\n        echo \"Error: neither curl nor wget found. Install one and re-run.\"\n        exit 1\n    fi\n}\n\n# ── Helper: check if a single package is available on the system ──\n_is_pkg_installed() {\n    case \"$1\" in\n        build-essential) command -v gcc >/dev/null 2>&1 ;;\n        libcurl4-openssl-dev)\n            command -v dpkg >/dev/null 2>&1 && dpkg -s \"$1\" >/dev/null 2>&1 ;;\n        pciutils)\n            command -v lspci >/dev/null 2>&1 ;;\n        *) command -v \"$1\" >/dev/null 2>&1 ;;\n    esac\n}\n\n# ── Helper: install packages via apt, escalating to sudo only if needed ──\n# Usage: _smart_apt_install pkg1 pkg2 pkg3 ...\n_smart_apt_install() {\n    _PKGS=\"$*\"\n\n    # Step 1: Try installing without sudo (works when already root)\n    apt-get update -y </dev/null >/dev/null 2>&1 || true\n    apt-get install -y $_PKGS </dev/null >/dev/null 2>&1 || true\n\n    # Step 2: Check which packages are still missing\n    _STILL_MISSING=\"\"\n    for _pkg in $_PKGS; do\n        if ! _is_pkg_installed \"$_pkg\"; then\n            _STILL_MISSING=\"$_STILL_MISSING $_pkg\"\n        fi\n    done\n    _STILL_MISSING=$(echo \"$_STILL_MISSING\" | sed 's/^ *//')\n\n    if [ -z \"$_STILL_MISSING\" ]; then\n        return 0\n    fi\n\n    # Step 3: Escalate -- need elevated permissions for remaining packages\n    if command -v sudo >/dev/null 2>&1; then\n        echo \"\"\n        echo \"    !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\"\n        echo \"    WARNING: We require sudo elevated permissions to install:\"\n        echo \"    $_STILL_MISSING\"\n        echo \"    If you accept, we'll run sudo now, and it'll prompt your password.\"\n        echo \"    !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\"\n        echo \"\"\n        printf \"    Accept? [Y/n] \"\n        if [ -r /dev/tty ]; then\n            read -r REPLY </dev/tty || REPLY=\"y\"\n        else\n            REPLY=\"y\"\n        fi\n        case \"$REPLY\" in\n            [nN]*)\n                echo \"\"\n                echo \"    Please install these packages first, then re-run Unsloth Studio setup:\"\n                echo \"    sudo apt-get update -y && sudo apt-get install -y $_STILL_MISSING\"\n                exit 1\n                ;;\n            *)\n                sudo apt-get update -y </dev/null\n                sudo apt-get install -y $_STILL_MISSING </dev/null\n                ;;\n        esac\n    else\n        echo \"\"\n        echo \"    sudo is not available on this system.\"\n        echo \"    Please install these packages as root, then re-run Unsloth Studio setup:\"\n        echo \"    apt-get update -y && apt-get install -y $_STILL_MISSING\"\n        exit 1\n    fi\n}\n\necho \"\"\necho \"=========================================\"\necho \"   Unsloth Studio Installer\"\necho \"=========================================\"\necho \"\"\n\n# ── Detect platform ──\nOS=\"linux\"\nif [ \"$(uname)\" = \"Darwin\" ]; then\n    OS=\"macos\"\nelif grep -qi microsoft /proc/version 2>/dev/null; then\n    OS=\"wsl\"\nfi\necho \"==> Platform: $OS\"\n\n# ── Check system dependencies ──\n# cmake and git are needed by unsloth studio setup to build the GGUF inference\n# engine (llama.cpp). build-essential and libcurl-dev are also needed on Linux.\nMISSING=\"\"\n\ncommand -v cmake >/dev/null 2>&1 || MISSING=\"$MISSING cmake\"\ncommand -v git   >/dev/null 2>&1 || MISSING=\"$MISSING git\"\n\ncase \"$OS\" in\n    macos)\n        # Xcode Command Line Tools provide the C/C++ compiler\n        if ! xcode-select -p >/dev/null 2>&1; then\n            echo \"\"\n            echo \"==> Xcode Command Line Tools are required.\"\n            echo \"    Installing (a system dialog will appear)...\"\n            xcode-select --install </dev/null 2>/dev/null || true\n            echo \"    After the installation completes, please re-run this script.\"\n            exit 1\n        fi\n        ;;\n    linux|wsl)\n        # curl or wget is needed for downloads; check both\n        if ! command -v curl >/dev/null 2>&1 && ! command -v wget >/dev/null 2>&1; then\n            MISSING=\"$MISSING curl\"\n        fi\n        command -v gcc  >/dev/null 2>&1 || MISSING=\"$MISSING build-essential\"\n        # libcurl dev headers for llama.cpp HTTPS support\n        if command -v dpkg >/dev/null 2>&1; then\n            dpkg -s libcurl4-openssl-dev >/dev/null 2>&1 || MISSING=\"$MISSING libcurl4-openssl-dev\"\n        fi\n        ;;\nesac\n\nMISSING=$(echo \"$MISSING\" | sed 's/^ *//')\n\nif [ -n \"$MISSING\" ]; then\n    echo \"\"\n    echo \"==> Unsloth Studio needs these packages: $MISSING\"\n    echo \"    These are needed to build the GGUF inference engine.\"\n\n    case \"$OS\" in\n        macos)\n            if ! command -v brew >/dev/null 2>&1; then\n                echo \"\"\n                echo \"    Homebrew is required to install them.\"\n                echo \"    Install Homebrew from https://brew.sh then re-run this script.\"\n                exit 1\n            fi\n            brew install $MISSING </dev/null\n            ;;\n        linux|wsl)\n            if command -v apt-get >/dev/null 2>&1; then\n                _smart_apt_install $MISSING\n            else\n                echo \"    apt-get is not available. Please install with your package manager:\"\n                echo \"    $MISSING\"\n                echo \"    Then re-run Unsloth Studio setup.\"\n                exit 1\n            fi\n            ;;\n    esac\n    echo \"\"\nelse\n    echo \"==> All system dependencies found.\"\nfi\n\n# ── Install uv ──\nif ! command -v uv >/dev/null 2>&1; then\n    echo \"==> Installing uv package manager...\"\n    _uv_tmp=$(mktemp)\n    download \"https://astral.sh/uv/install.sh\" \"$_uv_tmp\"\n    sh \"$_uv_tmp\" </dev/null\n    rm -f \"$_uv_tmp\"\n    if [ -f \"$HOME/.local/bin/env\" ]; then\n        . \"$HOME/.local/bin/env\"\n    fi\n    export PATH=\"$HOME/.local/bin:$PATH\"\nfi\n\n# ── Create venv (skip if it already exists and has a valid interpreter) ──\nif [ ! -x \"$VENV_NAME/bin/python\" ]; then\n    [ -e \"$VENV_NAME\" ] && rm -rf \"$VENV_NAME\"\n    echo \"==> Creating Python ${PYTHON_VERSION} virtual environment (${VENV_NAME})...\"\n    uv venv \"$VENV_NAME\" --python \"$PYTHON_VERSION\"\nelse\n    echo \"==> Virtual environment ${VENV_NAME} already exists, skipping creation.\"\nfi\n\n# ── Install unsloth directly into the venv (no activation needed) ──\necho \"==> Installing unsloth (this may take a few minutes)...\"\nuv pip install --python \"$VENV_NAME/bin/python\" unsloth --torch-backend=auto\n\n# ── Run studio setup ──\n# Ensure the venv's Python is on PATH for setup.sh's Python discovery.\n# On macOS the system Python may be outside the 3.11-3.13 range that\n# setup.sh requires, but uv already installed a compatible interpreter\n# inside the venv.\nVENV_ABS_BIN=\"$(cd \"$VENV_NAME/bin\" && pwd)\"\nif [ -n \"$VENV_ABS_BIN\" ]; then\n    export PATH=\"$VENV_ABS_BIN:$PATH\"\nfi\n\necho \"==> Running unsloth studio setup...\"\n\"$VENV_NAME/bin/unsloth\" studio setup </dev/null\n\necho \"\"\necho \"=========================================\"\necho \"   Unsloth Studio installed!\"\necho \"=========================================\"\necho \"\"\necho \"  To launch, run:\"\necho \"\"\necho \"    source ${VENV_NAME}/bin/activate\"\necho \"    unsloth studio -H 0.0.0.0 -p 8888\"\necho \"\"\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[build-system]\nrequires = [\"setuptools==80.9.0\", \"setuptools-scm==9.2.0\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"unsloth\"\ndynamic = [\"version\"]\ndescription = \"2-5X faster training, reinforcement learning & finetuning\"\nreadme = \"README.md\"\nrequires-python = \">=3.9,<3.15\"\nlicense = \"Apache-2.0\"\nkeywords = [\"ai\", \"llm\", \"reinforcement learning\", \"machine learning\", \"artificial intelligence\", \"pytorch\"]\nauthors = [\n    {email = \"info@unsloth.ai\"},\n    {name = \"Unsloth AI team\"},\n]\nmaintainers = [\n    {name = \"Daniel Han\", email = \"daniel@unsloth.ai\"},\n    {name = \"Michael Han\", email = \"info@unsloth.ai\"},\n]\nclassifiers = [\n    \"Programming Language :: Python\",\n    \"Environment :: GPU\",\n    \"Environment :: GPU :: NVIDIA CUDA\",\n    \"Topic :: Scientific/Engineering :: Artificial Intelligence\",\n]\ndependencies = [\n    \"typer\",\n    \"pydantic\",\n    \"pyyaml\",\n    \"nest-asyncio\",\n]\n\n[project.scripts]\nunsloth = \"unsloth_cli:app\"\n\n[tool.setuptools.dynamic]\nversion = {attr = \"unsloth.models._utils.__version__\"}\n\n[tool.setuptools]\ninclude-package-data = true\n\n[tool.setuptools.package-data]\nstudio = [\n    \"*.sh\",\n    \"*.ps1\",\n    \"*.bat\",\n    \"frontend/dist/**/*\",\n    \"frontend/public/**/*\",\n    \"frontend/src/**/*\",\n    \"frontend/*.json\",\n    \"frontend/*.ts\",\n    \"frontend/*.js\",\n    \"frontend/*.lock\",\n    \"frontend/*.html\",\n    \"frontend/*.yaml\",\n    \"frontend/.git*\",\n    \"backend/requirements/**/*\",\n    \"backend/core/data_recipe/oxc-validator/*.json\",\n    \"backend/core/data_recipe/oxc-validator/*.mjs\",\n]\n\n[tool.setuptools.packages.find]\nexclude = [\"images*\", \"tests*\", \"kernels/moe*\"]\n\n[project.optional-dependencies]\ntriton = [\n    \"triton>=3.0.0 ; ('linux' in sys_platform)\",\n    \"triton-windows ; (sys_platform == 'win32') and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n]\n\nhuggingfacenotorch = [\n    \"wheel>=0.42.0\",\n    \"packaging\",\n    \"numpy\",\n    \"tqdm\",\n    \"psutil\",\n    \"tyro\",\n    \"protobuf\",\n    \"sentencepiece>=0.2.0\",\n    \"datasets>=3.4.1,!=4.0.*,!=4.1.0,<4.4.0\",\n    \"accelerate>=0.34.1\",\n    \"peft>=0.18.0,!=0.11.0\",\n    \"huggingface_hub>=0.34.0\",\n    \"hf_transfer\",\n    \"diffusers\",\n    \"transformers>=4.51.3,!=4.52.0,!=4.52.1,!=4.52.2,!=4.52.3,!=4.53.0,!=4.54.0,!=4.55.0,!=4.55.1,!=4.57.0,!=4.57.4,!=4.57.5,!=5.0.0,!=5.1.0,<=5.3.0\",\n    \"trl>=0.18.2,!=0.19.0,<=0.24.0\",\n    \"sentence-transformers\",\n]\nhuggingface = [\n    \"unsloth[huggingfacenotorch]\",\n    \"unsloth_zoo>=2026.3.4\",\n    \"torchvision\",\n    \"unsloth[triton]\",\n]\nwindows = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0 ; (sys_platform == 'win32')\",\n    \"xformers>=0.0.22.post7 ; (sys_platform == 'win32')\",\n]\nbase = [\n    \"unsloth[huggingface]\",\n]\ncu118only = [\n    \"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.22.post7%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.22.post7%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.22.post7%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)\",\n]\ncu121only = [\n    \"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.22.post7-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.22.post7-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.22.post7-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)\",\n]\ncu118onlytorch211 = [\n    \"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)\",\n]\ncu121onlytorch211 = [\n    \"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)\",\n]\ncu118onlytorch212 = [\n    \"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23.post1%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23.post1%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23.post1%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)\",\n]\ncu121onlytorch212 = [\n    \"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23.post1-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23.post1-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23.post1-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)\",\n]\ncu118onlytorch220 = [\n    \"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.24%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.24%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.24%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)\",\n]\ncu121onlytorch220 = [\n    \"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.24-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.24-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.24-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)\",\n]\ncu118onlytorch230 = [\n    \"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27%2Bcu118-cp312-cp312-manylinux2014_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)\",\n]\ncu121onlytorch230 = [\n    \"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.27-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.27-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.27-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.27-cp312-cp312-manylinux2014_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)\",\n]\ncu118onlytorch240 = [\n    \"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27.post2%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27.post2%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27.post2%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27.post2%2Bcu118-cp312-cp312-manylinux2014_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)\",\n]\ncu121onlytorch240 = [\n    \"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post1-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post1-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post1-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post1-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)\",\n]\ncu124onlytorch240 = [\n    \"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp39-cp39-win_amd64.whl ; python_version=='3.9' and (sys_platform == 'win32')\",\n    \"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp310-cp310-win_amd64.whl ; python_version=='3.10' and (sys_platform == 'win32')\",\n    \"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp311-cp311-win_amd64.whl ; python_version=='3.11' and (sys_platform == 'win32')\",\n    \"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp312-cp312-win_amd64.whl ; python_version=='3.12' and (sys_platform == 'win32')\",\n]\ncu118onlytorch250 = [\n    \"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.28.post2-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.28.post2-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.28.post2-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.28.post2-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)\",\n]\ncu121onlytorch250 = [\n    \"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post2-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post2-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post2-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post2-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)\",\n]\ncu124onlytorch250 = [\n    \"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp39-cp39-win_amd64.whl ; python_version=='3.9' and (sys_platform == 'win32')\",\n    \"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp310-cp310-win_amd64.whl ; python_version=='3.10' and (sys_platform == 'win32')\",\n    \"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp311-cp311-win_amd64.whl ; python_version=='3.11' and (sys_platform == 'win32')\",\n    \"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp312-cp312-win_amd64.whl ; python_version=='3.12' and (sys_platform == 'win32')\",\n]\ncu118onlytorch251 = [\n    \"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post1-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post1-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post1-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post1-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)\",\n]\ncu121onlytorch251 = [\n    \"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29.post1-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29.post1-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29.post1-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29.post1-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)\",\n]\ncu124onlytorch251 = [\n    \"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp39-cp39-win_amd64.whl ; python_version=='3.9' and (sys_platform == 'win32')\",\n    \"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp310-cp310-win_amd64.whl ; python_version=='3.10' and (sys_platform == 'win32')\",\n    \"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp311-cp311-win_amd64.whl ; python_version=='3.11' and (sys_platform == 'win32')\",\n    \"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp312-cp312-win_amd64.whl ; python_version=='3.12' and (sys_platform == 'win32')\",\n]\ncu118onlytorch260 = [\n    \"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post3-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post3-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post3-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post3-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)\",\n]\ncu124onlytorch260 = [\n    \"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp39-cp39-win_amd64.whl ; python_version=='3.9' and (sys_platform == 'win32')\",\n    \"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp310-cp310-win_amd64.whl ; python_version=='3.10' and (sys_platform == 'win32')\",\n    \"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp311-cp311-win_amd64.whl ; python_version=='3.11' and (sys_platform == 'win32')\",\n    \"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp312-cp312-win_amd64.whl ; python_version=='3.12' and (sys_platform == 'win32')\",\n]\ncu126onlytorch260 = [\n    \"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp39-cp39-win_amd64.whl ; python_version=='3.9' and (sys_platform == 'win32')\",\n    \"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp310-cp310-win_amd64.whl ; python_version=='3.10' and (sys_platform == 'win32')\",\n    \"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp311-cp311-win_amd64.whl ; python_version=='3.11' and (sys_platform == 'win32')\",\n    \"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp312-cp312-win_amd64.whl ; python_version=='3.12' and (sys_platform == 'win32')\",\n]\ncu118onlytorch270 = [\n    \"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.30-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.30-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.30-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.30-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.30-cp39-cp39-win_amd64.whl ; python_version=='3.9' and (sys_platform == 'win32')\",\n    \"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.30-cp310-cp310-win_amd64.whl ; python_version=='3.10' and (sys_platform == 'win32')\",\n    \"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.30-cp311-cp311-win_amd64.whl ; python_version=='3.11' and (sys_platform == 'win32')\",\n    \"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.30-cp312-cp312-win_amd64.whl ; python_version=='3.12' and (sys_platform == 'win32')\",\n]\ncu126onlytorch270 = [\n    \"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.30-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.30-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.30-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.30-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.30-cp39-cp39-win_amd64.whl ; python_version=='3.9' and (sys_platform == 'win32')\",\n    \"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.30-cp310-cp310-win_amd64.whl ; python_version=='3.10' and (sys_platform == 'win32')\",\n    \"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.30-cp311-cp311-win_amd64.whl ; python_version=='3.11' and (sys_platform == 'win32')\",\n    \"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.30-cp312-cp312-win_amd64.whl ; python_version=='3.12' and (sys_platform == 'win32')\",\n]\ncu128onlytorch270 = [\n    \"xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.30-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.30-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.30-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.30-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.30-cp39-cp39-win_amd64.whl ; python_version=='3.9' and (sys_platform == 'win32')\",\n    \"xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.30-cp310-cp310-win_amd64.whl ; python_version=='3.10' and (sys_platform == 'win32')\",\n    \"xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.30-cp311-cp311-win_amd64.whl ; python_version=='3.11' and (sys_platform == 'win32')\",\n    \"xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.30-cp312-cp312-win_amd64.whl ; python_version=='3.12' and (sys_platform == 'win32')\",\n]\ncu118onlytorch271 = [\n    \"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.31.post1-cp39-abi3-manylinux_2_28_x86_64.whl ; ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.31.post1-cp39-abi3-win_amd64.whl ; (sys_platform == 'win32')\",\n]\ncu126onlytorch271 = [\n    \"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.31.post1-cp39-abi3-manylinux_2_28_x86_64.whl ; ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.31.post1-cp39-abi3-win_amd64.whl ; (sys_platform == 'win32')\",\n]\ncu128onlytorch271 = [\n    \"xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.31.post1-cp39-abi3-manylinux_2_28_x86_64.whl ; ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.31.post1-cp39-abi3-win_amd64.whl ; (sys_platform == 'win32')\",\n]\ncu118onlytorch280 = [\n    \"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.32.post2-cp39-abi3-manylinux_2_28_x86_64.whl ; ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.32.post2-cp39-abi3-win_amd64.whl ; (sys_platform == 'win32')\",\n]\ncu126onlytorch280 = [\n    \"xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.32.post2-cp39-abi3-manylinux_2_28_x86_64.whl ; ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.32.post2-cp39-abi3-win_amd64.whl ; (sys_platform == 'win32')\",\n]\ncu128onlytorch280 = [\n    \"xformers @ https://download.pytorch.org/whl/cu129/xformers-0.0.32.post2-cp39-abi3-manylinux_2_28_x86_64.whl ; ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu129/xformers-0.0.32.post2-cp39-abi3-win_amd64.whl ; (sys_platform == 'win32')\",\n]\ncu130onlytorch280 = [\n]\ncu126onlytorch290 = [\n    \"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.33.post1-cp39-abi3-manylinux_2_28_x86_64.whl ; ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.33.post1-cp39-abi3-win_amd64.whl ; (sys_platform == 'win32')\",\n]\ncu128onlytorch290 = [\n    \"xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.33.post1-cp39-abi3-manylinux_2_28_x86_64.whl ; ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.33.post1-cp39-abi3-win_amd64.whl ; (sys_platform == 'win32')\",\n]\ncu130onlytorch290 = [\n    \"xformers @ https://download.pytorch.org/whl/cu130/xformers-0.0.33.post1-cp39-abi3-manylinux_2_28_x86_64.whl ; ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu130/xformers-0.0.33.post1-cp39-abi3-win_amd64.whl ; (sys_platform == 'win32')\",\n]\ncu126onlytorch291 = [\n    \"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.33.post2-cp39-abi3-manylinux_2_28_x86_64.whl ; ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.33.post2-cp39-abi3-win_amd64.whl ; (sys_platform == 'win32')\",\n]\ncu128onlytorch291 = [\n    \"xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.33.post2-cp39-abi3-manylinux_2_28_x86_64.whl ; ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.33.post2-cp39-abi3-win_amd64.whl ; (sys_platform == 'win32')\",\n]\ncu130onlytorch291 = [\n    \"xformers @ https://download.pytorch.org/whl/cu130/xformers-0.0.33.post2-cp39-abi3-manylinux_2_28_x86_64.whl ; ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu130/xformers-0.0.33.post2-cp39-abi3-win_amd64.whl ; (sys_platform == 'win32')\",\n]\ncu126onlytorch2100 = [\n    \"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.34-cp39-abi3-manylinux_2_28_x86_64.whl ; ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.34-cp39-abi3-win_amd64.whl ; (sys_platform == 'win32')\",\n]\ncu128onlytorch2100 = [\n    \"xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.34-cp39-abi3-manylinux_2_28_x86_64.whl ; ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.34-cp39-abi3-win_amd64.whl ; (sys_platform == 'win32')\",\n]\ncu130onlytorch2100 = [\n    \"xformers @ https://download.pytorch.org/whl/cu130/xformers-0.0.34-cp39-abi3-manylinux_2_28_x86_64.whl ; ('linux' in sys_platform)\",\n    \"xformers @ https://download.pytorch.org/whl/cu130/xformers-0.0.34-cp39-abi3-win_amd64.whl ; (sys_platform == 'win32')\",\n]\ncu118 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu118only]\",\n]\ncu121 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu121only]\",\n]\ncu118-torch211 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes==0.45.5\",\n    \"unsloth[cu118onlytorch211]\",\n]\ncu121-torch211 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes==0.45.5\",\n    \"unsloth[cu121onlytorch211]\",\n]\ncu118-torch212 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes==0.45.5\",\n    \"unsloth[cu118onlytorch212]\",\n]\ncu121-torch212 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes==0.45.5\",\n    \"unsloth[cu121onlytorch212]\",\n]\ncu118-torch220 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu118onlytorch220]\",\n]\ncu121-torch220 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu121onlytorch220]\",\n]\ncu118-torch230 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu118onlytorch230]\",\n]\ncu121-torch230 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu121onlytorch230]\",\n]\ncu118-torch240 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu118onlytorch240]\",\n]\ncu121-torch240 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu121onlytorch240]\",\n]\ncu124-torch240 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu124onlytorch240]\",\n]\ncu118-torch250 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu118onlytorch250]\",\n]\ncu121-torch250 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu121onlytorch250]\",\n]\ncu124-torch250 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu124onlytorch250]\",\n]\ncu118-torch251 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu118onlytorch251]\",\n]\ncu121-torch251 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu121onlytorch251]\",\n]\ncu124-torch251 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu124onlytorch251]\",\n]\ncu118-torch260 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu118onlytorch260]\",\n]\ncu124-torch260 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu124onlytorch260]\",\n]\ncu126-torch260 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu126onlytorch260]\",\n]\ncu118-torch270 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu118onlytorch270]\",\n]\ncu126-torch270 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu126onlytorch270]\",\n]\ncu128-torch270 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu128onlytorch270]\",\n]\ncu118-torch271 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu118onlytorch271]\",\n]\ncu126-torch271 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu126onlytorch271]\",\n]\ncu128-torch271 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu128onlytorch271]\",\n]\ncu118-torch280 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu118onlytorch280]\",\n]\ncu126-torch280 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu126onlytorch280]\",\n]\ncu128-torch280 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu128onlytorch280]\",\n]\ncu130-torch280 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu130onlytorch280]\",\n]\ncu126-torch290 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu126onlytorch290]\",\n]\ncu128-torch290 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu128onlytorch290]\",\n]\ncu130-torch290 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu130onlytorch290]\",\n]\ncu126-torch291 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu126onlytorch291]\",\n]\ncu128-torch291 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu128onlytorch291]\",\n]\ncu130-torch291 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu130onlytorch291]\",\n]\ncu126-torch2100 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu126onlytorch2100]\",\n]\ncu128-torch2100 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu128onlytorch2100]\",\n]\ncu130-torch2100 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu130onlytorch2100]\",\n]\nkaggle = [\n    \"unsloth[huggingface]\",\n]\nkaggle-new = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n]\nconda = [\n    \"unsloth[huggingface]\",\n]\ncolab-torch211 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes==0.45.5\",\n    \"unsloth[cu121onlytorch211]\",\n]\ncolab-ampere-torch211 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes==0.45.5\",\n    \"unsloth[cu121onlytorch211]\",\n    \"packaging\",\n    \"ninja\",\n    \"flash-attn>=2.6.3 ; ('linux' in sys_platform)\",\n]\ncolab-torch220 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu121onlytorch220]\",\n]\ncolab-ampere-torch220 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu121onlytorch220]\",\n    \"packaging\",\n    \"ninja\",\n    \"flash-attn>=2.6.3 ; ('linux' in sys_platform)\",\n]\ncolab-new = [\n    \"unsloth_zoo>=2026.3.4\",\n    \"packaging\",\n    \"tyro\",\n    \"transformers>=4.51.3,!=4.52.0,!=4.52.1,!=4.52.2,!=4.52.3,!=4.53.0,!=4.54.0,!=4.55.0,!=4.55.1,!=4.57.0,!=4.57.4,!=4.57.5,!=5.0.0,!=5.1.0,<=5.3.0\",\n    \"datasets>=3.4.1,!=4.0.*,!=4.1.0,<4.4.0\",\n    \"sentencepiece>=0.2.0\",\n    \"tqdm\",\n    \"psutil\",\n    \"wheel>=0.42.0\",\n    \"numpy\",\n    \"protobuf\",\n    \"huggingface_hub>=0.34.0\",\n    \"hf_transfer\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[triton]\",\n    \"sentence-transformers\",\n]\ncolab-no-deps = [\n    \"accelerate>=0.34.1\",\n    \"trl>=0.18.2,!=0.19.0,<=0.24.0\",\n    \"peft>=0.18.0\",\n    \"xformers ; ('linux' in sys_platform or sys_platform == 'win32') and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"protobuf\",\n]\ncolab = [\n    \"unsloth[cu121]\",\n]\nflashattention = [\n    \"packaging ; ('linux' in sys_platform)\",\n    \"ninja ; ('linux' in sys_platform)\",\n    \"flash-attn>=2.6.3 ; ('linux' in sys_platform)\",\n]\ncolab-ampere = [\n    \"unsloth[colab-ampere-torch220]\",\n    \"unsloth[flashattention]\",\n]\ncu118-ampere = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu118only]\",\n    \"unsloth[flashattention]\",\n]\ncu121-ampere = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu121only]\",\n    \"unsloth[flashattention]\",\n]\ncu118-ampere-torch211 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes==0.45.5\",\n    \"unsloth[cu118onlytorch211]\",\n    \"unsloth[flashattention]\",\n]\ncu121-ampere-torch211 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes==0.45.5\",\n    \"unsloth[cu121onlytorch211]\",\n    \"unsloth[flashattention]\",\n]\ncu118-ampere-torch220 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu118onlytorch220]\",\n    \"unsloth[flashattention]\",\n]\ncu121-ampere-torch220 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu121onlytorch220]\",\n    \"unsloth[flashattention]\",\n]\ncu118-ampere-torch230 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu118onlytorch230]\",\n    \"unsloth[flashattention]\",\n]\ncu121-ampere-torch230 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu121onlytorch230]\",\n    \"unsloth[flashattention]\",\n]\ncu118-ampere-torch240 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu118onlytorch240]\",\n    \"unsloth[flashattention]\",\n]\ncu121-ampere-torch240 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu121onlytorch240]\",\n    \"unsloth[flashattention]\",\n]\ncu124-ampere-torch240 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu124onlytorch240]\",\n    \"unsloth[flashattention]\",\n]\ncu118-ampere-torch250 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu118onlytorch250]\",\n    \"unsloth[flashattention]\",\n]\ncu121-ampere-torch250 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu121onlytorch250]\",\n    \"unsloth[flashattention]\",\n]\ncu124-ampere-torch250 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu124onlytorch250]\",\n    \"unsloth[flashattention]\",\n]\ncu118-ampere-torch251 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu118onlytorch251]\",\n    \"unsloth[flashattention]\",\n]\ncu121-ampere-torch251 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu121onlytorch251]\",\n    \"unsloth[flashattention]\",\n]\ncu124-ampere-torch251 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu124onlytorch251]\",\n    \"unsloth[flashattention]\",\n]\ncu118-ampere-torch260 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu118onlytorch260]\",\n    \"unsloth[flashattention]\",\n]\ncu124-ampere-torch260 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu124onlytorch260]\",\n    \"unsloth[flashattention]\",\n]\ncu126-ampere-torch260 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu126onlytorch260]\",\n    \"unsloth[flashattention]\",\n]\ncu118-ampere-torch270 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu118onlytorch270]\",\n    \"unsloth[flashattention]\",\n]\ncu126-ampere-torch270 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu126onlytorch270]\",\n    \"unsloth[flashattention]\",\n]\ncu128-ampere-torch270 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu128onlytorch270]\",\n    \"unsloth[flashattention]\",\n]\ncu118-ampere-torch271 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu118onlytorch271]\",\n    \"unsloth[flashattention]\",\n]\ncu126-ampere-torch271 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu126onlytorch271]\",\n    \"unsloth[flashattention]\",\n]\ncu128-ampere-torch271 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu128onlytorch271]\",\n    \"unsloth[flashattention]\",\n]\ncu118-ampere-torch280 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu118onlytorch280]\",\n    \"unsloth[flashattention]\",\n]\ncu126-ampere-torch280 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu126onlytorch280]\",\n    \"unsloth[flashattention]\",\n]\ncu128-ampere-torch280 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu128onlytorch280]\",\n    \"unsloth[flashattention]\",\n]\ncu130-ampere-torch280 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu130onlytorch280]\",\n    \"unsloth[flashattention]\",\n]\ncu126-ampere-torch290 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu126onlytorch290]\",\n]\ncu128-ampere-torch290 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu128onlytorch290]\",\n]\ncu130-ampere-torch290 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu130onlytorch290]\",\n]\ncu126-ampere-torch291 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu126onlytorch291]\",\n]\ncu128-ampere-torch291 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu128onlytorch291]\",\n]\ncu130-ampere-torch291 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu130onlytorch291]\",\n]\ncu126-ampere-torch2100 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu126onlytorch2100]\",\n]\ncu128-ampere-torch2100 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu128onlytorch2100]\",\n]\ncu130-ampere-torch2100 = [\n    \"unsloth[huggingface]\",\n    \"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0\",\n    \"unsloth[cu130onlytorch2100]\",\n]\nflashattentiontorch260abiFALSEcu12x = [\n    \"flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp39-cp39-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.9'\",\n    \"flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.10'\",\n    \"flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.11'\",\n    \"flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.12'\",\n    \"flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp313-cp313-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.13'\",\n]\nflashattentiontorch260abiTRUEcu12x = [\n    \"flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiTRUE-cp39-cp39-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.9'\",\n    \"flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiTRUE-cp310-cp310-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.10'\",\n    \"flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiTRUE-cp311-cp311-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.11'\",\n    \"flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiTRUE-cp312-cp312-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.12'\",\n    \"flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiTRUE-cp313-cp313-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.13'\",\n]\nflashattentiontorch250abiFALSEcu12x = [\n    \"flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiFALSE-cp39-cp39-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.9'\",\n    \"flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.10'\",\n    \"flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.11'\",\n    \"flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.12'\",\n    \"flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiFALSE-cp313-cp313-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.13'\",\n]\nflashattentiontorch250abiTRUEcu12x = [\n    \"flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiTRUE-cp39-cp39-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.9'\",\n    \"flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiTRUE-cp310-cp310-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.10'\",\n    \"flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiTRUE-cp311-cp311-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.11'\",\n    \"flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiTRUE-cp312-cp312-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.12'\",\n    \"flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiTRUE-cp313-cp313-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.13'\",\n]\nflashattentiontorch240abiFALSEcu12x = [\n    \"flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.4cxx11abiFALSE-cp39-cp39-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.9'\",\n    \"flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.10'\",\n    \"flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.4cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.11'\",\n    \"flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.4cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.12'\",\n    \"flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.4cxx11abiFALSE-cp313-cp313-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.13'\",\n]\nflashattentiontorch240abiTRUEcu12x = [\n    \"flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.4cxx11abiTRUE-cp39-cp39-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.9'\",\n    \"flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.4cxx11abiTRUE-cp310-cp310-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.10'\",\n    \"flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.4cxx11abiTRUE-cp311-cp311-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.11'\",\n    \"flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.4cxx11abiTRUE-cp312-cp312-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.12'\",\n    \"flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.4cxx11abiTRUE-cp313-cp313-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.13'\",\n]\nintelgputorch260 = [\n    \"unsloth_zoo[intelgpu]\",\n    \"unsloth[huggingfacenotorch]\",\n\n    \"pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.2.0-cp39-cp39-linux_x86_64.whl#sha256=147607f190a7d7aa24ba454def5977fbbfec792fdae18e4ed278cfec29b69271 ; ('linux' in sys_platform) and python_version == '3.9' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.2.0-cp310-cp310-linux_x86_64.whl#sha256=23aa423fa1542afc34f67eb3ba8ef20060f6d1b3a4697eaeab22b11c92b30f2b ; ('linux' in sys_platform) and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.2.0-cp311-cp311-linux_x86_64.whl#sha256=bcfa995229bbfd9ffd8d6c8d9f6428d393e876fa6e23ee3c20e3c0d73ca75ca5 ; ('linux' in sys_platform) and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.2.0-cp312-cp312-linux_x86_64.whl#sha256=bd340903d03470708df3442438acb8b7e08087ab9e61fbe349b2872bf9257ab0 ; ('linux' in sys_platform) and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.2.0-cp313-cp313-linux_x86_64.whl#sha256=814dccc8a07159e6eca74bed70091bc8fea2d9dd87b0d91845f9f38cde62f01c ; ('linux' in sys_platform) and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n\n    \"bitsandbytes @ https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-1.33.7.preview-py3-none-manylinux_2_24_x86_64.whl ; ('linux' in sys_platform) and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"bitsandbytes @ https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-1.33.7.preview-py3-none-win_amd64.whl ; (sys_platform == 'win32') and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n\n    \"torch @ https://download.pytorch.org/whl/xpu/torch-2.6.0%2Bxpu-cp39-cp39-linux_x86_64.whl#sha256=6a8adf6dc4c089406e8b3a7e58ab57a463bddf9b07130d2576e76eced43e92af ; ('linux' in sys_platform) and python_version == '3.9' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"torch @ https://download.pytorch.org/whl/xpu/torch-2.6.0%2Bxpu-cp310-cp310-linux_x86_64.whl#sha256=ff4561cbf07c83bbccaa0f6e9bb0e6dcf721bacd53c9c43c4eb0e7331b4792f9 ; ('linux' in sys_platform) and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"torch @ https://download.pytorch.org/whl/xpu/torch-2.6.0%2Bxpu-cp311-cp311-linux_x86_64.whl#sha256=12005f66b810ddd3ab93f86c4522bcfdd412cbd27fc9d189b661ff7509bc5e8a ; ('linux' in sys_platform) and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"torch @ https://download.pytorch.org/whl/xpu/torch-2.6.0%2Bxpu-cp312-cp312-linux_x86_64.whl#sha256=c4c5c67625cdacf35765c2b94e61fe166e3c3f4a14521b1212a59ad1b3eb0f2e ; ('linux' in sys_platform) and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"torch @ https://download.pytorch.org/whl/xpu/torch-2.6.0%2Bxpu-cp313-cp313-linux_x86_64.whl#sha256=e6864f7a60a5ecc43d5d38f59a16e5dd132384f73dfd3a697f74944026038f7b ; ('linux' in sys_platform) and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n]\nintel-gpu-torch260 = [\n    \"unsloth[intelgputorch260]\"\n]\nintelgputorch270 = [\n    \"unsloth_zoo[intelgpu]\",\n    \"unsloth[huggingfacenotorch]\",\n\n    \"pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.3.0-cp39-cp39-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=749a7098492c6a27b356c97149a4a62973b953eae60bc1b6259260974f344913 ; ('linux' in sys_platform) and python_version == '3.9' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.3.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=44362e80abd752471a08341093321955b066daa2cfb4810e73b8e3b240850f93 ; ('linux' in sys_platform) and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.3.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=faa6b8c945a837a080f641bc8ccc77a98fa66980dcd7e62e715fd853737343fd ; ('linux' in sys_platform) and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.3.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=40f6fb65b345dc9a61813abe7ac9a585f2c9808f414d140cc2a5f11f53ee063c ; ('linux' in sys_platform) and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.3.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=b22b4c02ec71b4bfc862ae3cdfd2871dc0b05d2b1802f5db2196e0f897d581e9 ; ('linux' in sys_platform) and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.3.0-cp39-cp39-win_amd64.whl#sha256=d4b738d7fa5100c1bd766f91614962828a4810eb57b4df92cd5214a83505a752 ; sys_platform == 'win32' and python_version == '3.9' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.3.0-cp310-cp310-win_amd64.whl#sha256=143fe8a64d807bcdb7d81bbc062816add325570aa160448454ab6ded4a0a17a1 ; sys_platform == 'win32' and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.3.0-cp311-cp311-win_amd64.whl#sha256=a8025459ff325d6e3532eb5cf72519db1b178155e7d60aff6c56beb5968fc758 ; sys_platform == 'win32' and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.3.0-cp312-cp312-win_amd64.whl#sha256=0dd07e6d5b872e42e48f5ee140e609d4554ca3cc509d5bf509ac232267cf358e ; sys_platform == 'win32' and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.3.0-cp313-cp313-win_amd64.whl#sha256=a936a18182d8e065a9933afc9a3ebbffadd38604969f87c493831214539fc027 ; sys_platform == 'win32' and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n\n    \"bitsandbytes @ https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-1.33.7.preview-py3-none-manylinux_2_24_x86_64.whl ; ('linux' in sys_platform) and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"bitsandbytes @ https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-1.33.7.preview-py3-none-win_amd64.whl ; (sys_platform == 'win32') and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n\n    \"torch @ https://download.pytorch.org/whl/xpu/torch-2.7.0%2Bxpu-cp39-cp39-linux_x86_64.whl#sha256=f8ee75e50fcbb37ed5b498299ca2264da99ab278a93fae2358e921e4a6e28273 ; ('linux' in sys_platform) and python_version == '3.9' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"torch @ https://download.pytorch.org/whl/xpu/torch-2.7.0%2Bxpu-cp310-cp310-linux_x86_64.whl#sha256=d6fdc342961d98fdcd9d03dfd491a3208bb5f7fbb435841f8f72ce9fdcd2d026 ; ('linux' in sys_platform) and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"torch @ https://download.pytorch.org/whl/xpu/torch-2.7.0%2Bxpu-cp311-cp311-linux_x86_64.whl#sha256=74d07f9357df5cf2bf223ad3c84de16346bfaa0504f988fdd5590d3e177e5e86 ; ('linux' in sys_platform) and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"torch @ https://download.pytorch.org/whl/xpu/torch-2.7.0%2Bxpu-cp312-cp312-linux_x86_64.whl#sha256=c806d44aa2ca5d225629f6fbc6c994d5deaac2d2cde449195bc8e3522ddd219a ; ('linux' in sys_platform) and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"torch @ https://download.pytorch.org/whl/xpu/torch-2.7.0%2Bxpu-cp313-cp313-linux_x86_64.whl#sha256=25d8277b7f01d42e2e014ccbab57a2692b6ec4eff8dcf894eda1b297407cf97a ; ('linux' in sys_platform) and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"torch @ https://download.pytorch.org/whl/xpu/torch-2.7.0%2Bxpu-cp39-cp39-win_amd64.whl#sha256=046e85125266ae69c1a0d083e6c092f947ab4b6b41532c16bafe40dbced845df ; sys_platform == 'win32' and python_version == '3.9' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"torch @ https://download.pytorch.org/whl/xpu/torch-2.7.0%2Bxpu-cp310-cp310-win_amd64.whl#sha256=9ebaeffb82b0b3e39b6030927d3ebe0eb62a0e9045a3b2d7b0a9e7b15222c0db ; sys_platform == 'win32' and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"torch @ https://download.pytorch.org/whl/xpu/torch-2.7.0%2Bxpu-cp311-cp311-win_amd64.whl#sha256=356ba66cee127e7e2c942880bd50e03768306a4ea08d358a0f29c6eebfc4bc81 ; sys_platform == 'win32' and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"torch @ https://download.pytorch.org/whl/xpu/torch-2.7.0%2Bxpu-cp312-cp312-win_amd64.whl#sha256=94739e665d9b4d5cd7af5f517cb6103f6f9fb421c095184609653a24524040f5 ; sys_platform == 'win32' and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"torch @ https://download.pytorch.org/whl/xpu/torch-2.7.0%2Bxpu-cp313-cp313-win_amd64.whl#sha256=31df3cb674918e89bc8c532baa331dc84f4430e1f9c0ec379232db44cba78355 ; sys_platform == 'win32' and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n]\nintel-gpu-torch270 = [\n    \"unsloth[intelgputorch270]\"\n]\nintelgputorch280 = [\n    \"unsloth_zoo[intelgpu]\",\n    \"unsloth[huggingfacenotorch]\",\n\n    \"pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.4.0-cp39-cp39-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=ac4d8e33986b1c3c5e48151640539272b2187e83016985853111b46fb82c3c94 ; 'linux' in sys_platform and python_version == '3.9' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.4.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=999fef4c1f711092b9d3086525920545df490de476ecebe899ffc777019ae17f ; 'linux' in sys_platform and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.4.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=57b09c8c492985ff6a27cd3a22b08e8f7b96b407bd8030967b6efbb9f63b80cf ; 'linux' in sys_platform and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.4.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=df4bb3282bac9a3b90231700077110d8680b338416de03c2b7c6133c9b602649 ; 'linux' in sys_platform and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.4.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=60da63c99ca827bdcb0df28e0298bf7d066dc607454c6d6176783cb4e79d838b ; 'linux' in sys_platform and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.4.0-cp39-cp39-win_amd64.whl#sha256=64aea8de349f3e2e0ebf4c24b011a8122531fdffda5776edaef45829cc241cf8 ; sys_platform == 'win32' and python_version == '3.9' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.4.0-cp310-cp310-win_amd64.whl#sha256=ae573d255b257fdbed319a3440dc9d0a721e31160ab7f6eba1b2226e6a409a1d ; sys_platform == 'win32' and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.4.0-cp311-cp311-win_amd64.whl#sha256=8e0ea4558e5776d8ddab0264310be9b26aee5641bcac0da023537556d4317b86 ; sys_platform == 'win32' and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.4.0-cp312-cp312-win_amd64.whl#sha256=4090dde07a4fffc34aaf855701a9db28e9fccb57b368ade520f1a0f8e811c878 ; sys_platform == 'win32' and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.4.0-cp313-cp313-win_amd64.whl#sha256=a33d0888f3c8df028a2d028842715837d0049524d6c06b9bb11869890a13601a ; sys_platform == 'win32' and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n\n    \"torch @ https://download.pytorch.org/whl/xpu/torch-2.8.0%2Bxpu-cp39-cp39-linux_x86_64.whl ; 'linux' in sys_platform and python_version == '3.9' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"torch @ https://download.pytorch.org/whl/xpu/torch-2.8.0%2Bxpu-cp310-cp310-linux_x86_64.whl ; 'linux' in sys_platform and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"torch @ https://download.pytorch.org/whl/xpu/torch-2.8.0%2Bxpu-cp311-cp311-linux_x86_64.whl ; 'linux' in sys_platform and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"torch @ https://download.pytorch.org/whl/xpu/torch-2.8.0%2Bxpu-cp312-cp312-linux_x86_64.whl ; 'linux' in sys_platform and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"torch @ https://download.pytorch.org/whl/xpu/torch-2.8.0%2Bxpu-cp313-cp313-linux_x86_64.whl ; 'linux' in sys_platform and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"torch @ https://download.pytorch.org/whl/xpu/torch-2.8.0%2Bxpu-cp39-cp39-win_amd64.whl#sha256=f2f401276892428e4875cf1d8717c5cbab704b16fc594ccf23795e7b16549a99 ; sys_platform == 'win32' and python_version == '3.9' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"torch @ https://download.pytorch.org/whl/xpu/torch-2.8.0%2Bxpu-cp310-cp310-win_amd64.whl#sha256=125c60cd59d51b39581a7e9afcd4679bc3a6b8c1f9440b1bb502a23fdd60571e ; sys_platform == 'win32' and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"torch @ https://download.pytorch.org/whl/xpu/torch-2.8.0%2Bxpu-cp311-cp311-win_amd64.whl#sha256=47f1a57258cd460e80b38b2ed6744e31587ab77a96b4215bf59546cb4bab5cc0 ; sys_platform == 'win32' and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"torch @ https://download.pytorch.org/whl/xpu/torch-2.8.0%2Bxpu-cp312-cp312-win_amd64.whl#sha256=0937d8943c145a83d9bafc6f80ef28971167817f9eda26066d33f72caf8a6646 ; sys_platform == 'win32' and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"torch @ https://download.pytorch.org/whl/xpu/torch-2.8.0%2Bxpu-cp313-cp313-win_amd64.whl#sha256=e034aab1d71760dc80a731531be43673ffe15e99033b82d24e40d2e6d41bd8bf ; sys_platform == 'win32' and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n\n    \"bitsandbytes @ https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-1.33.7.preview-py3-none-manylinux_2_24_x86_64.whl ; ('linux' in sys_platform) and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"bitsandbytes @ https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-1.33.7.preview-py3-none-win_amd64.whl ; (sys_platform == 'win32') and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n\n    \"torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.23.0%2Bxpu-cp39-cp39-manylinux_2_28_x86_64.whl#sha256=6e981c192045fc249c008441179ff237bb00174d818b875b0475730b63f0eaca ; 'linux' in sys_platform and python_version == '3.9' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.23.0%2Bxpu-cp310-cp310-manylinux_2_28_x86_64.whl#sha256=e5ba4805969277175ebfd59cc717093528cc6e3ada89ac2725fc7a3c1fee6169 ; 'linux' in sys_platform and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.23.0%2Bxpu-cp311-cp311-manylinux_2_28_x86_64.whl#sha256=74c39c144104416bc4c5ad8c26ab0c169dc5cc6be58059e01bc3665dd0ef676f ; 'linux' in sys_platform and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.23.0%2Bxpu-cp312-cp312-manylinux_2_28_x86_64.whl#sha256=0acec355b80c3899841184084f365df336c508602812e34a44007b8b60d53af4 ; 'linux' in sys_platform and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.23.0%2Bxpu-cp313-cp313-manylinux_2_28_x86_64.whl#sha256=e2109ae773dad27b98ca17681044b4f876563c37f2382b75de3a371399edcff8 ; 'linux' in sys_platform and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.23.0%2Bxpu-cp39-cp39-win_amd64.whl#sha256=5f7904e7048d414379bc8c1167260f1e84204f105db2d0a2f9c89e87ce1cf205 ; sys_platform == 'win32' and python_version == '3.9' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.23.0%2Bxpu-cp310-cp310-win_amd64.whl#sha256=005fca5e658ca8e37adb63c1a021c84f5e56dfa6cf0d601d89cfe40b9473f79f ; sys_platform == 'win32' and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.23.0%2Bxpu-cp311-cp311-win_amd64.whl#sha256=c6d030f5361461550c0ff1339b5bca8585fc1e84fda2e64b6184e65a581e4f98 ; sys_platform == 'win32' and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.23.0%2Bxpu-cp312-cp312-win_amd64.whl#sha256=91aafd61864cdce27461cbec13ddbf28c1bc6494265a1e4b80131c64a3b7d18f ; sys_platform == 'win32' and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.23.0%2Bxpu-cp313-cp313-win_amd64.whl#sha256=71dc4a6421742ed1e7f585b04a100ad53615c341fbccfbc255aefb38ea9091da ; sys_platform == 'win32' and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n]\nintel-gpu-torch280 = [\n    \"unsloth[intelgputorch280]\"\n]\nintelgputorch290 = [\n    \"unsloth_zoo[intelgpu]\",\n    \"unsloth[huggingfacenotorch]\",\n\n    \"pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=c169a1de14c19673b17c751290d467fa282fc90fa5da4314b2e5cdab1f553146 ; platform_system == 'Linux' and python_version == '3.10' and platform_machine == 'x86_64'\",\n    \"pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=013d9dd5d6479bd22983161f462e61c8dbe1d82e6730624a7a8d5945507eaa61 ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'\",\n    \"pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=afc8cabfbf7ed51fd278d1e0f88d6afc157b0201bad4b99d681e4d542f9e66d4 ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'\",\n    \"pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=0d24c1716088f2764d0d24c64227732195b6a42706c3c5fc89eeb4904bfa0818 ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'\",\n    \"pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp310-cp310-win_amd64.whl#sha256=c83ab007311d9cfb6e809ee5a4587d99a9eef4be720b90da4f1aaa68b45139a0 ; sys_platform == 'win32' and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp311-cp311-win_amd64.whl#sha256=debf75348da8e8c7166b4d4a9b91d1508bb8d6581e339f79f7604b2e6746bacd ; sys_platform == 'win32' and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp312-cp312-win_amd64.whl#sha256=97337a47425f1963a723475bd61037460e84ba01db4f87a1d662c3718ff6c47e ; sys_platform == 'win32' and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp313-cp313-win_amd64.whl#sha256=2caf8138695f6abb023ecd02031a2611ba1bf8fff2f19802567cb2fadefe9e87 ; sys_platform == 'win32' and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n\n    \"torch @ https://download.pytorch.org/whl/xpu/torch-2.9.0%2Bxpu-cp310-cp310-linux_x86_64.whl#sha256=5afbe860ce991825a36b75706a523601087e414b77598ef0d9d3d565741c277d ; platform_system == 'Linux' and python_version == '3.10' and platform_machine == 'x86_64'\",\n    \"torch @ https://download.pytorch.org/whl/xpu/torch-2.9.0%2Bxpu-cp311-cp311-linux_x86_64.whl#sha256=607fe419c32d6e8e0556f745742e7cff1d0babce51f54be890e0c1422359c442 ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'\",\n    \"torch @ https://download.pytorch.org/whl/xpu/torch-2.9.0%2Bxpu-cp312-cp312-linux_x86_64.whl#sha256=376bae584d89980b8e59934d248c38d5fa3b7d4687a4df1a19f4bc1d23dcc8c1 ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'\",\n    \"torch @ https://download.pytorch.org/whl/xpu/torch-2.9.0%2Bxpu-cp313-cp313-linux_x86_64.whl#sha256=98d6a06dd7fb185874367b18bd609f05f16fdce4142a5980ca94461949965cd2 ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'\",\n    \"torch @ https://download.pytorch.org/whl/xpu/torch-2.9.0%2Bxpu-cp310-cp310-win_amd64.whl#sha256=47cc68f631f65bd9c84924d052cd04dec7531023caa85e80345e9c94611c887d ; sys_platform == 'win32' and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"torch @ https://download.pytorch.org/whl/xpu/torch-2.9.0%2Bxpu-cp311-cp311-win_amd64.whl#sha256=d56c44ab4818aba57e5c7b628f422d014e0d507427170a771c5be85e308b0bc6 ; sys_platform == 'win32' and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"torch @ https://download.pytorch.org/whl/xpu/torch-2.9.0%2Bxpu-cp312-cp312-win_amd64.whl#sha256=18cad93aaff76a01ce73aef6935ece7cfc03344b905592ec731446c44d44592b ; sys_platform == 'win32' and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"torch @ https://download.pytorch.org/whl/xpu/torch-2.9.0%2Bxpu-cp313-cp313-win_amd64.whl#sha256=579929cdc10a76800ead41289cac191ea36d1b16f5f501d3fc25607d4375cd83 ; sys_platform == 'win32' and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n\n    \"bitsandbytes @ https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-1.33.7.preview-py3-none-manylinux_2_24_x86_64.whl ; ('linux' in sys_platform) and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"bitsandbytes @ https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-1.33.7.preview-py3-none-win_amd64.whl ; (sys_platform == 'win32') and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n\n    \"torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.24.0%2Bxpu-cp310-cp310-manylinux_2_28_x86_64.whl#sha256=cbfae2b79b7549fd368c2462fc8e94f8f26cc450782ee72138e908077c09a519 ; platform_system == 'Linux' and python_version == '3.10' and platform_machine == 'x86_64'\",\n    \"torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.24.0%2Bxpu-cp311-cp311-manylinux_2_28_x86_64.whl#sha256=044fa36ef4b6b43edcd490b75c853fa4b3eb033c2bded29f8fbcf27734713c67 ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'\",\n    \"torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.24.0%2Bxpu-cp312-cp312-manylinux_2_28_x86_64.whl#sha256=4b91e4bec1d740a6211f02578a79888550b73f3a4e1383035f8f6d72f587212c ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'\",\n    \"torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.24.0%2Bxpu-cp313-cp313-manylinux_2_28_x86_64.whl#sha256=88239e73ca37254bec84f29cd5887e10ff712de7edbbda3fbb3609cd6190d99e ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'\",\n    \"torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.24.0%2Bxpu-cp310-cp310-win_amd64.whl#sha256=19c7da8ca767d593e13a88a12bb08d06e34a673f6f26c2f9c191d60e81c02953 ; sys_platform == 'win32' and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.24.0%2Bxpu-cp311-cp311-win_amd64.whl#sha256=9bb0d1421c544ac8e2eca5b47daacaf54706dc9139c003aa5e77ee5f355c5931 ; sys_platform == 'win32' and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.24.0%2Bxpu-cp312-cp312-win_amd64.whl#sha256=6a5194bc736089606342d48a3f6822829b167617e9495d91d753dd1bd46fda18 ; sys_platform == 'win32' and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.24.0%2Bxpu-cp313-cp313-win_amd64.whl#sha256=da47a3ce2bb7f0301a31124668b5908f9b9e92d6241443de15a310ef9632fd83 ; sys_platform == 'win32' and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n]\nintel-gpu-torch290 = [\n    \"unsloth[intelgputorch290]\"\n]\nintelgputorch210 = [\n    \"unsloth_zoo[intelgpu]\",\n    \"unsloth[huggingfacenotorch]\",\n\n    \"pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=c169a1de14c19673b17c751290d467fa282fc90fa5da4314b2e5cdab1f553146 ; platform_system == 'Linux' and python_version == '3.10' and platform_machine == 'x86_64'\",\n    \"pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=013d9dd5d6479bd22983161f462e61c8dbe1d82e6730624a7a8d5945507eaa61 ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'\",\n    \"pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=afc8cabfbf7ed51fd278d1e0f88d6afc157b0201bad4b99d681e4d542f9e66d4 ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'\",\n    \"pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=0d24c1716088f2764d0d24c64227732195b6a42706c3c5fc89eeb4904bfa0818 ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'\",\n    \"pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp310-cp310-win_amd64.whl#sha256=c83ab007311d9cfb6e809ee5a4587d99a9eef4be720b90da4f1aaa68b45139a0 ; sys_platform == 'win32' and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp311-cp311-win_amd64.whl#sha256=debf75348da8e8c7166b4d4a9b91d1508bb8d6581e339f79f7604b2e6746bacd ; sys_platform == 'win32' and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp312-cp312-win_amd64.whl#sha256=97337a47425f1963a723475bd61037460e84ba01db4f87a1d662c3718ff6c47e ; sys_platform == 'win32' and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp313-cp313-win_amd64.whl#sha256=2caf8138695f6abb023ecd02031a2611ba1bf8fff2f19802567cb2fadefe9e87 ; sys_platform == 'win32' and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n\n    \"torch @ https://download.pytorch.org/whl/xpu/torch-2.10.0%2Bxpu-cp310-cp310-linux_x86_64.whl#sha256=abb1d1ec1ac672bac0ff35420c965f2df0c636ef9d94e2a830e34578489d0a57 ; platform_system == 'Linux' and python_version == '3.10' and platform_machine == 'x86_64'\",\n    \"torch @ https://download.pytorch.org/whl/xpu/torch-2.10.0%2Bxpu-cp311-cp311-linux_x86_64.whl#sha256=71ad2f82da0f41eaec159f39fc85854e27c2391efa91b373e550648a6f4aaad3 ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'\",\n    \"torch @ https://download.pytorch.org/whl/xpu/torch-2.10.0%2Bxpu-cp312-cp312-linux_x86_64.whl#sha256=b473571d478912f92881cc13f15fa18f8463fb0fb8a068c96ed47a7d45a4da0a ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'\",\n    \"torch @ https://download.pytorch.org/whl/xpu/torch-2.10.0%2Bxpu-cp313-cp313-linux_x86_64.whl#sha256=3bc64a746ff25a93de140902c60c9e819d7413f5cea1e88d80999c27a5901e9c ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'\",\n    \"torch @ https://download.pytorch.org/whl/xpu/torch-2.10.0%2Bxpu-cp310-cp310-win_amd64.whl#sha256=ce50691ab3fb6301d9b7bb8b3834cf5fa7152a2b5f91fd24c5efdc601a25b780 ; sys_platform == 'win32' and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"torch @ https://download.pytorch.org/whl/xpu/torch-2.10.0%2Bxpu-cp311-cp311-win_amd64.whl#sha256=cb9d37f21cb9fb7df67d62863f021c3144e8d8832b9ea8e8523ac308bc620ea1 ; sys_platform == 'win32' and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"torch @ https://download.pytorch.org/whl/xpu/torch-2.10.0%2Bxpu-cp312-cp312-win_amd64.whl#sha256=3ad605be4728b6d3a28a44d07dd794b1a9e45551b0057815bf25eb2a6d6a56a7 ; sys_platform == 'win32' and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"torch @ https://download.pytorch.org/whl/xpu/torch-2.10.0%2Bxpu-cp313-cp313-win_amd64.whl#sha256=2b4b56dd6c792aef82006904fa888692e3782e4ae5da27526801bad4898f05a5 ; sys_platform == 'win32' and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n\n    \"bitsandbytes @ https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-1.33.7.preview-py3-none-manylinux_2_24_x86_64.whl ; ('linux' in sys_platform) and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"bitsandbytes @ https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-1.33.7.preview-py3-none-win_amd64.whl ; (sys_platform == 'win32') and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n\n    \"torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.25.0%2Bxpu-cp310-cp310-manylinux_2_28_x86_64.whl#sha256=7e1e7b170fcf7161c8499b67156c5a05462243626dc0974010791a0bab4378d3 ; platform_system == 'Linux' and python_version == '3.10' and platform_machine == 'x86_64'\",\n    \"torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.25.0%2Bxpu-cp311-cp311-manylinux_2_28_x86_64.whl#sha256=bd6add201bd7628af70437292e1447abb368e0b5f4ff9abd334ae435efd44792 ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'\",\n    \"torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.25.0%2Bxpu-cp312-cp312-manylinux_2_28_x86_64.whl#sha256=6ad2543496bc29e59d3dd614a94d09aa9870318aedb66045344fffddfedd2cf8 ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'\",\n    \"torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.25.0%2Bxpu-cp313-cp313-manylinux_2_28_x86_64.whl#sha256=80269f37865fcd8b57f20e4786efae2200bfa2b2727926c3c7acc82f0e7d3548 ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'\",\n    \"torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.25.0%2Bxpu-cp310-cp310-win_amd64.whl#sha256=6b9485ba85dcba4d196d6134d9c3332fb228fb2556416bf0450a64e8a472fcba ; sys_platform == 'win32' and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.25.0%2Bxpu-cp311-cp311-win_amd64.whl#sha256=36cbaedf10f6412af5c89afd9aeea474e6a56a0050348ada8fabe1ecaf6b879e ; sys_platform == 'win32' and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.25.0%2Bxpu-cp312-cp312-win_amd64.whl#sha256=738357d97468d75fe3d510ac37e65130f2787f81d9bbc1518898f7396dc3403f ; sys_platform == 'win32' and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n    \"torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.25.0%2Bxpu-cp313-cp313-win_amd64.whl#sha256=1c4b44b36a557f7381e3076fb8843366742238648441d607c8d049c6da0f8886 ; sys_platform == 'win32' and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n]\nintel-gpu-torch210 = [\n    \"unsloth[intelgputorch210]\"\n]\nintel = [\n    \"unsloth[intelgputorch280]\",\n]\namd = [\n    \"unsloth[huggingfacenotorch]\",\n    \"bitsandbytes>=0.49.1 ; ('linux' in sys_platform) and (platform_machine == 'AMD64' or platform_machine == 'x86_64' or platform_machine == 'aarch64')\",\n    \"bitsandbytes>=0.49.1 ; (sys_platform == 'win32') and (platform_machine == 'AMD64' or platform_machine == 'x86_64')\",\n]\nrocm702-torch280 = [\n    \"unsloth[amd]\",\n\n    \"triton @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.0.2/triton-3.4.0%2Brocm7.0.2.gitf9e5bf54-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'\",\n    \"triton @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.0.2/triton-3.4.0%2Brocm7.0.2.gitf9e5bf54-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'\",\n    \"triton @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.0.2/triton-3.4.0%2Brocm7.0.2.gitf9e5bf54-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'\",\n\n    \"torch @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.0.2/torch-2.8.0%2Brocm7.0.2.lw.git245bf6ed-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'\",\n    \"torch @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.0.2/torch-2.8.0%2Brocm7.0.2.lw.git245bf6ed-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'\",\n    \"torch @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.0.2/torch-2.8.0%2Brocm7.0.2.lw.git245bf6ed-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'\",\n\n    \"torchvision @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.0.2/torchvision-0.23.0%2Brocm7.0.2.git824e8c87-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'\",\n    \"torchvision @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.0.2/torchvision-0.23.0%2Brocm7.0.2.git824e8c87-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'\",\n    \"torchvision @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.0.2/torchvision-0.23.0%2Brocm7.0.2.git824e8c87-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'\",\n]\nrocm72-torch291 = [\n    \"unsloth[amd]\",\n\n    \"triton @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/triton-3.5.1%2Brocm7.2.0.gita272dfa8-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.10' and platform_machine == 'x86_64'\",\n    \"triton @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/triton-3.5.1%2Brocm7.2.0.gita272dfa8-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'\",\n    \"triton @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/triton-3.5.1%2Brocm7.2.0.gita272dfa8-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'\",\n    \"triton @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/triton-3.5.1%2Brocm7.2.0.gita272dfa8-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'\",\n\n    \"torch @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torch-2.9.1%2Brocm7.2.0.lw.git7e1940d4-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.10' and platform_machine == 'x86_64'\",\n    \"torch @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torch-2.9.1%2Brocm7.2.0.lw.git7e1940d4-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'\",\n    \"torch @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torch-2.9.1%2Brocm7.2.0.lw.git7e1940d4-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'\",\n    \"torch @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torch-2.9.1%2Brocm7.2.0.lw.git7e1940d4-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'\",\n    \"torch @ https://repo.radeon.com/rocm/windows/rocm-rel-7.2/torch-2.9.1%2Brocmsdk20260116-cp312-cp312-win_amd64.whl ; sys_platform == 'win32' and python_version == '3.12'\",\n\n    \"torchvision @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torchvision-0.24.0%2Brocm7.2.0.gitb919bd0c-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.10' and platform_machine == 'x86_64'\",\n    \"torchvision @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torchvision-0.24.0%2Brocm7.2.0.gitb919bd0c-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'\",\n    \"torchvision @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torchvision-0.24.0%2Brocm7.2.0.gitb919bd0c-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'\",\n    \"torchvision @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torchvision-0.24.0%2Brocm7.2.0.gitb919bd0c-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'\",\n    \"torchvision @ https://repo.radeon.com/rocm/windows/rocm-rel-7.2/torchvision-0.24.1%2Brocmsdk20260116-cp312-cp312-win_amd64.whl ; sys_platform == 'win32' and python_version == '3.12'\",\n]\nrocm711-torch291 = [\n    \"unsloth[amd]\",\n\n    \"triton @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/triton-3.5.1%2Brocm7.1.1.gita272dfa8-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.10' and platform_machine == 'x86_64'\",\n    \"triton @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/triton-3.5.1%2Brocm7.1.1.gita272dfa8-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'\",\n    \"triton @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/triton-3.5.1%2Brocm7.1.1.gita272dfa8-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'\",\n    \"triton @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/triton-3.5.1%2Brocm7.1.1.gita272dfa8-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'\",\n\n    \"torch @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/torch-2.9.1%2Brocm7.1.1.lw.git351ff442-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.10' and platform_machine == 'x86_64'\",\n    \"torch @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/torch-2.9.1%2Brocm7.1.1.lw.git351ff442-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'\",\n    \"torch @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/torch-2.9.1%2Brocm7.1.1.lw.git351ff442-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'\",\n    \"torch @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/torch-2.9.1%2Brocm7.1.1.lw.git351ff442-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'\",\n\n    \"torchvision @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/torchvision-0.24.0%2Brocm7.1.1.gitb919bd0c-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.10' and platform_machine == 'x86_64'\",\n    \"torchvision @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/torchvision-0.24.0%2Brocm7.1.1.gitb919bd0c-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'\",\n    \"torchvision @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/torchvision-0.24.0%2Brocm7.1.1.gitb919bd0c-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'\",\n    \"torchvision @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/torchvision-0.24.0%2Brocm7.1.1.gitb919bd0c-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'\",\n]\nrocm72-torch2100 = [\n    \"unsloth[amd]\",\n\n    \"triton @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/triton-3.6.0%2Brocm7.2.0.gitba5c1517-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.10' and platform_machine == 'x86_64'\",\n    \"triton @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/triton-3.6.0%2Brocm7.2.0.gitba5c1517-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'\",\n    \"triton @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/triton-3.6.0%2Brocm7.2.0.gitba5c1517-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'\",\n    \"triton @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/triton-3.6.0%2Brocm7.2.0.gitba5c1517-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'\",\n\n    \"torch @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torch-2.10.0%2Brocm7.2.0.lw.gitb6ee5fde-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.10' and platform_machine == 'x86_64'\",\n    \"torch @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torch-2.10.0%2Brocm7.2.0.lw.gitb6ee5fde-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'\",\n    \"torch @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torch-2.10.0%2Brocm7.2.0.lw.gitb6ee5fde-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'\",\n    \"torch @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torch-2.10.0%2Brocm7.2.0.lw.gitb6ee5fde-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'\",\n\n    \"torchvision @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torchvision-0.25.0%2Brocm7.2.0.git82df5f59-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.10' and platform_machine == 'x86_64'\",\n    \"torchvision @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torchvision-0.25.0%2Brocm7.2.0.git82df5f59-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'\",\n    \"torchvision @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torchvision-0.25.0%2Brocm7.2.0.git82df5f59-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'\",\n    \"torchvision @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torchvision-0.25.0%2Brocm7.2.0.git82df5f59-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'\",\n]\nrocm711-torch2100 = [\n    \"unsloth[amd]\",\n\n    \"triton @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/triton-3.6.0%2Brocm7.1.1.gitba5c1517-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.10' and platform_machine == 'x86_64'\",\n    \"triton @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/triton-3.6.0%2Brocm7.1.1.gitba5c1517-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'\",\n    \"triton @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/triton-3.6.0%2Brocm7.1.1.gitba5c1517-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'\",\n    \"triton @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/triton-3.6.0%2Brocm7.1.1.gitba5c1517-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'\",\n\n    \"torch @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/torch-2.10.0%2Brocm7.1.1.lw.gitd9556b05-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.10' and platform_machine == 'x86_64'\",\n    \"torch @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/torch-2.10.0%2Brocm7.1.1.lw.gitd9556b05-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'\",\n    \"torch @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/torch-2.10.0%2Brocm7.1.1.lw.gitd9556b05-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'\",\n    \"torch @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/torch-2.10.0%2Brocm7.1.1.lw.gitd9556b05-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'\",\n\n    \"torchvision @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/torchvision-0.25.0%2Brocm7.1.1.git82df5f59-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.10' and platform_machine == 'x86_64'\",\n    \"torchvision @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/torchvision-0.25.0%2Brocm7.1.1.git82df5f59-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'\",\n    \"torchvision @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/torchvision-0.25.0%2Brocm7.1.1.git82df5f59-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'\",\n    \"torchvision @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/torchvision-0.25.0%2Brocm7.1.1.git82df5f59-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'\",\n]\n\n[project.urls]\nhomepage = \"https://unsloth.ai\"\ndocumentation = \"https://unsloth.ai/docs\"\nrepository = \"https://github.com/unslothai/unsloth\"\n\n[tool.ruff]\ntarget-version = \"py311\"\nforce-exclude = true\nextend-exclude = [\n    \"*chat_templates.py\",\n    \"*ollama_template_mappers.py\",\n    \"*_auto_install.py\",\n    \"*mapper.py\",\n]\n\n[tool.ruff.lint]\nselect = [\"E9\", \"F63\", \"F7\", \"F82\"]\nignore = [\n    \"E402\",\n    \"E722\",\n    \"F403\",\n    \"F405\",\n    \"F811\",\n    \"F821\",\n    \"F841\",\n    \"F401\",\n    \"E731\",\n    \"E741\",\n    \"F601\",\n    \"E712\",\n]\n\n[tool.ruff.format]\n"
  },
  {
    "path": "scripts/enforce_kwargs_spacing.py",
    "content": "#!/usr/bin/env python3\n\"\"\"Ensure keyword arguments use spaces around '=', prune redundant pass statements.\"\"\"\n\nfrom __future__ import annotations\n\nimport ast\nimport argparse\nimport io\nimport sys\nimport tokenize\nfrom collections import defaultdict\nfrom pathlib import Path\n\n\ndef enforce_spacing(text: str) -> tuple[str, bool]:\n    \"\"\"Return updated text with keyword '=' padded by spaces, plus change flag.\"\"\"\n    lines = text.splitlines(keepends=True)\n    if not lines:\n        return text, False\n\n    offsets: dict[int, int] = defaultdict(int)\n    changed = False\n\n    reader = io.StringIO(text).readline\n    for token in tokenize.generate_tokens(reader):\n        if token.type != tokenize.OP or token.string != \"=\":\n            continue\n\n        line_index = token.start[0] - 1\n        col = token.start[1] + offsets[line_index]\n\n        if line_index < 0 or line_index >= len(lines):\n            continue\n\n        line = lines[line_index]\n        if col >= len(line) or line[col] != \"=\":\n            continue\n\n        line_changed = False\n\n        # Insert a space before '=' when missing and not preceded by whitespace.\n        if col > 0 and line[col - 1] not in {\" \", \"\\t\"}:\n            line = f\"{line[:col]} {line[col:]}\"\n            offsets[line_index] += 1\n            col += 1\n            line_changed = True\n            changed = True\n\n        # Insert a space after '=' when missing and not followed by whitespace or newline.\n        next_index = col + 1\n        if next_index < len(line) and line[next_index] not in {\" \", \"\\t\", \"\\n\", \"\\r\"}:\n            line = f\"{line[:next_index]} {line[next_index:]}\"\n            offsets[line_index] += 1\n            line_changed = True\n            changed = True\n\n        if line_changed:\n            lines[line_index] = line\n\n    if not changed:\n        return text, False\n\n    return \"\".join(lines), True\n\n\ndef remove_redundant_passes(text: str) -> tuple[str, bool]:\n    \"\"\"Drop pass statements that share a block with other executable code.\"\"\"\n\n    try:\n        tree = ast.parse(text)\n    except SyntaxError:\n        return text, False\n\n    redundant: list[ast.Pass] = []\n\n    def visit(node: ast.AST) -> None:\n        for attr in (\"body\", \"orelse\", \"finalbody\"):\n            value = getattr(node, attr, None)\n            if not isinstance(value, list) or len(value) <= 1:\n                continue\n            for stmt in value:\n                if isinstance(stmt, ast.Pass):\n                    redundant.append(stmt)\n            for stmt in value:\n                if isinstance(stmt, ast.AST):\n                    visit(stmt)\n        handlers = getattr(node, \"handlers\", None)\n        if handlers:\n            for handler in handlers:\n                visit(handler)\n\n    visit(tree)\n\n    if not redundant:\n        return text, False\n\n    lines = text.splitlines(keepends=True)\n    changed = False\n\n    for node in sorted(\n        redundant, key=lambda item: (item.lineno, item.col_offset), reverse=True\n    ):\n        start = node.lineno - 1\n        end = (node.end_lineno or node.lineno) - 1\n        if start >= len(lines):\n            continue\n        changed = True\n        if start == end:\n            line = lines[start]\n            col_start = node.col_offset\n            col_end = node.end_col_offset or (col_start + 4)\n            segment = line[:col_start] + line[col_end:]\n            lines[start] = segment if segment.strip() else \"\"\n            continue\n\n        # Defensive fall-back for unexpected multi-line 'pass'.\n        prefix = lines[start][: node.col_offset]\n        lines[start] = prefix if prefix.strip() else \"\"\n        for idx in range(start + 1, end):\n            lines[idx] = \"\"\n        suffix = lines[end][(node.end_col_offset or 0) :]\n        lines[end] = suffix\n\n    # Normalise to ensure lines end with newlines except at EOF.\n    result_lines: list[str] = []\n    for index, line in enumerate(lines):\n        if not line:\n            continue\n        if index < len(lines) - 1 and not line.endswith(\"\\n\"):\n            result_lines.append(f\"{line}\\n\")\n        else:\n            result_lines.append(line)\n\n    return \"\".join(result_lines), changed\n\n\ndef process_file(path: Path) -> bool:\n    try:\n        with tokenize.open(path) as handle:\n            original = handle.read()\n            encoding = handle.encoding\n    except (OSError, SyntaxError) as exc:  # SyntaxError from tokenize on invalid python\n        print(f\"Failed to read {path}: {exc}\", file=sys.stderr)\n        return False\n\n    updated, changed = enforce_spacing(original)\n    updated, removed = remove_redundant_passes(updated)\n    if changed or removed:\n        path.write_text(updated, encoding=encoding)\n        return True\n    return False\n\n\ndef main(argv: list[str]) -> int:\n    parser = argparse.ArgumentParser(description=__doc__)\n    parser.add_argument(\"files\", nargs=\"+\", help=\"Python files to fix\")\n    args = parser.parse_args(argv)\n\n    touched: list[Path] = []\n    self_path = Path(__file__).resolve()\n\n    for entry in args.files:\n        path = Path(entry)\n        # Skip modifying this script to avoid self-edit loops.\n        if path.resolve() == self_path:\n            continue\n        if not path.exists() or path.is_dir():\n            continue\n        if process_file(path):\n            touched.append(path)\n\n    if touched:\n        for path in touched:\n            print(f\"Adjusted kwarg spacing in {path}\")\n    return 0\n\n\nif __name__ == \"__main__\":\n    sys.exit(main(sys.argv[1:]))\n"
  },
  {
    "path": "scripts/run_ruff_format.py",
    "content": "#!/usr/bin/env python3\n\"\"\"Run `ruff format` followed by kwarg spacing enforcement.\"\"\"\n\nfrom __future__ import annotations\n\nimport subprocess\nimport sys\nfrom pathlib import Path\n\nHERE = Path(__file__).resolve().parent\n\n\ndef main(argv: list[str]) -> int:\n    files = [arg for arg in argv if Path(arg).exists()]\n    if not files:\n        return 0\n\n    ruff_cmd = [sys.executable, \"-m\", \"ruff\", \"format\", *files]\n    ruff_proc = subprocess.run(ruff_cmd)\n    if ruff_proc.returncode != 0:\n        return ruff_proc.returncode\n\n    spacing_script = HERE / \"enforce_kwargs_spacing.py\"\n    spacing_cmd = [sys.executable, str(spacing_script), *files]\n    spacing_proc = subprocess.run(spacing_cmd)\n    return spacing_proc.returncode\n\n\nif __name__ == \"__main__\":\n    raise SystemExit(main(sys.argv[1:]))\n"
  },
  {
    "path": "studio/LICENSE.AGPL-3.0",
    "content": "                    GNU AFFERO GENERAL PUBLIC LICENSE\n                       Version 3, 19 November 2007\n\n Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>\n Everyone is permitted to copy and distribute verbatim copies\n of this license document, but changing it is not allowed.\n\n                            Preamble\n\n  The GNU Affero General Public License is a free, copyleft license for\nsoftware and other kinds of works, specifically designed to ensure\ncooperation with the community in the case of network server software.\n\n  The licenses for most software and other practical works are designed\nto take away your freedom to share and change the works.  By contrast,\nour General Public Licenses are intended to guarantee your freedom to\nshare and change all versions of a program--to make sure it remains free\nsoftware for all its users.\n\n  When we speak of free software, we are referring to freedom, not\nprice.  Our General Public Licenses are designed to make sure that you\nhave the freedom to distribute copies of free software (and charge for\nthem if you wish), that you receive source code or can get it if you\nwant it, that you can change the software or use pieces of it in new\nfree programs, and that you know you can do these things.\n\n  Developers that use our General Public Licenses protect your rights\nwith two steps: (1) assert copyright on the software, and (2) offer\nyou this License which gives you legal permission to copy, distribute\nand/or modify the software.\n\n  A secondary benefit of defending all users' freedom is that\nimprovements made in alternate versions of the program, if they\nreceive widespread use, become available for other developers to\nincorporate.  Many developers of free software are heartened and\nencouraged by the resulting cooperation.  However, in the case of\nsoftware used on network servers, this result may fail to come about.\nThe GNU General Public License permits making a modified version and\nletting the public access it on a server without ever releasing its\nsource code to the public.\n\n  The GNU Affero General Public License is designed specifically to\nensure that, in such cases, the modified source code becomes available\nto the community.  It requires the operator of a network server to\nprovide the source code of the modified version running there to the\nusers of that server.  Therefore, public use of a modified version, on\na publicly accessible server, gives the public access to the source\ncode of the modified version.\n\n  An older license, called the Affero General Public License and\npublished by Affero, was designed to accomplish similar goals.  This is\na different license, not a version of the Affero GPL, but Affero has\nreleased a new version of the Affero GPL which permits relicensing under\nthis license.\n\n  The precise terms and conditions for copying, distribution and\nmodification follow.\n\n                       TERMS AND CONDITIONS\n\n  0. Definitions.\n\n  \"This License\" refers to version 3 of the GNU Affero General Public License.\n\n  \"Copyright\" also means copyright-like laws that apply to other kinds of\nworks, such as semiconductor masks.\n\n  \"The Program\" refers to any copyrightable work licensed under this\nLicense.  Each licensee is addressed as \"you\".  \"Licensees\" and\n\"recipients\" may be individuals or organizations.\n\n  To \"modify\" a work means to copy from or adapt all or part of the work\nin a fashion requiring copyright permission, other than the making of an\nexact copy.  The resulting work is called a \"modified version\" of the\nearlier work or a work \"based on\" the earlier work.\n\n  A \"covered work\" means either the unmodified Program or a work based\non the Program.\n\n  To \"propagate\" a work means to do anything with it that, without\npermission, would make you directly or secondarily liable for\ninfringement under applicable copyright law, except executing it on a\ncomputer or modifying a private copy.  Propagation includes copying,\ndistribution (with or without modification), making available to the\npublic, and in some countries other activities as well.\n\n  To \"convey\" a work means any kind of propagation that enables other\nparties to make or receive copies.  Mere interaction with a user through\na computer network, with no transfer of a copy, is not conveying.\n\n  An interactive user interface displays \"Appropriate Legal Notices\"\nto the extent that it includes a convenient and prominently visible\nfeature that (1) displays an appropriate copyright notice, and (2)\ntells the user that there is no warranty for the work (except to the\nextent that warranties are provided), that licensees may convey the\nwork under this License, and how to view a copy of this License.  If\nthe interface presents a list of user commands or options, such as a\nmenu, a prominent item in the list meets this criterion.\n\n  1. Source Code.\n\n  The \"source code\" for a work means the preferred form of the work\nfor making modifications to it.  \"Object code\" means any non-source\nform of a work.\n\n  A \"Standard Interface\" means an interface that either is an official\nstandard defined by a recognized standards body, or, in the case of\ninterfaces specified for a particular programming language, one that\nis widely used among developers working in that language.\n\n  The \"System Libraries\" of an executable work include anything, other\nthan the work as a whole, that (a) is included in the normal form of\npackaging a Major Component, but which is not part of that Major\nComponent, and (b) serves only to enable use of the work with that\nMajor Component, or to implement a Standard Interface for which an\nimplementation is available to the public in source code form.  A\n\"Major Component\", in this context, means a major essential component\n(kernel, window system, and so on) of the specific operating system\n(if any) on which the executable work runs, or a compiler used to\nproduce the work, or an object code interpreter used to run it.\n\n  The \"Corresponding Source\" for a work in object code form means all\nthe source code needed to generate, install, and (for an executable\nwork) run the object code and to modify the work, including scripts to\ncontrol those activities.  However, it does not include the work's\nSystem Libraries, or general-purpose tools or generally available free\nprograms which are used unmodified in performing those activities but\nwhich are not part of the work.  For example, Corresponding Source\nincludes interface definition files associated with source files for\nthe work, and the source code for shared libraries and dynamically\nlinked subprograms that the work is specifically designed to require,\nsuch as by intimate data communication or control flow between those\nsubprograms and other parts of the work.\n\n  The Corresponding Source need not include anything that users\ncan regenerate automatically from other parts of the Corresponding\nSource.\n\n  The Corresponding Source for a work in source code form is that\nsame work.\n\n  2. Basic Permissions.\n\n  All rights granted under this License are granted for the term of\ncopyright on the Program, and are irrevocable provided the stated\nconditions are met.  This License explicitly affirms your unlimited\npermission to run the unmodified Program.  The output from running a\ncovered work is covered by this License only if the output, given its\ncontent, constitutes a covered work.  This License acknowledges your\nrights of fair use or other equivalent, as provided by copyright law.\n\n  You may make, run and propagate covered works that you do not\nconvey, without conditions so long as your license otherwise remains\nin force.  You may convey covered works to others for the sole purpose\nof having them make modifications exclusively for you, or provide you\nwith facilities for running those works, provided that you comply with\nthe terms of this License in conveying all material for which you do\nnot control copyright.  Those thus making or running the covered works\nfor you must do so exclusively on your behalf, under your direction\nand control, on terms that prohibit them from making any copies of\nyour copyrighted material outside their relationship with you.\n\n  Conveying under any other circumstances is permitted solely under\nthe conditions stated below.  Sublicensing is not allowed; section 10\nmakes it unnecessary.\n\n  3. Protecting Users' Legal Rights From Anti-Circumvention Law.\n\n  No covered work shall be deemed part of an effective technological\nmeasure under any applicable law fulfilling obligations under article\n11 of the WIPO copyright treaty adopted on 20 December 1996, or\nsimilar laws prohibiting or restricting circumvention of such\nmeasures.\n\n  When you convey a covered work, you waive any legal power to forbid\ncircumvention of technological measures to the extent such circumvention\nis effected by exercising rights under this License with respect to\nthe covered work, and you disclaim any intention to limit operation or\nmodification of the work as a means of enforcing, against the work's\nusers, your or third parties' legal rights to forbid circumvention of\ntechnological measures.\n\n  4. Conveying Verbatim Copies.\n\n  You may convey verbatim copies of the Program's source code as you\nreceive it, in any medium, provided that you conspicuously and\nappropriately publish on each copy an appropriate copyright notice;\nkeep intact all notices stating that this License and any\nnon-permissive terms added in accord with section 7 apply to the code;\nkeep intact all notices of the absence of any warranty; and give all\nrecipients a copy of this License along with the Program.\n\n  You may charge any price or no price for each copy that you convey,\nand you may offer support or warranty protection for a fee.\n\n  5. Conveying Modified Source Versions.\n\n  You may convey a work based on the Program, or the modifications to\nproduce it from the Program, in the form of source code under the\nterms of section 4, provided that you also meet all of these conditions:\n\n    a) The work must carry prominent notices stating that you modified\n    it, and giving a relevant date.\n\n    b) The work must carry prominent notices stating that it is\n    released under this License and any conditions added under section\n    7.  This requirement modifies the requirement in section 4 to\n    \"keep intact all notices\".\n\n    c) You must license the entire work, as a whole, under this\n    License to anyone who comes into possession of a copy.  This\n    License will therefore apply, along with any applicable section 7\n    additional terms, to the whole of the work, and all its parts,\n    regardless of how they are packaged.  This License gives no\n    permission to license the work in any other way, but it does not\n    invalidate such permission if you have separately received it.\n\n    d) If the work has interactive user interfaces, each must display\n    Appropriate Legal Notices; however, if the Program has interactive\n    interfaces that do not display Appropriate Legal Notices, your\n    work need not make them do so.\n\n  A compilation of a covered work with other separate and independent\nworks, which are not by their nature extensions of the covered work,\nand which are not combined with it such as to form a larger program,\nin or on a volume of a storage or distribution medium, is called an\n\"aggregate\" if the compilation and its resulting copyright are not\nused to limit the access or legal rights of the compilation's users\nbeyond what the individual works permit.  Inclusion of a covered work\nin an aggregate does not cause this License to apply to the other\nparts of the aggregate.\n\n  6. Conveying Non-Source Forms.\n\n  You may convey a covered work in object code form under the terms\nof sections 4 and 5, provided that you also convey the\nmachine-readable Corresponding Source under the terms of this License,\nin one of these ways:\n\n    a) Convey the object code in, or embodied in, a physical product\n    (including a physical distribution medium), accompanied by the\n    Corresponding Source fixed on a durable physical medium\n    customarily used for software interchange.\n\n    b) Convey the object code in, or embodied in, a physical product\n    (including a physical distribution medium), accompanied by a\n    written offer, valid for at least three years and valid for as\n    long as you offer spare parts or customer support for that product\n    model, to give anyone who possesses the object code either (1) a\n    copy of the Corresponding Source for all the software in the\n    product that is covered by this License, on a durable physical\n    medium customarily used for software interchange, for a price no\n    more than your reasonable cost of physically performing this\n    conveying of source, or (2) access to copy the\n    Corresponding Source from a network server at no charge.\n\n    c) Convey individual copies of the object code with a copy of the\n    written offer to provide the Corresponding Source.  This\n    alternative is allowed only occasionally and noncommercially, and\n    only if you received the object code with such an offer, in accord\n    with subsection 6b.\n\n    d) Convey the object code by offering access from a designated\n    place (gratis or for a charge), and offer equivalent access to the\n    Corresponding Source in the same way through the same place at no\n    further charge.  You need not require recipients to copy the\n    Corresponding Source along with the object code.  If the place to\n    copy the object code is a network server, the Corresponding Source\n    may be on a different server (operated by you or a third party)\n    that supports equivalent copying facilities, provided you maintain\n    clear directions next to the object code saying where to find the\n    Corresponding Source.  Regardless of what server hosts the\n    Corresponding Source, you remain obligated to ensure that it is\n    available for as long as needed to satisfy these requirements.\n\n    e) Convey the object code using peer-to-peer transmission, provided\n    you inform other peers where the object code and Corresponding\n    Source of the work are being offered to the general public at no\n    charge under subsection 6d.\n\n  A separable portion of the object code, whose source code is excluded\nfrom the Corresponding Source as a System Library, need not be\nincluded in conveying the object code work.\n\n  A \"User Product\" is either (1) a \"consumer product\", which means any\ntangible personal property which is normally used for personal, family,\nor household purposes, or (2) anything designed or sold for incorporation\ninto a dwelling.  In determining whether a product is a consumer product,\ndoubtful cases shall be resolved in favor of coverage.  For a particular\nproduct received by a particular user, \"normally used\" refers to a\ntypical or common use of that class of product, regardless of the status\nof the particular user or of the way in which the particular user\nactually uses, or expects or is expected to use, the product.  A product\nis a consumer product regardless of whether the product has substantial\ncommercial, industrial or non-consumer uses, unless such uses represent\nthe only significant mode of use of the product.\n\n  \"Installation Information\" for a User Product means any methods,\nprocedures, authorization keys, or other information required to install\nand execute modified versions of a covered work in that User Product from\na modified version of its Corresponding Source.  The information must\nsuffice to ensure that the continued functioning of the modified object\ncode is in no case prevented or interfered with solely because\nmodification has been made.\n\n  If you convey an object code work under this section in, or with, or\nspecifically for use in, a User Product, and the conveying occurs as\npart of a transaction in which the right of possession and use of the\nUser Product is transferred to the recipient in perpetuity or for a\nfixed term (regardless of how the transaction is characterized), the\nCorresponding Source conveyed under this section must be accompanied\nby the Installation Information.  But this requirement does not apply\nif neither you nor any third party retains the ability to install\nmodified object code on the User Product (for example, the work has\nbeen installed in ROM).\n\n  The requirement to provide Installation Information does not include a\nrequirement to continue to provide support service, warranty, or updates\nfor a work that has been modified or installed by the recipient, or for\nthe User Product in which it has been modified or installed.  Access to a\nnetwork may be denied when the modification itself materially and\nadversely affects the operation of the network or violates the rules and\nprotocols for communication across the network.\n\n  Corresponding Source conveyed, and Installation Information provided,\nin accord with this section must be in a format that is publicly\ndocumented (and with an implementation available to the public in\nsource code form), and must require no special password or key for\nunpacking, reading or copying.\n\n  7. Additional Terms.\n\n  \"Additional permissions\" are terms that supplement the terms of this\nLicense by making exceptions from one or more of its conditions.\nAdditional permissions that are applicable to the entire Program shall\nbe treated as though they were included in this License, to the extent\nthat they are valid under applicable law.  If additional permissions\napply only to part of the Program, that part may be used separately\nunder those permissions, but the entire Program remains governed by\nthis License without regard to the additional permissions.\n\n  When you convey a copy of a covered work, you may at your option\nremove any additional permissions from that copy, or from any part of\nit.  (Additional permissions may be written to require their own\nremoval in certain cases when you modify the work.)  You may place\nadditional permissions on material, added by you to a covered work,\nfor which you have or can give appropriate copyright permission.\n\n  Notwithstanding any other provision of this License, for material you\nadd to a covered work, you may (if authorized by the copyright holders of\nthat material) supplement the terms of this License with terms:\n\n    a) Disclaiming warranty or limiting liability differently from the\n    terms of sections 15 and 16 of this License; or\n\n    b) Requiring preservation of specified reasonable legal notices or\n    author attributions in that material or in the Appropriate Legal\n    Notices displayed by works containing it; or\n\n    c) Prohibiting misrepresentation of the origin of that material, or\n    requiring that modified versions of such material be marked in\n    reasonable ways as different from the original version; or\n\n    d) Limiting the use for publicity purposes of names of licensors or\n    authors of the material; or\n\n    e) Declining to grant rights under trademark law for use of some\n    trade names, trademarks, or service marks; or\n\n    f) Requiring indemnification of licensors and authors of that\n    material by anyone who conveys the material (or modified versions of\n    it) with contractual assumptions of liability to the recipient, for\n    any liability that these contractual assumptions directly impose on\n    those licensors and authors.\n\n  All other non-permissive additional terms are considered \"further\nrestrictions\" within the meaning of section 10.  If the Program as you\nreceived it, or any part of it, contains a notice stating that it is\ngoverned by this License along with a term that is a further\nrestriction, you may remove that term.  If a license document contains\na further restriction but permits relicensing or conveying under this\nLicense, you may add to a covered work material governed by the terms\nof that license document, provided that the further restriction does\nnot survive such relicensing or conveying.\n\n  If you add terms to a covered work in accord with this section, you\nmust place, in the relevant source files, a statement of the\nadditional terms that apply to those files, or a notice indicating\nwhere to find the applicable terms.\n\n  Additional terms, permissive or non-permissive, may be stated in the\nform of a separately written license, or stated as exceptions;\nthe above requirements apply either way.\n\n  8. Termination.\n\n  You may not propagate or modify a covered work except as expressly\nprovided under this License.  Any attempt otherwise to propagate or\nmodify it is void, and will automatically terminate your rights under\nthis License (including any patent licenses granted under the third\nparagraph of section 11).\n\n  However, if you cease all violation of this License, then your\nlicense from a particular copyright holder is reinstated (a)\nprovisionally, unless and until the copyright holder explicitly and\nfinally terminates your license, and (b) permanently, if the copyright\nholder fails to notify you of the violation by some reasonable means\nprior to 60 days after the cessation.\n\n  Moreover, your license from a particular copyright holder is\nreinstated permanently if the copyright holder notifies you of the\nviolation by some reasonable means, this is the first time you have\nreceived notice of violation of this License (for any work) from that\ncopyright holder, and you cure the violation prior to 30 days after\nyour receipt of the notice.\n\n  Termination of your rights under this section does not terminate the\nlicenses of parties who have received copies or rights from you under\nthis License.  If your rights have been terminated and not permanently\nreinstated, you do not qualify to receive new licenses for the same\nmaterial under section 10.\n\n  9. Acceptance Not Required for Having Copies.\n\n  You are not required to accept this License in order to receive or\nrun a copy of the Program.  Ancillary propagation of a covered work\noccurring solely as a consequence of using peer-to-peer transmission\nto receive a copy likewise does not require acceptance.  However,\nnothing other than this License grants you permission to propagate or\nmodify any covered work.  These actions infringe copyright if you do\nnot accept this License.  Therefore, by modifying or propagating a\ncovered work, you indicate your acceptance of this License to do so.\n\n  10. Automatic Licensing of Downstream Recipients.\n\n  Each time you convey a covered work, the recipient automatically\nreceives a license from the original licensors, to run, modify and\npropagate that work, subject to this License.  You are not responsible\nfor enforcing compliance by third parties with this License.\n\n  An \"entity transaction\" is a transaction transferring control of an\norganization, or substantially all assets of one, or subdividing an\norganization, or merging organizations.  If propagation of a covered\nwork results from an entity transaction, each party to that\ntransaction who receives a copy of the work also receives whatever\nlicenses to the work the party's predecessor in interest had or could\ngive under the previous paragraph, plus a right to possession of the\nCorresponding Source of the work from the predecessor in interest, if\nthe predecessor has it or can get it with reasonable efforts.\n\n  You may not impose any further restrictions on the exercise of the\nrights granted or affirmed under this License.  For example, you may\nnot impose a license fee, royalty, or other charge for exercise of\nrights granted under this License, and you may not initiate litigation\n(including a cross-claim or counterclaim in a lawsuit) alleging that\nany patent claim is infringed by making, using, selling, offering for\nsale, or importing the Program or any portion of it.\n\n  11. Patents.\n\n  A \"contributor\" is a copyright holder who authorizes use under this\nLicense of the Program or a work on which the Program is based.  The\nwork thus licensed is called the contributor's \"contributor version\".\n\n  A contributor's \"essential patent claims\" are all patent claims\nowned or controlled by the contributor, whether already acquired or\nhereafter acquired, that would be infringed by some manner, permitted\nby this License, of making, using, or selling its contributor version,\nbut do not include claims that would be infringed only as a\nconsequence of further modification of the contributor version.  For\npurposes of this definition, \"control\" includes the right to grant\npatent sublicenses in a manner consistent with the requirements of\nthis License.\n\n  Each contributor grants you a non-exclusive, worldwide, royalty-free\npatent license under the contributor's essential patent claims, to\nmake, use, sell, offer for sale, import and otherwise run, modify and\npropagate the contents of its contributor version.\n\n  In the following three paragraphs, a \"patent license\" is any express\nagreement or commitment, however denominated, not to enforce a patent\n(such as an express permission to practice a patent or covenant not to\nsue for patent infringement).  To \"grant\" such a patent license to a\nparty means to make such an agreement or commitment not to enforce a\npatent against the party.\n\n  If you convey a covered work, knowingly relying on a patent license,\nand the Corresponding Source of the work is not available for anyone\nto copy, free of charge and under the terms of this License, through a\npublicly available network server or other readily accessible means,\nthen you must either (1) cause the Corresponding Source to be so\navailable, or (2) arrange to deprive yourself of the benefit of the\npatent license for this particular work, or (3) arrange, in a manner\nconsistent with the requirements of this License, to extend the patent\nlicense to downstream recipients.  \"Knowingly relying\" means you have\nactual knowledge that, but for the patent license, your conveying the\ncovered work in a country, or your recipient's use of the covered work\nin a country, would infringe one or more identifiable patents in that\ncountry that you have reason to believe are valid.\n\n  If, pursuant to or in connection with a single transaction or\narrangement, you convey, or propagate by procuring conveyance of, a\ncovered work, and grant a patent license to some of the parties\nreceiving the covered work authorizing them to use, propagate, modify\nor convey a specific copy of the covered work, then the patent license\nyou grant is automatically extended to all recipients of the covered\nwork and works based on it.\n\n  A patent license is \"discriminatory\" if it does not include within\nthe scope of its coverage, prohibits the exercise of, or is\nconditioned on the non-exercise of one or more of the rights that are\nspecifically granted under this License.  You may not convey a covered\nwork if you are a party to an arrangement with a third party that is\nin the business of distributing software, under which you make payment\nto the third party based on the extent of your activity of conveying\nthe work, and under which the third party grants, to any of the\nparties who would receive the covered work from you, a discriminatory\npatent license (a) in connection with copies of the covered work\nconveyed by you (or copies made from those copies), or (b) primarily\nfor and in connection with specific products or compilations that\ncontain the covered work, unless you entered into that arrangement,\nor that patent license was granted, prior to 28 March 2007.\n\n  Nothing in this License shall be construed as excluding or limiting\nany implied license or other defenses to infringement that may\notherwise be available to you under applicable patent law.\n\n  12. No Surrender of Others' Freedom.\n\n  If conditions are imposed on you (whether by court order, agreement or\notherwise) that contradict the conditions of this License, they do not\nexcuse you from the conditions of this License.  If you cannot convey a\ncovered work so as to satisfy simultaneously your obligations under this\nLicense and any other pertinent obligations, then as a consequence you may\nnot convey it at all.  For example, if you agree to terms that obligate you\nto collect a royalty for further conveying from those to whom you convey\nthe Program, the only way you could satisfy both those terms and this\nLicense would be to refrain entirely from conveying the Program.\n\n  13. Remote Network Interaction; Use with the GNU General Public License.\n\n  Notwithstanding any other provision of this License, if you modify the\nProgram, your modified version must prominently offer all users\ninteracting with it remotely through a computer network (if your version\nsupports such interaction) an opportunity to receive the Corresponding\nSource of your version by providing access to the Corresponding Source\nfrom a network server at no charge, through some standard or customary\nmeans of facilitating copying of software.  This Corresponding Source\nshall include the Corresponding Source for any work covered by version 3\nof the GNU General Public License that is incorporated pursuant to the\nfollowing paragraph.\n\n  Notwithstanding any other provision of this License, you have\npermission to link or combine any covered work with a work licensed\nunder version 3 of the GNU General Public License into a single\ncombined work, and to convey the resulting work.  The terms of this\nLicense will continue to apply to the part which is the covered work,\nbut the work with which it is combined will remain governed by version\n3 of the GNU General Public License.\n\n  14. Revised Versions of this License.\n\n  The Free Software Foundation may publish revised and/or new versions of\nthe GNU Affero General Public License from time to time.  Such new versions\nwill be similar in spirit to the present version, but may differ in detail to\naddress new problems or concerns.\n\n  Each version is given a distinguishing version number.  If the\nProgram specifies that a certain numbered version of the GNU Affero General\nPublic License \"or any later version\" applies to it, you have the\noption of following the terms and conditions either of that numbered\nversion or of any later version published by the Free Software\nFoundation.  If the Program does not specify a version number of the\nGNU Affero General Public License, you may choose any version ever published\nby the Free Software Foundation.\n\n  If the Program specifies that a proxy can decide which future\nversions of the GNU Affero General Public License can be used, that proxy's\npublic statement of acceptance of a version permanently authorizes you\nto choose that version for the Program.\n\n  Later license versions may give you additional or different\npermissions.  However, no additional obligations are imposed on any\nauthor or copyright holder as a result of your choosing to follow a\nlater version.\n\n  15. Disclaimer of Warranty.\n\n  THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY\nAPPLICABLE LAW.  EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT\nHOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM \"AS IS\" WITHOUT WARRANTY\nOF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,\nTHE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR\nPURPOSE.  THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM\nIS WITH YOU.  SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF\nALL NECESSARY SERVICING, REPAIR OR CORRECTION.\n\n  16. Limitation of Liability.\n\n  IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING\nWILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS\nTHE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY\nGENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE\nUSE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF\nDATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD\nPARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),\nEVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF\nSUCH DAMAGES.\n\n  17. Interpretation of Sections 15 and 16.\n\n  If the disclaimer of warranty and limitation of liability provided\nabove cannot be given local legal effect according to their terms,\nreviewing courts shall apply local law that most closely approximates\nan absolute waiver of all civil liability in connection with the\nProgram, unless a warranty or assumption of liability accompanies a\ncopy of the Program in return for a fee.\n\n                     END OF TERMS AND CONDITIONS\n\n            How to Apply These Terms to Your New Programs\n\n  If you develop a new program, and you want it to be of the greatest\npossible use to the public, the best way to achieve this is to make it\nfree software which everyone can redistribute and change under these terms.\n\n  To do so, attach the following notices to the program.  It is safest\nto attach them to the start of each source file to most effectively\nstate the exclusion of warranty; and each file should have at least\nthe \"copyright\" line and a pointer to where the full notice is found.\n\n    <one line to give the program's name and a brief idea of what it does.>\n    Copyright (C) <year>  <name of author>\n\n    This program is free software: you can redistribute it and/or modify\n    it under the terms of the GNU Affero General Public License as published by\n    the Free Software Foundation, either version 3 of the License, or\n    (at your option) any later version.\n\n    This program is distributed in the hope that it will be useful,\n    but WITHOUT ANY WARRANTY; without even the implied warranty of\n    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the\n    GNU Affero General Public License for more details.\n\n    You should have received a copy of the GNU Affero General Public License\n    along with this program.  If not, see <https://www.gnu.org/licenses/>.\n\nAlso add information on how to contact you by electronic and paper mail.\n\n  If your software can interact with users remotely through a computer\nnetwork, you should also make sure that it provides a way for users to\nget its source.  For example, if your program is a web application, its\ninterface could display a \"Source\" link that leads users to an archive\nof the code.  There are many ways you could offer source, and different\nsolutions will be better for different programs; see section 13 for the\nspecific requirements.\n\n  You should also get your employer (if you work as a programmer) or school,\nif any, to sign a \"copyright disclaimer\" for the program, if necessary.\nFor more information on this, and how to apply and follow the GNU AGPL, see\n<https://www.gnu.org/licenses/>.\n"
  },
  {
    "path": "studio/Unsloth_Studio_Colab.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"6b87de59\",\n   \"metadata\": {},\n   \"source\": [\n    \"To run this, press \\\"*Runtime*\\\" and press \\\"*Run all*\\\" on a **free** Tesla T4 Google Colab instance!\\n\",\n    \"<div class=\\\"align-center\\\">\\n\",\n    \"<a href=\\\"https://unsloth.ai/\\\"><img src=\\\"https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png\\\" width=\\\"115\\\"></a>\\n\",\n    \"<a href=\\\"https://discord.gg/unsloth\\\"><img src=\\\"https://github.com/unslothai/unsloth/raw/main/images/Discord button.png\\\" width=\\\"145\\\"></a>\\n\",\n    \"<a href=\\\"https://unsloth.ai/docs/\\\"><img src=\\\"https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true\\\" width=\\\"125\\\"></a> Join Discord if you need help + ⭐ <i>Star us on <a href=\\\"https://github.com/unslothai/unsloth\\\">Github</a> </i> ⭐\\n\",\n    \"</div>\\n\",\n    \"\\n\",\n    \"To install Unsloth Studio on your local device, follow [our guide](https://unsloth.ai/docs/new/unsloth-studio/install). Unsloth Studio is licensed [AGPL-3.0](https://github.com/unslothai/unsloth/blob/main/studio/LICENSE.AGPL-3.0).\\n\",\n    \"\\n\",\n    \"### Unsloth Studio\\n\",\n     \"\\n\",\n        \"Train and run open models with [**Unsloth Studio**](https://unsloth.ai/docs/new/unsloth-studio/start). Currently, installation may take 30+ mins so use a newer GPU.\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"We are actively working on making Unsloth Studio install on Colab T4 GPUs faster.\\n\",\n        \"\\n\",\n        \"[Features](https://unsloth.ai/docs/new/unsloth-studio#features) • [Quickstart](https://unsloth.ai/docs/new/unsloth-studio/start) • [Data Recipes](https://unsloth.ai/docs/new/unsloth-studio/data-recipe) • [Studio Chat](https://unsloth.ai/docs/new/unsloth-studio/chat) • [Export](https://unsloth.ai/docs/new/unsloth-studio/export)\"\n      ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"e4206349\",\n   \"metadata\": {},\n   \"source\": [\n    \"<p align=\\\"left\\\"><img src=\\\"https://github.com/unslothai/unsloth/raw/main/studio/frontend/public/studio%20github%20landscape%20colab%20display.png\\\" width=\\\"600\\\"></p>\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"27da2957\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Setup: Clone repo and run setup\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"27e68f91\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"!git clone --depth 1 --branch main https://github.com/unslothai/unsloth.git\\n\",\n    \"%cd /content/unsloth\\n\",\n    \"\\n\",\n    \"# Run setup script\\n\",\n    \"!chmod +x studio/setup.sh\\n\",\n    \"!./studio/setup.sh\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"3e1771a9\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Start Unsloth Studio\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"277e431e\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import sys\\n\",\n    \"sys.path.insert(0, '/content/unsloth/studio/backend')\\n\",\n    \"\\n\",\n    \"from colab import start\\n\",\n    \"start()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"f2b0c6a1\",\n   \"metadata\": {},\n   \"source\": [\n    \"And we're done! If you have any questions on Unsloth, we have a [Discord](https://discord.gg/unsloth) channel! If you find any bugs or want to keep updated with the latest LLM stuff, or need help, join projects etc, feel free to join our Discord!\\n\",\n    \"\\n\",\n    \"Some other resources:\\n\",\n    \"1. Looking to use Unsloth locally? Read our [Installation Guide](https://unsloth.ai/docs/get-started/install) for details on installing Unsloth on Windows, Docker, AMD, Intel GPUs.\\n\",\n    \"2. Learn how to do Reinforcement Learning with our [RL Guide and notebooks](https://unsloth.ai/docs/get-started/reinforcement-learning-rl-guide).\\n\",\n    \"3. Read our guides and notebooks for [Text-to-speech (TTS)](https://unsloth.ai/docs/basics/text-to-speech-tts-fine-tuning) and [vision](https://unsloth.ai/docs/basics/vision-fine-tuning) model support.\\n\",\n    \"4. Explore our [LLM Tutorials Directory](https://unsloth.ai/docs/models/tutorials-how-to-fine-tune-and-run-llms) to find dedicated guides for each model.\\n\",\n    \"5. Need help with Inference? Read our [Inference & Deployment page](https://unsloth.ai/docs/basics/inference-and-deployment) for details on using vLLM, llama.cpp, Ollama etc.\\n\",\n    \"\\n\",\n    \"<div class=\\\"align-center\\\">\\n\",\n    \"  <a href=\\\"https://unsloth.ai\\\"><img src=\\\"https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png\\\" width=\\\"115\\\"></a>\\n\",\n    \"  <a href=\\\"https://discord.gg/unsloth\\\"><img src=\\\"https://github.com/unslothai/unsloth/raw/main/images/Discord.png\\\" width=\\\"145\\\"></a>\\n\",\n    \"  <a href=\\\"https://unsloth.ai/docs/\\\"><img src=\\\"https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true\\\" width=\\\"125\\\"></a>\\n\",\n    \"\\n\",\n    \"  Join Discord if you need help + ⭐️ <i>Star us on <a href=\\\"https://github.com/unslothai/unsloth\\\">Github</a> </i> ⭐️\\n\",\n    \"\\n\",\n    \"  <b>This notebook is licensed <a href=\\\"https://github.com/unslothai/unsloth/blob/main/studio/LICENSE.AGPL-3.0\\\">AGPL-3.0</a></b>\\n\",\n    \"</div>\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"accelerator\": \"GPU\",\n  \"colab\": {\n   \"gpuType\": \"T4\",\n   \"include_colab_link\": true,\n   \"provenance\": []\n  },\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"name\": \"python\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "studio/__init__.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n"
  },
  {
    "path": "studio/backend/__init__.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n"
  },
  {
    "path": "studio/backend/assets/__init__.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n"
  },
  {
    "path": "studio/backend/assets/configs/__init__.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n"
  },
  {
    "path": "studio/backend/assets/configs/full_finetune.yaml",
    "content": "model: unsloth/Qwen2.5-0.5B\n\ndata:\n  dataset: tatsu-lab/alpaca\n  format_type: auto\n\ntraining:\n  training_type: full\n  max_seq_length: 2048\n  load_in_4bit: false\n  output_dir: outputs\n  num_epochs: 1\n  learning_rate: 0.0002\n  batch_size: 1\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 0\n  save_steps: 0\n  weight_decay: 0.01\n  random_seed: 3407\n  packing: false\n  train_on_completions: false\n  gradient_checkpointing: \"unsloth\"\n\nlora:\n  lora_r: 64\n  lora_alpha: 16\n  lora_dropout: 0.0\n  target_modules: \"\"\n  vision_all_linear: false\n  use_rslora: false\n  use_loftq: false\n  finetune_vision_layers: true\n  finetune_language_layers: true\n  finetune_attention_modules: true\n  finetune_mlp_modules: true\n\nlogging:\n  enable_wandb: false\n  wandb_project: unsloth-training\n  enable_tensorboard: false\n  tensorboard_dir: runs\n"
  },
  {
    "path": "studio/backend/assets/configs/inference_defaults.json",
    "content": "{\n  \"_comment\": \"Per-model-family inference parameter defaults. Sources: (1) Ollama params blobs, (2) Existing Unsloth Studio YAML configs. Patterns ordered longest-match-first.\",\n  \"families\": {\n    \"qwen3.5\": {\n      \"temperature\": 0.7,\n      \"top_p\": 0.8,\n      \"top_k\": 20,\n      \"min_p\": 0.0,\n      \"repetition_penalty\": 1.0,\n      \"presence_penalty\": 1.5\n    },\n    \"qwen3-coder\": {\n      \"temperature\": 0.7,\n      \"top_p\": 0.8,\n      \"top_k\": 20,\n      \"min_p\": 0.0,\n      \"repetition_penalty\": 1.0\n    },\n    \"qwen3-next\": {\n      \"temperature\": 0.7,\n      \"top_p\": 0.8,\n      \"top_k\": 20,\n      \"min_p\": 0.0,\n      \"repetition_penalty\": 1.0\n    },\n    \"qwen3-vl\": {\n      \"temperature\": 0.7,\n      \"top_p\": 0.8,\n      \"top_k\": 20,\n      \"min_p\": 0.0,\n      \"repetition_penalty\": 1.0\n    },\n    \"qwen3\": {\n      \"temperature\": 0.6,\n      \"top_p\": 0.95,\n      \"top_k\": 20,\n      \"min_p\": 0.0,\n      \"repetition_penalty\": 1.0\n    },\n    \"qwen2.5-coder\": {\n      \"temperature\": 1.5,\n      \"top_p\": 0.95,\n      \"top_k\": -1,\n      \"min_p\": 0.1,\n      \"repetition_penalty\": 1.0\n    },\n    \"qwen2.5-vl\": {\n      \"temperature\": 1.5,\n      \"top_p\": 0.95,\n      \"top_k\": -1,\n      \"min_p\": 0.1,\n      \"repetition_penalty\": 1.0\n    },\n    \"qwen2.5-omni\": {\n      \"temperature\": 0.7,\n      \"top_p\": 0.8,\n      \"top_k\": 20,\n      \"min_p\": 0.0,\n      \"repetition_penalty\": 1.0\n    },\n    \"qwen2.5-math\": {\n      \"temperature\": 0.7,\n      \"top_p\": 0.8,\n      \"top_k\": 20,\n      \"min_p\": 0.0,\n      \"repetition_penalty\": 1.0\n    },\n    \"qwen2.5\": {\n      \"temperature\": 0.7,\n      \"top_p\": 0.8,\n      \"top_k\": 20,\n      \"min_p\": 0.0,\n      \"repetition_penalty\": 1.0\n    },\n    \"qwen2-vl\": {\n      \"temperature\": 1.5,\n      \"top_p\": 0.95,\n      \"top_k\": -1,\n      \"min_p\": 0.1,\n      \"repetition_penalty\": 1.0\n    },\n    \"qwen2\": {\n      \"temperature\": 0.7,\n      \"top_p\": 0.8,\n      \"top_k\": 20,\n      \"min_p\": 0.0,\n      \"repetition_penalty\": 1.0\n    },\n    \"qwq\": {\n      \"temperature\": 0.6,\n      \"top_p\": 0.95,\n      \"top_k\": 40,\n      \"min_p\": 0.0,\n      \"repetition_penalty\": 1.0\n    },\n    \"gemma-3n\": {\n      \"temperature\": 1.0,\n      \"top_p\": 0.95,\n      \"top_k\": 64,\n      \"min_p\": 0.0,\n      \"repetition_penalty\": 1.0\n    },\n    \"gemma-3\": {\n      \"temperature\": 1.0,\n      \"top_p\": 0.95,\n      \"top_k\": 64,\n      \"min_p\": 0.0,\n      \"repetition_penalty\": 1.0\n    },\n    \"medgemma\": {\n      \"temperature\": 1.0,\n      \"top_p\": 0.95,\n      \"top_k\": 64,\n      \"min_p\": 0.0,\n      \"repetition_penalty\": 1.0\n    },\n    \"gemma-2\": {\n      \"temperature\": 1.0,\n      \"top_p\": 0.95,\n      \"top_k\": 64,\n      \"min_p\": 0.0,\n      \"repetition_penalty\": 1.0\n    },\n    \"llama-4\": {\n      \"temperature\": 1.0,\n      \"top_p\": 0.9,\n      \"top_k\": -1,\n      \"min_p\": 0.01,\n      \"repetition_penalty\": 1.0\n    },\n    \"llama-3.3\": {\n      \"temperature\": 1.5,\n      \"top_p\": 0.95,\n      \"top_k\": -1,\n      \"min_p\": 0.1,\n      \"repetition_penalty\": 1.0\n    },\n    \"llama-3.2\": {\n      \"temperature\": 1.5,\n      \"top_p\": 0.95,\n      \"top_k\": -1,\n      \"min_p\": 0.1,\n      \"repetition_penalty\": 1.0\n    },\n    \"llama-3.1\": {\n      \"temperature\": 1.5,\n      \"top_p\": 0.95,\n      \"top_k\": -1,\n      \"min_p\": 0.1,\n      \"repetition_penalty\": 1.0\n    },\n    \"llama-3\": {\n      \"temperature\": 1.5,\n      \"top_p\": 0.95,\n      \"top_k\": -1,\n      \"min_p\": 0.1,\n      \"repetition_penalty\": 1.0\n    },\n    \"phi-4\": {\n      \"temperature\": 0.8,\n      \"top_p\": 0.95,\n      \"top_k\": -1,\n      \"min_p\": 0.0,\n      \"repetition_penalty\": 1.0\n    },\n    \"phi-3\": {\n      \"temperature\": 0.7,\n      \"top_p\": 0.9,\n      \"top_k\": -1,\n      \"min_p\": 0.01,\n      \"repetition_penalty\": 1.0\n    },\n    \"mistral-nemo\": {\n      \"temperature\": 0.7,\n      \"top_p\": 0.95,\n      \"top_k\": -1,\n      \"min_p\": 0.01,\n      \"repetition_penalty\": 1.0\n    },\n    \"mistral-small\": {\n      \"temperature\": 0.15,\n      \"top_p\": 0.95,\n      \"top_k\": -1,\n      \"min_p\": 0.01,\n      \"repetition_penalty\": 1.0\n    },\n    \"mistral-large\": {\n      \"temperature\": 0.7,\n      \"top_p\": 0.95,\n      \"top_k\": -1,\n      \"min_p\": 0.01,\n      \"repetition_penalty\": 1.0\n    },\n    \"magistral\": {\n      \"temperature\": 0.7,\n      \"top_p\": 0.95,\n      \"top_k\": -1,\n      \"min_p\": 0.01,\n      \"repetition_penalty\": 1.0\n    },\n    \"ministral\": {\n      \"temperature\": 0.15,\n      \"top_p\": 0.95,\n      \"top_k\": -1,\n      \"min_p\": 0.01,\n      \"repetition_penalty\": 1.0\n    },\n    \"devstral\": {\n      \"temperature\": 0.7,\n      \"top_p\": 0.95,\n      \"top_k\": -1,\n      \"min_p\": 0.01,\n      \"repetition_penalty\": 1.0\n    },\n    \"pixtral\": {\n      \"temperature\": 1.5,\n      \"top_p\": 0.95,\n      \"top_k\": -1,\n      \"min_p\": 0.1,\n      \"repetition_penalty\": 1.0\n    },\n    \"deepseek-r1\": {\n      \"temperature\": 0.6,\n      \"top_p\": 0.95,\n      \"top_k\": -1,\n      \"min_p\": 0.01,\n      \"repetition_penalty\": 1.0\n    },\n    \"deepseek-v3\": {\n      \"temperature\": 0.6,\n      \"top_p\": 0.95,\n      \"top_k\": -1,\n      \"min_p\": 0.01,\n      \"repetition_penalty\": 1.0\n    },\n    \"deepseek-ocr\": {\n      \"temperature\": 0.0,\n      \"top_p\": 0.95,\n      \"top_k\": -1,\n      \"min_p\": 0.01,\n      \"repetition_penalty\": 1.0\n    },\n    \"glm-5\": {\n      \"temperature\": 1.0,\n      \"top_p\": 0.95,\n      \"top_k\": -1,\n      \"min_p\": 0.01,\n      \"repetition_penalty\": 1.0\n    },\n    \"glm-4\": {\n      \"temperature\": 1.0,\n      \"top_p\": 0.95,\n      \"top_k\": -1,\n      \"min_p\": 0.01,\n      \"repetition_penalty\": 1.0\n    },\n    \"nemotron\": {\n      \"temperature\": 1.0,\n      \"top_p\": 1.0,\n      \"top_k\": -1,\n      \"min_p\": 0.01,\n      \"repetition_penalty\": 1.0\n    },\n    \"minimax-m2.5\": {\n      \"temperature\": 1.0,\n      \"top_p\": 0.95,\n      \"top_k\": 40,\n      \"min_p\": 0.01,\n      \"repetition_penalty\": 1.0\n    },\n    \"minimax\": {\n      \"temperature\": 1.0,\n      \"top_p\": 0.95,\n      \"top_k\": 40,\n      \"min_p\": 0.01,\n      \"repetition_penalty\": 1.0\n    },\n    \"gpt-oss\": {\n      \"temperature\": 1.0,\n      \"top_p\": 1.0,\n      \"top_k\": 0,\n      \"min_p\": 0.01,\n      \"repetition_penalty\": 1.0\n    },\n    \"granite-4\": {\n      \"temperature\": 0.0,\n      \"top_p\": 1.0,\n      \"top_k\": 0,\n      \"min_p\": 0.01,\n      \"repetition_penalty\": 1.0\n    },\n    \"kimi-k2\": {\n      \"temperature\": 0.6,\n      \"top_p\": 0.95,\n      \"top_k\": -1,\n      \"min_p\": 0.01,\n      \"repetition_penalty\": 1.0\n    },\n    \"kimi\": {\n      \"temperature\": 0.6,\n      \"top_p\": 0.95,\n      \"top_k\": -1,\n      \"min_p\": 0.01,\n      \"repetition_penalty\": 1.0\n    },\n    \"lfm2\": {\n      \"temperature\": 0.1,\n      \"top_p\": 0.1,\n      \"top_k\": 50,\n      \"min_p\": 0.15,\n      \"repetition_penalty\": 1.05\n    },\n    \"smollm\": {\n      \"temperature\": 0.7,\n      \"top_p\": 0.95,\n      \"top_k\": -1,\n      \"min_p\": 0.01,\n      \"repetition_penalty\": 1.0\n    },\n    \"olmo\": {\n      \"temperature\": 0.7,\n      \"top_p\": 0.95,\n      \"top_k\": -1,\n      \"min_p\": 0.01,\n      \"repetition_penalty\": 1.0\n    },\n    \"falcon\": {\n      \"temperature\": 0.7,\n      \"top_p\": 0.95,\n      \"top_k\": -1,\n      \"min_p\": 0.01,\n      \"repetition_penalty\": 1.0\n    },\n    \"ernie\": {\n      \"temperature\": 0.7,\n      \"top_p\": 0.95,\n      \"top_k\": -1,\n      \"min_p\": 0.01,\n      \"repetition_penalty\": 1.0\n    },\n    \"seed\": {\n      \"temperature\": 0.7,\n      \"top_p\": 0.95,\n      \"top_k\": -1,\n      \"min_p\": 0.01,\n      \"repetition_penalty\": 1.0\n    },\n    \"grok\": {\n      \"temperature\": 1.0,\n      \"top_p\": 0.95,\n      \"top_k\": -1,\n      \"min_p\": 0.01,\n      \"repetition_penalty\": 1.0\n    },\n    \"mimo\": {\n      \"temperature\": 0.7,\n      \"top_p\": 0.95,\n      \"top_k\": -1,\n      \"min_p\": 0.01,\n      \"repetition_penalty\": 1.0\n    }\n  },\n  \"patterns\": [\n    \"qwen3.5\",\n    \"qwen3-coder\", \"qwen3-next\", \"qwen3-vl\", \"qwen3\",\n    \"qwen2.5-coder\", \"qwen2.5-vl\", \"qwen2.5-omni\", \"qwen2.5-math\", \"qwen2.5\",\n    \"qwen2-vl\", \"qwen2\",\n    \"qwq\",\n    \"gemma-3n\", \"gemma-3\", \"medgemma\", \"gemma-2\",\n    \"llama-4\", \"llama-3.3\", \"llama-3.2\", \"llama-3.1\", \"llama-3\",\n    \"phi-4\", \"phi-3\",\n    \"mistral-nemo\", \"mistral-small\", \"mistral-large\", \"magistral\", \"ministral\",\n    \"devstral\", \"pixtral\",\n    \"deepseek-r1\", \"deepseek-v3\", \"deepseek-ocr\",\n    \"glm-5\", \"glm-4\",\n    \"nemotron\",\n    \"minimax-m2.5\", \"minimax\",\n    \"gpt-oss\", \"granite-4\",\n    \"kimi-k2\", \"kimi\",\n    \"lfm2\", \"smollm\", \"olmo\", \"falcon\", \"ernie\", \"seed\", \"grok\", \"mimo\"\n  ]\n}\n"
  },
  {
    "path": "studio/backend/assets/configs/lora_text.yaml",
    "content": "model: unsloth/Qwen2.5-0.5B\n\ndata:\n  dataset: tatsu-lab/alpaca\n  format_type: auto\n\ntraining:\n  training_type: lora\n  max_seq_length: 2048\n  load_in_4bit: true\n  output_dir: outputs\n  num_epochs: 1\n  learning_rate: 0.0002\n  batch_size: 2\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 0\n  save_steps: 0\n  weight_decay: 0.01\n  random_seed: 3407\n  packing: false\n  train_on_completions: false\n  gradient_checkpointing: \"unsloth\"\n\nlora:\n  lora_r: 64\n  lora_alpha: 16\n  lora_dropout: 0.0\n  target_modules: \"q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj\"\n  vision_all_linear: false\n  use_rslora: false\n  use_loftq: false\n  finetune_vision_layers: true\n  finetune_language_layers: true\n  finetune_attention_modules: true\n  finetune_mlp_modules: true\n\nlogging:\n  enable_wandb: false\n  wandb_project: unsloth-training\n  enable_tensorboard: false\n  tensorboard_dir: runs\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/default.yaml",
    "content": "# Default model training parameters\n# Used for models without specific configurations\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 5e-5\n  batch_size: 2\n  gradient_accumulation_steps: 4\n  warmup_ratio: 0.1\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.01\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 16\n  lora_alpha: 16\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n  finetune_vision_layers: true\n  finetune_language_layers: true\n  finetune_attention_modules: true\n  finetune_mlp_modules: true\n\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n  temperature: 0.7\n  top_p: 0.95\n  top_k: -1\n  min_p: 0.01\n\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/embedding/unsloth_Qwen3-Embedding-0.6B.yaml",
    "content": "# Model defaults for unsloth/Qwen3-Embedding-0.6B\n# Based on Qwen3_Embedding_(0_6B).py embedding notebook\n# Also applies to: unsloth/Qwen3-Embedding-4B\n\ntraining:\n  max_seq_length: 512\n  # num_epochs: 2\n  num_epochs: 0\n  learning_rate: 3e-5\n  batch_size: 256\n  gradient_accumulation_steps: 1\n  warmup_ratio: 0.03\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.01\n  random_seed: 3407\n  packing: false\n  train_on_completions: false\n  gradient_checkpointing: false\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"constant_with_warmup\"\n\nlora:\n  lora_r: 32\n  lora_alpha: 32\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"embedding-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 50\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/embedding/unsloth_all-MiniLM-L6-v2.yaml",
    "content": "# Model defaults for unsloth/all-MiniLM-L6-v2\n# Based on All_MiniLM_L6_v2.py embedding notebook\n\ntraining:\n  max_seq_length: 512\n  # num_epochs: 2\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 256\n  gradient_accumulation_steps: 1\n  warmup_ratio: 0.03\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.01\n  random_seed: 3407\n  packing: false\n  train_on_completions: false\n  gradient_checkpointing: false\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 64\n  lora_alpha: 128\n  lora_dropout: 0.0\n  target_modules:\n    - \"value\"\n    - \"key\"\n    - \"dense\"\n    - \"query\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"embedding-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 50\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/embedding/unsloth_bge-m3.yaml",
    "content": "# Model defaults for unsloth/bge-m3\n# Based on BGE_M3.py embedding notebook\n\ntraining:\n  max_seq_length: 512\n  # num_epochs: 2\n  num_epochs: 0\n  learning_rate: 3e-5\n  batch_size: 256\n  gradient_accumulation_steps: 1\n  warmup_ratio: 0.03\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.01\n  random_seed: 3407\n  packing: false\n  train_on_completions: false\n  gradient_checkpointing: false\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"constant_with_warmup\"\n\nlora:\n  lora_r: 32\n  lora_alpha: 64\n  lora_dropout: 0.0\n  target_modules:\n    - \"key\"\n    - \"query\"\n    - \"dense\"\n    - \"value\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"embedding-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 50\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/embedding/unsloth_embeddinggemma-300m.yaml",
    "content": "# Model defaults for unsloth/embeddinggemma-300m\n# Based on EmbeddingGemma_(300M).py embedding notebook\n\ntraining:\n  max_seq_length: 1024\n  # num_epochs: 1\n  num_epochs: 0\n  learning_rate: 2e-5\n  batch_size: 64\n  gradient_accumulation_steps: 2\n  warmup_ratio: 0.03\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.01\n  random_seed: 3407\n  packing: false\n  train_on_completions: false\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 32\n  lora_alpha: 64\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"embedding-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 5\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/embedding/unsloth_gte-modernbert-base.yaml",
    "content": "# Model defaults for unsloth/gte-modernbert-base\n# Based on ModernBert.py embedding notebook\n\ntraining:\n  max_seq_length: 512\n  # num_epochs: 2\n  num_epochs: 0\n  learning_rate: 3e-5\n  batch_size: 256\n  gradient_accumulation_steps: 1\n  warmup_ratio: 0.03\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.01\n  random_seed: 3407\n  packing: false\n  train_on_completions: false\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"constant_with_warmup\"\n\nlora:\n  lora_r: 64\n  lora_alpha: 128\n  lora_dropout: 0.0\n  target_modules:\n    - \"Wi\"\n    - \"Wo\"\n    - \"Wqkv\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"embedding-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 50\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/ernie/unsloth_ERNIE-4.5-21B-A3B-PT.yaml",
    "content": "# Model defaults for unsloth/ERNIE-4.5-21B-A3B-PT\n# Based on ERNIE_4_5_21B_A3B_PT-Conversational.ipynb\n# Also applies to: unsloth/ERNIE-4.5-21B-A3B-PT\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 4\n  gradient_accumulation_steps: 2\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 8\n  lora_alpha: 16\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/ernie/unsloth_ERNIE-4.5-VL-28B-A3B-PT.yaml",
    "content": "# Model defaults for unsloth/ERNIE-4.5-VL-28B-A3B-PT\n# Based on ERNIE_4_5_VL_28B_A3B_PT_Vision.ipynb\n# Also applies to: unsloth/ERNIE-4.5-VL-28B-A3B-PT\n# added inference parameters from unsloth notebook\n\ntraining:\n  trust_remote_code: true\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 2\n  gradient_accumulation_steps: 2\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 8\n  lora_alpha: 16\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n  finetune_vision_layers: true\n  finetune_language_layers: true\n  finetune_attention_modules: true\n  finetune_mlp_modules: true\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: true\n  temperature: 1.5\n  min_p: 0.1\n\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/falcon/tiiuae_Falcon-H1-0.5B-Instruct.yaml",
    "content": "# Model defaults for tiiuae/Falcon-H1-0.5B-Instruct\n# Based on Falcon_H1_(0.5B)-Alpaca.ipynb\n# Also applies to: tiiuae/Falcon-H1-0.5B-Instruct, unsloth/Falcon-H1-0.5B-Instruct\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 2\n  gradient_accumulation_steps: 8\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.01\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: false\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 16\n  lora_alpha: 16\n  lora_dropout: 0.1\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/gemma/unsloth_codegemma-7b-bnb-4bit.yaml",
    "content": "# Model defaults for unsloth/codegemma-7b-bnb-4bit\n# Based on CodeGemma_(7B)-Conversational.ipynb\n# Also applies to: unsloth/codegemma-7b, google/codegemma-7b\n# added inference parameters from Ollama\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 4096\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 1\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 16\n  lora_alpha: 16\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n  temperature: 0\n  top_p: 0.9\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/gemma/unsloth_functiongemma-270m-it.yaml",
    "content": "# Model defaults for unsloth/functiongemma-270m-it\n# Based on FunctionGemma_(270M).ipynb\n# Also applies to: unsloth/functiongemma-270m-it-unsloth-bnb-4bit, google/functiongemma-270m-it, unsloth/functiongemma-270m-it-unsloth-bnb-4bit\n# added inference parameters from unsloth guides\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 4096\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 4\n  gradient_accumulation_steps: 2\n  warmup_steps: 10\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 128\n  lora_alpha: 256\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n  temperature: 1.0\n  top_k: 64\n  top_p: 0.95\n  min_p: 0.0\n\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/gemma/unsloth_gemma-2-27b-bnb-4bit.yaml",
    "content": "# Model defaults for unsloth/gemma-2-27b-bnb-4bit\n# Based on Gemma2_(9B)-Alpaca.ipynb (same defaults for larger models)\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 2\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 16\n  lora_alpha: 16\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/gemma/unsloth_gemma-2-2b.yaml",
    "content": "# Model defaults for unsloth/gemma-2-2b\n# Based on Gemma2_(2B)-Alpaca.ipynb\n# Also applies to: unsloth/gemma-2-2b-bnb-4bit, google/gemma-2-2b\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 2\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.01\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 16\n  lora_alpha: 16\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/gemma/unsloth_gemma-3-270m-it.yaml",
    "content": "# Model defaults for unsloth/gemma-3-270m-it\n# Based on Gemma3_(270M).ipynb\n# Also applies to: unsloth/gemma-3-270m-it-unsloth-bnb-4bit, google/gemma-3-270m-it, unsloth/gemma-3-270m-it-bnb-4bit\n# added inference parameters from unsloth guides\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 5e-5\n  batch_size: 4\n  gradient_accumulation_steps: 1\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 128\n  lora_alpha: 128\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n  temperature: 1.0\n  top_k: 64\n  top_p: 0.95\n  min_p: 0.0\n\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/gemma/unsloth_gemma-3-27b-it.yaml",
    "content": "# Model defaults for unsloth/gemma-3-27b-it\n# Based on Gemma3_(27B)_A100-Conversational.ipynb\n# Also applies to: unsloth/gemma-3-27b-it-unsloth-bnb-4bit, google/gemma-3-27b-it, unsloth/gemma-3-27b-it-bnb-4bit\n# added inference parameters from unsloth guides\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 2\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 8\n  lora_alpha: 8\n  lora_dropout: 0.0\n  target_modules:\n    - \"all-linear\"\n  use_rslora: false\n  use_loftq: false\n  finetune_vision_layers: true\n  finetune_language_layers: true\n  finetune_attention_modules: true\n  finetune_mlp_modules: true\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n  temperature: 1.0\n  top_k: 64\n  top_p: 0.95\n  min_p: 0.0\n\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/gemma/unsloth_gemma-3-4b-it.yaml",
    "content": "# Model defaults for unsloth/gemma-3-4b-it\n# Based on Gemma3_(4B).ipynb\n# Also applies to: unsloth/gemma-3-4b-it-unsloth-bnb-4bit, google/gemma-3-4b-it, unsloth/gemma-3-4b-it-bnb-4bit\n# added inference parameters from unsloth guides\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 2\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 8\n  lora_alpha: 8\n  lora_dropout: 0.0\n  target_modules:\n    - \"all-linear\"\n  use_rslora: false\n  use_loftq: false\n  finetune_vision_layers: true\n  finetune_language_layers: true\n  finetune_attention_modules: true\n  finetune_mlp_modules: true\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n  temperature: 1.0\n  top_k: 64\n  top_p: 0.95\n  min_p: 0.0\n\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/gemma/unsloth_gemma-3-4b-pt.yaml",
    "content": "# Model defaults for unsloth/gemma-3-4b-pt\n# Based on Gemma3_(4B)-Vision.ipynb\n# Also applies to: unsloth/gemma-3-4b-pt-unsloth-bnb-4bit, google/gemma-3-4b-pt, unsloth/gemma-3-4b-pt-bnb-4bit\n# added inference parameters from unsloth guides\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 2048\n  # num_epochs: 2\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 1\n  gradient_accumulation_steps: 4\n  warmup_ratio: 0.03\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: true\n  optim: \"adamw_torch_fused\"\n  lr_scheduler_type: \"cosine\"\n\nlora:\n  lora_r: 16\n  lora_alpha: 16\n  lora_dropout: 0.0\n  target_modules:\n    - \"all-linear\"\n  use_rslora: false\n  use_loftq: false\n  finetune_vision_layers: true\n  finetune_language_layers: true\n  finetune_attention_modules: true\n  finetune_mlp_modules: true\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n  temperature: 1.0\n  top_k: 64\n  top_p: 0.95\n  min_p: 0.0\n\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/gemma/unsloth_gemma-3n-E4B-it.yaml",
    "content": "# Model defaults for unsloth/gemma-3n-E4B-it\n# Based on Gemma3N_(4B)-Conversational.ipynb\n# Also applies to: unsloth/gemma-3n-E4B-it-unsloth-bnb-4bit, google/gemma-3n-E4B-it, unsloth/gemma-3n-E4B-it-unsloth-bnb-4bit\n# added inference parameters from unsloth guides\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 1024\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 1\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 8\n  lora_alpha: 8\n  lora_dropout: 0.0\n  target_modules:\n    - \"all-linear\"\n  use_rslora: false\n  use_loftq: false\n  finetune_vision_layers: true\n  finetune_language_layers: true\n  finetune_attention_modules: true\n  finetune_mlp_modules: true\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\naudio_input: true\n\ninference:\n  trust_remote_code: false\n  temperature: 1.0\n  top_k: 64\n  top_p: 0.95\n  min_p: 0.0\n\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/gemma/unsloth_gemma-3n-E4B.yaml",
    "content": "# Model defaults for unsloth/gemma-3n-E4B\n# Based on Gemma3N_(4B)-Vision.ipynb\n# Also applies to: unsloth/gemma-3n-E4B-unsloth-bnb-4bit, google/gemma-3n-E4B\n# added inference parameters from unsloth guides\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 2048\n  # num_epochs: 2\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 1\n  gradient_accumulation_steps: 4\n  warmup_ratio: 0.03\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: true\n  optim: \"adamw_torch_fused\"\n  lr_scheduler_type: \"cosine\"\n\nlora:\n  lora_r: 32\n  lora_alpha: 32\n  lora_dropout: 0.0\n  target_modules:\n    - \"all-linear\"\n  use_rslora: false\n  use_loftq: false\n  finetune_vision_layers: true\n  finetune_language_layers: true\n  finetune_attention_modules: true\n  finetune_mlp_modules: true\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\naudio_input: true\n\ninference:\n  trust_remote_code: false\n  temperature: 1.0\n  top_k: 64\n  top_p: 0.95\n  min_p: 0.0\n\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/gpt-oss/unsloth_gpt-oss-120b.yaml",
    "content": "# Model defaults for unsloth/gpt-oss-120b\n# Based on gpt-oss-(120B)_A100-Fine-tuning.ipynb\n# Also applies to: openai/gpt-oss-120b, unsloth/gpt-oss-120b-unsloth-bnb-4bit\n# added inference parameters from unsloth guides\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 4096\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 4\n  gradient_accumulation_steps: 1\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 16\n  lora_alpha: 32\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n  temperature: 1.0\n  top_p: 1.0\n  top_k: 0\n\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/gpt-oss/unsloth_gpt-oss-20b.yaml",
    "content": "# Model defaults for unsloth/gpt-oss-20b\n# Based on gpt-oss-(20B)-Fine-tuning.ipynb\n# Also applies to: openai/gpt-oss-20b, unsloth/gpt-oss-20b-unsloth-bnb-4bit, unsloth/gpt-oss-20b-BF16\n# added inference parameters from unsloth guides\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 1024\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 1\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 8\n  lora_alpha: 16\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n  temperature: 1.0\n  top_p: 1.0\n  top_k: 0\n\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/granite/unsloth_granite-4.0-350m-unsloth-bnb-4bit.yaml",
    "content": "# Model defaults for unsloth/granite-4.0-350m\n# Based on Granite4.0_350M.ipynb\n# Also applies to: ibm-granite/granite-4.0-350m, unsloth/granite-4.0-350m-bnb-4bit\n# added inference parameters from unsloth guides\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 2\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 32\n  lora_alpha: 32\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n    - \"shared_mlp.input_linear\"\n    - \"shared_mlp.output_linear\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n  temperature: 0.0\n  top_p: 1.0\n  top_k: 0\n\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/granite/unsloth_granite-4.0-h-micro.yaml",
    "content": "# Model defaults for unsloth/granite-4.0-h-micro\n# Based on Granite4.0.ipynb\n# Also applies to: ibm-granite/granite-4.0-h-micro, unsloth/granite-4.0-h-micro-bnb-4bit, unsloth/granite-4.0-h-micro-unsloth-bnb-4bit\n# added inference parameters from unsloth guides\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 2\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 32\n  lora_alpha: 32\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n    - \"shared_mlp.input_linear\"\n    - \"shared_mlp.output_linear\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n  temperature: 0.0\n  top_p: 1.0\n  top_k: 0\n\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/llama/unsloth_Llama-3.2-11B-Vision-Instruct.yaml",
    "content": "# Model defaults for unsloth/Llama-3.2-11B-Vision-Instruct\n# Based on Llama3.2_(11B)-Vision.ipynb\n# Also applies to: unsloth/Llama-3.2-11B-Vision-Instruct-unsloth-bnb-4bit, meta-llama/Llama-3.2-11B-Vision-Instruct, unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit\n# added inference parameters from unsloth notebook\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 2\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 16\n  lora_alpha: 16\n  lora_dropout: 0.0\n  target_modules:\n    - \"all-linear\"\n  use_rslora: false\n  use_loftq: false\n  finetune_vision_layers: true\n  finetune_language_layers: true\n  finetune_attention_modules: true\n  finetune_mlp_modules: true\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n  temperature: 1.5\n  min_p: 0.1\n\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/llama/unsloth_Llama-3.2-1B-Instruct.yaml",
    "content": "# Model defaults for unsloth/Llama-3.2-1B-Instruct\n# Based on Llama3.2_(1B)-RAFT.ipynb\n# Also applies to: unsloth/Llama-3.2-1B-Instruct-unsloth-bnb-4bit, meta-llama/Llama-3.2-1B-Instruct, unsloth/Llama-3.2-1B-Instruct-bnb-4bit, RedHatAI/Llama-3.2-1B-Instruct-FP8, unsloth/Llama-3.2-1B-Instruct-FP8-Block, unsloth/Llama-3.2-1B-Instruct-FP8-Dynamic\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 2048\n  # num_epochs: 5\n  num_epochs: 0\n  learning_rate: 2e-5\n  batch_size: 1\n  gradient_accumulation_steps: 8\n  warmup_steps: 0\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.01\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: true\n  optim: \"adamw_torch\"\n  lr_scheduler_type: \"cosine\"\n\nlora:\n  lora_r: 16\n  lora_alpha: 16\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/llama/unsloth_Llama-3.2-3B-Instruct.yaml",
    "content": "# Model defaults for unsloth/Llama-3.2-3B-Instruct\n# Based on Llama3.2_(1B_and_3B)-Conversational.ipynb\n# Also applies to: unsloth/Llama-3.2-3B-Instruct-unsloth-bnb-4bit, meta-llama/Llama-3.2-3B-Instruct, unsloth/Llama-3.2-3B-Instruct-bnb-4bit, RedHatAI/Llama-3.2-3B-Instruct-FP8, unsloth/Llama-3.2-3B-Instruct-FP8-Block, unsloth/Llama-3.2-3B-Instruct-FP8-Dynamic\n# added inference parameters from unsloth notebook\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 2\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 16\n  lora_alpha: 16\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n  temperature: 1.5\n  min_p: 0.1\n\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/llama/unsloth_Llama-3.3-70B-Instruct.yaml",
    "content": "# Model defaults for unsloth/Llama-3.3-70B-Instruct\n# Based on Llama3.3_(70B)_A100-Conversational.ipynb\n# Also applies to: unsloth/Llama-3.3-70B-Instruct-unsloth-bnb-4bit, meta-llama/Llama-3.3-70B-Instruct, unsloth/Llama-3.3-70B-Instruct-bnb-4bit, RedHatAI/Llama-3.3-70B-Instruct-FP8, unsloth/Llama-3.3-70B-Instruct-FP8-Block, unsloth/Llama-3.3-70B-Instruct-FP8-Dynamic\n# added inference parameters from unsloth notebook\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 2\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 16\n  lora_alpha: 16\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n  temperature: 1.5\n  min_p: 0.1\n\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/llama/unsloth_Meta-Llama-3.1-70B-bnb-4bit.yaml",
    "content": "# Model defaults for unsloth/Meta-Llama-3.1-70B-bnb-4bit\n# Based on Llama3.1_(8B)-Alpaca.ipynb\n# Also applies to: unsloth/Meta-Llama-3.1-8B-bnb-4bit, unsloth/Meta-Llama-3.1-8B-unsloth-bnb-4bit, meta-llama/Meta-Llama-3.1-8B, unsloth/Meta-Llama-3.1-8B, unsloth/Meta-Llama-3.1-70B, meta-llama/Meta-Llama-3.1-70B, unsloth/Meta-Llama-3.1-405B-bnb-4bit, meta-llama/Meta-Llama-3.1-405B\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 2\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 16\n  lora_alpha: 16\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/llama/unsloth_Meta-Llama-3.1-8B-Instruct-bnb-4bit.yaml",
    "content": "# Model defaults for unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit\n# Based on Llama3.1_(8B)-Inference.ipynb\n# Also applies to: \"unsloth/Meta-Llama-3.1-8B-Instruct-unsloth-bnb-4bit\", \"meta-llama/Meta-Llama-3.1-8B-Instruct\", \"unsloth/Meta-Llama-3.1-8B-Instruct\",\"RedHatAI/Llama-3.1-8B-Instruct-FP8\",\"unsloth/Llama-3.1-8B-Instruct-FP8-Block\",\"unsloth/Llama-3.1-8B-Instruct-FP8-Dynamic\"\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 8192\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 2\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 16\n  lora_alpha: 16\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/llama/unsloth_llama-3-8b-Instruct-bnb-4bit.yaml",
    "content": "# Model defaults for unsloth/llama-3-8b-Instruct-bnb-4bit\n# Based on Llama3_(8B)-Conversational.ipynb\n# Also applies to: unsloth/llama-3-8b-Instruct, meta-llama/Meta-Llama-3-8B-Instruct\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 2\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 16\n  lora_alpha: 16\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/llama/unsloth_llama-3-8b-bnb-4bit.yaml",
    "content": "# Model defaults for unsloth/llama-3-8b-bnb-4bit\n# Based on Llama3_(8B)-Alpaca.ipynb\n# Also applies to: unsloth/llama-3-8b, meta-llama/Meta-Llama-3-8B\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 2\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 16\n  lora_alpha: 16\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/llasa/unsloth_Llasa-3B.yaml",
    "content": "# Model defaults for unsloth/Llasa-3B\n# Based on Llasa_TTS_(3B).ipynb and Llasa_TTS_(1B).ipynb\n# Also applies to: HKUSTAudio/Llasa-1B\n# added inference parameters from unsloth notebook\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 5e-4\n  batch_size: 2\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 128\n  lora_alpha: 128\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"v_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n  temperature: 1.2\n  top_p: 1.2\n\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/mistral/unsloth_Magistral-Small-2509-unsloth-bnb-4bit.yaml",
    "content": "# Model defaults for unsloth/Magistral-Small-2509\n# Based on Magistral_(24B)-Reasoning-Conversational.ipynb\n# Also applies to: mistralai/Magistral-Small-2509, unsloth/Magistral-Small-2509-bnb-4bit\n# added inference parameters from unsloth guides\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 2\n  gradient_accumulation_steps: 2\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 32\n  lora_alpha: 32\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n  finetune_vision_layers: true\n  finetune_language_layers: true\n  finetune_attention_modules: true\n  finetune_mlp_modules: true\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n  temperature: 0.7\n  min_p: 0.01\n  top_p: 0.95\n\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/mistral/unsloth_Ministral-3-3B-Instruct-2512.yaml",
    "content": "# Model defaults for unsloth/Ministral-3-3B-Instruct-2512\n# Based on Ministral_3_VL_(3B)_Vision.ipynb\n# Also applies to: unsloth/Ministral-3-3B-Instruct-2512\n# added inference parameters from unsloth guides\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 4\n  gradient_accumulation_steps: 2\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 32\n  lora_alpha: 32\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n  finetune_vision_layers: true\n  finetune_language_layers: true\n  finetune_attention_modules: true\n  finetune_mlp_modules: true\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n  temperature: 0.15\n  top_p: 0.95\n\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/mistral/unsloth_Mistral-Nemo-Base-2407-bnb-4bit.yaml",
    "content": "# Model defaults for unsloth/Mistral-Nemo-Base-2407-bnb-4bit\n# Based on Mistral_Nemo_(12B)-Alpaca.ipynb\n# Also applies to:  \"unsloth/Mistral-Nemo-Base-2407\",  \"mistralai/Mistral-Nemo-Base-2407\", \"unsloth/Mistral-Nemo-Instruct-2407-bnb-4bit\", \"unsloth/Mistral-Nemo-Instruct-2407\", \"mistralai/Mistral-Nemo-Instruct-2407\",\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 2\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 16\n  lora_alpha: 16\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/mistral/unsloth_Mistral-Small-Instruct-2409.yaml",
    "content": "# Model defaults for unsloth/Mistral-Small-Instruct-2409\n# Based on Mistral_Small_(22B)-Alpaca.ipynb \n# Also applies to: unsloth/Mistral-Small-Instruct-2409-bnb-4bit, mistralai/Mistral-Small-Instruct-2409\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 1\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 16\n  lora_alpha: 16\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/mistral/unsloth_Pixtral-12B-2409.yaml",
    "content": "# Model defaults for unsloth/Pixtral-12B-2409\n# Based on Pixtral_(12B)-Vision.ipynb\n# Also applies to: unsloth/Pixtral-12B-2409-unsloth-bnb-4bit, mistralai/Pixtral-12B-2409, unsloth/Pixtral-12B-2409-bnb-4bit\n# added inference parameters from unsloth notebook\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 1\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"paged_adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 8\n  lora_alpha: 8\n  lora_dropout: 0.0\n  target_modules:\n    - \"all-linear\"\n  use_rslora: false\n  use_loftq: false\n  finetune_vision_layers: true\n  finetune_language_layers: true\n  finetune_attention_modules: false\n  finetune_mlp_modules: true\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n  temperature: 1.5\n  min_p: 0.1\n\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/mistral/unsloth_mistral-7b-instruct-v0.3-bnb-4bit.yaml",
    "content": "# Model defaults for unsloth/mistral-7b-instruct-v0.3-bnb-4bit\n# Based on Mistral_v0.3_(7B)-Conversational.ipynb\n# Also applies to: unsloth/mistral-7b-instruct-v0.3, mistralai/Mistral-7B-Instruct-v0.3\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 2\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 16\n  lora_alpha: 16\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/mistral/unsloth_mistral-7b-v0.3-bnb-4bit.yaml",
    "content": "# Model defaults for unsloth/mistral-7b-v0.3-bnb-4bit\n# Based on Mistral_v0.3_(7B)-Alpaca.ipynb\n# Also applies to: \"unsloth/mistral-7b-v0.3\", \"mistralai/Mistral-7B-v0.3\",\ntraining:\n  trust_remote_code: false\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 2\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 16\n  lora_alpha: 16\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/other/OuteAI_Llama-OuteTTS-1.0-1B.yaml",
    "content": "# Model defaults for OuteAI/Llama-OuteTTS-1.0-1B\n# Based on Oute_TTS_(1B).ipynb\n# Also applies to: OuteAI/Llama-OuteTTS-1.0-1B\n# added inference parameters from unsloth notebook\n\naudio_type: dac\n\ntraining:\n  trust_remote_code: false\n  eval_steps: 0\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 2\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 128\n  lora_alpha: 128\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"v_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n  temperature: 0.4\n  top_k: 40\n  top_p: 0.9\n  min_p: 0.05\n\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/other/Spark-TTS-0.5B_LLM.yaml",
    "content": "# Model defaults for Spark-TTS-0.5B/LLM\n# Based on Spark_TTS_(0_5B).ipynb\n# Also applies to: Spark-TTS-0.5B/LLM\n# added inference parameters from unsloth notebook\n\naudio_type: bicodec\n\ntraining:\n  trust_remote_code: false\n  eval_steps: 0\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 2\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 128\n  lora_alpha: 128\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n  temperature: 0.8\n  top_k: 50\n  top_p: 1.0\n\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/other/sesame_csm-1b.yaml",
    "content": "# Model defaults for sesame/csm-1b\n# Based on Sesame_CSM_(1B)-TTS.ipynb\n# Also applies to: sesame/csm-1b\n\naudio_type: csm\n\ntraining:\n  trust_remote_code: false\n  eval_steps: 0\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 2\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 32\n  lora_alpha: 32\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/other/unsloth_GLM-4.7-Flash.yaml",
    "content": "# Model defaults for unsloth/GLM-4.7-Flash\n# Based on GLM_Flash_A100(80GB).py\n# Also applies to: unsloth/GLM-4.7-Flash-unsloth-bnb-4bit, unsloth/GLM-4.7-Flash-bnb-4bit, THUDM/GLM-4.7-Flash\n\ntraining:\n  trust_remote_code: true\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 4\n  gradient_accumulation_steps: 2\n  warmup_steps: 5\n  max_steps: 60\n  save_steps: 60\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 8\n  lora_alpha: 16\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n    - \"out_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: true\n  temperature: 0.7\n  top_p: 0.8\n  top_k: 20\n\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/other/unsloth_LFM2-1.2B.yaml",
    "content": "# Model defaults for unsloth/LFM2-1.2B\n# Based on Liquid_LFM2_(1.2B)-Conversational.ipynb\n# Also applies to: unsloth/LFM2-1.2B\n# added inference parameters from unsloth notebook\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 2\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 16\n  lora_alpha: 16\n  lora_dropout: 0.0\n  target_modules:\n    - \"all-linear\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n  temperature: 0.3\n  min_p: 0.15\n\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/other/unsloth_Nemotron-3-Nano-30B-A3B.yaml",
    "content": "# Model defaults for unsloth/Nemotron-3-Nano-30B-A3B\n# Based on Nemotron-3-Nano-30B-A3B_A100.ipynb\n# Also applies to: unsloth/Nemotron-3-Nano-30B-A3B\n# added inference parameters from unsloth guides\n\ntraining:\n  trust_remote_code: true\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 4\n  gradient_accumulation_steps: 2\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 8\n  lora_alpha: 16\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n    - \"in_proj\"\n    - \"out_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: true\n  temperature: 1.0\n  top_p: 1.0\n\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/other/unsloth_PaddleOCR-VL.yaml",
    "content": "# Model defaults for unsloth/PaddleOCR-VL\n# Based on Paddle_OCR_(1B)_Vision.ipynb\n# Also applies to: unsloth/PaddleOCR-VL\n# added inference parameters from unsloth notebook\n\ntraining:\n  trust_remote_code: true\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 5e-5\n  batch_size: 4\n  gradient_accumulation_steps: 2\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 64\n  lora_alpha: 64\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n  finetune_vision_layers: true\n  finetune_language_layers: true\n  finetune_attention_modules: true\n  finetune_mlp_modules: true\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: true\n  temperature: 1.5\n  min_p: 0.1\n\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/other/unsloth_answerdotai_ModernBERT-large.yaml",
    "content": "# Model defaults for answerdotai/ModernBERT-large\n# Based on bert_classification.ipynb\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 2048\n  # num_epochs: 1\n  num_epochs: 0\n  learning_rate: 5e-5\n  batch_size: 32\n  gradient_accumulation_steps: 1\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 16\n  lora_alpha: 16\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/other/unsloth_orpheus-3b-0.1-ft.yaml",
    "content": "# Model defaults for unsloth/orpheus-3b-0.1-ft\n# Based on Orpheus_(3B)-TTS.ipynb\n# Also applies to: unsloth/orpheus-3b-0.1-ft-unsloth-bnb-4bit, canopylabs/orpheus-3b-0.1-ft, unsloth/orpheus-3b-0.1-ft-bnb-4bit\n# added inference parameters from unsloth notebook\n\naudio_type: snac\n\ntraining:\n  trust_remote_code: false\n  eval_steps: 0\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 1\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 64\n  lora_alpha: 64\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n  temperature: 0.6\n  top_p: 0.95\n\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/other/unsloth_tinyllama-bnb-4bit.yaml",
    "content": "# Model defaults for unsloth/tinyllama\n# Based on TinyLlama_(1.1B)-Alpaca.ipynb\n# Also applies to: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 4096\n  # num_epochs: 1\n  num_epochs: 0\n  learning_rate: 2e-5\n  batch_size: 2\n  gradient_accumulation_steps: 4\n  warmup_ratio: 0.1\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.1\n  random_seed: 3407\n  packing: true\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 32\n  lora_alpha: 32\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/other/unsloth_whisper-large-v3.yaml",
    "content": "# Model defaults for unsloth/whisper-large-v3\n# Based on Whisper.ipynb\n# Also applies to: unsloth/whisper-large-v3, openai/whisper-large-v3\n\naudio_type: whisper\naudio_input: true\n\ntraining:\n  trust_remote_code: false\n  eval_steps: 5\n  max_seq_length: 448\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 1e-4\n  batch_size: 1\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 64\n  lora_alpha: 64\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"v_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/phi/unsloth_Phi-3-medium-4k-instruct.yaml",
    "content": "# Model defaults for unsloth/Phi-3-medium-4k-instruct\n# Based on Phi_3_Medium-Conversational.ipynb\n# Also applies to: \"unsloth/Phi-3-medium-4k-instruct-bnb-4bit\", \"microsoft/Phi-3-medium-4k-instruct\",\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 2\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 16\n  lora_alpha: 16\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/phi/unsloth_Phi-3.5-mini-instruct.yaml",
    "content": "# Model defaults for unsloth/Phi-3.5-mini-instruct\n# Based on Phi_3.5_Mini-Conversational.ipynb\n# Also applies to: \"unsloth/Phi-3.5-mini-instruct-bnb-4bit\", \"microsoft/Phi-3.5-mini-instruct\"\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 2\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 16\n  lora_alpha: 16\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/phi/unsloth_Phi-4.yaml",
    "content": "# Model defaults for unsloth/Phi-4\n# Based on Phi_4-Conversational.ipynb\n# Also applies to: unsloth/phi-4-unsloth-bnb-4bit, microsoft/phi-4, unsloth/phi-4-bnb-4bit\n# added inference parameters from unsloth guides\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 2\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 16\n  lora_alpha: 16\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n  temperature: 0.8\n  top_p: 0.95\n\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/qwen/imdatta0_tiny_qwen3_moe_2.8B_0.7B.yaml",
    "content": "# Model defaults for imdatta0/tiny_qwen3_moe_2.8B_0.7B\n# Based on TinyQwen3_MoE.py\n# Dummy model of qwen3moe architecture created to fit in T4\n# MoE model - includes gate_up_proj for MoE layers\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 1\n  gradient_accumulation_steps: 1\n  warmup_steps: 5\n  max_steps: 50\n  save_steps: 50\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 32\n  lora_alpha: 64\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n    - \"gate_up_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n  temperature: 0.6\n  top_k: 20\n  top_p: 0.95\n\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/qwen/unsloth_Qwen2-7B.yaml",
    "content": "# Model defaults for unsloth/Qwen2-7B\n# Based on Qwen2_(7B)-Alpaca.ipynb\n# Also applies to: unsloth/Qwen2-7B-bnb-4bit, Qwen/Qwen2-7B\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 2\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 16\n  lora_alpha: 16\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/qwen/unsloth_Qwen2-VL-7B-Instruct.yaml",
    "content": "# Model defaults for unsloth/Qwen2-VL-7B-Instruct\n# Based on Qwen2_VL_(7B)-Vision.ipynb \n# Also applies to: unsloth/Qwen2-VL-7B-Instruct-unsloth-bnb-4bit, Qwen/Qwen2-VL-7B-Instruct, unsloth/Qwen2-VL-7B-Instruct-bnb-4bit\n# added inference parameters from unsloth notebook\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 2\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 16\n  lora_alpha: 16\n  lora_dropout: 0.0\n  target_modules:\n    - \"all-linear\"\n  use_rslora: false\n  use_loftq: false\n  finetune_vision_layers: true\n  finetune_language_layers: true\n  finetune_attention_modules: true\n  finetune_mlp_modules: true\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n  temperature: 1.5\n  min_p: 0.1\n\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/qwen/unsloth_Qwen2.5-1.5B-Instruct.yaml",
    "content": "# Model defaults for unsloth/Qwen2.5-1.5B-Instruct\n# Based on nemo_gym_sudoku.ipynb\n# Also applies to: unsloth/Qwen2.5-1.5B-Instruct-unsloth-bnb-4bit, Qwen/Qwen2.5-1.5B-Instruct, unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 4096\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 1e-5\n  batch_size: 1\n  gradient_accumulation_steps: 64\n  warmup_ratio: 0.1\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 42\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 4\n  lora_alpha: 8\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/qwen/unsloth_Qwen2.5-7B.yaml",
    "content": "# Model defaults for unsloth/Qwen2.5-7B\n# Based on Qwen2.5_(7B)-Alpaca.ipynb\n# Also applies to: unsloth/Qwen2.5-7B-unsloth-bnb-4bit, Qwen/Qwen2.5-7B, unsloth/Qwen2.5-7B-bnb-4bit\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 2\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 16\n  lora_alpha: 16\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/qwen/unsloth_Qwen2.5-Coder-1.5B-Instruct.yaml",
    "content": "# Model defaults for unsloth/Qwen2.5-Coder-1.5B-Instruct\n# Based on Qwen2.5_Coder_(1.5B)-Tool_Calling.ipynb\n# Also applies to: unsloth/Qwen2.5-Coder-1.5B-Instruct-bnb-4bit, Qwen/Qwen2.5-Coder-1.5B-Instruct\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 2\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 16\n  lora_alpha: 16\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/qwen/unsloth_Qwen2.5-Coder-14B-Instruct.yaml",
    "content": "# Model defaults for unsloth/Qwen2.5-Coder-14B-Instruct\n# Based on Qwen2.5_Coder_(14B)-Conversational.ipynb\n# Also applies to: unsloth/Qwen2.5-Coder-14B-Instruct-bnb-4bit, Qwen/Qwen2.5-Coder-14B-Instruct\n# added inference parameters from unsloth notebook\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 1\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"paged_adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 16\n  lora_alpha: 16\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n  temperature: 1.5\n  min_p: 0.1\n\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/qwen/unsloth_Qwen2.5-Coder-7B-Instruct-bnb-4bit.yaml",
    "content": "# Model defaults for unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit\n# Based on CodeForces-cot-Finetune_for_Reasoning_on_CodeForces.ipynb\n# Also applies to: unsloth/Qwen2.5-Coder-7B-Instruct, Qwen/Qwen2.5-Coder-7B-Instruct\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 32768\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 2\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 16\n  lora_alpha: 16\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/qwen/unsloth_Qwen2.5-VL-7B-Instruct-bnb-4bit.yaml",
    "content": "# Model defaults for unsloth/Qwen2.5-VL-7B-Instruct-bnb-4bit\n# Based on Qwen2.5_VL_(7B)-Vision.ipynb\n# Also applies to: unsloth/Qwen2.5-VL-7B-Instruct, Qwen/Qwen2.5-VL-7B-Instruct, unsloth/Qwen2.5-VL-7B-Instruct-unsloth-bnb-4bit\n# added inference parameters from unsloth notebook\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 2\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 16\n  lora_alpha: 16\n  lora_dropout: 0.0\n  target_modules:\n    - \"all-linear\"\n  use_rslora: false\n  use_loftq: false\n  finetune_vision_layers: true\n  finetune_language_layers: true\n  finetune_attention_modules: true\n  finetune_mlp_modules: true\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n  temperature: 1.5\n  min_p: 0.1\n\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/qwen/unsloth_Qwen3-0.6B.yaml",
    "content": "# Model defaults for unsloth/Qwen3-0.6B\n# Based on Qwen3_(0_6B)-Phone_Deployment.ipynb\n# Also applies to: unsloth/Qwen3-0.6B-unsloth-bnb-4bit, Qwen/Qwen3-0.6B, unsloth/Qwen3-0.6B-bnb-4bit, Qwen/Qwen3-0.6B-FP8, unsloth/Qwen3-0.6B-FP8\n# added inference parameters from Ollama\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 1024\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 5e-5\n  batch_size: 2\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 16\n  lora_alpha: 16\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n  temperature: 0.6\n  top_k: 20\n  top_p: 0.95\n\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/qwen/unsloth_Qwen3-14B-Base-unsloth-bnb-4bit.yaml",
    "content": "# Model defaults for unsloth/Qwen3-14B-Base\n# Based on Qwen3_(14B)-Alpaca.ipynb\n# Also applies to: unsloth/Qwen3-14B-Base, Qwen/Qwen3-14B-Base, unsloth/Qwen3-14B-Base-bnb-4bit\n# added inference parameters from Ollama\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 2\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 16\n  lora_alpha: 16\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n  temperature: 0.6\n  top_k: 20\n  top_p: 0.95\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/qwen/unsloth_Qwen3-14B.yaml",
    "content": "# Model defaults for unsloth/Qwen3-14B\n# Based on Qwen3_(14B).ipynb\n# Also applies to: unsloth/Qwen3-14B-unsloth-bnb-4bit, Qwen/Qwen3-14B, unsloth/Qwen3-14B-bnb-4bit, Qwen/Qwen3-14B-FP8, unsloth/Qwen3-14B-FP8\n# added inference parameters from Ollama\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 2\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 32\n  lora_alpha: 32\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n  temperature: 0.6\n  top_k: 20\n  top_p: 0.95\n\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/qwen/unsloth_Qwen3-30B-A3B-Instruct-2507.yaml",
    "content": "# Model defaults for unsloth/Qwen3-30B-A3B-Instruct-2507\n# Based on Qwen3_MoE.py\n# Also applies to: Qwen/Qwen3-30B-A3B-Instruct-2507, unsloth/Qwen3-30B-A3B-Instruct-2507-bnb-4bit\n# MoE model - includes gate_up_proj for MoE layers\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 1\n  gradient_accumulation_steps: 1\n  warmup_steps: 5\n  max_steps: 50\n  save_steps: 50\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 32\n  lora_alpha: 64\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n    - \"gate_up_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n  temperature: 0.6\n  top_k: 20\n  top_p: 0.95\n\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/qwen/unsloth_Qwen3-32B.yaml",
    "content": "# Model defaults for unsloth/Qwen3-32B\n# Based on Qwen3_(32B)_A100-Reasoning-Conversational.ipynb\n# Also applies to: unsloth/Qwen3-32B-unsloth-bnb-4bit, Qwen/Qwen3-32B, unsloth/Qwen3-32B-bnb-4bit, Qwen/Qwen3-32B-FP8, unsloth/Qwen3-32B-FP8\n# added inference parameters from Ollama\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 2\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 32\n  lora_alpha: 32\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n  temperature: 0.6\n  top_k: 20\n  top_p: 0.95\n\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/qwen/unsloth_Qwen3-4B-Instruct-2507.yaml",
    "content": "# Model defaults for unsloth/Qwen3-4B-Instruct-2507\n# Based on Qwen3_(4B)-Instruct.ipynb\n# Also applies to: unsloth/Qwen3-4B-Instruct-2507-unsloth-bnb-4bit, Qwen/Qwen3-4B-Instruct-2507, unsloth/Qwen3-4B-Instruct-2507-bnb-4bit, Qwen/Qwen3-4B-Instruct-2507-FP8, unsloth/Qwen3-4B-Instruct-2507-FP8\n# added inference parameters from unsloth guides\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 2\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 32\n  lora_alpha: 32\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n  temperature: 0.7\n  top_p: 0.80\n  top_k: 20\n  min_p: 0.00\n\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/qwen/unsloth_Qwen3-4B-Thinking-2507.yaml",
    "content": "# Model defaults for unsloth/Qwen3-4B-Thinking-2507\n# Based on Qwen3_(4B)-Thinking.ipynb\n# Also applies to: unsloth/Qwen3-4B-Thinking-2507-unsloth-bnb-4bit, Qwen/Qwen3-4B-Thinking-2507, unsloth/Qwen3-4B-Thinking-2507-bnb-4bit, Qwen/Qwen3-4B-Thinking-2507-FP8, unsloth/Qwen3-4B-Thinking-2507-FP8\n# added inference parameters from unsloth guides\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 2\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 32\n  lora_alpha: 32\n  lora_dropout: 0.0\n  target_modules:\n    - \"q_proj\"\n    - \"k_proj\"\n    - \"v_proj\"\n    - \"o_proj\"\n    - \"gate_proj\"\n    - \"up_proj\"\n    - \"down_proj\"\n  use_rslora: false\n  use_loftq: false\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n  temperature: 0.6\n  top_p: 0.95\n  top_k: 20\n  min_p: 0.00\n\n"
  },
  {
    "path": "studio/backend/assets/configs/model_defaults/qwen/unsloth_Qwen3-VL-8B-Instruct-unsloth-bnb-4bit.yaml",
    "content": "# Model defaults for unsloth/Qwen3-VL-8B-Instruct\n# Based on Qwen3_VL_(8B)-Vision.ipynb\n# Also applies to: Qwen/Qwen3-VL-8B-Instruct-FP8, unsloth/Qwen3-VL-8B-Instruct-FP8, unsloth/Qwen3-VL-8B-Instruct, Qwen/Qwen3-VL-8B-Instruct, unsloth/Qwen3-VL-8B-Instruct-bnb-4bit\n# added inference parameters from unsloth guides\n\ntraining:\n  trust_remote_code: false\n  max_seq_length: 2048\n  # num_epochs: 4\n  num_epochs: 0\n  learning_rate: 2e-4\n  batch_size: 2\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 30\n  save_steps: 30\n  weight_decay: 0.001\n  random_seed: 3407\n  packing: false\n  train_on_completions: true\n  gradient_checkpointing: \"unsloth\"\n  optim: \"adamw_8bit\"\n  lr_scheduler_type: \"linear\"\n\nlora:\n  lora_r: 16\n  lora_alpha: 16\n  lora_dropout: 0.0\n  target_modules:\n    - \"all-linear\"\n  use_rslora: false\n  use_loftq: false\n  finetune_vision_layers: true\n  finetune_language_layers: true\n  finetune_attention_modules: true\n  finetune_mlp_modules: true\n\nlogging:\n  enable_wandb: false\n  wandb_project: \"llm-finetuning\"\n  enable_tensorboard: false\n  tensorboard_dir: \"runs\"\n  log_frequency: 10\n\ninference:\n  trust_remote_code: false\n  temperature: 0.7\n  top_p: 0.8\n  top_k: 20\n\n"
  },
  {
    "path": "studio/backend/assets/configs/vision_lora.yaml",
    "content": "model: unsloth/Qwen2-VL-2B-Instruct-bnb-4bit\n\ndata:\n  dataset: philschmid/amazon-product-descriptions-vlm\n  format_type: auto\n\ntraining:\n  training_type: lora\n  max_seq_length: 2048\n  load_in_4bit: true\n  output_dir: outputs\n  num_epochs: 1\n  learning_rate: 0.0002\n  batch_size: 1\n  gradient_accumulation_steps: 4\n  warmup_steps: 5\n  max_steps: 0\n  save_steps: 0\n  weight_decay: 0.01\n  random_seed: 3407\n  packing: false\n  train_on_completions: false\n  gradient_checkpointing: \"unsloth\"\n\nlora:\n  lora_r: 64\n  lora_alpha: 16\n  lora_dropout: 0.0\n  target_modules: \"\"   # vision uses vision_all_linear by default\n  vision_all_linear: true\n  use_rslora: false\n  use_loftq: false\n  finetune_vision_layers: true\n  finetune_language_layers: true\n  finetune_attention_modules: true\n  finetune_mlp_modules: true\n\nlogging:\n  enable_wandb: false\n  wandb_project: unsloth-training\n  enable_tensorboard: false\n  tensorboard_dir: runs\n"
  },
  {
    "path": "studio/backend/auth/.gitkeep",
    "content": ""
  },
  {
    "path": "studio/backend/auth/__init__.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nAuthentication module for JWT-based auth with SQLite storage.\n\"\"\"\n\nfrom .authentication import (\n    create_access_token,\n    create_refresh_token,\n    refresh_access_token,\n    get_current_subject,\n    get_current_subject_allow_password_change,\n    reload_secret,\n)\nfrom .storage import (\n    DEFAULT_ADMIN_USERNAME,\n    clear_bootstrap_password,\n    generate_bootstrap_password,\n    get_bootstrap_password,\n    is_initialized,\n    create_initial_user,\n    ensure_default_admin,\n    get_jwt_secret,\n    get_user_and_secret,\n    load_jwt_secret,\n    requires_password_change,\n    save_refresh_token,\n    update_password,\n    verify_refresh_token,\n    revoke_user_refresh_tokens,\n)\nfrom .hashing import hash_password, verify_password\n\n__all__ = [\n    \"create_access_token\",\n    \"create_refresh_token\",\n    \"refresh_access_token\",\n    \"get_current_subject\",\n    \"get_current_subject_allow_password_change\",\n    \"reload_secret\",\n    \"DEFAULT_ADMIN_USERNAME\",\n    \"clear_bootstrap_password\",\n    \"generate_bootstrap_password\",\n    \"get_bootstrap_password\",\n    \"is_initialized\",\n    \"create_initial_user\",\n    \"ensure_default_admin\",\n    \"get_jwt_secret\",\n    \"get_user_and_secret\",\n    \"load_jwt_secret\",\n    \"requires_password_change\",\n    \"save_refresh_token\",\n    \"update_password\",\n    \"verify_refresh_token\",\n    \"revoke_user_refresh_tokens\",\n    \"hash_password\",\n    \"verify_password\",\n]\n"
  },
  {
    "path": "studio/backend/auth/authentication.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport secrets\nfrom datetime import datetime, timedelta, timezone\nfrom typing import Optional, Tuple\n\nfrom fastapi import Depends, HTTPException, status\nfrom fastapi.security import HTTPAuthorizationCredentials, HTTPBearer\nimport jwt\n\nfrom .storage import (\n    get_jwt_secret,\n    get_user_and_secret,\n    load_jwt_secret,\n    save_refresh_token,\n    verify_refresh_token,\n)\n\nALGORITHM = \"HS256\"\nACCESS_TOKEN_EXPIRE_MINUTES = 60\nREFRESH_TOKEN_EXPIRE_DAYS = 7\n\nsecurity = HTTPBearer()  # Reads Authorization: Bearer <token>\n\n\ndef _get_secret_for_subject(subject: str) -> str:\n    secret = get_jwt_secret(subject)\n    if secret is None:\n        raise HTTPException(\n            status_code = status.HTTP_401_UNAUTHORIZED,\n            detail = \"Invalid or expired token\",\n        )\n    return secret\n\n\ndef _decode_subject_without_verification(token: str) -> Optional[str]:\n    try:\n        payload = jwt.decode(\n            token,\n            options = {\"verify_signature\": False, \"verify_exp\": False},\n        )\n    except jwt.InvalidTokenError:\n        return None\n\n    subject = payload.get(\"sub\")\n    return subject if isinstance(subject, str) else None\n\n\ndef create_access_token(\n    subject: str,\n    expires_delta: Optional[timedelta] = None,\n) -> str:\n    \"\"\"\n    Create a signed JWT for the given subject (e.g. username).\n\n    Tokens are valid across restarts because the signing secret is stored in SQLite.\n    \"\"\"\n    to_encode = {\"sub\": subject}\n    expire = datetime.now(timezone.utc) + (\n        expires_delta or timedelta(minutes = ACCESS_TOKEN_EXPIRE_MINUTES)\n    )\n    to_encode.update({\"exp\": expire})\n    return jwt.encode(\n        to_encode,\n        _get_secret_for_subject(subject),\n        algorithm = ALGORITHM,\n    )\n\n\ndef create_refresh_token(subject: str) -> str:\n    \"\"\"\n    Create a random refresh token, store its hash in SQLite, and return it.\n\n    Refresh tokens are opaque (not JWTs) and expire after REFRESH_TOKEN_EXPIRE_DAYS.\n    \"\"\"\n    token = secrets.token_urlsafe(48)\n    expires_at = datetime.now(timezone.utc) + timedelta(days = REFRESH_TOKEN_EXPIRE_DAYS)\n    save_refresh_token(token, subject, expires_at.isoformat())\n    return token\n\n\ndef refresh_access_token(refresh_token: str) -> Tuple[Optional[str], Optional[str]]:\n    \"\"\"\n    Validate a refresh token and issue a new access token.\n\n    The refresh token itself is NOT consumed — it stays valid until expiry.\n    Returns a new access_token or None if the refresh token is invalid/expired.\n    \"\"\"\n    username = verify_refresh_token(refresh_token)\n    if username is None:\n        return None, None\n    return create_access_token(subject = username), username\n\n\ndef reload_secret() -> None:\n    \"\"\"\n    Keep legacy API compatibility for callers expecting auth storage init.\n\n    Auth now resolves the current signing secret directly from SQLite.\n    \"\"\"\n    load_jwt_secret()\n\n\nasync def get_current_subject(\n    credentials: HTTPAuthorizationCredentials = Depends(security),\n) -> str:\n    \"\"\"Validate JWT and require the password-change flow to be completed.\"\"\"\n    return await _get_current_subject(\n        credentials,\n        allow_password_change = False,\n    )\n\n\nasync def get_current_subject_allow_password_change(\n    credentials: HTTPAuthorizationCredentials = Depends(security),\n) -> str:\n    \"\"\"Validate JWT but allow access to the password-change endpoint.\"\"\"\n    return await _get_current_subject(\n        credentials,\n        allow_password_change = True,\n    )\n\n\nasync def _get_current_subject(\n    credentials: HTTPAuthorizationCredentials,\n    *,\n    allow_password_change: bool,\n) -> str:\n    \"\"\"\n    FastAPI dependency to validate the JWT and return the subject.\n\n    Use this as a dependency on routes that should be protected, e.g.:\n\n        @router.get(\"/secure\")\n        async def secure_endpoint(current_subject: str = Depends(get_current_subject)):\n            ...\n    \"\"\"\n    token = credentials.credentials\n    subject = _decode_subject_without_verification(token)\n    if subject is None:\n        raise HTTPException(\n            status_code = status.HTTP_401_UNAUTHORIZED,\n            detail = \"Invalid token payload\",\n        )\n\n    record = get_user_and_secret(subject)\n    if record is None:\n        raise HTTPException(\n            status_code = status.HTTP_401_UNAUTHORIZED,\n            detail = \"Invalid or expired token\",\n        )\n\n    _salt, _pwd_hash, jwt_secret, must_change_password = record\n    try:\n        payload = jwt.decode(token, jwt_secret, algorithms = [ALGORITHM])\n        if payload.get(\"sub\") != subject:\n            raise HTTPException(\n                status_code = status.HTTP_401_UNAUTHORIZED,\n                detail = \"Invalid token payload\",\n            )\n        if must_change_password and not allow_password_change:\n            raise HTTPException(\n                status_code = status.HTTP_403_FORBIDDEN,\n                detail = \"Password change required\",\n            )\n        return subject\n    except jwt.InvalidTokenError:\n        raise HTTPException(\n            status_code = status.HTTP_401_UNAUTHORIZED,\n            detail = \"Invalid or expired token\",\n        )\n"
  },
  {
    "path": "studio/backend/auth/hashing.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nPassword hashing utilities using PBKDF2.\n\"\"\"\n\nimport hashlib\nimport hmac\nimport secrets\nfrom typing import Tuple\n\n\ndef hash_password(password: str, salt: str | None = None) -> Tuple[str, str]:\n    \"\"\"\n    Hash a password using PBKDF2-HMAC-SHA256.\n\n    Returns (salt, hex_hash) tuple.\n    \"\"\"\n    if salt is None:\n        salt = secrets.token_hex(16)\n    dk = hashlib.pbkdf2_hmac(\n        \"sha256\",\n        password.encode(\"utf-8\"),\n        salt.encode(\"utf-8\"),\n        100_000,  # 100k iterations\n    )\n    return salt, dk.hex()\n\n\ndef verify_password(password: str, salt: str, hashed: str) -> bool:\n    \"\"\"\n    Verify a password against a stored salt and hash.\n\n    Uses constant-time comparison to prevent timing attacks.\n    \"\"\"\n    dk = hashlib.pbkdf2_hmac(\n        \"sha256\",\n        password.encode(\"utf-8\"),\n        salt.encode(\"utf-8\"),\n        100_000,\n    )\n    return hmac.compare_digest(dk.hex(), hashed)\n"
  },
  {
    "path": "studio/backend/auth/storage.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nSQLite storage for authentication data (user credentials + JWT secret).\n\"\"\"\n\nimport hashlib\nimport secrets\nimport sqlite3\nfrom datetime import datetime, timezone\nfrom typing import Optional, Tuple\n\nfrom utils.paths import auth_db_path, ensure_dir\n\nDB_PATH = auth_db_path()\nDEFAULT_ADMIN_USERNAME = \"unsloth\"\n\n# Plaintext bootstrap password file — lives beside auth.db, deleted on\n# first password change so the credential never lingers on disk.\n_BOOTSTRAP_PW_PATH = DB_PATH.parent / \".bootstrap_password\"\n\n# In-process cache so we don't re-read the file on every HTML serve.\n_bootstrap_password: Optional[str] = None\n\n\ndef generate_bootstrap_password() -> str:\n    \"\"\"Generate a 4-word diceware passphrase and persist it to disk.\n\n    The passphrase is written to ``_BOOTSTRAP_PW_PATH`` so that it\n    survives server restarts (the DB only stores the *hash*).  On\n    subsequent calls / restarts, the persisted value is returned.\n    \"\"\"\n    global _bootstrap_password\n\n    # 1. Already cached in this process?\n    if _bootstrap_password is not None:\n        return _bootstrap_password\n\n    # 2. Already persisted from a previous run?\n    if _BOOTSTRAP_PW_PATH.is_file():\n        _bootstrap_password = _BOOTSTRAP_PW_PATH.read_text().strip()\n        if _bootstrap_password:\n            return _bootstrap_password\n\n    # 3. First-ever startup — generate a fresh passphrase.\n    import diceware\n\n    _bootstrap_password = diceware.get_passphrase(\n        options = diceware.handle_options(args = [\"-n\", \"4\", \"-d\", \"\", \"-c\"])\n    )\n\n    # Persist so the *same* passphrase is used if the server restarts\n    # before the user changes the password.\n    ensure_dir(_BOOTSTRAP_PW_PATH.parent)\n    _BOOTSTRAP_PW_PATH.write_text(_bootstrap_password)\n\n    return _bootstrap_password\n\n\ndef get_bootstrap_password() -> Optional[str]:\n    \"\"\"Return the cached bootstrap password, or None if not yet generated.\"\"\"\n    return _bootstrap_password\n\n\ndef clear_bootstrap_password() -> None:\n    \"\"\"Delete the persisted bootstrap password file (called after password change).\"\"\"\n    global _bootstrap_password\n    _bootstrap_password = None\n    if _BOOTSTRAP_PW_PATH.is_file():\n        _BOOTSTRAP_PW_PATH.unlink(missing_ok = True)\n\n\ndef _hash_token(token: str) -> str:\n    \"\"\"SHA-256 hash helper used for refresh token storage.\"\"\"\n    return hashlib.sha256(token.encode(\"utf-8\")).hexdigest()\n\n\ndef get_connection() -> sqlite3.Connection:\n    \"\"\"Get a connection to the auth database, creating tables if needed.\"\"\"\n    ensure_dir(DB_PATH.parent)\n    conn = sqlite3.connect(DB_PATH)\n    conn.row_factory = sqlite3.Row\n    conn.execute(\n        \"\"\"\n        CREATE TABLE IF NOT EXISTS auth_user (\n            id INTEGER PRIMARY KEY,\n            username TEXT UNIQUE NOT NULL,\n            password_salt TEXT NOT NULL,\n            password_hash TEXT NOT NULL,\n            jwt_secret TEXT NOT NULL,\n            must_change_password INTEGER NOT NULL DEFAULT 0\n        );\n        \"\"\"\n    )\n    conn.execute(\n        \"\"\"\n        CREATE TABLE IF NOT EXISTS refresh_tokens (\n            id INTEGER PRIMARY KEY,\n            token_hash TEXT NOT NULL,\n            username TEXT NOT NULL,\n            expires_at TEXT NOT NULL\n        );\n        \"\"\"\n    )\n    columns = {row[\"name\"] for row in conn.execute(\"PRAGMA table_info(auth_user)\")}\n    if \"must_change_password\" not in columns:\n        conn.execute(\n            \"ALTER TABLE auth_user ADD COLUMN must_change_password INTEGER NOT NULL DEFAULT 0\"\n        )\n    conn.commit()\n    return conn\n\n\ndef is_initialized() -> bool:\n    \"\"\"Check if auth is ready for login (at least one user exists in DB).\"\"\"\n    conn = get_connection()\n    cur = conn.execute(\"SELECT COUNT(*) AS c FROM auth_user\")\n    row = cur.fetchone()\n    conn.close()\n    return bool(row[\"c\"])\n\n\ndef create_initial_user(\n    username: str,\n    password: str,\n    jwt_secret: str,\n    *,\n    must_change_password: bool = False,\n) -> None:\n    \"\"\"\n    Create the initial admin user in the database.\n\n    Raises sqlite3.IntegrityError if username already exists.\n    \"\"\"\n    from .hashing import hash_password\n\n    salt, pwd_hash = hash_password(password)\n    conn = get_connection()\n    try:\n        conn.execute(\n            \"\"\"\n            INSERT INTO auth_user (\n                username,\n                password_salt,\n                password_hash,\n                jwt_secret,\n                must_change_password\n            )\n            VALUES (?, ?, ?, ?, ?)\n            \"\"\",\n            (username, salt, pwd_hash, jwt_secret, int(must_change_password)),\n        )\n        conn.commit()\n    finally:\n        conn.close()\n\n\ndef delete_user(username: str) -> None:\n    \"\"\"\n    Delete a user from the database.\n\n    Used for rollback when user creation fails partway through bootstrap.\n    \"\"\"\n    conn = get_connection()\n    try:\n        conn.execute(\"DELETE FROM auth_user WHERE username = ?\", (username,))\n        conn.commit()\n    finally:\n        conn.close()\n\n\ndef get_user_and_secret(username: str) -> Optional[Tuple[str, str, str, bool]]:\n    \"\"\"\n    Get user's password salt, hash, and JWT secret.\n\n    Returns (password_salt, password_hash, jwt_secret, must_change_password)\n    or None if user not found.\n    \"\"\"\n    conn = get_connection()\n    try:\n        cur = conn.execute(\n            \"\"\"\n            SELECT password_salt, password_hash, jwt_secret, must_change_password\n            FROM auth_user\n            WHERE username = ?\n            \"\"\",\n            (username,),\n        )\n        row = cur.fetchone()\n        if not row:\n            return None\n        return (\n            row[\"password_salt\"],\n            row[\"password_hash\"],\n            row[\"jwt_secret\"],\n            bool(row[\"must_change_password\"]),\n        )\n    finally:\n        conn.close()\n\n\ndef get_jwt_secret(username: str) -> Optional[str]:\n    \"\"\"Return the current JWT signing secret for a user.\"\"\"\n    conn = get_connection()\n    try:\n        cur = conn.execute(\n            \"SELECT jwt_secret FROM auth_user WHERE username = ?\",\n            (username,),\n        )\n        row = cur.fetchone()\n        return row[\"jwt_secret\"] if row else None\n    finally:\n        conn.close()\n\n\ndef requires_password_change(username: str) -> bool:\n    \"\"\"Return whether the user must change the seeded default password.\"\"\"\n    conn = get_connection()\n    try:\n        cur = conn.execute(\n            \"SELECT must_change_password FROM auth_user WHERE username = ?\",\n            (username,),\n        )\n        row = cur.fetchone()\n        return bool(row and row[\"must_change_password\"])\n    finally:\n        conn.close()\n\n\ndef load_jwt_secret() -> str:\n    \"\"\"\n    Load the JWT secret from the database.\n\n    Raises RuntimeError if no auth user has been created yet.\n    \"\"\"\n    conn = get_connection()\n    try:\n        cur = conn.execute(\"SELECT jwt_secret FROM auth_user LIMIT 1\")\n        row = cur.fetchone()\n        if not row:\n            raise RuntimeError(\n                \"Auth is not initialized. Wait for the seeded admin bootstrap to complete.\"\n            )\n        return row[\"jwt_secret\"]\n    finally:\n        conn.close()\n\n\ndef ensure_default_admin() -> bool:\n    \"\"\"Seed the default admin account on first startup.\n\n    Uses a randomly generated diceware passphrase as the bootstrap password.\n    Returns True when the default admin was created in this call.\n    \"\"\"\n    bootstrap_pw = generate_bootstrap_password()\n    try:\n        create_initial_user(\n            username = DEFAULT_ADMIN_USERNAME,\n            password = bootstrap_pw,\n            jwt_secret = secrets.token_urlsafe(64),\n            must_change_password = True,\n        )\n        return True\n    except sqlite3.IntegrityError:\n        return False\n\n\ndef update_password(username: str, new_password: str) -> bool:\n    \"\"\"Update password, clear first-login requirement, rotate JWT secret.\"\"\"\n    from .hashing import hash_password\n\n    salt, pwd_hash = hash_password(new_password)\n    jwt_secret = secrets.token_urlsafe(64)\n    conn = get_connection()\n    try:\n        cursor = conn.execute(\n            \"\"\"\n            UPDATE auth_user\n            SET password_salt = ?, password_hash = ?, jwt_secret = ?, must_change_password = 0\n            WHERE username = ?\n            \"\"\",\n            (salt, pwd_hash, jwt_secret, username),\n        )\n        conn.commit()\n        if cursor.rowcount > 0:\n            clear_bootstrap_password()\n        return cursor.rowcount > 0\n    finally:\n        conn.close()\n\n\ndef save_refresh_token(token: str, username: str, expires_at: str) -> None:\n    \"\"\"\n    Store a hashed refresh token with its associated username and expiry.\n    \"\"\"\n    token_hash = _hash_token(token)\n    conn = get_connection()\n    try:\n        conn.execute(\n            \"\"\"\n            INSERT INTO refresh_tokens (token_hash, username, expires_at)\n            VALUES (?, ?, ?)\n            \"\"\",\n            (token_hash, username, expires_at),\n        )\n        conn.commit()\n    finally:\n        conn.close()\n\n\ndef verify_refresh_token(token: str) -> Optional[str]:\n    \"\"\"\n    Verify a refresh token and return the username.\n\n    Returns the username if valid and not expired, None otherwise.\n    The token is NOT consumed — it stays valid until it expires.\n    \"\"\"\n    token_hash = _hash_token(token)\n    conn = get_connection()\n    try:\n        # Clean up any expired tokens while we're here\n        conn.execute(\n            \"DELETE FROM refresh_tokens WHERE expires_at < ?\",\n            (datetime.now(timezone.utc).isoformat(),),\n        )\n        conn.commit()\n\n        cur = conn.execute(\n            \"\"\"\n            SELECT id, username, expires_at FROM refresh_tokens\n            WHERE token_hash = ?\n            \"\"\",\n            (token_hash,),\n        )\n        row = cur.fetchone()\n        if row is None:\n            return None\n\n        # Check expiry\n        expires_at = datetime.fromisoformat(row[\"expires_at\"])\n        if datetime.now(timezone.utc) > expires_at:\n            conn.execute(\"DELETE FROM refresh_tokens WHERE id = ?\", (row[\"id\"],))\n            conn.commit()\n            return None\n\n        return row[\"username\"]\n    finally:\n        conn.close()\n\n\ndef revoke_user_refresh_tokens(username: str) -> None:\n    \"\"\"Revoke all refresh tokens for a user (e.g. on logout).\"\"\"\n    conn = get_connection()\n    try:\n        conn.execute(\"DELETE FROM refresh_tokens WHERE username = ?\", (username,))\n        conn.commit()\n    finally:\n        conn.close()\n"
  },
  {
    "path": "studio/backend/colab.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nColab-specific helpers for running Unsloth Studio.\nUses Colab's built-in proxy - no external tunneling needed!\n\"\"\"\n\nfrom pathlib import Path\nimport sys\n\n\ndef _bootstrap_studio_venv() -> None:\n    \"\"\"Expose the Studio venv's site-packages to the current interpreter.\n\n    On Colab, notebook cells run outside the venv subshell. Instead of\n    installing the full stack into system Python, we prepend the venv's\n    site-packages so that packages like structlog, fastapi, etc. are\n    importable from notebook cells and take priority over system copies.\n    \"\"\"\n    venv_lib = Path.home() / \".unsloth\" / \"studio\" / \".venv\" / \"lib\"\n    if not venv_lib.exists():\n        import warnings\n\n        warnings.warn(\n            f\"Studio venv not found at {venv_lib.parent} -- run 'unsloth studio setup' first\",\n            stacklevel = 2,\n        )\n        return\n    for sp in venv_lib.glob(\"python*/site-packages\"):\n        sp_str = str(sp)\n        if sp_str not in sys.path:\n            sys.path.insert(0, sp_str)\n\n\n_bootstrap_studio_venv()\n\n# Add backend to path early so local modules like loggers can be imported\nbackend_path = str(Path(__file__).parent)\nif backend_path not in sys.path:\n    sys.path.insert(0, backend_path)\n\nfrom loggers import get_logger\n\nlogger = get_logger(__name__)\n\n\ndef get_colab_url(port: int = 8888) -> str:\n    \"\"\"\n    Get the actual Colab proxy URL for a port.\n    \"\"\"\n    try:\n        from google.colab.output import eval_js\n\n        # Use Colab's proxy mechanism\n        url = eval_js(f\"google.colab.kernel.proxyPort({port})\", timeout_sec = 5)\n        return url if url else f\"http://localhost:{port}\"\n    except Exception as e:\n        logger.info(f\"Note: Could not get Colab URL ({e})\")\n        return f\"http://localhost:{port}\"\n\n\ndef show_link(port: int = 8888):\n    \"\"\"Display a styled clickable link to the UI.\"\"\"\n    from IPython.display import display, HTML\n\n    # Get real Colab proxy URL\n    url = get_colab_url(port)\n\n    short_url = (\n        url[: url.index(\"-\", url.index(f\"{port}-\") + len(str(port)) + 1) + 1] + \"...\"\n        if f\"{port}-\" in url\n        else url\n    )\n    html = f\"\"\"\n    <div style=\"display: inline-block; padding: 20px; background: #ffffff; border: 2px solid #000000;\n                border-radius: 12px; margin: 10px 0; font-family: system-ui, -apple-system, sans-serif;\">\n        <h2 style=\"color: #000000; margin: 0 0 12px 0; font-size: 26px; font-weight: 800;\n                   display: flex; align-items: center; gap: 12px;\">\n            <img src=\"https://github.com/unslothai/unsloth/raw/main/studio/frontend/public/unsloth-gem.png\"\n                 height=\"48\" style=\"display:block;\">\n            Unsloth Studio is Ready!\n        </h2>\n        <a href=\"{url}\" target=\"_blank\"\n           style=\"display: inline-flex; align-items: center; gap: 10px; padding: 14px 28px;\n                  background: #000000; color: white; text-decoration: none; border-radius: 8px;\n                  font-weight: 800; font-size: 16px;\">\n            <svg xmlns=\"http://www.w3.org/2000/svg\" width=\"18\" height=\"18\" viewBox=\"0 0 24 24\" fill=\"white\"><polygon points=\"5,3 19,12 5,21\"/></svg>\n            Open Unsloth Studio\n        </a>\n        <p style=\"color: #333333; margin: 16px 0 0 0; font-size: 13px; font-family: monospace;\">\n            {short_url}\n        </p>\n    </div>\n    \"\"\"\n    display(HTML(html))\n\n\ndef start(port: int = 8888):\n    \"\"\"\n    Start Unsloth Studio server in Colab and display the URL.\n\n    Usage:\n        from colab import start\n        start()\n    \"\"\"\n    import sys\n\n    logger.info(\"🦥 Starting Unsloth Studio...\")\n\n    logger.info(\"   Loading backend...\")\n    from run import run_server\n\n    # Auto-detect frontend path\n    repo_root = Path(__file__).parent.parent\n    frontend_path = repo_root / \"frontend\" / \"dist\"\n\n    if not frontend_path.exists():\n        logger.info(\"❌ Frontend not built! Please run the setup cell first.\")\n        return\n\n    logger.info(\"   Starting server...\")\n    # Start server silently\n    run_server(host = \"0.0.0.0\", port = port, frontend_path = frontend_path, silent = True)\n\n    logger.info(\"   Server started!\")\n\n    # Show the clickable link with real URL\n    show_link(port)\n\n\nif __name__ == \"__main__\":\n    start()\n"
  },
  {
    "path": "studio/backend/core/__init__.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nUnified core module for Unsloth backend\n\nImports are LAZY (via __getattr__) so that training subprocesses can\nimport core.training.worker without pulling in heavy ML dependencies\nlike unsloth, transformers, or torch before the version activation\ncode has a chance to run.\n\"\"\"\n\nimport sys\nfrom pathlib import Path\n\n# Ensure the backend directory is on sys.path so that bare \"from utils.*\"\n# imports used throughout the backend work when core is imported as a package\n# (e.g. from the CLI: \"from studio.backend.core import ModelConfig\").\n_backend_dir = str(Path(__file__).resolve().parent.parent)\nif _backend_dir not in sys.path:\n    sys.path.insert(0, _backend_dir)\n\n__all__ = [\n    # Inference\n    \"InferenceBackend\",\n    \"get_inference_backend\",\n    # Training\n    \"get_training_backend\",\n    \"TrainingBackend\",\n    \"TrainingProgress\",\n    # Config\n    \"ModelConfig\",\n    \"is_vision_model\",\n    \"scan_trained_loras\",\n    \"load_model_defaults\",\n    \"get_base_model_from_lora\",\n    # Utils\n    \"format_and_template_dataset\",\n    \"normalize_path\",\n    \"is_local_path\",\n    \"is_model_cached\",\n    \"without_hf_auth\",\n    \"format_error_message\",\n    \"get_gpu_memory_info\",\n    \"log_gpu_memory\",\n    \"get_device\",\n    \"is_apple_silicon\",\n    \"clear_gpu_cache\",\n    \"DeviceType\",\n]\n\n\ndef __getattr__(name):\n    # Inference\n    if name in (\"InferenceBackend\", \"get_inference_backend\"):\n        from .inference import InferenceBackend, get_inference_backend\n\n        globals()[\"InferenceBackend\"] = InferenceBackend\n        globals()[\"get_inference_backend\"] = get_inference_backend\n        return globals()[name]\n\n    # Training\n    if name in (\"TrainingBackend\", \"get_training_backend\", \"TrainingProgress\"):\n        from .training import TrainingBackend, get_training_backend, TrainingProgress\n\n        globals()[\"TrainingBackend\"] = TrainingBackend\n        globals()[\"get_training_backend\"] = get_training_backend\n        globals()[\"TrainingProgress\"] = TrainingProgress\n        return globals()[name]\n\n    # Config (from utils.models)\n    if name in (\n        \"is_vision_model\",\n        \"ModelConfig\",\n        \"scan_trained_loras\",\n        \"load_model_defaults\",\n        \"get_base_model_from_lora\",\n    ):\n        from utils.models import (\n            is_vision_model,\n            ModelConfig,\n            scan_trained_loras,\n            load_model_defaults,\n            get_base_model_from_lora,\n        )\n\n        globals()[\"is_vision_model\"] = is_vision_model\n        globals()[\"ModelConfig\"] = ModelConfig\n        globals()[\"scan_trained_loras\"] = scan_trained_loras\n        globals()[\"load_model_defaults\"] = load_model_defaults\n        globals()[\"get_base_model_from_lora\"] = get_base_model_from_lora\n        return globals()[name]\n\n    # Paths\n    if name in (\"normalize_path\", \"is_local_path\", \"is_model_cached\"):\n        from utils.paths import normalize_path, is_local_path, is_model_cached\n\n        globals()[\"normalize_path\"] = normalize_path\n        globals()[\"is_local_path\"] = is_local_path\n        globals()[\"is_model_cached\"] = is_model_cached\n        return globals()[name]\n\n    # Utils\n    if name in (\"without_hf_auth\", \"format_error_message\"):\n        from utils.utils import without_hf_auth, format_error_message\n\n        globals()[\"without_hf_auth\"] = without_hf_auth\n        globals()[\"format_error_message\"] = format_error_message\n        return globals()[name]\n\n    # Hardware\n    if name in (\n        \"get_device\",\n        \"is_apple_silicon\",\n        \"clear_gpu_cache\",\n        \"get_gpu_memory_info\",\n        \"log_gpu_memory\",\n        \"DeviceType\",\n    ):\n        from utils.hardware import (\n            get_device,\n            is_apple_silicon,\n            clear_gpu_cache,\n            get_gpu_memory_info,\n            log_gpu_memory,\n            DeviceType,\n        )\n\n        globals()[\"get_device\"] = get_device\n        globals()[\"is_apple_silicon\"] = is_apple_silicon\n        globals()[\"clear_gpu_cache\"] = clear_gpu_cache\n        globals()[\"get_gpu_memory_info\"] = get_gpu_memory_info\n        globals()[\"log_gpu_memory\"] = log_gpu_memory\n        globals()[\"DeviceType\"] = DeviceType\n        return globals()[name]\n\n    # Datasets\n    if name == \"format_and_template_dataset\":\n        from utils.datasets import format_and_template_dataset\n\n        globals()[\"format_and_template_dataset\"] = format_and_template_dataset\n        return format_and_template_dataset\n\n    raise AttributeError(f\"module 'core' has no attribute {name!r}\")\n"
  },
  {
    "path": "studio/backend/core/data_recipe/__init__.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nData Recipe core (DataDesigner wrapper + job runner).\n\"\"\"\n\nfrom .jobs import JobManager, get_job_manager\n\n__all__ = [\"JobManager\", \"get_job_manager\"]\n"
  },
  {
    "path": "studio/backend/core/data_recipe/huggingface.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nfrom __future__ import annotations\n\nimport json\nfrom pathlib import Path\n\nfrom utils.paths import recipe_datasets_root, resolve_dataset_path\n\n_DATA_DESIGNER_FOOTER = (\n    '<sub style=\"white-space: nowrap;\">Made with ❤️ using 🎨 '\n    '<a href=\"https://github.com/NVIDIA-NeMo/DataDesigner\">NeMo Data Designer</a></sub>'\n)\n_UNSLOTH_STUDIO_FOOTER = (\n    '<sub style=\"white-space: nowrap;\">Made with ❤️ using 🦥 ' \"Unsloth Studio</sub>\"\n)\n\n\nclass RecipeDatasetPublishError(ValueError):\n    \"\"\"Raised when a recipe dataset cannot be published to Hugging Face.\"\"\"\n\n\ndef _resolve_recipe_artifact_path(artifact_path: str) -> Path:\n    root = recipe_datasets_root().expanduser().resolve()\n    candidate = resolve_dataset_path(artifact_path).expanduser()\n    resolved = candidate.resolve(strict = False)\n\n    try:\n        resolved.relative_to(root)\n    except ValueError as exc:\n        raise RecipeDatasetPublishError(\n            \"This execution artifact is outside the Recipe Studio dataset storage.\"\n        ) from exc\n\n    if not resolved.exists():\n        raise RecipeDatasetPublishError(\"Execution artifacts are no longer available.\")\n    if not resolved.is_dir():\n        raise RecipeDatasetPublishError(\n            \"Execution artifact path is not a dataset folder.\"\n        )\n\n    return resolved\n\n\ndef publish_recipe_dataset(\n    *,\n    artifact_path: str,\n    repo_id: str,\n    description: str,\n    hf_token: str | None = None,\n    private: bool = False,\n) -> str:\n    dataset_path = _resolve_recipe_artifact_path(artifact_path)\n\n    try:\n        from data_designer.engine.storage.artifact_storage import (\n            FINAL_DATASET_FOLDER_NAME,\n            METADATA_FILENAME,\n            PROCESSORS_OUTPUTS_FOLDER_NAME,\n            SDG_CONFIG_FILENAME,\n        )\n        from data_designer.integrations.huggingface.client import (\n            HuggingFaceHubClient,\n            HuggingFaceHubClientUploadError,\n        )\n        from data_designer.integrations.huggingface.dataset_card import (\n            DataDesignerDatasetCard,\n        )\n    except ImportError as exc:\n        raise RecipeDatasetPublishError(\n            \"NeMo Data Designer Hugging Face integration is not installed.\"\n        ) from exc\n\n    try:\n        client = HuggingFaceHubClient(token = hf_token)\n        client._validate_repo_id(repo_id = repo_id)\n        client._validate_dataset_path(base_dataset_path = dataset_path)\n        client._create_or_get_repo(repo_id = repo_id, private = private)\n\n        metadata_path = dataset_path / METADATA_FILENAME\n        builder_config_path = dataset_path / SDG_CONFIG_FILENAME\n\n        with metadata_path.open(encoding = \"utf-8\") as fh:\n            metadata = json.load(fh)\n\n        builder_config = None\n        if builder_config_path.exists():\n            with builder_config_path.open(encoding = \"utf-8\") as fh:\n                builder_config = json.load(fh)\n\n        card = DataDesignerDatasetCard.from_metadata(\n            metadata = metadata,\n            builder_config = builder_config,\n            repo_id = repo_id,\n            description = description,\n            tags = None,\n        )\n        card.text = card.text.replace(_DATA_DESIGNER_FOOTER, _UNSLOTH_STUDIO_FOOTER)\n        # Data Designer currently drops the explicit token when pushing the\n        # dataset card. Push it ourselves so auth stays request-local.\n        card.push_to_hub(repo_id, token = hf_token, repo_type = \"dataset\")\n\n        client._upload_main_dataset_files(\n            repo_id = repo_id,\n            parquet_folder = dataset_path / FINAL_DATASET_FOLDER_NAME,\n        )\n        client._upload_images_folder(\n            repo_id = repo_id,\n            images_folder = dataset_path / \"images\",\n        )\n        client._upload_processor_files(\n            repo_id = repo_id,\n            processors_folder = dataset_path / PROCESSORS_OUTPUTS_FOLDER_NAME,\n        )\n        client._upload_config_files(\n            repo_id = repo_id,\n            metadata_path = metadata_path,\n            builder_config_path = builder_config_path,\n        )\n\n        return f\"https://huggingface.co/datasets/{repo_id}\"\n    except HuggingFaceHubClientUploadError as exc:\n        raise RecipeDatasetPublishError(str(exc)) from exc\n"
  },
  {
    "path": "studio/backend/core/data_recipe/jobs/__init__.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nfrom .manager import JobManager, get_job_manager\n\n__all__ = [\"JobManager\", \"get_job_manager\"]\n"
  },
  {
    "path": "studio/backend/core/data_recipe/jobs/constants.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nfrom __future__ import annotations\n\n# stages parsed from data-designer logs\nSTAGE_CREATE = \"create\"\nSTAGE_PREVIEW = \"preview\"\nSTAGE_DAG = \"dag\"\nSTAGE_HEALTHCHECK = \"healthcheck\"\nSTAGE_SAMPLING = \"sampling\"\nSTAGE_COLUMN_CONFIG = \"column_config\"\nSTAGE_GENERATING = \"generating\"\nSTAGE_BATCH = \"batch\"\nSTAGE_PROFILING = \"profiling\"\n\nUSAGE_RESET_STAGES = {\n    STAGE_CREATE,\n    STAGE_PREVIEW,\n    STAGE_DAG,\n    STAGE_HEALTHCHECK,\n    STAGE_SAMPLING,\n    STAGE_GENERATING,\n    STAGE_PROFILING,\n}\n\n# job event types emitted by worker/manager\nEVENT_JOB_ENQUEUED = \"job.enqueued\"\nEVENT_JOB_STARTED = \"job.started\"\nEVENT_JOB_CANCELLING = \"job.cancelling\"\nEVENT_JOB_CANCELLED = \"job.cancelled\"\nEVENT_JOB_COMPLETED = \"job.completed\"\nEVENT_JOB_ERROR = \"job.error\"\n"
  },
  {
    "path": "studio/backend/core/data_recipe/jobs/manager.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nfrom __future__ import annotations\n\nimport asyncio\nimport json\nimport queue\nimport threading\nimport time\nimport uuid\nfrom pathlib import Path\nfrom collections import deque\nfrom dataclasses import dataclass\nfrom typing import Any\n\nimport multiprocessing as mp\n\nfrom ..jsonable import to_preview_jsonable\nfrom .constants import (\n    EVENT_JOB_CANCELLING,\n    EVENT_JOB_CANCELLED,\n    EVENT_JOB_COMPLETED,\n    EVENT_JOB_ENQUEUED,\n    EVENT_JOB_ERROR,\n    EVENT_JOB_STARTED,\n)\nfrom .parse import apply_update, coerce_event, parse_log_message\nfrom .types import Job\nfrom .worker import run_job_process\n\n\n_CTX = mp.get_context(\"spawn\")\n\n\n@dataclass\nclass Subscription:\n    replay: list[dict]\n    _q: queue.Queue\n    _next_id: int = 0\n\n    async def next_event(self, *, timeout_sec: float) -> dict | None:\n        \"\"\"Wait for next event (SSE), w/ timeout so we can check disconnects.\"\"\"\n        try:\n            return await asyncio.to_thread(self._q.get, True, timeout_sec)\n        except queue.Empty:\n            return None\n\n    def format_sse(self, event: dict) -> bytes:\n        \"\"\"Turn event dict into SSE bytes (id/event/data).\"\"\"\n        event_id = event.get(\"seq\")\n        if event_id is None:\n            self._next_id += 1\n            event_id = self._next_id\n        body = json.dumps(event, separators = (\",\", \":\"), ensure_ascii = False)\n        event_type = event.get(\"type\") or \"message\"\n        return (\n            f\"id: {event_id}\\n\" f\"event: {event_type}\\n\" f\"data: {body}\\n\\n\"\n        ).encode(\"utf-8\")\n\n\nclass JobManager:\n    def __init__(self) -> None:\n        \"\"\"Single-job runner (in-mem). Simple on purpose, not a whole platform.\"\"\"\n        self._lock = threading.Lock()\n        self._job: Job | None = None\n        self._proc: mp.Process | None = None\n        self._mp_q: Any | None = None\n        self._events: deque[dict] = deque(maxlen = 5000)\n        self._subs: list[queue.Queue] = []\n        self._pump_thread: threading.Thread | None = None\n        self._seq: int = 0\n\n    def start(self, *, recipe: dict, run: dict) -> str:\n        \"\"\"Spawn the job subprocess (one at a time, no cap).\"\"\"\n        llm_columns = recipe.get(\"columns\") or []\n        llm_column_count = 0\n        if isinstance(llm_columns, list):\n            for column in llm_columns:\n                if not isinstance(column, dict):\n                    continue\n                column_type = str(column.get(\"column_type\") or \"\").strip().lower()\n                if column_type.startswith(\"llm\"):\n                    llm_column_count += 1\n        if llm_column_count <= 0:\n            llm_column_count = 1\n\n        with self._lock:\n            if self._proc is not None and self._proc.is_alive():\n                raise RuntimeError(\"job already running\")\n\n            job_id = uuid.uuid4().hex\n            self._job = Job(job_id = job_id, status = \"pending\", started_at = time.time())\n            self._job.progress_columns_total = llm_column_count\n            self._events.clear()\n            self._seq = 0\n\n            run_payload = dict(run)\n            run_payload[\"_job_id\"] = job_id\n            mp_q = _CTX.Queue()\n            proc = _CTX.Process(\n                target = run_job_process,\n                kwargs = {\"event_queue\": mp_q, \"recipe\": recipe, \"run\": run_payload},\n                daemon = True,\n            )\n            proc.start()\n\n            self._mp_q = mp_q\n            self._proc = proc\n            self._pump_thread = threading.Thread(target = self._pump_loop, daemon = True)\n            self._pump_thread.start()\n\n            self._emit(\n                {\"type\": EVENT_JOB_ENQUEUED, \"ts\": time.time(), \"job_id\": job_id}\n            )\n            return job_id\n\n    def cancel(self, job_id: str) -> bool:\n        \"\"\"Hard stop. We terminate the subprocess. Quick + reliable.\"\"\"\n        with self._lock:\n            if self._job is None or self._job.job_id != job_id:\n                return False\n            if self._proc is None or not self._proc.is_alive():\n                return True\n            self._job.status = \"cancelling\"\n            self._emit(\n                {\"type\": EVENT_JOB_CANCELLING, \"ts\": time.time(), \"job_id\": job_id}\n            )\n            try:\n                self._proc.terminate()\n            except (AttributeError, OSError):\n                pass\n            return True\n\n    def get_status(self, job_id: str) -> dict | None:\n        \"\"\"UI friendly snapshot that we need. Alternative to sse kinda of and structured\"\"\"\n        with self._lock:\n            if self._job is None or self._job.job_id != job_id:\n                return None\n            job = self._job\n            return {\n                \"job_id\": job.job_id,\n                \"status\": job.status,\n                \"stage\": job.stage,\n                \"current_column\": job.current_column,\n                \"completed_columns\": list(job.completed_columns),\n                \"batch\": {\"idx\": job.batch.idx, \"total\": job.batch.total},\n                \"progress\": {\n                    \"done\": job.progress.done,\n                    \"total\": job.progress.total,\n                    \"percent\": job.progress.percent,\n                    \"eta_sec\": job.progress.eta_sec,\n                    \"rate\": job.progress.rate,\n                    \"ok\": job.progress.ok,\n                    \"failed\": job.progress.failed,\n                },\n                \"column_progress\": {\n                    \"done\": job.column_progress.done,\n                    \"total\": job.column_progress.total,\n                    \"percent\": job.column_progress.percent,\n                    \"eta_sec\": job.column_progress.eta_sec,\n                    \"rate\": job.column_progress.rate,\n                    \"ok\": job.column_progress.ok,\n                    \"failed\": job.column_progress.failed,\n                },\n                \"model_usage\": {\n                    name: {\n                        \"model\": usage.model,\n                        \"tokens\": {\n                            \"input\": usage.input_tokens,\n                            \"output\": usage.output_tokens,\n                            \"total\": usage.total_tokens,\n                            \"tps\": usage.tps,\n                        },\n                        \"requests\": {\n                            \"success\": usage.requests_success,\n                            \"failed\": usage.requests_failed,\n                            \"total\": usage.requests_total,\n                            \"rpm\": usage.rpm,\n                        },\n                    }\n                    for name, usage in job.model_usage.items()\n                },\n                \"rows\": job.rows,\n                \"cols\": job.cols,\n                \"error\": job.error,\n                \"has_analysis\": job.analysis is not None,\n                \"dataset_rows\": None if job.dataset is None else len(job.dataset),\n                \"artifact_path\": job.artifact_path,\n                \"execution_type\": job.execution_type,\n                \"started_at\": job.started_at,\n                \"finished_at\": job.finished_at,\n            }\n\n    def get_current_status(self) -> dict | None:\n        \"\"\"Single-job convenience (last/current).\"\"\"\n        job_id = self.get_current_job_id()\n        if job_id is None:\n            return None\n        return self.get_status(job_id)\n\n    def get_current_job_id(self) -> str | None:\n        \"\"\"Return current job_id (or None).\"\"\"\n        with self._lock:\n            return None if self._job is None else self._job.job_id\n\n    def get_analysis(self, job_id: str) -> dict | None:\n        \"\"\"Final profiling output (only after job completes).\"\"\"\n        with self._lock:\n            if self._job is None or self._job.job_id != job_id:\n                return None\n            return self._job.analysis\n\n    def get_dataset(\n        self,\n        job_id: str,\n        *,\n        limit: int,\n        offset: int = 0,\n    ) -> dict[str, Any] | None:\n        \"\"\"Load dataset page (offset + limit) and include total rows.\"\"\"\n        with self._lock:\n            if self._job is None or self._job.job_id != job_id:\n                return None\n            in_memory_dataset = self._job.dataset\n            artifact_path = self._job.artifact_path\n            job_status = self._job.status\n\n        if in_memory_dataset is not None:\n            total = len(in_memory_dataset)\n            rows = in_memory_dataset[offset : offset + limit]\n            return {\"dataset\": rows, \"total\": total}\n        if not artifact_path:\n            if job_status in {\"completed\", \"error\", \"cancelled\"}:\n                return {\"error\": \"artifact path missing\"}\n            return None\n\n        try:\n            base_dataset_path = Path(artifact_path)\n            parquet_dir = base_dataset_path / \"parquet-files\"\n            if not parquet_dir.exists():\n                return {\"error\": f\"dataset path missing: {parquet_dir}\"}\n\n            return self._load_dataset_page(\n                parquet_dir = parquet_dir, limit = limit, offset = offset\n            )\n        except Exception as exc:\n            return {\"error\": f\"dataset load failed: {exc}\"}\n\n    @staticmethod\n    def _load_dataset_page(\n        *,\n        parquet_dir: Path,\n        limit: int,\n        offset: int,\n    ) -> dict[str, Any]:\n        dataset_page = JobManager._load_dataset_page_with_duckdb(\n            parquet_dir = parquet_dir,\n            limit = limit,\n            offset = offset,\n        )\n        if dataset_page is not None:\n            return dataset_page\n        return JobManager._load_dataset_page_with_data_designer(\n            parquet_dir = parquet_dir,\n            limit = limit,\n            offset = offset,\n        )\n\n    @staticmethod\n    def _load_dataset_page_with_duckdb(\n        *,\n        parquet_dir: Path,\n        limit: int,\n        offset: int,\n    ) -> dict[str, Any] | None:\n        parquet_glob = str((parquet_dir / \"*.parquet\").resolve())\n        try:\n            import duckdb  # type: ignore\n        except Exception:\n            return None\n\n        try:\n            conn = duckdb.connect(\":memory:\")\n            try:\n                total_row = conn.execute(\n                    \"SELECT COUNT(*) FROM read_parquet(?)\",\n                    [parquet_glob],\n                ).fetchone()\n                total = int(total_row[0] if total_row else 0)\n                dataframe = conn.execute(\n                    (\n                        \"SELECT *, row_number() OVER (PARTITION BY filename) AS __row_num__ \"\n                        \"FROM read_parquet(?, filename=true) \"\n                        \"ORDER BY filename, __row_num__ \"\n                        \"LIMIT ? OFFSET ?\"\n                    ),\n                    [parquet_glob, int(limit), int(offset)],\n                ).fetchdf()\n            finally:\n                conn.close()\n        except (RuntimeError, ValueError, duckdb.Error):\n            return None\n\n        for helper_col in (\"filename\", \"__row_num__\"):\n            if helper_col in dataframe.columns:\n                dataframe = dataframe.drop(columns = [helper_col])\n\n        rows = dataframe.to_dict(orient = \"records\")\n        return {\"dataset\": to_preview_jsonable(rows), \"total\": total}\n\n    @staticmethod\n    def _load_dataset_page_with_data_designer(\n        *,\n        parquet_dir: Path,\n        limit: int,\n        offset: int,\n    ) -> dict[str, Any]:\n        from data_designer.config.utils.io_helpers import read_parquet_dataset\n\n        dataframe = read_parquet_dataset(parquet_dir)\n        total = int(len(dataframe.index))\n        rows = dataframe.iloc[offset : offset + limit].to_dict(orient = \"records\")\n        return {\"dataset\": to_preview_jsonable(rows), \"total\": total}\n\n    def subscribe(\n        self, job_id: str, *, after_seq: int | None = None\n    ) -> Subscription | None:\n        \"\"\"SSE subscribe: get replay buffer + live events stream.\"\"\"\n        with self._lock:\n            if self._job is None or self._job.job_id != job_id:\n                return None\n            q: queue.Queue = queue.Queue(maxsize = 2000)\n            self._subs.append(q)\n            if after_seq is None:\n                replay = list(self._events)\n            else:\n                replay = [e for e in self._events if int(e.get(\"seq\") or 0) > after_seq]\n            return Subscription(replay = replay, _q = q)\n\n    def unsubscribe(self, sub: Subscription) -> None:\n        \"\"\"Drop SSE subscriber (client disconnected).\"\"\"\n        with self._lock:\n            self._subs = [q for q in self._subs if q is not sub._q]\n\n    def _emit(self, event: dict) -> None:\n        \"\"\"Broadcast event to replay buffer + all subscribers.\"\"\"\n        self._seq += 1\n        event[\"seq\"] = self._seq\n        self._events.append(event)\n        stale: list[queue.Queue] = []\n        for q in self._subs:\n            try:\n                q.put_nowait(event)\n            except queue.Full:\n                stale.append(q)\n        if stale:\n            self._subs = [q for q in self._subs if q not in stale]\n\n    def _snapshot(self) -> tuple[Job, mp.Process, Any] | None:\n        \"\"\"Grab pointers for the pump loop (avoid holding lock too long).\"\"\"\n        with self._lock:\n            if self._job is None or self._proc is None or self._mp_q is None:\n                return None\n            return self._job, self._proc, self._mp_q\n\n    @staticmethod\n    def _read_queue_with_timeout(q: Any, *, timeout_sec: float) -> dict | None:\n        \"\"\"Try read 1 event from mp queue. Timeout = pump stays responsive.\"\"\"\n        try:\n            return coerce_event(q.get(timeout = timeout_sec))\n        except queue.Empty:\n            return None\n        except (EOFError, OSError, ValueError):\n            return None\n\n    @staticmethod\n    def _drain_queue(q: Any) -> list[dict]:\n        \"\"\"Drain mp queue fast (used on process exit).\"\"\"\n        events: list[dict] = []\n        while True:\n            try:\n                events.append(coerce_event(q.get_nowait()))\n            except queue.Empty:\n                return events\n            except (EOFError, OSError, ValueError):\n                return events\n\n    def _pump_loop(self) -> None:\n        \"\"\"Background thread: consumes worker events + updates job snapshot.\"\"\"\n        while True:\n            snap = self._snapshot()\n            if snap is None:\n                return\n            job, proc, mp_q = snap\n\n            event = self._read_queue_with_timeout(mp_q, timeout_sec = 0.25)\n            if event is not None:\n                self._handle_event(job, event)\n                continue\n\n            if proc.is_alive():\n                continue\n\n            for e in self._drain_queue(mp_q):\n                self._handle_event(job, e)\n\n            with self._lock:\n                if self._job and self._job.status in {\n                    \"pending\",\n                    \"active\",\n                    \"cancelling\",\n                }:\n                    if self._job.status == \"cancelling\":\n                        self._job.status = \"cancelled\"\n                    else:\n                        self._job.status = \"error\"\n                        self._job.error = self._job.error or \"process exited\"\n                    self._job.finished_at = time.time()\n                    event_type = (\n                        EVENT_JOB_CANCELLED\n                        if self._job.status == \"cancelled\"\n                        else EVENT_JOB_ERROR\n                    )\n                    self._emit(\n                        {\n                            \"type\": event_type,\n                            \"ts\": time.time(),\n                            \"job_id\": self._job.job_id,\n                        }\n                    )\n            return\n\n    def _handle_event(self, job: Job, event: dict) -> None:\n        \"\"\"Apply event -> job state + forward to SSE.\"\"\"\n        et = event.get(\"type\")\n        msg = event.get(\"message\") if et == \"log\" else None\n\n        with self._lock:\n            if self._job is None or self._job.job_id != job.job_id:\n                return\n            if et == EVENT_JOB_STARTED:\n                self._job.status = \"active\"\n            if et == EVENT_JOB_COMPLETED:\n                self._job.status = \"completed\"\n                self._job.finished_at = time.time()\n                self._job.analysis = event.get(\"analysis\")\n                self._job.artifact_path = event.get(\"artifact_path\")\n                self._job.execution_type = event.get(\"execution_type\")\n                self._job.dataset = event.get(\"dataset\")\n                self._job.processor_artifacts = event.get(\"processor_artifacts\")\n                if self._job.progress.total and self._job.progress.total > 0:\n                    self._job.progress.done = self._job.progress.total\n                    self._job.progress.percent = 100.0\n            if et == EVENT_JOB_ERROR:\n                self._job.status = \"error\"\n                self._job.finished_at = time.time()\n                self._job.error = event.get(\"error\") or \"error\"\n\n            if msg:\n                upd = parse_log_message(msg)\n                if upd:\n                    apply_update(self._job, upd)\n\n        self._emit(event)\n\n\n_JOB_MANAGER: JobManager | None = None\n\n\ndef get_job_manager() -> JobManager:\n    \"\"\"Singleton JobManager (we only run 1 job anyway).\"\"\"\n    global _JOB_MANAGER\n    if _JOB_MANAGER is None:\n        _JOB_MANAGER = JobManager()\n    return _JOB_MANAGER\n"
  },
  {
    "path": "studio/backend/core/data_recipe/jobs/parse.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nfrom __future__ import annotations\n\nimport re\nfrom dataclasses import dataclass\nfrom typing import Any\n\nfrom .constants import (\n    STAGE_BATCH,\n    STAGE_COLUMN_CONFIG,\n    STAGE_CREATE,\n    STAGE_DAG,\n    STAGE_GENERATING,\n    STAGE_HEALTHCHECK,\n    STAGE_PREVIEW,\n    STAGE_PROFILING,\n    STAGE_SAMPLING,\n    USAGE_RESET_STAGES,\n)\nfrom .types import Job, ModelUsage, Progress\n\n\n@dataclass(frozen = True)\nclass ParsedUpdate:\n    stage: str | None = None\n    current_column: str | None = None\n    progress: Progress | None = None\n    rows: int | None = None\n    cols: int | None = None\n    batch_idx: int | None = None\n    batch_total: int | None = None\n    usage_model: str | None = None\n    usage_input_tokens: int | None = None\n    usage_output_tokens: int | None = None\n    usage_total_tokens: int | None = None\n    usage_tps: float | None = None\n    usage_requests_success: int | None = None\n    usage_requests_failed: int | None = None\n    usage_requests_total: int | None = None\n    usage_rpm: float | None = None\n    usage_section_start: bool | None = None\n\n\n# kinda of a bummber but currently only option, Best effort parser from data-designer logs -> structured status for UI.\n_RE_SAMPLERS = re.compile(\n    r\"Preparing samplers to generate (?P<rows>\\d+) records across (?P<cols>\\d+) columns\"\n)\n_RE_COLCFG = re.compile(r\"model config for column '(?P<col>[^']+)'\")\n_RE_PROCESSING_COL = re.compile(r\"Processing .* column '(?P<col>[^']+)'\")\n_RE_PROGRESS = re.compile(\n    r\"progress: (?P<done>\\d+)/(?P<total>\\d+) \\((?P<pct>\\d+)%\\) complete, \"\n    r\"(?P<ok>\\d+) ok, (?P<failed>\\d+) failed, (?P<rate>[0-9.]+) rec/s, eta (?P<eta>[0-9.]+)s\"\n)\n_RE_BATCH = re.compile(r\"Processing batch (?P<idx>\\d+) of (?P<total>\\d+)\")\n_RE_USAGE_MODEL = re.compile(r\"model:\\s*(?P<model>.+)$\")\n_RE_USAGE_TOKENS = re.compile(\n    r\"tokens:\\s*input=(?P<input>\\d+),\\s*output=(?P<output>\\d+),\\s*total=(?P<total>\\d+),\\s*tps=(?P<tps>[0-9.]+)\"\n)\n_RE_USAGE_REQUESTS = re.compile(\n    r\"requests:\\s*success=(?P<success>\\d+),\\s*failed=(?P<failed>\\d+),\\s*total=(?P<total>\\d+),\\s*rpm=(?P<rpm>[0-9.]+)\"\n)\n\n\ndef parse_log_message(msg: str) -> ParsedUpdate | None:\n    m = _RE_SAMPLERS.search(msg)\n    if m:\n        return ParsedUpdate(\n            stage = STAGE_SAMPLING,\n            rows = int(m.group(\"rows\")),\n            cols = int(m.group(\"cols\")),\n        )\n\n    if \"Sorting column configs into a Directed Acyclic Graph\" in msg:\n        return ParsedUpdate(stage = STAGE_DAG)\n    if \"Running health checks for models\" in msg:\n        return ParsedUpdate(stage = STAGE_HEALTHCHECK)\n    if \"Preview generation in progress\" in msg:\n        return ParsedUpdate(stage = STAGE_PREVIEW)\n    if \"Creating Data Designer dataset\" in msg:\n        return ParsedUpdate(stage = STAGE_CREATE)\n    if \"Measuring dataset column statistics\" in msg:\n        return ParsedUpdate(stage = STAGE_PROFILING)\n\n    m = _RE_COLCFG.search(msg)\n    if m:\n        col = m.group(\"col\")\n        return ParsedUpdate(stage = STAGE_COLUMN_CONFIG, current_column = col)\n\n    m = _RE_PROCESSING_COL.search(msg)\n    if m:\n        col = m.group(\"col\")\n        return ParsedUpdate(stage = STAGE_GENERATING, current_column = col)\n\n    m = _RE_PROGRESS.search(msg)\n    if m:\n        p = Progress(\n            done = int(m.group(\"done\")),\n            total = int(m.group(\"total\")),\n            percent = float(m.group(\"pct\")),\n            ok = int(m.group(\"ok\")),\n            failed = int(m.group(\"failed\")),\n            rate = float(m.group(\"rate\")),\n            eta_sec = float(m.group(\"eta\")),\n        )\n        return ParsedUpdate(stage = STAGE_GENERATING, progress = p)\n\n    m = _RE_BATCH.search(msg)\n    if m:\n        return ParsedUpdate(\n            stage = STAGE_BATCH,\n            batch_idx = int(m.group(\"idx\")),\n            batch_total = int(m.group(\"total\")),\n        )\n\n    if \"Model usage summary\" in msg:\n        return ParsedUpdate(usage_section_start = True)\n\n    m = _RE_USAGE_MODEL.search(msg)\n    if m and \"|-- model:\" in msg:\n        return ParsedUpdate(usage_model = str(m.group(\"model\")).strip())\n\n    m = _RE_USAGE_TOKENS.search(msg)\n    if m:\n        return ParsedUpdate(\n            usage_input_tokens = int(m.group(\"input\")),\n            usage_output_tokens = int(m.group(\"output\")),\n            usage_total_tokens = int(m.group(\"total\")),\n            usage_tps = float(m.group(\"tps\")),\n        )\n\n    m = _RE_USAGE_REQUESTS.search(msg)\n    if m:\n        return ParsedUpdate(\n            usage_requests_success = int(m.group(\"success\")),\n            usage_requests_failed = int(m.group(\"failed\")),\n            usage_requests_total = int(m.group(\"total\")),\n            usage_rpm = float(m.group(\"rpm\")),\n        )\n\n    return None\n\n\ndef apply_update(job: Job, update: ParsedUpdate) -> None:\n    if update.stage is not None:\n        job.stage = update.stage\n    if update.current_column is not None:\n        job.current_column = update.current_column\n        if (\n            update.stage == STAGE_GENERATING\n            and update.current_column not in job._seen_generation_columns\n        ):\n            job._seen_generation_columns.append(update.current_column)\n    if update.rows is not None:\n        job.rows = update.rows\n    if update.cols is not None:\n        job.cols = update.cols\n    if update.progress is not None:\n        job.column_progress = update.progress\n        if (\n            job.current_column\n            and update.progress.done is not None\n            and update.progress.total is not None\n            and update.progress.total > 0\n            and update.progress.done >= update.progress.total\n            and job.current_column not in job.completed_columns\n        ):\n            job.completed_columns.append(job.current_column)\n        job.progress = _compute_overall_progress(job, update.progress)\n    if update.batch_idx is not None:\n        job.batch.idx = update.batch_idx\n    if update.batch_total is not None:\n        job.batch.total = update.batch_total\n\n    if update.stage in USAGE_RESET_STAGES:\n        # usage summary is a short block so we reset once we move into the next stage.\n        job._in_usage_summary = False\n\n    if update.usage_section_start is not None:\n        job._in_usage_summary = update.usage_section_start\n        if update.usage_section_start:\n            job._current_usage_model = None\n\n    if not job._in_usage_summary:\n        return\n\n    if update.usage_model is not None:\n        name = update.usage_model.strip().strip(\"'\").strip('\"')\n        job._current_usage_model = name\n        if name not in job.model_usage:\n            job.model_usage[name] = ModelUsage(model = name)\n\n    if job._current_usage_model is None:\n        return\n\n    usage = job.model_usage.get(job._current_usage_model)\n    if usage is None:\n        return\n\n    if update.usage_input_tokens is not None:\n        usage.input_tokens = update.usage_input_tokens\n    if update.usage_output_tokens is not None:\n        usage.output_tokens = update.usage_output_tokens\n    if update.usage_total_tokens is not None:\n        usage.total_tokens = update.usage_total_tokens\n    if update.usage_tps is not None:\n        usage.tps = update.usage_tps\n    if update.usage_requests_success is not None:\n        usage.requests_success = update.usage_requests_success\n    if update.usage_requests_failed is not None:\n        usage.requests_failed = update.usage_requests_failed\n    if update.usage_requests_total is not None:\n        usage.requests_total = update.usage_requests_total\n    if update.usage_rpm is not None:\n        usage.rpm = update.usage_rpm\n\n\ndef _compute_overall_progress(job: Job, column_progress: Progress) -> Progress:\n    if not job.rows:\n        return column_progress\n\n    total_rows = max(1, int(job.rows))\n    current_done = 0 if column_progress.done is None else int(column_progress.done)\n    current_done = max(0, min(current_done, total_rows))\n    total_columns = max(1, int(job.progress_columns_total or 1))\n\n    if job.current_column:\n        job._column_done[job.current_column] = current_done\n\n    if len(job._column_done) == 0:\n        done = current_done\n    else:\n        sum_done = sum(\n            max(0, min(value, total_rows)) for value in job._column_done.values()\n        )\n        done = int(sum_done / total_columns)\n\n    prev_done = int(job.progress.done or 0)\n    if done < prev_done:\n        done = prev_done\n    if done > total_rows:\n        done = total_rows\n    percent = (done / total_rows) * 100 if total_rows > 0 else 100.0\n    prev_percent = float(job.progress.percent or 0.0)\n    if percent < prev_percent:\n        percent = prev_percent\n\n    return Progress(\n        done = done,\n        total = total_rows,\n        percent = percent,\n        eta_sec = column_progress.eta_sec,\n        rate = column_progress.rate,\n        ok = column_progress.ok,\n        failed = column_progress.failed,\n    )\n\n\ndef coerce_event(obj: Any) -> dict:\n    \"\"\"Normalize worker payload into event dict.\"\"\"\n    return obj if isinstance(obj, dict) else {\"type\": \"log\", \"message\": str(obj)}\n"
  },
  {
    "path": "studio/backend/core/data_recipe/jobs/types.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nfrom __future__ import annotations\n\nfrom dataclasses import dataclass, field\nfrom typing import Any, Literal\n\n\nJobStatus = Literal[\n    \"created\",\n    \"pending\",\n    \"active\",\n    \"cancelling\",\n    \"cancelled\",\n    \"error\",\n    \"completed\",\n]\n\n\n@dataclass\nclass Progress:\n    done: int | None = None\n    total: int | None = None\n    percent: float | None = None\n    eta_sec: float | None = None\n    rate: float | None = None\n    ok: int | None = None\n    failed: int | None = None\n\n\n@dataclass\nclass BatchProgress:\n    idx: int | None = None\n    total: int | None = None\n\n\n@dataclass\nclass ModelUsage:\n    model: str\n    input_tokens: int | None = None\n    output_tokens: int | None = None\n    total_tokens: int | None = None\n    tps: float | None = None\n    requests_success: int | None = None\n    requests_failed: int | None = None\n    requests_total: int | None = None\n    rpm: float | None = None\n\n\n@dataclass\nclass Job:\n    job_id: str\n    status: JobStatus = \"created\"\n    stage: str | None = None\n    current_column: str | None = None\n    progress: Progress = field(default_factory = Progress)\n    column_progress: Progress = field(default_factory = Progress)\n    batch: BatchProgress = field(default_factory = BatchProgress)\n    rows: int | None = None\n    cols: int | None = None\n    error: str | None = None\n    started_at: float | None = None\n    finished_at: float | None = None\n\n    analysis: dict[str, Any] | None = None\n    artifact_path: str | None = None\n    execution_type: str | None = None\n    dataset: list[dict[str, Any]] | None = None\n    processor_artifacts: dict[str, Any] | None = None\n    model_usage: dict[str, ModelUsage] = field(default_factory = dict)\n    progress_columns_total: int | None = None\n    completed_columns: list[str] = field(default_factory = list)\n    _current_usage_model: str | None = None\n    _in_usage_summary: bool = False\n    _seen_generation_columns: list[str] = field(default_factory = list)\n    _column_done: dict[str, int] = field(default_factory = dict)\n"
  },
  {
    "path": "studio/backend/core/data_recipe/jobs/worker.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nfrom __future__ import annotations\n\nimport json\nimport structlog\nimport loggers\nimport logging\nimport re\nimport shutil\nimport time\nimport traceback\nimport unicodedata\nfrom pathlib import Path\nfrom typing import Any\n\nfrom ..jsonable import to_jsonable, to_preview_jsonable\nfrom .constants import EVENT_JOB_COMPLETED, EVENT_JOB_ERROR, EVENT_JOB_STARTED\nfrom ..service import build_config_builder, create_data_designer\nfrom utils.paths import ensure_dir, recipe_datasets_root\n\n_ARTIFACT_ROOT = recipe_datasets_root()\n\n\nclass _QueueLogHandler(logging.Handler):\n    def __init__(self, event_queue):\n        super().__init__()\n        self._q = event_queue\n\n    def emit(self, record: logging.LogRecord) -> None:\n        try:\n            event = {\n                \"type\": \"log\",\n                \"ts\": record.created,\n                \"level\": record.levelname,\n                \"logger\": record.name,\n                \"message\": record.getMessage(),\n            }\n            self._q.put(event)\n        except (OSError, RuntimeError, ValueError):\n            pass\n\n\ndef _slugify_run_name(value: str) -> str:\n    normalized = unicodedata.normalize(\"NFKD\", value)\n    ascii_only = normalized.encode(\"ascii\", \"ignore\").decode(\"ascii\")\n    slug = re.sub(r\"[^a-zA-Z0-9]+\", \"-\", ascii_only).strip(\"-\").lower()\n    if not slug:\n        return \"\"\n    return slug[:80].strip(\"-\")\n\n\ndef _build_dataset_name(\n    *, run_name: str | None, job_id: str, artifact_root: Path\n) -> str:\n    fallback = f\"recipe_{job_id}\"\n    slug = _slugify_run_name(run_name or \"\")\n    base_name = f\"recipe_{slug}\" if slug else fallback\n    candidate = base_name\n    suffix = 2\n    while (artifact_root / candidate).exists():\n        candidate = f\"{base_name}_{suffix}\"\n        suffix += 1\n    return candidate\n\n\ndef run_job_process(\n    *,\n    event_queue,\n    recipe: dict[str, Any],\n    run: dict[str, Any],\n) -> None:\n    \"\"\"\n    Subprocess entrypoint.\n    Sends events to `event_queue`.\n    \"\"\"\n    import os\n\n    os.environ[\"PYTHONWARNINGS\"] = (\n        \"ignore\"  # Suppress warnings at C-level before imports\n    )\n\n    import warnings\n    from loggers.config import LogConfig\n\n    if os.getenv(\"ENVIRONMENT_TYPE\", \"production\") == \"production\":\n        warnings.filterwarnings(\"ignore\")\n\n    LogConfig.setup_logging(\n        service_name = \"unsloth-studio-data-worker\",\n        env = os.getenv(\"ENVIRONMENT_TYPE\", \"production\"),\n    )\n\n    event_queue.put({\"type\": EVENT_JOB_STARTED, \"ts\": time.time()})\n\n    try:\n        from data_designer.config.run_config import RunConfig\n\n        rows = int(run.get(\"rows\") or 1000)\n        job_id = str(run.get(\"_job_id\") or \"\").strip()\n        if not job_id:\n            job_id = f\"{int(time.time())}\"\n        run_name_raw = run.get(\"run_name\")\n        run_name = run_name_raw if isinstance(run_name_raw, str) else None\n        dataset_name = _build_dataset_name(\n            run_name = run_name,\n            job_id = job_id,\n            artifact_root = _ARTIFACT_ROOT,\n        )\n        merge_batches = bool(run.get(\"merge_batches\"))\n        ensure_dir(_ARTIFACT_ROOT)\n        run_config_raw = run.get(\"run_config\") or {}\n\n        builder = build_config_builder(recipe)\n        designer = create_data_designer(recipe, artifact_path = str(_ARTIFACT_ROOT))\n\n        # DataDesigner configures root logging in DataDesigner.__init__.\n        # Attach queue logger directly to `data_designer` so parser events survive root resets.\n        handler = _QueueLogHandler(event_queue)\n        handler.setLevel(logging.INFO)\n        data_designer_logger = logging.getLogger(\"data_designer\")\n        data_designer_logger.addHandler(handler)\n        data_designer_logger.setLevel(logging.INFO)\n        data_designer_logger.propagate = True\n\n        if run_config_raw:\n            designer.set_run_config(RunConfig.model_validate(run_config_raw))\n\n        execution_type = str(run.get(\"execution_type\") or \"full\").strip().lower()\n        if execution_type == \"preview\":\n            results = designer.preview(builder, num_records = rows)\n            analysis = (\n                None\n                if results.analysis is None\n                else to_jsonable(results.analysis.model_dump(mode = \"json\"))\n            )\n            dataset = (\n                []\n                if results.dataset is None\n                else to_preview_jsonable(results.dataset.to_dict(orient = \"records\"))\n            )\n            processor_artifacts = (\n                None\n                if results.processor_artifacts is None\n                else to_jsonable(results.processor_artifacts)\n            )\n            event_queue.put(\n                {\n                    \"type\": EVENT_JOB_COMPLETED,\n                    \"ts\": time.time(),\n                    \"analysis\": analysis,\n                    \"dataset\": dataset,\n                    \"processor_artifacts\": processor_artifacts,\n                    \"artifact_path\": None,\n                    \"execution_type\": execution_type,\n                }\n            )\n        else:\n            results = designer.create(\n                builder, num_records = rows, dataset_name = dataset_name\n            )\n            analysis = to_jsonable(results.load_analysis().model_dump(mode = \"json\"))\n            if merge_batches:\n                _merge_batches_to_single_parquet(\n                    results.artifact_storage.base_dataset_path\n                )\n            artifact_path = str(results.artifact_storage.base_dataset_path)\n            event_queue.put(\n                {\n                    \"type\": EVENT_JOB_COMPLETED,\n                    \"ts\": time.time(),\n                    \"analysis\": analysis,\n                    \"artifact_path\": artifact_path,\n                    \"execution_type\": execution_type,\n                }\n            )\n    except Exception as exc:\n        event_queue.put(\n            {\n                \"type\": EVENT_JOB_ERROR,\n                \"ts\": time.time(),\n                \"error\": str(exc),\n                \"stack\": traceback.format_exc(limit = 20),\n            }\n        )\n\n\ndef _merge_batches_to_single_parquet(base_dataset_path: Path) -> None:\n    parquet_dir = base_dataset_path / \"parquet-files\"\n    parquet_files = sorted(parquet_dir.glob(\"*.parquet\"))\n    if len(parquet_files) <= 1:\n        return\n\n    try:\n        from data_designer.config.utils.io_helpers import read_parquet_dataset\n    except ImportError:\n        return\n\n    dataframe = read_parquet_dataset(parquet_dir)\n    shutil.rmtree(parquet_dir)\n    parquet_dir.mkdir(parents = True, exist_ok = True)\n    merged_file = parquet_dir / \"batch_00000.parquet\"\n    dataframe.to_parquet(merged_file, index = False)\n    _rewrite_merged_metadata(\n        base_dataset_path = base_dataset_path,\n        parquet_file = merged_file,\n    )\n\n\ndef _rewrite_merged_metadata(*, base_dataset_path: Path, parquet_file: Path) -> None:\n    metadata_path = base_dataset_path / \"metadata.json\"\n    if not metadata_path.exists():\n        return\n\n    try:\n        metadata = json.loads(metadata_path.read_text(encoding = \"utf-8\"))\n    except (OSError, TypeError, ValueError):\n        return\n\n    if not isinstance(metadata, dict):\n        return\n\n    relative_parquet_path = str(parquet_file.relative_to(base_dataset_path))\n    file_paths = metadata.get(\"file_paths\")\n    if not isinstance(file_paths, dict):\n        file_paths = {}\n    file_paths[\"parquet-files\"] = [relative_parquet_path]\n    metadata[\"file_paths\"] = file_paths\n    metadata[\"total_num_batches\"] = 1\n    metadata[\"num_completed_batches\"] = 1\n\n    try:\n        metadata_path.write_text(\n            json.dumps(metadata, indent = 2, sort_keys = True),\n            encoding = \"utf-8\",\n        )\n    except OSError:\n        return\n"
  },
  {
    "path": "studio/backend/core/data_recipe/jsonable.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nfrom __future__ import annotations\n\nimport base64\nimport io\nfrom pathlib import Path\nfrom typing import Any\n\n\ndef _pil_to_preview_payload(image: Any) -> dict[str, Any]:\n    buffer = io.BytesIO()\n    image.convert(\"RGB\").save(buffer, format = \"JPEG\", quality = 85)\n    return {\n        \"type\": \"image\",\n        \"mime\": \"image/jpeg\",\n        \"width\": image.width,\n        \"height\": image.height,\n        \"data\": base64.b64encode(buffer.getvalue()).decode(\"ascii\"),\n    }\n\n\ndef _open_pil_image_from_bytes(raw_bytes: bytes):\n    from PIL import Image  # type: ignore\n\n    with Image.open(io.BytesIO(raw_bytes)) as image:\n        return image.copy()\n\n\ndef _to_pil_from_hf_image_dict(value: Any) -> Any | None:\n    if not isinstance(value, dict):\n        return None\n\n    raw_bytes = value.get(\"bytes\")\n    if isinstance(raw_bytes, (bytes, bytearray)) and len(raw_bytes) > 0:\n        try:\n            return _open_pil_image_from_bytes(bytes(raw_bytes))\n        except (OSError, ValueError):\n            pass\n    if (\n        isinstance(raw_bytes, list)\n        and len(raw_bytes) > 0\n        and all(isinstance(item, int) and 0 <= item <= 255 for item in raw_bytes)\n    ):\n        try:\n            return _open_pil_image_from_bytes(bytes(raw_bytes))\n        except (OSError, ValueError):\n            pass\n\n    path_value = value.get(\"path\")\n    if isinstance(path_value, str) and path_value.strip():\n        try:\n            from PIL import Image  # type: ignore\n\n            with Image.open(Path(path_value)) as image:\n                return image.copy()\n        except (OSError, ValueError, TypeError):\n            return None\n\n    return None\n\n\ndef to_jsonable(value: Any) -> Any:\n    \"\"\"Convert numpy/pandas-ish values into plain JSON-safe values.\"\"\"\n    try:\n        import numpy as np  # type: ignore\n    except ImportError:  # pragma: no cover\n        np = None  # type: ignore\n\n    if np is not None:\n        if isinstance(value, np.ndarray):\n            return value.tolist()\n        if isinstance(value, np.generic):\n            return value.item()\n\n    if isinstance(value, dict):\n        return {str(k): to_jsonable(v) for k, v in value.items()}\n    if isinstance(value, (list, tuple, set)):\n        return [to_jsonable(v) for v in value]\n\n    if hasattr(value, \"isoformat\") and callable(value.isoformat):\n        try:\n            return value.isoformat()\n        except (TypeError, ValueError):\n            return value\n\n    return value\n\n\ndef _to_preview_image_payload(value: Any) -> dict[str, Any] | None:\n    try:\n        from PIL.Image import Image as PILImage  # type: ignore\n    except ImportError:  # pragma: no cover\n        return None\n\n    if not isinstance(value, PILImage):\n        hf_image = _to_pil_from_hf_image_dict(value)\n        if hf_image is None:\n            return None\n        value = hf_image\n\n    return _pil_to_preview_payload(value)\n\n\ndef to_preview_jsonable(value: Any) -> Any:\n    \"\"\"Convert values into JSON-safe preview values, including PIL images.\"\"\"\n    image_payload = _to_preview_image_payload(value)\n    if image_payload is not None:\n        return image_payload\n\n    converted = to_jsonable(value)\n    if converted is None or isinstance(converted, (str, int, float, bool)):\n        return converted\n    if isinstance(converted, dict):\n        return {str(k): to_preview_jsonable(v) for k, v in converted.items()}\n    if isinstance(converted, (list, tuple, set)):\n        return [to_preview_jsonable(v) for v in converted]\n    if isinstance(converted, (bytes, bytearray)):\n        return base64.b64encode(bytes(converted)).decode(\"ascii\")\n    return str(converted)\n"
  },
  {
    "path": "studio/backend/core/data_recipe/local_callable_validators.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nfrom __future__ import annotations\n\nimport json\nimport os\nimport structlog\nimport subprocess\nfrom copy import deepcopy\nfrom dataclasses import dataclass\nfrom functools import lru_cache\nfrom pathlib import Path\nfrom typing import Any\n\nfrom loggers import get_logger\nfrom utils.paths import ensure_dir, oxc_validator_tmp_root\n\nlogger = get_logger(__name__)\n\nOXC_VALIDATION_FN_MARKER = \"unsloth_oxc_validator\"\n\n_OXC_LANG_TO_NODE_LANG = {\n    \"javascript\": \"js\",\n    \"typescript\": \"ts\",\n    \"jsx\": \"jsx\",\n    \"tsx\": \"tsx\",\n}\n_OXC_VALIDATION_MODES = {\"syntax\", \"lint\", \"syntax+lint\"}\n_OXC_CODE_SHAPES = {\"auto\", \"module\", \"snippet\"}\n\n_OXC_TOOL_DIR = Path(__file__).resolve().parent / \"oxc-validator\"\n_OXC_RUNNER_PATH = _OXC_TOOL_DIR / \"validate.mjs\"\n\n\n@dataclass(frozen = True)\nclass OxcLocalCallableValidatorSpec:\n    name: str\n    drop: bool\n    target_columns: list[str]\n    batch_size: int\n    code_lang: str\n    validation_mode: str\n    code_shape: str\n\n\ndef split_oxc_local_callable_validators(\n    recipe_core: dict[str, Any],\n) -> tuple[dict[str, Any], list[OxcLocalCallableValidatorSpec]]:\n    columns = recipe_core.get(\"columns\")\n    if not isinstance(columns, list):\n        return recipe_core, []\n\n    sanitized = deepcopy(recipe_core)\n    sanitized_columns = sanitized.get(\"columns\")\n    if not isinstance(sanitized_columns, list):\n        return sanitized, []\n\n    kept_columns: list[Any] = []\n    oxc_specs: list[OxcLocalCallableValidatorSpec] = []\n\n    for column in sanitized_columns:\n        if not isinstance(column, dict):\n            kept_columns.append(column)\n            continue\n\n        maybe_spec = _parse_oxc_spec(column = column)\n        if maybe_spec is None:\n            kept_columns.append(column)\n            continue\n        oxc_specs.append(maybe_spec)\n\n    sanitized[\"columns\"] = kept_columns\n    return sanitized, oxc_specs\n\n\ndef register_oxc_local_callable_validators(\n    *,\n    builder,\n    specs: list[OxcLocalCallableValidatorSpec],\n) -> None:\n    if not specs:\n        return\n\n    from data_designer.config.column_configs import ValidationColumnConfig\n    from data_designer.config.validator_params import (\n        LocalCallableValidatorParams,\n        ValidatorType,\n    )\n\n    for spec in specs:\n        validation_function = _build_oxc_validation_function(\n            spec.code_lang,\n            spec.validation_mode,\n            spec.code_shape,\n        )\n        builder.add_column(\n            ValidationColumnConfig(\n                name = spec.name,\n                drop = spec.drop,\n                target_columns = spec.target_columns,\n                validator_type = ValidatorType.LOCAL_CALLABLE,\n                validator_params = LocalCallableValidatorParams(\n                    validation_function = validation_function,\n                ),\n                batch_size = spec.batch_size,\n            )\n        )\n\n\ndef _parse_oxc_spec(\n    *,\n    column: dict[str, Any],\n) -> OxcLocalCallableValidatorSpec | None:\n    if str(column.get(\"column_type\") or \"\").strip() != \"validation\":\n        return None\n    if str(column.get(\"validator_type\") or \"\").strip() != \"local_callable\":\n        return None\n\n    params = column.get(\"validator_params\")\n    if not isinstance(params, dict):\n        return None\n\n    fn_raw = params.get(\"validation_function\")\n    fn_name = fn_raw.strip() if isinstance(fn_raw, str) else \"\"\n    if not fn_name.startswith(OXC_VALIDATION_FN_MARKER):\n        return None\n\n    name = str(column.get(\"name\") or \"\").strip()\n    if not name:\n        return None\n\n    target_columns_raw = column.get(\"target_columns\")\n    target_columns = (\n        [\n            value.strip()\n            for value in target_columns_raw\n            if isinstance(value, str) and value.strip()\n        ]\n        if isinstance(target_columns_raw, list)\n        else []\n    )\n    if not target_columns:\n        return None\n\n    code_lang, validation_mode, code_shape = _parse_oxc_validation_marker(fn_name)\n    batch_size = _parse_batch_size(column.get(\"batch_size\"))\n    drop = bool(column.get(\"drop\") is True)\n\n    return OxcLocalCallableValidatorSpec(\n        name = name,\n        drop = drop,\n        target_columns = target_columns,\n        batch_size = batch_size,\n        code_lang = code_lang,\n        validation_mode = validation_mode,\n        code_shape = code_shape,\n    )\n\n\ndef _parse_batch_size(value: Any) -> int:\n    try:\n        parsed = int(value)\n    except (TypeError, ValueError):\n        return 10\n    return parsed if parsed >= 1 else 10\n\n\ndef _parse_oxc_validation_marker(fn_name: str) -> tuple[str, str, str]:\n    marker = f\"{OXC_VALIDATION_FN_MARKER}:\"\n    if not fn_name.startswith(marker):\n        return \"javascript\", \"syntax\", \"auto\"\n    suffix = fn_name[len(marker) :]\n    parts = [part.strip() for part in suffix.split(\":\") if part.strip()]\n    if len(parts) < 2:\n        return \"javascript\", \"syntax\", \"auto\"\n    code_lang = parts[0] if parts[0] in _OXC_LANG_TO_NODE_LANG else \"javascript\"\n    mode = parts[1] if parts[1] in _OXC_VALIDATION_MODES else \"syntax\"\n    code_shape = (\n        parts[2] if len(parts) >= 3 and parts[2] in _OXC_CODE_SHAPES else \"auto\"\n    )\n    return code_lang, mode, code_shape\n\n\n@lru_cache(maxsize = 8)\ndef _build_oxc_validation_function(lang: str, validation_mode: str, code_shape: str):\n    node_lang = _OXC_LANG_TO_NODE_LANG.get(lang, \"js\")\n    mode = validation_mode if validation_mode in _OXC_VALIDATION_MODES else \"syntax\"\n    normalized_code_shape = code_shape if code_shape in _OXC_CODE_SHAPES else \"auto\"\n\n    def _validator(df):\n        import pandas as pd  # imported lazily for local callable runtime\n\n        row_count = int(len(df.index))\n        if row_count == 0:\n            return pd.DataFrame({\"is_valid\": []})\n\n        code_column = str(df.columns[0]) if len(df.columns) > 0 else \"\"\n        code_values = (\n            [\"\" for _ in range(row_count)]\n            if not code_column\n            else [\n                \"\" if value is None else str(value)\n                for value in df[code_column].tolist()\n            ]\n        )\n\n        results = _run_oxc_batch(\n            node_lang = node_lang,\n            validation_mode = mode,\n            code_shape = normalized_code_shape,\n            code_values = code_values,\n        )\n        if len(results) != row_count:\n            results = _fallback_results(\n                row_count,\n                \"OXC validator returned mismatched result size.\",\n            )\n        return pd.DataFrame(results)\n\n    _validator.__name__ = f\"{OXC_VALIDATION_FN_MARKER}_{node_lang}_{mode.replace('+', '_')}_{normalized_code_shape}\"\n    return _validator\n\n\ndef _run_oxc_batch(\n    *,\n    node_lang: str,\n    validation_mode: str,\n    code_shape: str,\n    code_values: list[str],\n) -> list[dict[str, Any]]:\n    if not _OXC_RUNNER_PATH.exists():\n        return _fallback_results(\n            len(code_values),\n            f\"OXC runner missing at {_OXC_RUNNER_PATH}\",\n        )\n\n    payload = {\n        \"lang\": node_lang,\n        \"mode\": validation_mode,\n        \"code_shape\": code_shape,\n        \"codes\": code_values,\n    }\n    try:\n        tmp_dir = ensure_dir(oxc_validator_tmp_root())\n        env = dict(os.environ)\n        tmp_dir_str = str(tmp_dir)\n        env[\"TMPDIR\"] = tmp_dir_str\n        env[\"TMP\"] = tmp_dir_str\n        env[\"TEMP\"] = tmp_dir_str\n        proc = subprocess.run(\n            [\"node\", str(_OXC_RUNNER_PATH)],\n            cwd = str(_OXC_TOOL_DIR),\n            input = json.dumps(payload),\n            text = True,\n            capture_output = True,\n            check = False,\n            env = env,\n        )\n    except (OSError, ValueError) as exc:\n        logger.warning(\"OXC subprocess launch failed: %s\", exc)\n        return _fallback_results(len(code_values), f\"OXC launch failed: {exc}\")\n\n    if proc.returncode != 0:\n        message = (proc.stderr or proc.stdout or \"unknown error\").strip()\n        if len(message) > 300:\n            message = f\"{message[:300]}...\"\n        return _fallback_results(len(code_values), f\"OXC failed: {message}\")\n\n    try:\n        raw = json.loads(proc.stdout)\n    except json.JSONDecodeError:\n        return _fallback_results(len(code_values), \"OXC output parse failed.\")\n\n    if not isinstance(raw, list):\n        return _fallback_results(len(code_values), \"OXC output must be an array.\")\n\n    out: list[dict[str, Any]] = []\n    for item in raw:\n        if not isinstance(item, dict):\n            out.append(\n                {\n                    \"is_valid\": False,\n                    \"error_count\": 1,\n                    \"error_message\": \"Invalid OXC result entry.\",\n                    \"severity\": None,\n                    \"code\": None,\n                    \"labels\": [],\n                    \"codeframe\": None,\n                    \"warning_count\": 0,\n                }\n            )\n            continue\n        is_valid_raw = item.get(\"is_valid\")\n        error_count_raw = item.get(\"error_count\")\n        message_raw = item.get(\"error_message\")\n        severity_raw = item.get(\"severity\")\n        code_raw = item.get(\"code\")\n        labels_raw = item.get(\"labels\")\n        codeframe_raw = item.get(\"codeframe\")\n        warning_count_raw = item.get(\"warning_count\")\n        out.append(\n            {\n                \"is_valid\": bool(is_valid_raw)\n                if isinstance(is_valid_raw, bool)\n                else False,\n                \"error_count\": int(error_count_raw)\n                if isinstance(error_count_raw, int)\n                else 0,\n                \"error_message\": str(message_raw or \"\"),\n                \"severity\": str(severity_raw)\n                if isinstance(severity_raw, str)\n                else None,\n                \"code\": str(code_raw) if isinstance(code_raw, str) else None,\n                \"labels\": labels_raw if isinstance(labels_raw, list) else [],\n                \"codeframe\": str(codeframe_raw)\n                if isinstance(codeframe_raw, str)\n                else None,\n                \"warning_count\": int(warning_count_raw)\n                if isinstance(warning_count_raw, int)\n                else 0,\n            }\n        )\n    return out\n\n\ndef _fallback_results(row_count: int, message: str) -> list[dict[str, Any]]:\n    return [\n        {\n            \"is_valid\": False,\n            \"error_count\": 1,\n            \"error_message\": message,\n            \"severity\": None,\n            \"code\": None,\n            \"labels\": [],\n            \"codeframe\": None,\n            \"warning_count\": 0,\n        }\n        for _ in range(row_count)\n    ]\n"
  },
  {
    "path": "studio/backend/core/data_recipe/oxc-validator/package.json",
    "content": "{\n  \"name\": \"unsloth-oxc-validator-runtime\",\n  \"private\": true,\n  \"version\": \"0.0.1\",\n  \"type\": \"module\",\n  \"dependencies\": {\n    \"oxc-parser\": \"^0.116.0\",\n    \"oxlint\": \"^1.51.0\"\n  }\n}\n"
  },
  {
    "path": "studio/backend/core/data_recipe/oxc-validator/validate.mjs",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { spawnSync } from \"node:child_process\";\nimport { mkdtempSync, rmSync, writeFileSync } from \"node:fs\";\nimport { tmpdir } from \"node:os\";\nimport { basename, dirname, join } from \"node:path\";\nimport { fileURLToPath } from \"node:url\";\nimport { parseSync } from \"oxc-parser\";\n\nconst LANG_TO_EXT = {\n  js: \"js\",\n  jsx: \"jsx\",\n  ts: \"ts\",\n  tsx: \"tsx\",\n};\n\nconst VALIDATION_MODES = new Set([\"syntax\", \"lint\", \"syntax+lint\"]);\nconst CODE_SHAPES = new Set([\"auto\", \"module\", \"snippet\"]);\nconst SNIPPET_PREFIX = \"(() => {\\n\";\nconst SNIPPET_SUFFIX = \"\\n})();\\nexport {};\\n\";\nconst OXLINT_SUPPRESSED_RULES = [\"no-unused-vars\", \"no-new-array\"];\nconst TOOL_DIR = dirname(fileURLToPath(import.meta.url));\n\nfunction mapLang(value) {\n  const normalized = String(value || \"\").trim().toLowerCase();\n  if (normalized === \"javascript\" || normalized === \"js\") {\n    return \"js\";\n  }\n  if (normalized === \"typescript\" || normalized === \"ts\") {\n    return \"ts\";\n  }\n  if (normalized === \"jsx\") {\n    return \"jsx\";\n  }\n  if (normalized === \"tsx\") {\n    return \"tsx\";\n  }\n  return \"js\";\n}\n\nfunction mapMode(value) {\n  const normalized = String(value || \"\").trim().toLowerCase();\n  if (VALIDATION_MODES.has(normalized)) {\n    return normalized;\n  }\n  return \"syntax\";\n}\n\nfunction mapCodeShape(value) {\n  const normalized = String(value || \"\").trim().toLowerCase();\n  if (CODE_SHAPES.has(normalized)) {\n    return normalized;\n  }\n  return \"auto\";\n}\n\nfunction parseFileIndex(filePath) {\n  if (typeof filePath !== \"string\") {\n    return null;\n  }\n  const match = basename(filePath).match(/^snippet_(\\d+)\\./);\n  if (!match) {\n    return null;\n  }\n  const parsed = Number.parseInt(match[1], 10);\n  return Number.isFinite(parsed) ? parsed : null;\n}\n\nfunction toCodeString(code) {\n  return typeof code === \"string\" ? code : String(code ?? \"\");\n}\n\nfunction makeValidationEntry({ code, index, lang, codeShape }) {\n  const source = toCodeString(code);\n  if (codeShape === \"snippet\") {\n    return {\n      index,\n      lang,\n      code: `${SNIPPET_PREFIX}${source}${SNIPPET_SUFFIX}`,\n      offset: SNIPPET_PREFIX.length,\n    };\n  }\n  return {\n    index,\n    lang,\n    code: source,\n    offset: 0,\n  };\n}\n\nfunction shiftOffset(value, offset) {\n  if (!Number.isInteger(value)) {\n    return null;\n  }\n  const shifted = value - offset;\n  return shifted >= 0 ? shifted : null;\n}\n\nfunction remapDiagnosticOffsets(diagnostic, offset) {\n  if (!diagnostic || typeof diagnostic !== \"object\" || offset <= 0) {\n    return diagnostic;\n  }\n  return {\n    ...diagnostic,\n    labels: Array.isArray(diagnostic.labels)\n      ? diagnostic.labels.map((label) => ({\n          ...label,\n          start: shiftOffset(label.start, offset),\n          end: shiftOffset(label.end, offset),\n        }))\n      : [],\n  };\n}\n\nfunction normalizeParserError(error) {\n  if (typeof error === \"string\") {\n    return {\n      code: null,\n      message: error.trim() || \"Unknown parser error\",\n      severity: null,\n      labels: [],\n      codeframe: null,\n    };\n  }\n  if (!error || typeof error !== \"object\") {\n    return {\n      code: null,\n      message: \"Unknown parser error\",\n      severity: null,\n      labels: [],\n      codeframe: null,\n    };\n  }\n  const code = typeof error.code === \"string\" ? error.code : null;\n  const message = String(error.message || error.reason || \"\").trim() || \"Unknown parser error\";\n  const severity = typeof error.severity === \"string\" ? error.severity : null;\n  const labels = Array.isArray(error.labels)\n    ? error.labels.map((label) => ({\n        message:\n          label && typeof label === \"object\" && typeof label.message === \"string\"\n            ? label.message\n            : null,\n        start:\n          label && typeof label === \"object\" && Number.isInteger(label.start)\n            ? label.start\n            : null,\n        end:\n          label && typeof label === \"object\" && Number.isInteger(label.end)\n            ? label.end\n            : null,\n      }))\n    : [];\n  const codeframe = typeof error.codeframe === \"string\" ? error.codeframe : null;\n  return {\n    code,\n    message,\n    severity,\n    labels,\n    codeframe,\n  };\n}\n\nfunction normalizeLintDiagnostic(diagnostic) {\n  if (!diagnostic || typeof diagnostic !== \"object\") {\n    return null;\n  }\n\n  const readString = (value) =>\n    typeof value === \"string\" ? value : null;\n  const readInt = (value) =>\n    Number.isInteger(value) ? value : null;\n  const asObject = (value) =>\n    value && typeof value === \"object\" ? value : null;\n\n  const message = String(diagnostic.message || \"\").trim();\n  if (!message) {\n    return null;\n  }\n\n  const severityRaw = String(diagnostic.severity || \"\").trim().toLowerCase();\n  const severity = severityRaw === \"error\" ? \"error\" : \"warning\";\n\n  const labels = [];\n  if (Array.isArray(diagnostic.labels)) {\n    for (const label of diagnostic.labels) {\n      const labelObj = asObject(label);\n      const span = asObject(labelObj?.span);\n      const start = readInt(span?.offset);\n      const length = readInt(span?.length);\n      labels.push({\n        message: readString(labelObj?.label),\n        start,\n        end: start !== null && length !== null ? start + length : null,\n      });\n    }\n  }\n\n  const code = typeof diagnostic.code === \"string\" ? diagnostic.code : null;\n  return {\n    code,\n    message: code ? `${code}: ${message}` : message,\n    severity,\n    labels,\n    codeframe: null,\n  };\n}\n\nfunction makeResult({\n  isValid,\n  errorCount,\n  warningCount = 0,\n  message = \"\",\n  severity = null,\n  code = null,\n  labels = [],\n  codeframe = null,\n}) {\n  return {\n    is_valid: Boolean(isValid),\n    error_count: Number.isInteger(errorCount) ? errorCount : 0,\n    warning_count: Number.isInteger(warningCount) ? warningCount : 0,\n    error_message: String(message || \"\"),\n    severity: typeof severity === \"string\" ? severity : null,\n    code: typeof code === \"string\" ? code : null,\n    labels: Array.isArray(labels) ? labels : [],\n    codeframe: typeof codeframe === \"string\" ? codeframe : null,\n  };\n}\n\nfunction syntaxResultFromErrors(errors) {\n  const first = errors[0] ?? null;\n  return makeResult({\n    isValid: errors.length === 0,\n    errorCount: errors.length,\n    warningCount: 0,\n    message: errors.slice(0, 3).map((error) => error.message).join(\" | \"),\n    severity: first ? first.severity : null,\n    code: first ? first.code : null,\n    labels: first ? first.labels : [],\n    codeframe: first ? first.codeframe : null,\n  });\n}\n\nfunction runSyntaxParse(entry) {\n  const ext = LANG_TO_EXT[entry.lang] ?? \"js\";\n  const filename = `snippet_${entry.index}.${ext}`;\n  try {\n    const parsed = parseSync(filename, entry.code, {\n      lang: entry.lang,\n      sourceType: \"module\",\n      showSemanticErrors: true,\n    });\n    const errors = Array.isArray(parsed?.errors)\n      ? parsed.errors\n          .map(normalizeParserError)\n          .filter(Boolean)\n          .map((error) => remapDiagnosticOffsets(error, entry.offset))\n      : [];\n    return errors;\n  } catch (error) {\n    return [\n      remapDiagnosticOffsets(\n        normalizeParserError(error),\n        entry.offset,\n      ),\n    ];\n  }\n}\n\nfunction pickPreferredErrorList(firstErrors, secondErrors) {\n  if (secondErrors.length < firstErrors.length) {\n    return secondErrors;\n  }\n  return firstErrors;\n}\n\nfunction validateSyntaxOne({ code, lang, index, codeShape }) {\n  if (codeShape !== \"auto\") {\n    const lintEntry = makeValidationEntry({\n      code,\n      index,\n      lang,\n      codeShape,\n    });\n    const errors = runSyntaxParse(lintEntry);\n    return {\n      result: syntaxResultFromErrors(errors),\n      lintEntry,\n    };\n  }\n\n  const moduleEntry = makeValidationEntry({\n    code,\n    index,\n    lang,\n    codeShape: \"module\",\n  });\n  const moduleErrors = runSyntaxParse(moduleEntry);\n  if (moduleErrors.length === 0) {\n    return {\n      result: syntaxResultFromErrors(moduleErrors),\n      lintEntry: moduleEntry,\n    };\n  }\n\n  const snippetEntry = makeValidationEntry({\n    code,\n    index,\n    lang,\n    codeShape: \"snippet\",\n  });\n  const snippetErrors = runSyntaxParse(snippetEntry);\n  if (snippetErrors.length === 0) {\n    return {\n      result: syntaxResultFromErrors(snippetErrors),\n      lintEntry: snippetEntry,\n    };\n  }\n\n  const chosenErrors = pickPreferredErrorList(moduleErrors, snippetErrors);\n  const lintEntry = chosenErrors === snippetErrors ? snippetEntry : moduleEntry;\n  return {\n    result: syntaxResultFromErrors(chosenErrors),\n    lintEntry,\n  };\n}\n\nfunction resolveLintEntry({ code, lang, index, codeShape }) {\n  if (codeShape !== \"auto\") {\n    return makeValidationEntry({\n      code,\n      index,\n      lang,\n      codeShape,\n    });\n  }\n\n  const moduleEntry = makeValidationEntry({\n    code,\n    index,\n    lang,\n    codeShape: \"module\",\n  });\n  if (runSyntaxParse(moduleEntry).length === 0) {\n    return moduleEntry;\n  }\n\n  const snippetEntry = makeValidationEntry({\n    code,\n    index,\n    lang,\n    codeShape: \"snippet\",\n  });\n  if (runSyntaxParse(snippetEntry).length === 0) {\n    return snippetEntry;\n  }\n\n  return moduleEntry;\n}\n\nfunction fallbackLintResults(entries, message) {\n  return new Map(\n    entries.map((entry) => [\n      entry.index,\n      makeResult({\n        isValid: false,\n        errorCount: 1,\n        warningCount: 0,\n        message,\n        severity: \"error\",\n      }),\n    ]),\n  );\n}\n\nfunction runLintBatch(entries) {\n  if (entries.length === 0) {\n    return new Map();\n  }\n\n  const entryByIndex = new Map(entries.map((entry) => [entry.index, entry]));\n  const tempDir = mkdtempSync(join(tmpdir(), \"oxlint-\"));\n  try {\n    for (const entry of entries) {\n      const ext = LANG_TO_EXT[entry.lang] ?? \"js\";\n      const filePath = join(tempDir, `snippet_${entry.index}.${ext}`);\n      writeFileSync(filePath, entry.code, \"utf8\");\n    }\n\n    const oxlintBin = join(TOOL_DIR, \"node_modules\", \".bin\", \"oxlint\");\n    const oxlintArgs = [\n      ...OXLINT_SUPPRESSED_RULES.flatMap((rule) => [\"-A\", rule]),\n      \"--format\",\n      \"json\",\n      tempDir,\n    ];\n    const exec = spawnSync(oxlintBin, oxlintArgs, {\n      encoding: \"utf8\",\n      cwd: TOOL_DIR,\n    });\n    if (exec.error) {\n      return fallbackLintResults(\n        entries,\n        `oxlint execution failed: ${exec.error.message}`,\n      );\n    }\n    const stdout = String(exec.stdout || \"\").trim();\n    if (!stdout) {\n      const stderr = String(exec.stderr || \"\").trim();\n      return fallbackLintResults(\n        entries,\n        stderr || \"oxlint returned empty output\",\n      );\n    }\n\n    let parsed;\n    try {\n      parsed = JSON.parse(stdout);\n    } catch {\n      return fallbackLintResults(entries, \"oxlint JSON parse failed\");\n    }\n\n    const rawDiagnostics = Array.isArray(parsed?.diagnostics)\n      ? parsed.diagnostics\n      : [];\n    const byIndex = new Map();\n\n    for (const diag of rawDiagnostics) {\n      const filenameRaw =\n        typeof diag?.filename === \"string\" ? diag.filename : \"\";\n      const filename = filenameRaw.startsWith(\"file://\")\n        ? filenameRaw.replace(\"file://\", \"\")\n        : filenameRaw;\n      const index = parseFileIndex(filename);\n      if (index === null) {\n        continue;\n      }\n      const normalized = normalizeLintDiagnostic(diag);\n      if (!normalized) {\n        continue;\n      }\n      const entry = entryByIndex.get(index);\n      const remapped = remapDiagnosticOffsets(normalized, entry?.offset ?? 0);\n      const list = byIndex.get(index) ?? [];\n      list.push(remapped);\n      byIndex.set(index, list);\n    }\n\n    const results = new Map();\n    for (const entry of entries) {\n      const diagnostics = byIndex.get(entry.index) ?? [];\n      const errorDiagnostics = diagnostics.filter(\n        (diag) => diag.severity === \"error\",\n      );\n      const warningDiagnostics = diagnostics.filter(\n        (diag) => diag.severity !== \"error\",\n      );\n      const top = errorDiagnostics[0] ?? warningDiagnostics[0] ?? null;\n      const messageSource =\n        errorDiagnostics.length > 0 ? errorDiagnostics : warningDiagnostics;\n      results.set(\n        entry.index,\n        makeResult({\n          isValid: errorDiagnostics.length === 0,\n          errorCount: errorDiagnostics.length,\n          warningCount: warningDiagnostics.length,\n          message: messageSource\n            .slice(0, 3)\n            .map((diag) => diag.message)\n            .join(\" | \"),\n          severity: top ? top.severity : null,\n          code: top ? top.code : null,\n          labels: top ? top.labels : [],\n          codeframe: top ? top.codeframe : null,\n        }),\n      );\n    }\n    return results;\n  } catch (error) {\n    return fallbackLintResults(entries, `oxlint execution failed: ${error}`);\n  } finally {\n    rmSync(tempDir, { recursive: true, force: true });\n  }\n}\n\nfunction readStdin() {\n  return new Promise((resolve, reject) => {\n    let data = \"\";\n    process.stdin.setEncoding(\"utf8\");\n    process.stdin.on(\"data\", (chunk) => {\n      data += chunk;\n    });\n    process.stdin.on(\"end\", () => resolve(data));\n    process.stdin.on(\"error\", (error) => reject(error));\n  });\n}\n\nfunction runValidation({ codes, lang, mode, codeShape }) {\n  if (mode === \"syntax\") {\n    return codes.map((code, index) =>\n      validateSyntaxOne({ code, lang, index, codeShape }).result,\n    );\n  }\n\n  if (mode === \"lint\") {\n    const entries = codes.map((code, index) =>\n      resolveLintEntry({ code, lang, index, codeShape }),\n    );\n    const lintMap = runLintBatch(entries);\n    return entries.map(\n      (entry) =>\n        lintMap.get(entry.index) ??\n        makeResult({\n          isValid: true,\n          errorCount: 0,\n          warningCount: 0,\n        }),\n    );\n  }\n\n  const syntaxRuns = codes.map((code, index) =>\n    validateSyntaxOne({ code, lang, index, codeShape }),\n  );\n  const lintTargets = syntaxRuns\n    .filter((run) => run.result.is_valid === true)\n    .map((run) => run.lintEntry);\n  const lintMap = runLintBatch(lintTargets);\n\n  return syntaxRuns.map((run) => {\n    if (run.result.is_valid !== true) {\n      return run.result;\n    }\n    return (\n      lintMap.get(run.lintEntry.index) ??\n      makeResult({\n        isValid: true,\n        errorCount: 0,\n        warningCount: 0,\n      })\n    );\n  });\n}\n\nasync function main() {\n  const raw = await readStdin();\n  let payload;\n  try {\n    payload = JSON.parse(raw || \"{}\");\n  } catch {\n    process.stdout.write(\n      JSON.stringify([\n        makeResult({\n          isValid: false,\n          errorCount: 1,\n          warningCount: 0,\n          message: \"Invalid JSON payload\",\n          severity: \"error\",\n        }),\n      ]),\n    );\n    return;\n  }\n\n  const lang = mapLang(payload?.lang);\n  const mode = mapMode(payload?.mode);\n  const codeShape = mapCodeShape(payload?.code_shape);\n  const codes = Array.isArray(payload?.codes) ? payload.codes : [];\n  const out = runValidation({ codes, lang, mode, codeShape });\n  process.stdout.write(JSON.stringify(out));\n}\n\nmain().catch((error) => {\n  process.stderr.write(String(error?.stack || error));\n  process.exit(1);\n});\n"
  },
  {
    "path": "studio/backend/core/data_recipe/service.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nfrom __future__ import annotations\n\nimport base64\nimport io\nimport os\nfrom pathlib import Path\nfrom typing import Any\n\nfrom .jsonable import to_jsonable\nfrom .local_callable_validators import (\n    register_oxc_local_callable_validators,\n    split_oxc_local_callable_validators,\n)\n\n_IMAGE_CONTEXT_PATCHED = False\n\n\ndef _encode_bytes_to_base64(value: bytes | bytearray) -> str:\n    return base64.b64encode(bytes(value)).decode(\"utf-8\")\n\n\ndef _load_image_file_to_base64(\n    path_value: str, *, base_path: str | None = None\n) -> str | None:\n    try:\n        path = Path(path_value)\n        candidates: list[Path] = []\n        if path.is_absolute():\n            candidates.append(path)\n        else:\n            if base_path:\n                candidates.append(Path(base_path) / path)\n            candidates.append(Path.cwd() / path)\n\n        for candidate in candidates:\n            if not candidate.exists() or not candidate.is_file():\n                continue\n            with candidate.open(\"rb\") as f:\n                return _encode_bytes_to_base64(f.read())\n    except (OSError, TypeError, ValueError):\n        return None\n    return None\n\n\ndef _pil_image_to_base64(value: Any) -> str | None:\n    try:\n        from PIL.Image import Image as PILImage  # type: ignore\n    except ImportError:\n        return None\n    if not isinstance(value, PILImage):\n        return None\n    buffer = io.BytesIO()\n    image_format = str(getattr(value, \"format\", \"\") or \"\").upper()\n    if image_format not in {\"PNG\", \"JPEG\", \"JPG\", \"WEBP\", \"GIF\"}:\n        image_format = \"PNG\"\n    value.save(buffer, format = image_format)\n    return _encode_bytes_to_base64(buffer.getvalue())\n\n\ndef _normalize_image_context_value(value: Any, *, base_path: str | None = None) -> Any:\n    if isinstance(value, str):\n        return value\n\n    if isinstance(value, (bytes, bytearray)):\n        return _encode_bytes_to_base64(value)\n\n    pil_base64 = _pil_image_to_base64(value)\n    if pil_base64 is not None:\n        return pil_base64\n\n    if isinstance(value, dict):\n        url = value.get(\"url\")\n        if isinstance(url, str):\n            return url\n\n        image_url = value.get(\"image_url\")\n        if isinstance(image_url, str):\n            return image_url\n        if isinstance(image_url, dict):\n            nested_url = image_url.get(\"url\")\n            if isinstance(nested_url, str):\n                return nested_url\n\n        inline_data = value.get(\"data\")\n        if isinstance(inline_data, str):\n            return inline_data\n\n        raw_bytes = value.get(\"bytes\")\n        if isinstance(raw_bytes, (bytes, bytearray)):\n            return _encode_bytes_to_base64(raw_bytes)\n        if isinstance(raw_bytes, str) and raw_bytes.strip():\n            return raw_bytes\n\n        path_value = value.get(\"path\")\n        if isinstance(path_value, str) and path_value.strip():\n            if as_base64 := _load_image_file_to_base64(path_value, base_path = base_path):\n                return as_base64\n            return path_value\n\n    return value\n\n\ndef _apply_data_designer_image_context_patch() -> None:\n    global _IMAGE_CONTEXT_PATCHED\n    if _IMAGE_CONTEXT_PATCHED:\n        return\n\n    try:\n        from data_designer.config.models import ImageContext\n    except ImportError:\n        return\n\n    if getattr(ImageContext, \"_unsloth_image_context_patch_applied\", False):\n        _IMAGE_CONTEXT_PATCHED = True\n        return\n\n    original_auto_resolve = ImageContext._auto_resolve_context_value\n\n    def _patched_auto_resolve(\n        self: Any, context_value: Any, base_path: str | None\n    ) -> Any:\n        normalized = _normalize_image_context_value(context_value, base_path = base_path)\n        return original_auto_resolve(self, normalized, base_path)\n\n    ImageContext._auto_resolve_context_value = _patched_auto_resolve\n    setattr(ImageContext, \"_unsloth_image_context_patch_applied\", True)\n    _IMAGE_CONTEXT_PATCHED = True\n\n\ndef build_model_providers(recipe: dict[str, Any]):\n    from data_designer.config.models import ModelProvider\n\n    providers: list[ModelProvider] = []\n    for provider in recipe.get(\"model_providers\", []):\n        api_key = provider.get(\"api_key\")\n        api_key_env = provider.get(\"api_key_env\")\n        if not api_key and api_key_env:\n            api_key = os.getenv(api_key_env)\n        providers.append(\n            ModelProvider(\n                name = provider[\"name\"],\n                endpoint = provider[\"endpoint\"],\n                provider_type = provider.get(\"provider_type\", \"openai\"),\n                api_key = api_key,\n                extra_headers = provider.get(\"extra_headers\"),\n                extra_body = provider.get(\"extra_body\"),\n            )\n        )\n\n    return providers\n\n\ndef _recipe_has_llm_columns(recipe: dict[str, Any]) -> bool:\n    for column in recipe.get(\"columns\", []):\n        if not isinstance(column, dict):\n            continue\n        column_type = column.get(\"column_type\")\n        if isinstance(column_type, str) and column_type.startswith(\"llm-\"):\n            return True\n    return False\n\n\ndef _validate_recipe_runtime_support(\n    recipe: dict[str, Any],\n    model_providers: list[Any],\n) -> None:\n    if not _recipe_has_llm_columns(recipe):\n        raise ValueError(\n            \"Recipe Studio currently requires at least one AI generation step.\"\n        )\n\n    if not model_providers:\n        raise ValueError(\"Add a Provider connection block before running this recipe.\")\n\n\ndef build_mcp_providers(\n    recipe: dict[str, Any],\n) -> list:\n    from data_designer.config.mcp import LocalStdioMCPProvider, MCPProvider\n\n    providers: list[MCPProvider | LocalStdioMCPProvider] = []\n    for provider in recipe.get(\"mcp_providers\", []):\n        if not isinstance(provider, dict):\n            continue\n        provider_type = provider.get(\"provider_type\")\n        if provider_type == \"stdio\":\n            env = provider.get(\"env\")\n            if not isinstance(env, dict):\n                env = {}\n            args = provider.get(\"args\")\n            if not isinstance(args, list):\n                args = []\n            providers.append(\n                LocalStdioMCPProvider(\n                    name = str(provider.get(\"name\", \"\")),\n                    command = str(provider.get(\"command\", \"\")),\n                    args = [str(value) for value in args],\n                    env = {str(key): str(value) for key, value in env.items()},\n                )\n            )\n            continue\n\n        if provider_type in {\"sse\", \"streamable_http\"}:\n            api_key = provider.get(\"api_key\")\n            api_key_env = provider.get(\"api_key_env\")\n            if not api_key and api_key_env:\n                api_key = os.getenv(str(api_key_env))\n            providers.append(\n                MCPProvider(\n                    name = str(provider.get(\"name\", \"\")),\n                    endpoint = str(provider.get(\"endpoint\", \"\")),\n                    provider_type = str(provider_type),\n                    api_key = str(api_key) if api_key else None,\n                )\n            )\n    return providers\n\n\ndef build_config_builder(recipe: dict[str, Any]):\n    _apply_data_designer_image_context_patch()\n    from data_designer.config import DataDesignerConfigBuilder\n    from data_designer.config.processors import ProcessorType\n\n    recipe_core = {\n        key: value\n        for key, value in recipe.items()\n        if key not in {\"model_providers\", \"mcp_providers\"}\n    }\n    recipe_core, oxc_local_callable_specs = split_oxc_local_callable_validators(\n        recipe_core\n    )\n    builder = DataDesignerConfigBuilder.from_config({\"data_designer\": recipe_core})\n    register_oxc_local_callable_validators(\n        builder = builder,\n        specs = oxc_local_callable_specs,\n    )\n\n    # DataDesignerConfigBuilder.from_config currently skips processors.\n    # Re-attach explicitly so drop_columns/schema_transform survive API payload.\n    for processor in recipe_core.get(\"processors\") or []:\n        if not isinstance(processor, dict):\n            continue\n        processor_type_raw = processor.get(\"processor_type\")\n        if not isinstance(processor_type_raw, str):\n            continue\n        kwargs = {k: v for k, v in processor.items() if k != \"processor_type\"}\n        builder.add_processor(\n            processor_type = ProcessorType(processor_type_raw),\n            **kwargs,\n        )\n\n    return builder\n\n\ndef create_data_designer(\n    recipe: dict[str, Any],\n    *,\n    artifact_path: str | None = None,\n):\n    _apply_data_designer_image_context_patch()\n    from data_designer.interface.data_designer import DataDesigner\n\n    model_providers = build_model_providers(recipe)\n    _validate_recipe_runtime_support(recipe, model_providers)\n\n    return DataDesigner(\n        artifact_path = artifact_path,\n        model_providers = model_providers,\n        mcp_providers = build_mcp_providers(recipe),\n    )\n\n\ndef validate_recipe(recipe: dict[str, Any]) -> None:\n    builder = build_config_builder(recipe)\n    designer = create_data_designer(recipe)\n    designer.validate(builder)\n\n\ndef preview_recipe(\n    recipe: dict[str, Any],\n    num_records: int,\n) -> tuple[list[dict[str, Any]], dict[str, Any] | None, dict[str, Any] | None]:\n    builder = build_config_builder(recipe)\n    designer = create_data_designer(recipe)\n    results = designer.preview(builder, num_records = num_records)\n\n    dataset: list[dict[str, Any]] = []\n    if results.dataset is not None:\n        raw_rows = results.dataset.to_dict(orient = \"records\")\n        dataset = [to_jsonable(row) for row in raw_rows]\n\n    artifacts = (\n        None\n        if results.processor_artifacts is None\n        else to_jsonable(results.processor_artifacts)\n    )\n    analysis = (\n        None\n        if results.analysis is None\n        else to_jsonable(results.analysis.model_dump(mode = \"json\"))\n    )\n\n    return dataset, artifacts, analysis\n"
  },
  {
    "path": "studio/backend/core/export/__init__.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nExport submodule - Model export operations\n\nThe default get_export_backend() returns an ExportOrchestrator that\ndelegates to a subprocess. The original ExportBackend runs inside\nthe subprocess and can be imported directly from .export when needed.\n\"\"\"\n\nfrom .orchestrator import ExportOrchestrator, get_export_backend\n\n# Expose ExportOrchestrator as ExportBackend for backward compat\nExportBackend = ExportOrchestrator\n\n__all__ = [\n    \"ExportBackend\",\n    \"ExportOrchestrator\",\n    \"get_export_backend\",\n]\n"
  },
  {
    "path": "studio/backend/core/export/export.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n# backend/export.py\n\"\"\"\nExport backend - handles model exporting in various formats\n\"\"\"\n\nimport glob\nimport json\nimport structlog\nfrom loggers import get_logger\nimport os\nimport shutil\nfrom pathlib import Path\nfrom typing import Optional, Tuple, List\nfrom peft import PeftModel, PeftModelForCausalLM\nfrom unsloth import FastLanguageModel, FastVisionModel\nfrom huggingface_hub import HfApi, ModelCard\nfrom transformers.modeling_utils import PushToHubMixin\nimport torch\nfrom utils.hardware import clear_gpu_cache\n\nfrom utils.models import is_vision_model, get_base_model_from_lora\nfrom utils.models.model_config import detect_audio_type\nfrom utils.paths import ensure_dir, outputs_root, resolve_export_dir, resolve_output_dir\nfrom core.inference import get_inference_backend\n\nlogger = get_logger(__name__)\n\n\ndef _is_wsl():\n    \"\"\"Detect if running under Windows Subsystem for Linux.\"\"\"\n    try:\n        return \"microsoft\" in open(\"/proc/version\").read().lower()\n    except Exception:\n        return False\n\n\ndef _apply_wsl_sudo_patch():\n    \"\"\"On WSL, monkey-patch do_we_need_sudo() to return False.\n\n    WSL doesn't have passwordless sudo, and do_we_need_sudo() runs\n    `sudo apt-get update` which hangs waiting for a stdin password\n    inside a non-interactive subprocess. setup.sh pre-installs the\n    build dependencies on WSL, so sudo is not needed at runtime.\n    \"\"\"\n    if not _is_wsl():\n        return\n\n    try:\n        import unsloth_zoo.llama_cpp as llama_cpp_module\n\n        def _wsl_do_we_need_sudo(system_type = \"debian\"):\n            logger.info(\n                \"WSL detected — skipping sudo check \"\n                \"(build deps pre-installed by setup.sh)\"\n            )\n            return False\n\n        llama_cpp_module.do_we_need_sudo = _wsl_do_we_need_sudo\n        logger.info(\n            \"Applied WSL sudo patch to \" \"unsloth_zoo.llama_cpp.do_we_need_sudo\"\n        )\n    except Exception as e:\n        logger.warning(f\"Could not apply WSL sudo patch: {e}\")\n\n\n# Model card template\nMODEL_CARD = \"\"\"---\nbase_model: {base_model}\ntags:\n- text-generation-inference\n- transformers\n- unsloth\n- {model_type}\n- {extra}\nlicense: apache-2.0\nlanguage:\n- en\n---\n\n# Uploaded finetuned {method} model\n\n- **Developed by:** {username}\n- **License:** apache-2.0\n- **Finetuned from model :** {base_model}\n\nThis {model_type} model was trained 2x faster with [Unsloth](https://github.com/unslothai/unsloth) and Huggingface's TRL library.\n\n[<img src=\"https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20made%20with%20love.png\" width=\"200\"/>](https://github.com/unslothai/unsloth)\n\"\"\"\n\n\nclass ExportBackend:\n    \"\"\"Handles model export operations\"\"\"\n\n    def __init__(self):\n        self.inference_backend = get_inference_backend()\n        self.current_checkpoint = None\n        self.current_model = None\n        self.current_tokenizer = None\n        self.is_vision = False\n        self.is_peft = False\n        self._audio_type = None\n\n    def cleanup_memory(self):\n        \"\"\"Offload and delete all models from memory\"\"\"\n        try:\n            logger.info(\"Starting memory cleanup...\")\n\n            # Unload all models from inference backend\n            model_names = list(self.inference_backend.models.keys())\n            for model_name in model_names:\n                self.inference_backend.unload_model(model_name)\n\n            # Clear current export state\n            self.current_model = None\n            self.current_tokenizer = None\n            self.current_checkpoint = None\n            self._audio_type = None\n\n            # Clear GPU memory cache (handles gc + backend-specific cleanup)\n            clear_gpu_cache()\n\n            logger.info(\"Memory cleanup completed successfully\")\n            return True\n\n        except Exception as e:\n            logger.error(f\"Error during memory cleanup: {e}\")\n            return False\n\n    def scan_checkpoints(\n        self, outputs_dir: str = str(outputs_root())\n    ) -> List[Tuple[str, List[Tuple[str, str]]]]:\n        \"\"\"\n        Scan outputs folder for training runs and their checkpoints.\n\n        Returns:\n            List of tuples: [(model_name, [(display_name, checkpoint_path), ...]), ...]\n        \"\"\"\n        from utils.models.checkpoints import scan_checkpoints\n\n        return scan_checkpoints(outputs_dir = outputs_dir)\n\n    def load_checkpoint(\n        self,\n        checkpoint_path: str,\n        max_seq_length: int = 2048,\n        load_in_4bit: bool = True,\n        trust_remote_code: bool = False,\n    ) -> Tuple[bool, str]:\n        \"\"\"\n        Load a checkpoint for export.\n\n        Returns:\n            Tuple of (success: bool, message: str)\n        \"\"\"\n        try:\n            logger.info(f\"Loading checkpoint: {checkpoint_path}\")\n\n            # First, cleanup existing models\n            self.cleanup_memory()\n\n            checkpoint_path_obj = Path(checkpoint_path)\n\n            # Determine the model identity for type detection\n            adapter_config = checkpoint_path_obj / \"adapter_config.json\"\n            base_model = None\n            if adapter_config.exists():\n                base_model = get_base_model_from_lora(checkpoint_path)\n                if not base_model:\n                    return False, \"Could not determine base model for adapter\"\n\n            model_id = base_model or checkpoint_path\n\n            # Detect audio type and vision\n            self._audio_type = detect_audio_type(model_id)\n            self.is_vision = not self._audio_type and is_vision_model(model_id)\n\n            # Load model based on type\n            if self._audio_type == \"csm\":\n                from unsloth import FastModel\n                from transformers import CsmForConditionalGeneration\n\n                logger.info(\"Loading as CSM audio model...\")\n                model, tokenizer = FastModel.from_pretrained(\n                    model_name = checkpoint_path,\n                    max_seq_length = max_seq_length,\n                    dtype = None,\n                    auto_model = CsmForConditionalGeneration,\n                    load_in_4bit = False,\n                    trust_remote_code = trust_remote_code,\n                )\n\n            elif self._audio_type == \"whisper\":\n                from unsloth import FastModel\n                from transformers import WhisperForConditionalGeneration\n\n                logger.info(\"Loading as Whisper audio model...\")\n                model, tokenizer = FastModel.from_pretrained(\n                    model_name = checkpoint_path,\n                    dtype = None,\n                    load_in_4bit = False,\n                    auto_model = WhisperForConditionalGeneration,\n                    trust_remote_code = trust_remote_code,\n                )\n\n            elif self._audio_type == \"snac\":\n                logger.info(\"Loading as SNAC (Orpheus) audio model...\")\n                model, tokenizer = FastLanguageModel.from_pretrained(\n                    model_name = checkpoint_path,\n                    max_seq_length = max_seq_length,\n                    dtype = None,\n                    load_in_4bit = load_in_4bit,\n                    trust_remote_code = trust_remote_code,\n                )\n\n            elif self._audio_type == \"bicodec\":\n                from unsloth import FastModel\n\n                logger.info(\"Loading as BiCodec (Spark-TTS) audio model...\")\n                model, tokenizer = FastModel.from_pretrained(\n                    model_name = checkpoint_path,\n                    max_seq_length = max_seq_length,\n                    dtype = torch.float32,\n                    load_in_4bit = False,\n                    trust_remote_code = trust_remote_code,\n                )\n\n            elif self._audio_type == \"dac\":\n                from unsloth import FastModel\n\n                logger.info(\"Loading as DAC (OuteTTS) audio model...\")\n                model, tokenizer = FastModel.from_pretrained(\n                    model_name = checkpoint_path,\n                    max_seq_length = max_seq_length,\n                    load_in_4bit = False,\n                    trust_remote_code = trust_remote_code,\n                )\n\n            elif self.is_vision:\n                logger.info(\"Loading as vision model...\")\n                model, processor = FastVisionModel.from_pretrained(\n                    model_name = checkpoint_path,\n                    max_seq_length = max_seq_length,\n                    dtype = None,\n                    load_in_4bit = load_in_4bit,\n                    trust_remote_code = trust_remote_code,\n                )\n                tokenizer = processor  # For vision models, processor acts as tokenizer\n\n            else:\n                logger.info(\"Loading as text model...\")\n                model, tokenizer = FastLanguageModel.from_pretrained(\n                    model_name = checkpoint_path,\n                    max_seq_length = max_seq_length,\n                    dtype = None,\n                    load_in_4bit = load_in_4bit,\n                    trust_remote_code = trust_remote_code,\n                )\n\n            # Check if PEFT model\n            self.is_peft = isinstance(model, (PeftModel, PeftModelForCausalLM))\n\n            # Store loaded model\n            self.current_model = model\n            self.current_tokenizer = tokenizer\n            self.current_checkpoint = checkpoint_path\n\n            if self._audio_type:\n                model_type = f\"Audio ({self._audio_type})\"\n            elif self.is_vision:\n                model_type = \"Vision\"\n            else:\n                model_type = \"Text\"\n            peft_info = \" (PEFT Adapter)\" if self.is_peft else \" (Merged Model)\"\n\n            logger.info(f\"Successfully loaded {model_type} model{peft_info}\")\n            return True, f\"Loaded {model_type} model{peft_info} successfully\"\n\n        except Exception as e:\n            logger.error(f\"Error loading checkpoint: {e}\")\n            import traceback\n\n            logger.error(traceback.format_exc())\n            return False, f\"Failed to load checkpoint: {str(e)}\"\n\n    def _write_export_metadata(self, save_directory: str):\n        \"\"\"Write export_metadata.json with base model info for Chat page discovery.\"\"\"\n        try:\n            base_model = (\n                get_base_model_from_lora(self.current_checkpoint)\n                if self.current_checkpoint\n                else None\n            )\n            metadata = {\"base_model\": base_model}\n            metadata_path = os.path.join(save_directory, \"export_metadata.json\")\n            with open(metadata_path, \"w\") as f:\n                json.dump(metadata, f, indent = 2)\n            logger.info(f\"Wrote export metadata to {metadata_path}\")\n        except Exception as e:\n            logger.warning(f\"Could not write export metadata: {e}\")\n\n    def export_merged_model(\n        self,\n        save_directory: str,\n        format_type: str = \"16-bit (FP16)\",\n        push_to_hub: bool = False,\n        repo_id: Optional[str] = None,\n        hf_token: Optional[str] = None,\n        private: bool = False,\n    ) -> Tuple[bool, str]:\n        \"\"\"\n        Export merged model (for PEFT models).\n\n        Args:\n            save_directory: Local directory to save model\n            format_type: \"16-bit (FP16)\" or \"4-bit (FP4)\"\n            push_to_hub: Whether to push to Hugging Face Hub\n            repo_id: Hub repository ID (username/model-name)\n            hf_token: Hugging Face token\n            private: Whether to make the repo private\n\n        Returns:\n            Tuple of (success: bool, message: str)\n        \"\"\"\n        if not self.current_model or not self.current_tokenizer:\n            return False, \"No model loaded. Please select a checkpoint first.\"\n\n        if not self.is_peft:\n            return False, \"This is not a PEFT model. Use 'Export Base Model' instead.\"\n\n        try:\n            # Determine save method\n            if format_type == \"4-bit (FP4)\":\n                save_method = \"merged_4bit_forced\"\n            elif self._audio_type == \"whisper\":\n                # Whisper uses save_method=None for local 16-bit merged save\n                save_method = None\n            else:  # 16-bit (FP16)\n                save_method = \"merged_16bit\"\n\n            # Save locally if requested\n            if save_directory:\n                save_directory = str(resolve_export_dir(save_directory))\n                logger.info(f\"Saving merged model locally to: {save_directory}\")\n                ensure_dir(Path(save_directory))\n\n                self.current_model.save_pretrained_merged(\n                    save_directory, self.current_tokenizer, save_method = save_method\n                )\n\n                # Write export metadata so the Chat page can identify the base model\n                self._write_export_metadata(save_directory)\n                logger.info(f\"Model saved successfully to {save_directory}\")\n\n            # Push to hub if requested\n            if push_to_hub:\n                if not repo_id or not hf_token:\n                    return (\n                        False,\n                        \"Repository ID and Hugging Face token required for Hub upload\",\n                    )\n\n                logger.info(f\"Pushing merged model to Hub: {repo_id}\")\n\n                # Whisper uses save_method=None for local but \"merged_16bit\" for hub push\n                hub_save_method = (\n                    save_method if save_method is not None else \"merged_16bit\"\n                )\n                self.current_model.push_to_hub_merged(\n                    repo_id,\n                    self.current_tokenizer,\n                    save_method = hub_save_method,\n                    token = hf_token,\n                    private = private,\n                )\n                logger.info(f\"Model pushed successfully to {repo_id}\")\n\n            return True, \"Model exported successfully\"\n\n        except Exception as e:\n            logger.error(f\"Error exporting merged model: {e}\")\n            import traceback\n\n            logger.error(traceback.format_exc())\n            return False, f\"Export failed: {str(e)}\"\n\n    def export_base_model(\n        self,\n        save_directory: str,\n        push_to_hub: bool = False,\n        repo_id: Optional[str] = None,\n        hf_token: Optional[str] = None,\n        private: bool = False,\n        base_model_id: Optional[str] = None,\n    ) -> Tuple[bool, str]:\n        \"\"\"\n        Export base model (for non-PEFT models).\n\n        Returns:\n            Tuple of (success: bool, message: str)\n        \"\"\"\n        if not self.current_model or not self.current_tokenizer:\n            return False, \"No model loaded. Please select a checkpoint first.\"\n\n        if self.is_peft:\n            return (\n                False,\n                \"This is a PEFT model. Use 'Merged Model' export type instead.\",\n            )\n\n        try:\n            # Save locally if requested\n            if save_directory:\n                save_directory = str(resolve_export_dir(save_directory))\n                logger.info(f\"Saving base model locally to: {save_directory}\")\n                ensure_dir(Path(save_directory))\n\n                self.current_model.save_pretrained(save_directory)\n                self.current_tokenizer.save_pretrained(save_directory)\n\n                # Write export metadata so the Chat page can identify the base model\n                self._write_export_metadata(save_directory)\n                logger.info(f\"Model saved successfully to {save_directory}\")\n\n            # Push to hub if requested\n            if push_to_hub:\n                if not repo_id or not hf_token:\n                    return (\n                        False,\n                        \"Repository ID and Hugging Face token required for Hub upload\",\n                    )\n\n                logger.info(f\"Pushing base model to Hub: {repo_id}\")\n\n                # Get base model name from request or model config\n                base_model = (\n                    base_model_id\n                    or self.current_model.config._name_or_path\n                    or \"unknown\"\n                )\n\n                # Create repo\n                hf_api = HfApi(token = hf_token)\n                repo_id = PushToHubMixin._create_repo(\n                    PushToHubMixin,\n                    repo_id = repo_id,\n                    private = private,\n                    token = hf_token,\n                )\n                username = repo_id.split(\"/\")[0]\n\n                # Create and push model card\n                content = MODEL_CARD.format(\n                    username = username,\n                    base_model = base_model,\n                    model_type = self.current_model.config.model_type,\n                    method = \"\",\n                    extra = \"unsloth\",\n                )\n                card = ModelCard(content)\n                card.push_to_hub(\n                    repo_id, token = hf_token, commit_message = \"Unsloth Model Card\"\n                )\n\n                # Upload model files\n                if save_directory:\n                    hf_api.upload_folder(\n                        folder_path = save_directory, repo_id = repo_id, repo_type = \"model\"\n                    )\n                    logger.info(f\"Model pushed successfully to {repo_id}\")\n                else:\n                    return False, \"Local save directory required for Hub upload\"\n\n            return True, \"Model exported successfully\"\n\n        except Exception as e:\n            logger.error(f\"Error exporting base model: {e}\")\n            import traceback\n\n            logger.error(traceback.format_exc())\n            return False, f\"Export failed: {str(e)}\"\n\n    def export_gguf(\n        self,\n        save_directory: str,\n        quantization_method: str = \"Q4_K_M\",\n        push_to_hub: bool = False,\n        repo_id: Optional[str] = None,\n        hf_token: Optional[str] = None,\n    ) -> Tuple[bool, str]:\n        \"\"\"\n        Export model in GGUF format.\n\n        Args:\n            save_directory: Local directory to save model\n            quantization_method: GGUF quantization method (e.g., \"Q4_K_M\")\n            push_to_hub: Whether to push to Hugging Face Hub\n            repo_id: Hub repository ID\n            hf_token: Hugging Face token\n\n        Returns:\n            Tuple of (success: bool, message: str)\n        \"\"\"\n        if not self.current_model or not self.current_tokenizer:\n            return False, \"No model loaded. Please select a checkpoint first.\"\n\n        try:\n            # Convert quantization method to lowercase for unsloth\n            quant_method = quantization_method.lower()\n\n            # Save locally if requested\n            if save_directory:\n                save_directory = str(resolve_export_dir(save_directory))\n                # Resolve to absolute path so unsloth's relative-path internals\n                # (check_llama_cpp, use_local_gguf, _download_convert_hf_to_gguf)\n                # all resolve against the repo root cwd, NOT the export directory.\n                abs_save_dir = os.path.abspath(save_directory)\n                logger.info(f\"Saving GGUF model locally to: {abs_save_dir}\")\n\n                # Create the directory if it doesn't exist\n                ensure_dir(Path(abs_save_dir))\n\n                # On WSL, patch out sudo check before llama.cpp build\n                _apply_wsl_sudo_patch()\n\n                # Snapshot existing .gguf files in cwd before conversion.\n                # unsloth's convert_to_gguf writes output files relative to\n                # cwd (repo root), so we diff afterwards and relocate them.\n                cwd = os.getcwd()\n                pre_existing_ggufs = set(glob.glob(os.path.join(cwd, \"*.gguf\")))\n\n                # Pass absolute path — no os.chdir needed.\n                # unsloth saves intermediate HF model files into model_save_path.\n                # unsloth-zoo's check_llama_cpp() uses ~/.unsloth/llama.cpp by default.\n                model_save_path = os.path.join(abs_save_dir, \"model\")\n                self.current_model.save_pretrained_gguf(\n                    model_save_path,\n                    self.current_tokenizer,\n                    quantization_method = quant_method,\n                )\n\n                # Relocate GGUF artifacts into the export directory.\n                # convert_to_gguf writes .gguf files to cwd (repo root)\n                # because --outfile is a relative path like \"model.Q4_K_M.gguf\".\n                new_ggufs = (\n                    set(glob.glob(os.path.join(cwd, \"*.gguf\"))) - pre_existing_ggufs\n                )\n                for src in sorted(new_ggufs):\n                    dest = os.path.join(abs_save_dir, os.path.basename(src))\n                    shutil.move(src, dest)\n                    logger.info(\n                        f\"Relocated GGUF: {os.path.basename(src)} → {abs_save_dir}/\"\n                    )\n\n                # Flatten any .gguf files from subdirectories into abs_save_dir.\n                # save_pretrained_gguf may create subdirs (e.g. model_gguf/)\n                # with a name different from model_save_path.\n                for sub in list(Path(abs_save_dir).iterdir()):\n                    if not sub.is_dir():\n                        continue\n                    for src in sub.glob(\"*.gguf\"):\n                        dest = os.path.join(abs_save_dir, src.name)\n                        shutil.move(str(src), dest)\n                        logger.info(f\"Relocated GGUF: {src.name} → {abs_save_dir}/\")\n                    # Clean up the subdirectory (intermediate HF files, etc.)\n                    shutil.rmtree(str(sub), ignore_errors = True)\n                    logger.info(f\"Cleaned up subdirectory: {sub.name}\")\n\n                # Write export metadata so the Chat page can identify the base model\n                self._write_export_metadata(abs_save_dir)\n\n                # Log final file locations (after relocation) so it's clear\n                # where the GGUF files actually ended up.\n                final_ggufs = sorted(glob.glob(os.path.join(abs_save_dir, \"*.gguf\")))\n                logger.info(\n                    \"GGUF export complete. Final files in %s:\\n  %s\",\n                    abs_save_dir,\n                    \"\\n  \".join(os.path.basename(f) for f in final_ggufs) or \"(none)\",\n                )\n\n            # Push to hub if requested\n            if push_to_hub:\n                if not repo_id or not hf_token:\n                    return (\n                        False,\n                        \"Repository ID and Hugging Face token required for Hub upload\",\n                    )\n\n                logger.info(f\"Pushing GGUF model to Hub: {repo_id}\")\n\n                self.current_model.push_to_hub_gguf(\n                    repo_id,\n                    self.current_tokenizer,\n                    quantization_method = quant_method,\n                    token = hf_token,\n                )\n                logger.info(f\"GGUF model pushed successfully to {repo_id}\")\n\n            return True, f\"GGUF model exported successfully ({quantization_method})\"\n\n        except Exception as e:\n            logger.error(f\"Error exporting GGUF model: {e}\")\n            import traceback\n\n            logger.error(traceback.format_exc())\n            return False, f\"GGUF export failed: {str(e)}\"\n\n    def export_lora_adapter(\n        self,\n        save_directory: str,\n        push_to_hub: bool = False,\n        repo_id: Optional[str] = None,\n        hf_token: Optional[str] = None,\n        private: bool = False,\n    ) -> Tuple[bool, str]:\n        \"\"\"\n        Export LoRA adapter only (not merged).\n\n        Returns:\n            Tuple of (success: bool, message: str)\n        \"\"\"\n        if not self.current_model or not self.current_tokenizer:\n            return False, \"No model loaded. Please select a checkpoint first.\"\n\n        if not self.is_peft:\n            return False, \"This is not a PEFT model. No adapter to export.\"\n\n        try:\n            # Save locally if requested\n            if save_directory:\n                save_directory = str(resolve_export_dir(save_directory))\n                logger.info(f\"Saving LoRA adapter locally to: {save_directory}\")\n                ensure_dir(Path(save_directory))\n\n                self.current_model.save_pretrained(save_directory)\n                self.current_tokenizer.save_pretrained(save_directory)\n                logger.info(f\"Adapter saved successfully to {save_directory}\")\n\n            # Push to hub if requested\n            if push_to_hub:\n                if not repo_id or not hf_token:\n                    return (\n                        False,\n                        \"Repository ID and Hugging Face token required for Hub upload\",\n                    )\n\n                logger.info(f\"Pushing LoRA adapter to Hub: {repo_id}\")\n\n                self.current_model.push_to_hub(repo_id, token = hf_token, private = private)\n                self.current_tokenizer.push_to_hub(\n                    repo_id, token = hf_token, private = private\n                )\n                logger.info(f\"Adapter pushed successfully to {repo_id}\")\n\n            return True, \"LoRA adapter exported successfully\"\n\n        except Exception as e:\n            logger.error(f\"Error exporting LoRA adapter: {e}\")\n            import traceback\n\n            logger.error(traceback.format_exc())\n            return False, f\"Adapter export failed: {str(e)}\"\n\n\n# Global export backend instance\n_export_backend = None\n\n\ndef get_export_backend() -> ExportBackend:\n    \"\"\"Get or create the global export backend instance\"\"\"\n    global _export_backend\n    if _export_backend is None:\n        _export_backend = ExportBackend()\n    return _export_backend\n"
  },
  {
    "path": "studio/backend/core/export/orchestrator.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nExport orchestrator — subprocess-based.\n\nProvides the same API as ExportBackend, but delegates all ML work\nto a persistent subprocess. The subprocess is spawned on first checkpoint\nload and stays alive for subsequent export operations.\n\nWhen switching between checkpoints that need different transformers versions,\nthe old subprocess is killed and a new one is spawned with the correct version.\n\nPattern follows core/inference/orchestrator.py.\n\"\"\"\n\nimport atexit\nimport structlog\nfrom loggers import get_logger\nimport multiprocessing as mp\nimport queue\nimport threading\nimport time\nfrom pathlib import Path\nfrom typing import Any, List, Optional, Tuple\nfrom utils.paths import outputs_root\n\nlogger = get_logger(__name__)\n\n_CTX = mp.get_context(\"spawn\")\n\n\nclass ExportOrchestrator:\n    \"\"\"\n    Export backend orchestrator — subprocess-based.\n\n    Exposes the same API surface as ExportBackend so routes/export.py\n    needs minimal changes. Internally, all heavy ML operations happen in\n    a persistent subprocess.\n    \"\"\"\n\n    def __init__(self):\n        # Subprocess state\n        self._proc: Optional[mp.Process] = None\n        self._cmd_queue: Any = None\n        self._resp_queue: Any = None\n        self._lock = threading.Lock()\n\n        # Local state mirrors (updated from subprocess responses)\n        self.current_checkpoint: Optional[str] = None\n        self.is_vision: bool = False\n        self.is_peft: bool = False\n\n        atexit.register(self._cleanup)\n        logger.info(\"ExportOrchestrator initialized (subprocess mode)\")\n\n    # ------------------------------------------------------------------\n    # Subprocess lifecycle\n    # ------------------------------------------------------------------\n\n    def _spawn_subprocess(self, config: dict) -> None:\n        \"\"\"Spawn a new export subprocess.\"\"\"\n        from .worker import run_export_process\n\n        self._cmd_queue = _CTX.Queue()\n        self._resp_queue = _CTX.Queue()\n\n        self._proc = _CTX.Process(\n            target = run_export_process,\n            kwargs = {\n                \"cmd_queue\": self._cmd_queue,\n                \"resp_queue\": self._resp_queue,\n                \"config\": config,\n            },\n            daemon = True,\n        )\n        self._proc.start()\n        logger.info(\"Export subprocess started (pid=%s)\", self._proc.pid)\n\n    def _shutdown_subprocess(self, timeout: float = 10.0) -> None:\n        \"\"\"Gracefully shut down the export subprocess.\"\"\"\n        if self._proc is None or not self._proc.is_alive():\n            self._proc = None\n            return\n\n        # 1. Drain stale responses\n        self._drain_queue()\n\n        # 2. Send shutdown command\n        try:\n            self._cmd_queue.put({\"type\": \"shutdown\"})\n        except (OSError, ValueError):\n            pass\n\n        # 3. Wait for graceful shutdown\n        try:\n            self._proc.join(timeout = timeout)\n        except Exception:\n            pass\n\n        # 4. Force kill if still alive\n        if self._proc is not None and self._proc.is_alive():\n            logger.warning(\"Export subprocess did not exit gracefully, terminating\")\n            try:\n                self._proc.terminate()\n                self._proc.join(timeout = 5)\n            except Exception:\n                pass\n            if self._proc is not None and self._proc.is_alive():\n                logger.warning(\"Subprocess still alive after terminate, killing\")\n                try:\n                    self._proc.kill()\n                    self._proc.join(timeout = 3)\n                except Exception:\n                    pass\n\n        self._proc = None\n        self._cmd_queue = None\n        self._resp_queue = None\n        logger.info(\"Export subprocess shut down\")\n\n    def _cleanup(self):\n        \"\"\"atexit handler.\"\"\"\n        self._shutdown_subprocess(timeout = 5.0)\n\n    def _ensure_subprocess_alive(self) -> bool:\n        \"\"\"Check if subprocess is alive.\"\"\"\n        return self._proc is not None and self._proc.is_alive()\n\n    # ------------------------------------------------------------------\n    # Queue helpers\n    # ------------------------------------------------------------------\n\n    def _send_cmd(self, cmd: dict) -> None:\n        \"\"\"Send a command to the subprocess.\"\"\"\n        if self._cmd_queue is None:\n            raise RuntimeError(\"No export subprocess running\")\n        try:\n            self._cmd_queue.put(cmd)\n        except (OSError, ValueError) as exc:\n            raise RuntimeError(f\"Failed to send command to subprocess: {exc}\")\n\n    def _read_resp(self, timeout: float = 1.0) -> Optional[dict]:\n        \"\"\"Read a response from the subprocess (non-blocking with timeout).\"\"\"\n        if self._resp_queue is None:\n            return None\n        try:\n            return self._resp_queue.get(timeout = timeout)\n        except queue.Empty:\n            return None\n        except (EOFError, OSError, ValueError):\n            return None\n\n    def _wait_response(self, expected_type: str, timeout: float = 3600.0) -> dict:\n        \"\"\"Block until a response of the expected type arrives.\n\n        Export operations can take a very long time — GGUF conversion for\n        large models (30B+) easily takes 20-30 minutes. Default timeout\n        is 1 hour.\n        \"\"\"\n        deadline = time.monotonic() + timeout\n\n        while time.monotonic() < deadline:\n            remaining = max(0.1, deadline - time.monotonic())\n            resp = self._read_resp(timeout = min(remaining, 2.0))\n\n            if resp is None:\n                # Check subprocess health\n                if not self._ensure_subprocess_alive():\n                    raise RuntimeError(\"Export subprocess crashed during wait\")\n                continue\n\n            rtype = resp.get(\"type\", \"\")\n\n            if rtype == expected_type:\n                return resp\n\n            if rtype == \"error\":\n                error_msg = resp.get(\"error\", \"Unknown error\")\n                raise RuntimeError(f\"Subprocess error: {error_msg}\")\n\n            if rtype == \"status\":\n                logger.info(\"Export subprocess status: %s\", resp.get(\"message\", \"\"))\n                continue\n\n            # Other response types during wait — skip\n            logger.debug(\n                \"Skipping response type '%s' while waiting for '%s'\",\n                rtype,\n                expected_type,\n            )\n\n        raise RuntimeError(\n            f\"Timeout waiting for '{expected_type}' response after {timeout}s\"\n        )\n\n    def _drain_queue(self) -> list:\n        \"\"\"Drain all pending responses.\"\"\"\n        events = []\n        if self._resp_queue is None:\n            return events\n        while True:\n            try:\n                events.append(self._resp_queue.get_nowait())\n            except queue.Empty:\n                return events\n            except (EOFError, OSError, ValueError):\n                return events\n\n    # ------------------------------------------------------------------\n    # Public API — same interface as ExportBackend\n    # ------------------------------------------------------------------\n\n    def load_checkpoint(\n        self,\n        checkpoint_path: str,\n        max_seq_length: int = 2048,\n        load_in_4bit: bool = True,\n        trust_remote_code: bool = False,\n    ) -> Tuple[bool, str]:\n        \"\"\"Load a checkpoint for export.\n\n        Always spawns a fresh subprocess to ensure a clean Python interpreter.\n        \"\"\"\n        sub_config = {\n            \"checkpoint_path\": checkpoint_path,\n            \"max_seq_length\": max_seq_length,\n            \"load_in_4bit\": load_in_4bit,\n            \"trust_remote_code\": trust_remote_code,\n        }\n\n        # Always kill existing subprocess and spawn fresh.\n        if self._ensure_subprocess_alive():\n            self._shutdown_subprocess()\n        elif self._proc is not None:\n            self._shutdown_subprocess(timeout = 2)\n\n        logger.info(\"Spawning fresh export subprocess for '%s'\", checkpoint_path)\n        self._spawn_subprocess(sub_config)\n\n        try:\n            resp = self._wait_response(\"loaded\", timeout = 300)\n        except RuntimeError as exc:\n            self._shutdown_subprocess(timeout = 5)\n            self.current_checkpoint = None\n            self.is_vision = False\n            self.is_peft = False\n            return False, str(exc)\n\n        if resp.get(\"success\"):\n            self.current_checkpoint = resp.get(\"checkpoint\")\n            self.is_vision = resp.get(\"is_vision\", False)\n            self.is_peft = resp.get(\"is_peft\", False)\n            logger.info(\"Checkpoint '%s' loaded in subprocess\", checkpoint_path)\n            return True, resp.get(\"message\", \"Loaded successfully\")\n        else:\n            error = resp.get(\"message\", \"Failed to load checkpoint\")\n            logger.error(\"Failed to load checkpoint: %s\", error)\n            self.current_checkpoint = None\n            self.is_vision = False\n            self.is_peft = False\n            return False, error\n\n    def export_merged_model(\n        self,\n        save_directory: str,\n        format_type: str = \"16-bit (FP16)\",\n        push_to_hub: bool = False,\n        repo_id: Optional[str] = None,\n        hf_token: Optional[str] = None,\n        private: bool = False,\n    ) -> Tuple[bool, str]:\n        \"\"\"Export merged PEFT model.\"\"\"\n        return self._run_export(\n            \"merged\",\n            {\n                \"save_directory\": save_directory,\n                \"format_type\": format_type,\n                \"push_to_hub\": push_to_hub,\n                \"repo_id\": repo_id,\n                \"hf_token\": hf_token,\n                \"private\": private,\n            },\n        )\n\n    def export_base_model(\n        self,\n        save_directory: str,\n        push_to_hub: bool = False,\n        repo_id: Optional[str] = None,\n        hf_token: Optional[str] = None,\n        private: bool = False,\n        base_model_id: Optional[str] = None,\n    ) -> Tuple[bool, str]:\n        \"\"\"Export base model (non-PEFT).\"\"\"\n        return self._run_export(\n            \"base\",\n            {\n                \"save_directory\": save_directory,\n                \"push_to_hub\": push_to_hub,\n                \"repo_id\": repo_id,\n                \"hf_token\": hf_token,\n                \"private\": private,\n                \"base_model_id\": base_model_id,\n            },\n        )\n\n    def export_gguf(\n        self,\n        save_directory: str,\n        quantization_method: str = \"Q4_K_M\",\n        push_to_hub: bool = False,\n        repo_id: Optional[str] = None,\n        hf_token: Optional[str] = None,\n    ) -> Tuple[bool, str]:\n        \"\"\"Export model in GGUF format.\"\"\"\n        return self._run_export(\n            \"gguf\",\n            {\n                \"save_directory\": save_directory,\n                \"quantization_method\": quantization_method,\n                \"push_to_hub\": push_to_hub,\n                \"repo_id\": repo_id,\n                \"hf_token\": hf_token,\n            },\n        )\n\n    def export_lora_adapter(\n        self,\n        save_directory: str,\n        push_to_hub: bool = False,\n        repo_id: Optional[str] = None,\n        hf_token: Optional[str] = None,\n        private: bool = False,\n    ) -> Tuple[bool, str]:\n        \"\"\"Export LoRA adapter only.\"\"\"\n        return self._run_export(\n            \"lora\",\n            {\n                \"save_directory\": save_directory,\n                \"push_to_hub\": push_to_hub,\n                \"repo_id\": repo_id,\n                \"hf_token\": hf_token,\n                \"private\": private,\n            },\n        )\n\n    def _run_export(self, export_type: str, params: dict) -> Tuple[bool, str]:\n        \"\"\"Send an export command to the subprocess and wait for result.\"\"\"\n        if not self._ensure_subprocess_alive():\n            return False, \"No export subprocess running. Load a checkpoint first.\"\n\n        cmd = {\"type\": \"export\", \"export_type\": export_type, **params}\n\n        try:\n            self._send_cmd(cmd)\n            resp = self._wait_response(\n                f\"export_{export_type}_done\",\n                timeout = 3600,  # GGUF for 30B+ models can take 30+ min\n            )\n            return resp.get(\"success\", False), resp.get(\"message\", \"\")\n        except RuntimeError as exc:\n            return False, str(exc)\n\n    def cleanup_memory(self) -> bool:\n        \"\"\"Cleanup export-related models from memory.\"\"\"\n        if not self._ensure_subprocess_alive():\n            # No subprocess — just clear local state\n            self.current_checkpoint = None\n            self.is_vision = False\n            self.is_peft = False\n            return True\n\n        try:\n            self._send_cmd({\"type\": \"cleanup\"})\n            resp = self._wait_response(\"cleanup_done\", timeout = 30)\n            success = resp.get(\"success\", False)\n        except RuntimeError:\n            success = False\n\n        # Shut down subprocess after cleanup — no model loaded\n        self._shutdown_subprocess()\n\n        self.current_checkpoint = None\n        self.is_vision = False\n        self.is_peft = False\n        return success\n\n    def scan_checkpoints(\n        self, outputs_dir: str = str(outputs_root())\n    ) -> List[Tuple[str, list]]:\n        \"\"\"Scan for checkpoints — no ML imports needed, runs locally.\"\"\"\n        from utils.models.checkpoints import scan_checkpoints\n\n        return scan_checkpoints(outputs_dir = outputs_dir)\n\n\n# ========== GLOBAL INSTANCE ==========\n_export_backend = None\n\n\ndef get_export_backend() -> ExportOrchestrator:\n    \"\"\"Get global export backend instance (orchestrator).\"\"\"\n    global _export_backend\n    if _export_backend is None:\n        _export_backend = ExportOrchestrator()\n    return _export_backend\n"
  },
  {
    "path": "studio/backend/core/export/worker.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nExport subprocess entry point.\n\nEach export session runs in a persistent subprocess (mp.get_context(\"spawn\")).\nThis gives us a clean Python interpreter with no stale module state —\nsolving the transformers version-switching problem completely.\n\nThe subprocess stays alive while a model is loaded, accepting commands\n(load, export_merged, export_base, export_gguf, export_lora, cleanup,\nshutdown) via mp.Queue.\n\nPattern follows core/inference/worker.py and core/training/worker.py.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport structlog\nfrom loggers import get_logger\nimport os\nimport sys\nimport time\nimport traceback\nfrom pathlib import Path\nfrom typing import Any\n\nlogger = get_logger(__name__)\n\n\ndef _activate_transformers_version(model_name: str) -> None:\n    \"\"\"Activate the correct transformers version BEFORE any ML imports.\n\n    If the model needs transformers 5.x, prepend the pre-installed .venv_t5/\n    directory to sys.path. Otherwise do nothing (default 4.57.x in .venv/).\n    \"\"\"\n    # Ensure backend is on path for utils imports\n    backend_path = str(Path(__file__).resolve().parent.parent.parent)\n    if backend_path not in sys.path:\n        sys.path.insert(0, backend_path)\n\n    from utils.transformers_version import (\n        needs_transformers_5,\n        _resolve_base_model,\n        _ensure_venv_t5_exists,\n        _VENV_T5_DIR,\n    )\n\n    resolved = _resolve_base_model(model_name)\n    if needs_transformers_5(resolved):\n        if not _ensure_venv_t5_exists():\n            raise RuntimeError(\n                f\"Cannot activate transformers 5.x: .venv_t5 missing at {_VENV_T5_DIR}\"\n            )\n        if _VENV_T5_DIR not in sys.path:\n            sys.path.insert(0, _VENV_T5_DIR)\n        logger.info(\"Activated transformers 5.x from %s\", _VENV_T5_DIR)\n        # Propagate to child subprocesses (e.g. GGUF converter)\n        _pp = os.environ.get(\"PYTHONPATH\", \"\")\n        os.environ[\"PYTHONPATH\"] = _VENV_T5_DIR + (os.pathsep + _pp if _pp else \"\")\n    else:\n        logger.info(\"Using default transformers (4.57.x) for %s\", model_name)\n\n\ndef _send_response(resp_queue: Any, response: dict) -> None:\n    \"\"\"Send a response to the parent process.\"\"\"\n    try:\n        resp_queue.put(response)\n    except (OSError, ValueError) as exc:\n        logger.error(\"Failed to send response: %s\", exc)\n\n\ndef _handle_load(backend, cmd: dict, resp_queue: Any) -> None:\n    \"\"\"Handle a load_checkpoint command.\"\"\"\n    checkpoint_path = cmd[\"checkpoint_path\"]\n    max_seq_length = cmd.get(\"max_seq_length\", 2048)\n    load_in_4bit = cmd.get(\"load_in_4bit\", True)\n    trust_remote_code = cmd.get(\"trust_remote_code\", False)\n\n    try:\n        _send_response(\n            resp_queue,\n            {\n                \"type\": \"status\",\n                \"message\": f\"Loading checkpoint: {checkpoint_path}\",\n                \"ts\": time.time(),\n            },\n        )\n\n        success, message = backend.load_checkpoint(\n            checkpoint_path = checkpoint_path,\n            max_seq_length = max_seq_length,\n            load_in_4bit = load_in_4bit,\n            trust_remote_code = trust_remote_code,\n        )\n\n        _send_response(\n            resp_queue,\n            {\n                \"type\": \"loaded\",\n                \"success\": success,\n                \"message\": message,\n                \"checkpoint\": checkpoint_path if success else None,\n                \"is_vision\": backend.is_vision if success else False,\n                \"is_peft\": backend.is_peft if success else False,\n                \"ts\": time.time(),\n            },\n        )\n\n    except Exception as exc:\n        _send_response(\n            resp_queue,\n            {\n                \"type\": \"loaded\",\n                \"success\": False,\n                \"message\": str(exc),\n                \"stack\": traceback.format_exc(limit = 20),\n                \"ts\": time.time(),\n            },\n        )\n\n\ndef _handle_export(backend, cmd: dict, resp_queue: Any) -> None:\n    \"\"\"Handle any export command (merged, base, gguf, lora).\"\"\"\n    export_type = cmd[\"export_type\"]  # \"merged\", \"base\", \"gguf\", \"lora\"\n    response_type = f\"export_{export_type}_done\"\n\n    try:\n        if export_type == \"merged\":\n            success, message = backend.export_merged_model(\n                save_directory = cmd.get(\"save_directory\", \"\"),\n                format_type = cmd.get(\"format_type\", \"16-bit (FP16)\"),\n                push_to_hub = cmd.get(\"push_to_hub\", False),\n                repo_id = cmd.get(\"repo_id\"),\n                hf_token = cmd.get(\"hf_token\"),\n                private = cmd.get(\"private\", False),\n            )\n        elif export_type == \"base\":\n            success, message = backend.export_base_model(\n                save_directory = cmd.get(\"save_directory\", \"\"),\n                push_to_hub = cmd.get(\"push_to_hub\", False),\n                repo_id = cmd.get(\"repo_id\"),\n                hf_token = cmd.get(\"hf_token\"),\n                private = cmd.get(\"private\", False),\n                base_model_id = cmd.get(\"base_model_id\"),\n            )\n        elif export_type == \"gguf\":\n            success, message = backend.export_gguf(\n                save_directory = cmd.get(\"save_directory\", \"\"),\n                quantization_method = cmd.get(\"quantization_method\", \"Q4_K_M\"),\n                push_to_hub = cmd.get(\"push_to_hub\", False),\n                repo_id = cmd.get(\"repo_id\"),\n                hf_token = cmd.get(\"hf_token\"),\n            )\n        elif export_type == \"lora\":\n            success, message = backend.export_lora_adapter(\n                save_directory = cmd.get(\"save_directory\", \"\"),\n                push_to_hub = cmd.get(\"push_to_hub\", False),\n                repo_id = cmd.get(\"repo_id\"),\n                hf_token = cmd.get(\"hf_token\"),\n                private = cmd.get(\"private\", False),\n            )\n        else:\n            success, message = False, f\"Unknown export type: {export_type}\"\n\n        _send_response(\n            resp_queue,\n            {\n                \"type\": response_type,\n                \"success\": success,\n                \"message\": message,\n                \"ts\": time.time(),\n            },\n        )\n\n    except Exception as exc:\n        _send_response(\n            resp_queue,\n            {\n                \"type\": response_type,\n                \"success\": False,\n                \"message\": str(exc),\n                \"stack\": traceback.format_exc(limit = 20),\n                \"ts\": time.time(),\n            },\n        )\n\n\ndef _handle_cleanup(backend, resp_queue: Any) -> None:\n    \"\"\"Handle a cleanup command.\"\"\"\n    try:\n        success = backend.cleanup_memory()\n        _send_response(\n            resp_queue,\n            {\n                \"type\": \"cleanup_done\",\n                \"success\": success,\n                \"ts\": time.time(),\n            },\n        )\n    except Exception as exc:\n        _send_response(\n            resp_queue,\n            {\n                \"type\": \"cleanup_done\",\n                \"success\": False,\n                \"message\": str(exc),\n                \"ts\": time.time(),\n            },\n        )\n\n\ndef run_export_process(\n    *,\n    cmd_queue: Any,\n    resp_queue: Any,\n    config: dict,\n) -> None:\n    \"\"\"Subprocess entrypoint. Persistent — runs command loop until shutdown.\n\n    Args:\n        cmd_queue: mp.Queue for receiving commands from parent.\n        resp_queue: mp.Queue for sending responses to parent.\n        config: Initial configuration dict with checkpoint_path.\n    \"\"\"\n    import queue as _queue\n\n    os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n    os.environ[\"PYTHONWARNINGS\"] = (\n        \"ignore\"  # Suppress warnings at C-level before imports\n    )\n\n    import warnings\n    from loggers.config import LogConfig\n\n    if os.getenv(\"ENVIRONMENT_TYPE\", \"production\") == \"production\":\n        warnings.filterwarnings(\"ignore\")\n\n    LogConfig.setup_logging(\n        service_name = \"unsloth-studio-export-worker\",\n        env = os.getenv(\"ENVIRONMENT_TYPE\", \"production\"),\n    )\n\n    checkpoint_path = config[\"checkpoint_path\"]\n\n    # ── 1. Activate correct transformers version BEFORE any ML imports ──\n    try:\n        _activate_transformers_version(checkpoint_path)\n    except Exception as exc:\n        _send_response(\n            resp_queue,\n            {\n                \"type\": \"error\",\n                \"error\": f\"Failed to activate transformers version: {exc}\",\n                \"stack\": traceback.format_exc(limit = 20),\n                \"ts\": time.time(),\n            },\n        )\n        return\n\n    # ── 1b. On Windows, check Triton availability (must be before import torch) ──\n    if sys.platform == \"win32\":\n        try:\n            import triton  # noqa: F401\n\n            logger.info(\"Triton available — torch.compile enabled\")\n        except ImportError:\n            os.environ[\"TORCHDYNAMO_DISABLE\"] = \"1\"\n            logger.warning(\n                \"Triton not found on Windows — torch.compile disabled. \"\n                'Install for better performance: pip install \"triton-windows<3.7\"'\n            )\n\n    # ── 2. Import ML libraries (fresh in this clean process) ──\n    try:\n        _send_response(\n            resp_queue,\n            {\n                \"type\": \"status\",\n                \"message\": \"Importing Unsloth...\",\n                \"ts\": time.time(),\n            },\n        )\n\n        backend_path = str(Path(__file__).resolve().parent.parent.parent)\n        if backend_path not in sys.path:\n            sys.path.insert(0, backend_path)\n\n        from core.export.export import ExportBackend\n\n        import transformers\n\n        logger.info(\n            \"Export subprocess loaded transformers %s\", transformers.__version__\n        )\n\n    except Exception as exc:\n        _send_response(\n            resp_queue,\n            {\n                \"type\": \"error\",\n                \"error\": f\"Failed to import ML libraries: {exc}\",\n                \"stack\": traceback.format_exc(limit = 20),\n                \"ts\": time.time(),\n            },\n        )\n        return\n\n    # ── 3. Create export backend and load initial checkpoint ──\n    try:\n        backend = ExportBackend()\n\n        _handle_load(backend, config, resp_queue)\n\n    except Exception as exc:\n        _send_response(\n            resp_queue,\n            {\n                \"type\": \"error\",\n                \"error\": f\"Failed to initialize export backend: {exc}\",\n                \"stack\": traceback.format_exc(limit = 20),\n                \"ts\": time.time(),\n            },\n        )\n        return\n\n    # ── 4. Command loop — process commands until shutdown ──\n    logger.info(\"Export subprocess ready, entering command loop\")\n\n    while True:\n        try:\n            cmd = cmd_queue.get(timeout = 1.0)\n        except _queue.Empty:\n            continue\n        except (EOFError, OSError):\n            logger.info(\"Command queue closed, shutting down\")\n            return\n\n        if cmd is None:\n            continue\n\n        cmd_type = cmd.get(\"type\", \"\")\n        logger.info(\"Received command: %s\", cmd_type)\n\n        try:\n            if cmd_type == \"load\":\n                # Load a new checkpoint (reusing this subprocess)\n                backend.cleanup_memory()\n                _handle_load(backend, cmd, resp_queue)\n\n            elif cmd_type == \"export\":\n                _handle_export(backend, cmd, resp_queue)\n\n            elif cmd_type == \"cleanup\":\n                _handle_cleanup(backend, resp_queue)\n\n            elif cmd_type == \"status\":\n                _send_response(\n                    resp_queue,\n                    {\n                        \"type\": \"status_response\",\n                        \"checkpoint\": backend.current_checkpoint,\n                        \"is_vision\": backend.is_vision,\n                        \"is_peft\": backend.is_peft,\n                        \"ts\": time.time(),\n                    },\n                )\n\n            elif cmd_type == \"shutdown\":\n                logger.info(\"Shutdown command received, cleaning up and exiting\")\n                try:\n                    backend.cleanup_memory()\n                except Exception:\n                    pass\n                _send_response(\n                    resp_queue,\n                    {\n                        \"type\": \"shutdown_ack\",\n                        \"ts\": time.time(),\n                    },\n                )\n                return\n\n            else:\n                logger.warning(\"Unknown command type: %s\", cmd_type)\n                _send_response(\n                    resp_queue,\n                    {\n                        \"type\": \"error\",\n                        \"error\": f\"Unknown command type: {cmd_type}\",\n                        \"ts\": time.time(),\n                    },\n                )\n\n        except Exception as exc:\n            logger.error(\n                \"Error handling command '%s': %s\", cmd_type, exc, exc_info = True\n            )\n            _send_response(\n                resp_queue,\n                {\n                    \"type\": \"error\",\n                    \"error\": f\"Command '{cmd_type}' failed: {exc}\",\n                    \"stack\": traceback.format_exc(limit = 20),\n                    \"ts\": time.time(),\n                },\n            )\n"
  },
  {
    "path": "studio/backend/core/inference/__init__.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nInference submodule - Inference backend for model loading and generation\n\nThe default get_inference_backend() returns an InferenceOrchestrator that\ndelegates to a subprocess. The original InferenceBackend runs inside\nthe subprocess and can be imported directly from .inference when needed.\n\"\"\"\n\nfrom .orchestrator import InferenceOrchestrator, get_inference_backend\nfrom .llama_cpp import LlamaCppBackend\n\n# Expose InferenceOrchestrator as InferenceBackend for backward compat\nInferenceBackend = InferenceOrchestrator\n\n__all__ = [\n    \"InferenceBackend\",\n    \"InferenceOrchestrator\",\n    \"get_inference_backend\",\n    \"LlamaCppBackend\",\n]\n"
  },
  {
    "path": "studio/backend/core/inference/audio_codecs.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nAudio codec loading and decoding for TTS inference.\nSupports: SNAC (Orpheus), CSM (Sesame), BiCodec (Spark), DAC (OuteTTS)\n\"\"\"\n\nimport io\nimport re\nimport wave\nimport structlog\nfrom loggers import get_logger\nfrom typing import Optional, Tuple\n\nimport numpy as np\nimport torch\n\nlogger = get_logger(__name__)\n\n\ndef _numpy_to_wav_bytes(waveform: np.ndarray, sample_rate: int) -> bytes:\n    \"\"\"Convert a float32 numpy waveform to WAV bytes (16-bit PCM).\"\"\"\n    waveform = waveform.flatten()\n    peak = max(abs(waveform.max()), abs(waveform.min()))\n    if peak > 1.0:\n        waveform = waveform / peak\n    pcm = (waveform * 32767).astype(np.int16)\n\n    buf = io.BytesIO()\n    with wave.open(buf, \"wb\") as wf:\n        wf.setnchannels(1)\n        wf.setsampwidth(2)\n        wf.setframerate(sample_rate)\n        wf.writeframes(pcm.tobytes())\n\n    return buf.getvalue()\n\n\nclass AudioCodecManager:\n    \"\"\"Manages loading and caching of audio codec models for TTS decoding.\"\"\"\n\n    def __init__(self):\n        self._snac_model = None\n        self._bicodec_tokenizer = None\n        self._bicodec_repo_path = None\n        self._dac_audio_codec = None\n\n    def load_codec(\n        self,\n        audio_type: str,\n        device: str = \"cuda\",\n        model_repo_path: Optional[str] = None,\n    ) -> None:\n        \"\"\"Load the appropriate codec for the given audio type.\"\"\"\n        if audio_type == \"snac\":\n            self._load_snac(device)\n        elif audio_type == \"bicodec\":\n            self._load_bicodec(device, model_repo_path)\n        elif audio_type == \"dac\":\n            self._load_dac(device)\n        elif audio_type == \"csm\":\n            pass  # CSM decoding is built into the model (output_audio=True)\n        else:\n            raise ValueError(f\"Unknown audio_type: {audio_type}\")\n\n    # ── Lazy loaders ─────────────────────────────────────────────\n\n    def _load_snac(self, device: str) -> None:\n        if self._snac_model is not None:\n            return\n        from snac import SNAC\n\n        self._snac_model = (\n            SNAC.from_pretrained(\"hubertsiuzdak/snac_24khz\").to(device).eval()\n        )\n        logger.info(\"Loaded SNAC codec (24kHz)\")\n\n    def _load_bicodec(self, device: str, model_repo_path: Optional[str] = None) -> None:\n        if self._bicodec_tokenizer is not None:\n            return\n        import os\n        import sys\n        import subprocess\n\n        # Clone SparkAudio/Spark-TTS GitHub repo for the sparktts Python package\n        # (same approach as training — the HF model repos don't contain the package)\n        spark_code_dir = os.path.join(\n            os.path.dirname(model_repo_path or \".\"), \"Spark-TTS\"\n        )\n        sparktts_pkg = os.path.join(spark_code_dir, \"sparktts\")\n        if not os.path.isdir(sparktts_pkg):\n            logger.info(f\"Cloning SparkAudio/Spark-TTS to {spark_code_dir}...\")\n            subprocess.run(\n                [\n                    \"git\",\n                    \"clone\",\n                    \"--depth\",\n                    \"1\",\n                    \"https://github.com/SparkAudio/Spark-TTS\",\n                    spark_code_dir,\n                ],\n                check = True,\n            )\n\n        if spark_code_dir not in sys.path:\n            sys.path.insert(0, spark_code_dir)\n\n        from sparktts.models.audio_tokenizer import BiCodecTokenizer\n\n        # BiCodecTokenizer needs the MODEL repo path (contains BiCodec/ weights)\n        tokenizer_path = model_repo_path or spark_code_dir\n        self._bicodec_repo_path = tokenizer_path\n        self._bicodec_tokenizer = BiCodecTokenizer(tokenizer_path, device)\n        logger.info(f\"Loaded BiCodec tokenizer from {tokenizer_path}\")\n\n    def _load_dac(self, device: str) -> None:\n        if self._dac_audio_codec is not None:\n            return\n        import os\n        import sys\n        import subprocess\n\n        # Clone OuteTTS repo (same pattern as Spark-TTS / BiCodec)\n        # The pip package has problematic dependencies; the notebook clones and\n        # removes gguf_model.py, interface.py, __init__.py before importing.\n        base_dir = os.path.dirname(os.path.abspath(__file__))\n        outetts_code_dir = os.path.join(base_dir, \"OuteTTS\")\n        outetts_pkg = os.path.join(outetts_code_dir, \"outetts\")\n        if not os.path.isdir(outetts_pkg):\n            logger.info(f\"Cloning edwko/OuteTTS to {outetts_code_dir}...\")\n            subprocess.run(\n                [\n                    \"git\",\n                    \"clone\",\n                    \"--depth\",\n                    \"1\",\n                    \"https://github.com/edwko/OuteTTS\",\n                    outetts_code_dir,\n                ],\n                check = True,\n            )\n            # Remove files that pull in heavy / incompatible dependencies\n            # (matches notebook: gguf_model.py is under models/, others under outetts/)\n            remove_paths = [\n                os.path.join(outetts_pkg, \"models\", \"gguf_model.py\"),\n                os.path.join(outetts_pkg, \"interface.py\"),\n                os.path.join(outetts_pkg, \"__init__.py\"),\n            ]\n            for fpath in remove_paths:\n                if os.path.exists(fpath):\n                    os.remove(fpath)\n                    logger.info(f\"Removed {fpath}\")\n\n        if outetts_code_dir not in sys.path:\n            sys.path.insert(0, outetts_code_dir)\n\n        from outetts.version.v3.audio_processor import AudioProcessor\n        from outetts.models.config import ModelConfig as OuteTTSModelConfig\n\n        dummy_config = OuteTTSModelConfig(\n            tokenizer_path = \"OuteAI/Llama-OuteTTS-1.0-1B\",\n            device = device,\n            audio_codec_path = None,\n        )\n        processor = AudioProcessor(config = dummy_config)\n        self._dac_audio_codec = processor.audio_codec\n        logger.info(\"Loaded DAC audio codec\")\n\n    # ── Decoders ─────────────────────────────────────────────────\n\n    def decode_snac(\n        self, generated_ids: torch.Tensor, device: str\n    ) -> Tuple[bytes, int]:\n        \"\"\"\n        Decode SNAC tokens (Orpheus) into WAV bytes.\n\n        generated_ids: full model output including prompt tokens.\n        Looks for START_OF_SPEECH (128257) marker, extracts codes after it,\n        strips EOS (128258), redistributes 7-per-frame codes into 3 SNAC layers.\n\n        Returns (wav_bytes, 24000).\n        \"\"\"\n        # Find START_OF_SPEECH token (128257)\n        token_indices = (generated_ids == 128257).nonzero(as_tuple = True)\n        if len(token_indices[1]) > 0:\n            cropped = generated_ids[:, token_indices[1][-1] + 1 :]\n        else:\n            # Gracefully fall back to using entire output if marker not found\n            logger.warning(\n                \"No START_OF_SPEECH token (128257) found — using full generated output\"\n            )\n            cropped = generated_ids\n        row = cropped[0]\n\n        # Remove EOS tokens (128258)\n        row = row[row != 128258]\n\n        # Trim to multiple of 7\n        row = row[: (len(row) // 7) * 7]\n        if len(row) == 0:\n            raise ValueError(\"No valid audio codes found after START_OF_SPEECH token\")\n\n        codes = [t.item() - 128266 for t in row]\n\n        # Redistribute into 3 SNAC layers (7 codes per frame → 1+2+4)\n        layer_1, layer_2, layer_3 = [], [], []\n        for i in range(len(codes) // 7):\n            layer_1.append(codes[7 * i])\n            layer_2.append(codes[7 * i + 1] - 4096)\n            layer_3.append(codes[7 * i + 2] - 8192)\n            layer_3.append(codes[7 * i + 3] - 12288)\n            layer_2.append(codes[7 * i + 4] - 16384)\n            layer_3.append(codes[7 * i + 5] - 20480)\n            layer_3.append(codes[7 * i + 6] - 24576)\n\n        snac_codes = [\n            torch.tensor(layer).unsqueeze(0).to(device)\n            for layer in [layer_1, layer_2, layer_3]\n        ]\n\n        with torch.no_grad():\n            audio = self._snac_model.decode(snac_codes)\n\n        waveform = audio.squeeze().cpu().numpy()\n        return _numpy_to_wav_bytes(waveform, 24000), 24000\n\n    def decode_csm(self, audio_values: torch.Tensor) -> Tuple[bytes, int]:\n        \"\"\"\n        Decode CSM output (already a waveform from model.generate(output_audio=True)).\n        Returns (wav_bytes, 24000).\n        \"\"\"\n        waveform = audio_values[0].to(torch.float32).cpu().numpy()\n        return _numpy_to_wav_bytes(waveform, 24000), 24000\n\n    def decode_bicodec(self, generated_text: str, device: str) -> Tuple[bytes, int]:\n        \"\"\"\n        Decode BiCodec tokens (Spark-TTS) from generated text.\n        Extracts bicodec_semantic_N and bicodec_global_N tokens via regex.\n        Returns (wav_bytes, sample_rate).\n        \"\"\"\n        semantic_matches = re.findall(r\"<\\|bicodec_semantic_(\\d+)\\|>\", generated_text)\n        global_matches = re.findall(r\"<\\|bicodec_global_(\\d+)\\|>\", generated_text)\n\n        logger.info(\n            f\"BiCodec decode: {len(global_matches)} global tokens, {len(semantic_matches)} semantic tokens\"\n        )\n        if len(global_matches) < 10:\n            logger.info(\n                f\"BiCodec generated text (first 500 chars): {generated_text[:500]}\"\n            )\n\n        if not semantic_matches:\n            raise ValueError(\"No bicodec_semantic tokens found in generated output\")\n\n        semantic_ids = (\n            torch.tensor([int(t) for t in semantic_matches]).long().unsqueeze(0)\n        )\n\n        # Speaker encoder expects exactly 32 global tokens (token_num=32 in BiCodec config).\n        # Pad with zeros or truncate to 32.\n        GLOBAL_TOKEN_NUM = 32\n        if global_matches:\n            raw = [int(t) for t in global_matches]\n        else:\n            raw = []\n        if len(raw) < GLOBAL_TOKEN_NUM:\n            raw = raw + [0] * (GLOBAL_TOKEN_NUM - len(raw))\n        raw = raw[:GLOBAL_TOKEN_NUM]\n        global_ids = torch.tensor(raw).long().unsqueeze(0)  # (1, 32)\n\n        self._bicodec_tokenizer.device = device\n        self._bicodec_tokenizer.model.to(device)\n\n        wav_np = self._bicodec_tokenizer.detokenize(\n            global_ids.to(device),\n            semantic_ids.to(device),\n        )\n        sr = self._bicodec_tokenizer.config.get(\"sample_rate\", 16000)\n        return _numpy_to_wav_bytes(wav_np, sr), sr\n\n    def decode_dac(self, generated_text: str, device: str) -> Tuple[bytes, int]:\n        \"\"\"\n        Decode DAC tokens (OuteTTS) from generated text.\n        Extracts c1_N and c2_N codec code tokens via regex.\n        Returns (wav_bytes, 24000).\n        \"\"\"\n        c1 = list(map(int, re.findall(r\"<\\|c1_(\\d+)\\|>\", generated_text)))\n        c2 = list(map(int, re.findall(r\"<\\|c2_(\\d+)\\|>\", generated_text)))\n\n        if not c1 or not c2:\n            raise ValueError(\"No DAC code tokens (c1/c2) found in generated output\")\n\n        t = min(len(c1), len(c2))\n        c1 = c1[:t]\n        c2 = c2[:t]\n\n        codes = torch.tensor([[c1, c2]], dtype = torch.int64).to(device)\n        with torch.no_grad():\n            audio = self._dac_audio_codec.decode(codes)\n\n        waveform = audio.squeeze().cpu().numpy()\n        return _numpy_to_wav_bytes(waveform, 24000), 24000\n\n    def decode(\n        self,\n        audio_type: str,\n        device: str,\n        token_ids: Optional[list] = None,\n        text: Optional[str] = None,\n    ) -> Tuple[bytes, int]:\n        \"\"\"Unified decode — dispatches to the right codec decoder.\"\"\"\n        if audio_type == \"snac\":\n            if not token_ids:\n                raise ValueError(\"SNAC decoding requires token_ids\")\n            return self.decode_snac(torch.tensor([token_ids], dtype = torch.long), device)\n        elif audio_type == \"bicodec\":\n            if not text:\n                raise ValueError(\"BiCodec decoding requires text\")\n            return self.decode_bicodec(text, device)\n        elif audio_type == \"dac\":\n            if not text:\n                raise ValueError(\"DAC decoding requires text\")\n            return self.decode_dac(text, device)\n        raise ValueError(f\"Cannot decode audio_type: {audio_type}\")\n\n    # ── Cleanup ──────────────────────────────────────────────────\n\n    def unload(self) -> None:\n        \"\"\"Release all codec models from memory.\"\"\"\n        if self._snac_model is not None:\n            del self._snac_model\n            self._snac_model = None\n        if self._bicodec_tokenizer is not None:\n            del self._bicodec_tokenizer\n            self._bicodec_tokenizer = None\n            self._bicodec_repo_path = None\n        if self._dac_audio_codec is not None:\n            del self._dac_audio_codec\n            self._dac_audio_codec = None\n        logger.info(\"Unloaded all audio codecs\")\n"
  },
  {
    "path": "studio/backend/core/inference/defaults.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"Default model lists for inference, split by platform.\"\"\"\n\nimport utils.hardware.hardware as hw\n\nDEFAULT_MODELS_GGUF = [\n    \"unsloth/Llama-3.2-1B-Instruct-GGUF\",\n    \"unsloth/Llama-3.2-3B-Instruct-GGUF\",\n    \"unsloth/Llama-3.1-8B-Instruct-GGUF\",\n    \"unsloth/gemma-3-1b-it-GGUF\",\n    \"unsloth/gemma-3-4b-it-GGUF\",\n    \"unsloth/Qwen3-4B-GGUF\",\n]\n\nDEFAULT_MODELS_STANDARD = [\n    \"unsloth/Qwen3-4B-Instruct-2507\",\n    \"unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit\",\n    \"unsloth/Mistral-Nemo-Instruct-2407-bnb-4bit\",\n    \"unsloth/Phi-3.5-mini-instruct\",\n    \"unsloth/Gemma-3-4B-it\",\n    \"unsloth/Qwen2-VL-2B-Instruct-bnb-4bit\",\n]\n\n\ndef get_default_models() -> list[str]:\n    hw.get_device()  # ensure detect_hardware() has run\n    if hw.CHAT_ONLY:\n        return list(DEFAULT_MODELS_GGUF)\n    return list(DEFAULT_MODELS_STANDARD)\n"
  },
  {
    "path": "studio/backend/core/inference/inference.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nCore inference backend - streamlined\n\"\"\"\n\nfrom unsloth import FastLanguageModel, FastVisionModel\nfrom unsloth.chat_templates import get_chat_template\nfrom transformers import TextStreamer\nfrom peft import PeftModel, PeftModelForCausalLM\n\nimport json\nimport sys\nimport torch\nfrom pathlib import Path\nfrom typing import Optional, Union, Generator, Tuple\nfrom utils.models import ModelConfig, get_base_model_from_lora\nfrom utils.paths import is_model_cached\nfrom utils.utils import format_error_message\nfrom utils.hardware import get_device, clear_gpu_cache, log_gpu_memory\nfrom core.inference.audio_codecs import AudioCodecManager\nfrom io import StringIO\nimport structlog\nfrom loggers import get_logger\n\n\nlogger = get_logger(__name__)\n\n\nclass HarmonyTextStreamer:\n    \"\"\"Streaming text decoder for gpt-oss harmony channel protocol.\n\n    gpt-oss models emit multi-channel output using special tokens like\n    ``<|channel|>analysis<|message|>...`` and ``<|channel|>final<|message|>...``.\n    A plain ``TextIteratorStreamer(skip_special_tokens=True)`` strips the special\n    tokens but leaves the channel names concatenated with content, producing\n    garbled output such as ``analysisWe need to respond...assistantfinalHello!``.\n\n    This streamer decodes with ``skip_special_tokens=False`` so the full\n    harmony markup is visible, then uses **stateful incremental** parsing\n    to emit properly-formatted text:\n\n    - ``<think>`` emitted once when the ``analysis`` channel is first seen\n    - Analysis content streamed incrementally\n    - ``</think>`` emitted once when the ``final`` channel is first seen\n    - Final content streamed incrementally\n\n    This avoids the delta-on-transformed bug where wrapping tags shift\n    position as content grows.\n\n    Implements the same ``put`` / ``end`` / iterator interface as\n    ``TextIteratorStreamer`` so ``generate_stream`` can use it as a drop-in\n    replacement.\n    \"\"\"\n\n    import re as _re\n\n    _HARMONY_RE = _re.compile(\n        r\"<\\|channel\\|>(\\w+)<\\|message\\|>(.*?)(?=<\\|end\\|>|<\\|channel\\|>|\\Z)\",\n        _re.DOTALL,\n    )\n\n    def __init__(self, tokenizer, *, skip_prompt: bool = True, timeout: float = 0.2):\n        import queue\n\n        self.tokenizer = tokenizer\n        self.skip_prompt = skip_prompt\n        self.timeout = timeout\n\n        self._queue: queue.Queue = queue.Queue()\n        self._token_ids: list = []\n        self._prompt_len: int = 0\n        self._is_first_put: bool = True\n        self._stop: bool = False\n\n        # Stateful channel tracking — avoids delta-on-transformed bugs\n        self._emitted_think_open: bool = False\n        self._emitted_think_close: bool = False\n        self._analysis_emitted: int = 0  # chars of analysis content emitted\n        self._final_emitted: int = 0  # chars of final content emitted\n\n    # ------------------------------------------------------------------\n    # put / end — called from the generation thread\n    # ------------------------------------------------------------------\n\n    def put(self, value):\n        \"\"\"Receive new token IDs from model.generate().\"\"\"\n        import torch\n\n        if isinstance(value, torch.Tensor):\n            # value shape: (batch, seq) — take first batch element\n            ids = value[0].tolist() if value.dim() > 1 else value.tolist()\n        elif isinstance(value, (list, tuple)):\n            ids = list(value)\n        else:\n            ids = [value]\n\n        if self._is_first_put and self.skip_prompt:\n            # First call contains the full prompt; remember its length\n            self._prompt_len = len(ids)\n            self._token_ids = list(ids)\n            self._is_first_put = False\n            return\n\n        self._token_ids.extend(ids)\n\n        # Decode only the generated part (after the prompt)\n        gen_ids = self._token_ids[self._prompt_len :]\n        raw = self.tokenizer.decode(gen_ids, skip_special_tokens = False)\n        self._process_incremental(raw)\n\n    def end(self):\n        \"\"\"Signal generation is complete.\"\"\"\n        # Final decode to capture any remaining content\n        gen_ids = self._token_ids[self._prompt_len :]\n        if gen_ids:\n            raw = self.tokenizer.decode(gen_ids, skip_special_tokens = False)\n            self._process_incremental(raw)\n\n        # Close any open think tags\n        if self._emitted_think_open and not self._emitted_think_close:\n            self._queue.put(\"</think>\")\n            self._emitted_think_close = True\n\n        self._stop = True\n        self._queue.put(None)  # sentinel\n\n    # ------------------------------------------------------------------\n    # Iterator interface — consumed by the streaming loop\n    # ------------------------------------------------------------------\n\n    def __iter__(self):\n        return self\n\n    def __next__(self):\n        from queue import Empty\n\n        while True:\n            try:\n                val = self._queue.get(timeout = self.timeout)\n            except Empty:\n                if self._stop:\n                    raise StopIteration\n                raise  # propagate Empty so caller can check thread liveness\n            if val is None:\n                raise StopIteration\n            return val\n\n    # ------------------------------------------------------------------\n    # Stateful incremental harmony protocol parsing\n    # ------------------------------------------------------------------\n\n    def _process_incremental(self, raw: str) -> None:\n        \"\"\"Parse harmony channels and emit deltas per-channel.\n\n        Instead of transforming the entire raw text and computing a string\n        delta (which breaks when wrapping ``<think>`` tags shift position),\n        this tracks per-channel content lengths and emits:\n\n        - ``<think>`` once when analysis channel first appears\n        - analysis content deltas (computed on channel content directly)\n        - ``</think>`` once when final channel first appears\n        - final content deltas\n        \"\"\"\n        # If raw contains <|channel|> but no complete channel+message pair yet,\n        # buffer silently — don't emit partial channel names as text.\n        has_channel_token = \"<|channel|>\" in raw\n        matches = list(self._HARMONY_RE.finditer(raw))\n\n        if has_channel_token and not matches:\n            # Partial harmony markup still building — wait for more tokens\n            return\n\n        if not has_channel_token and not matches:\n            # No harmony protocol at all — should not happen for gpt-oss\n            # but handle gracefully by not emitting anything\n            return\n\n        for m in matches:\n            channel = m.group(1).lower()\n            content = m.group(2)\n\n            if channel == \"analysis\":\n                if not self._emitted_think_open:\n                    self._queue.put(\"<think>\")\n                    self._emitted_think_open = True\n\n                new_content = content[self._analysis_emitted :]\n                if new_content:\n                    self._analysis_emitted = len(content)\n                    self._queue.put(new_content)\n\n            elif channel in (\"final\", \"assistant\"):\n                if self._emitted_think_open and not self._emitted_think_close:\n                    self._queue.put(\"</think>\")\n                    self._emitted_think_close = True\n\n                new_content = content[self._final_emitted :]\n                if new_content:\n                    self._final_emitted = len(content)\n                    self._queue.put(new_content)\n\n\nclass InferenceBackend:\n    \"\"\"Unified inference backend supporting text, vision, and LoRA models\"\"\"\n\n    def __init__(self):\n        self.models = {}\n        self.active_model_name = None\n        self.loading_models = set()\n        self.loaded_local_models = []  # [(display_name, path), ...]\n        from core.inference.defaults import get_default_models\n\n        self.default_models = get_default_models()\n        self.device = get_device().value\n        self._audio_codec_manager = AudioCodecManager()\n\n        # Thread safety — _generation_lock serializes model.generate() calls.\n        # Must be a regular Lock (NOT RLock) because in async FastAPI, multiple\n        # requests share the same event-loop thread, so RLock reentrancy lets\n        # concurrent compare-mode requests race on the GPU.  The lock is\n        # acquired by the *background generation thread*, not the event-loop.\n        import threading\n\n        self._generation_lock = threading.Lock()\n        self._model_state_lock = threading.Lock()\n\n        logger.info(f\"InferenceBackend initialized on {self.device}\")\n\n    @staticmethod\n    def _normalize_top_k(top_k: int) -> int:\n        # API supports -1 as \"disable top-k\"; transformers expects 0 to disable.\n        return 0 if top_k < 0 else top_k\n\n    def load_model(\n        self,\n        config: ModelConfig,\n        max_seq_length: int = 2048,\n        dtype = None,\n        load_in_4bit: bool = True,\n        hf_token: Optional[str] = None,\n        trust_remote_code: bool = False,\n    ) -> bool:\n        \"\"\"\n        Load any model: base, LoRA adapter, text, or vision.\n        \"\"\"\n        try:\n            model_name = config.identifier\n\n            # Check if already loaded\n            if model_name in self.models and self.models[model_name].get(\"model\"):\n                logger.info(f\"Model {model_name} already loaded\")\n                self.active_model_name = model_name\n                return True\n\n            # Check if currently loading\n            if model_name in self.loading_models:\n                logger.info(f\"Model {model_name} is already being loaded\")\n                return False\n\n            self.loading_models.add(model_name)\n\n            self.models[model_name] = {\n                \"is_vision\": config.is_vision,\n                \"is_lora\": config.is_lora,\n                \"is_audio\": config.is_audio,\n                \"audio_type\": config.audio_type,\n                \"has_audio_input\": config.has_audio_input,\n                \"model_path\": config.path,\n                \"base_model\": config.base_model if config.is_lora else None,\n                \"loaded_adapters\": {},\n                \"active_adapter\": None,\n            }\n\n            # ── Audio model loading path ──────────────────────────\n            if config.is_audio:\n                audio_type = config.audio_type\n                adapter_info = \" (LoRA adapter)\" if config.is_lora else \"\"\n                logger.info(\n                    f\"Loading audio ({audio_type}) model{adapter_info}: {model_name}\"\n                )\n                log_gpu_memory(f\"Before loading {model_name}\")\n\n                if audio_type == \"csm\":\n                    from unsloth import FastModel\n                    from transformers import CsmForConditionalGeneration\n\n                    model, processor = FastModel.from_pretrained(\n                        config.path,\n                        auto_model = CsmForConditionalGeneration,\n                        load_in_4bit = False,\n                        token = hf_token if hf_token and hf_token.strip() else None,\n                        trust_remote_code = trust_remote_code,\n                    )\n                    FastModel.for_inference(model)\n                    self.models[model_name][\"model\"] = model\n                    self.models[model_name][\"tokenizer\"] = processor\n                    self.models[model_name][\"processor\"] = processor\n                elif audio_type == \"bicodec\":\n                    import os\n                    from unsloth import FastModel\n\n                    if config.is_lora and config.base_model:\n                        # LoRA adapter: load from local adapter path.\n                        # base_model is e.g. /home/.../Spark-TTS-0.5B/LLM\n                        # The BiCodec weights are in the parent dir (Spark-TTS-0.5B/).\n                        base_path = config.base_model\n                        if os.path.isdir(base_path):\n                            abs_repo_path = os.path.abspath(os.path.dirname(base_path))\n                        else:\n                            # base_model is an HF ID — download it\n                            from huggingface_hub import snapshot_download\n\n                            local_dir = base_path.split(\"/\")[-1]\n                            repo_path = snapshot_download(\n                                base_path, local_dir = local_dir\n                            )\n                            abs_repo_path = os.path.abspath(repo_path)\n\n                        logger.info(\n                            f\"Spark-TTS LoRA: loading adapter from {config.path}, BiCodec from {abs_repo_path}\"\n                        )\n                        model, tokenizer = FastModel.from_pretrained(\n                            config.path,\n                            dtype = torch.float32,\n                            load_in_4bit = False,\n                            token = hf_token if hf_token and hf_token.strip() else None,\n                            trust_remote_code = trust_remote_code,\n                        )\n                    else:\n                        # Base model: download full HF repo, then load from /LLM subfolder\n                        from huggingface_hub import snapshot_download\n\n                        hf_repo = config.path\n                        local_dir = hf_repo.split(\"/\")[-1]\n                        repo_path = snapshot_download(hf_repo, local_dir = local_dir)\n                        abs_repo_path = os.path.abspath(repo_path)\n                        llm_path = os.path.join(abs_repo_path, \"LLM\")\n                        logger.info(\n                            f\"Spark-TTS: downloaded repo to {repo_path}, loading LLM from {llm_path}\"\n                        )\n\n                        model, tokenizer = FastModel.from_pretrained(\n                            llm_path,\n                            dtype = torch.float32,\n                            load_in_4bit = False,\n                            token = hf_token if hf_token and hf_token.strip() else None,\n                            trust_remote_code = trust_remote_code,\n                        )\n\n                    FastModel.for_inference(model)\n                    self.models[model_name][\"model\"] = model\n                    self.models[model_name][\"tokenizer\"] = tokenizer\n                    self.models[model_name][\"model_repo_path\"] = abs_repo_path\n                elif audio_type == \"dac\":\n                    # OuteTTS uses FastModel (not FastLanguageModel)\n                    from unsloth import FastModel\n\n                    model, tokenizer = FastModel.from_pretrained(\n                        config.path,\n                        max_seq_length = max_seq_length,\n                        load_in_4bit = False,\n                        token = hf_token if hf_token and hf_token.strip() else None,\n                        trust_remote_code = trust_remote_code,\n                    )\n                    FastModel.for_inference(model)\n                    self.models[model_name][\"model\"] = model\n                    self.models[model_name][\"tokenizer\"] = tokenizer\n                elif audio_type == \"whisper\":\n                    # Whisper ASR — uses FastModel with WhisperForConditionalGeneration\n                    from unsloth import FastModel\n                    from transformers import WhisperForConditionalGeneration\n\n                    model, tokenizer = FastModel.from_pretrained(\n                        config.path,\n                        auto_model = WhisperForConditionalGeneration,\n                        whisper_language = \"English\",\n                        whisper_task = \"transcribe\",\n                        load_in_4bit = False,\n                        token = hf_token if hf_token and hf_token.strip() else None,\n                        trust_remote_code = trust_remote_code,\n                    )\n                    FastModel.for_inference(model)\n                    model.eval()\n\n                    # Create ASR pipeline (per notebook)\n                    from transformers import pipeline as hf_pipeline\n\n                    whisper_pipe = hf_pipeline(\n                        \"automatic-speech-recognition\",\n                        model = model,\n                        tokenizer = tokenizer.tokenizer,\n                        feature_extractor = tokenizer.feature_extractor,\n                        processor = tokenizer,\n                        return_language = True,\n                        torch_dtype = torch.float16,\n                    )\n                    self.models[model_name][\"model\"] = model\n                    self.models[model_name][\"tokenizer\"] = tokenizer\n                    self.models[model_name][\"whisper_pipeline\"] = whisper_pipe\n                else:\n                    # SNAC (Orpheus) uses FastLanguageModel\n                    model, tokenizer = FastLanguageModel.from_pretrained(\n                        model_name = config.path,\n                        max_seq_length = max_seq_length,\n                        load_in_4bit = False,\n                        token = hf_token if hf_token and hf_token.strip() else None,\n                        trust_remote_code = trust_remote_code,\n                    )\n                    FastLanguageModel.for_inference(model)\n                    self.models[model_name][\"model\"] = model\n                    self.models[model_name][\"tokenizer\"] = tokenizer\n\n                # Load the external codec for TTS audio types\n                # (Whisper is ASR, audio_vlm is audio input — neither needs a codec)\n                if audio_type not in (\"whisper\", \"audio_vlm\"):\n                    model_repo_path = self.models[model_name].get(\"model_repo_path\")\n                    self._audio_codec_manager.load_codec(\n                        audio_type, self.device, model_repo_path = model_repo_path\n                    )\n\n                self.active_model_name = model_name\n                self.loading_models.discard(model_name)\n                logger.info(f\"Successfully loaded audio model: {model_name}\")\n                log_gpu_memory(f\"After loading {model_name}\")\n                return True\n\n            model_type = \"vision\" if config.is_vision else \"text\"\n            adapter_info = (\n                \" (LoRA adapter)\" if self.models[model_name][\"is_lora\"] else \"\"\n            )\n            logger.info(f\"Loading {model_type} model{adapter_info}: {model_name}\")\n            log_gpu_memory(f\"Before loading {model_name}\")\n\n            # Load model - same approach for base models and LoRA adapters\n            if config.is_vision:\n                # Vision model (or vision LoRA adapter)\n                model, processor = FastVisionModel.from_pretrained(\n                    model_name = config.path,  # Can be base model OR LoRA adapter path\n                    max_seq_length = max_seq_length,\n                    dtype = dtype,\n                    load_in_4bit = load_in_4bit,\n                    token = hf_token if hf_token and hf_token.strip() else None,\n                    trust_remote_code = trust_remote_code,\n                )\n\n                # Apply inference optimization\n                FastVisionModel.for_inference(model)\n\n                # FastVisionModel may return a raw tokenizer (e.g. GemmaTokenizerFast)\n                # instead of a proper Processor for some models (e.g. Gemma-3).\n                # In that case, load the real processor from the base model.\n                from transformers import ProcessorMixin\n\n                if not (\n                    isinstance(processor, ProcessorMixin)\n                    or hasattr(processor, \"image_processor\")\n                ):\n                    # For LoRA adapters, use the base model. For local merged exports,\n                    # read export_metadata.json to find the original base model.\n                    processor_source = (\n                        config.base_model if config.is_lora else config.identifier\n                    )\n                    if not config.is_lora and config.is_local:\n                        _meta_path = Path(config.path) / \"export_metadata.json\"\n                        try:\n                            if _meta_path.exists():\n                                _meta = json.loads(_meta_path.read_text())\n                                if _meta.get(\"base_model\"):\n                                    processor_source = _meta[\"base_model\"]\n                        except Exception:\n                            pass\n                    logger.warning(\n                        f\"FastVisionModel returned {type(processor).__name__} (no image_processor) \"\n                        f\"for '{model_name}' — loading proper processor from '{processor_source}'\"\n                    )\n                    from transformers import AutoProcessor\n\n                    processor = AutoProcessor.from_pretrained(\n                        processor_source,\n                        token = hf_token if hf_token and hf_token.strip() else None,\n                        trust_remote_code = trust_remote_code,\n                    )\n                    logger.info(\n                        f\"Loaded {type(processor).__name__} from {processor_source}\"\n                    )\n\n                self.models[model_name][\"model\"] = model\n                self.models[model_name][\"tokenizer\"] = processor\n                self.models[model_name][\"processor\"] = processor\n\n            else:\n                # Text model (or text LoRA adapter)\n                model, tokenizer = FastLanguageModel.from_pretrained(\n                    model_name = config.path,  # Can be base model OR LoRA adapter path\n                    max_seq_length = max_seq_length,\n                    dtype = dtype,\n                    load_in_4bit = load_in_4bit,\n                    token = hf_token if hf_token and hf_token.strip() else None,\n                    trust_remote_code = trust_remote_code,\n                )\n\n                # Apply inference optimization\n                FastLanguageModel.for_inference(model)\n\n                self.models[model_name][\"model\"] = model\n                self.models[model_name][\"tokenizer\"] = tokenizer\n\n            # Load chat template info\n            self._load_chat_template_info(model_name)\n\n            self.active_model_name = model_name\n            self.loading_models.discard(model_name)\n\n            logger.info(f\"Successfully loaded model: {model_name}\")\n            log_gpu_memory(f\"After loading {model_name}\")\n            return True\n\n        except Exception as e:\n            logger.error(f\"Failed to load model: {e}\")\n            error_msg = format_error_message(e, config.identifier)\n\n            # Cleanup on failure\n            if model_name in self.models:\n                del self.models[model_name]\n            self.loading_models.discard(model_name)\n\n            raise Exception(error_msg)\n\n    def unload_model(self, model_name: str) -> bool:\n        \"\"\"\n        Completely removes a model from the registry and clears GPU memory.\n        \"\"\"\n        if model_name in self.models:\n            try:\n                # If this was an audio model, clean up codecs\n                if self.models[model_name].get(\"is_audio\"):\n                    self._audio_codec_manager.unload()\n\n                logger.info(f\"Unloading model '{model_name}' from memory.\")\n                # Delete the model entry from our registry\n                del self.models[model_name]\n\n                # Clear the active model if it was the one being unloaded\n                if self.active_model_name == model_name:\n                    self.active_model_name = None\n\n                # Clear GPU memory cache\n                clear_gpu_cache()\n\n                # Remove stale compiled cache so the next model gets a fresh one\n                from utils.cache_cleanup import clear_unsloth_compiled_cache\n\n                clear_unsloth_compiled_cache()\n\n                logger.info(f\"Model '{model_name}' successfully unloaded.\")\n                return True\n            except Exception as e:\n                logger.error(f\"Error while unloading model '{model_name}': {e}\")\n                return False\n        else:\n            logger.warning(\n                f\"Attempted to unload model '{model_name}', but it was not found in the registry.\"\n            )\n            return True\n\n    def revert_to_base_model(self, base_model_name: str) -> bool:\n        \"\"\"\n        Reverts the model to its pristine base state by unloading AND\n        deleting all adapter configurations, as instructed.\n        \"\"\"\n        if base_model_name not in self.models:\n            return False\n\n        model = self.models[base_model_name].get(\"model\")\n\n        try:\n            # Step 1: Unload the adapter weights if model is a PeftModel.\n            if isinstance(model, (PeftModel, PeftModelForCausalLM)):\n                logger.info(f\"Unloading LoRA adapters from '{base_model_name}'...\")\n                unwrapped_base_model = model.unload()\n                self.models[base_model_name][\"model\"] = unwrapped_base_model\n                model = unwrapped_base_model\n\n            # Step 2: Clear any lingering peft_config from the unwrapped model.\n            # After model.unload(), the base model may still carry a peft_config\n            # attribute. Removing it ensures PeftModel.from_pretrained() gets\n            # a clean base model without \"multiple adapters\" warnings.\n            if hasattr(model, \"peft_config\"):\n                del model.peft_config\n\n            logger.info(f\"Model '{base_model_name}' reverted to clean base state.\")\n            return True\n\n        except Exception as e:\n            logger.error(f\"Failed to revert model to base state: {e}\")\n            import traceback\n\n            logger.error(traceback.format_exc())\n            return False\n\n    def load_for_eval(\n        self,\n        lora_path: str,\n        max_seq_length: int = 2048,\n        dtype = None,\n        load_in_4bit: bool = True,\n        hf_token: Optional[str] = None,\n    ) -> Tuple[bool, Optional[str], Optional[str]]:\n        \"\"\"\n        Final Corrected Version:\n        Ensures the base model and the specified adapter are loaded.\n        This function is idempotent and handles all states correctly.\n        \"\"\"\n        try:\n            from utils.models import ModelConfig\n\n            lora_config = ModelConfig.from_lora_path(lora_path, hf_token)\n            if not lora_config:\n                return False, None, None\n\n            base_model_name = lora_config.base_model\n\n            # 1. Load the base model if it's not already in memory\n            if base_model_name not in self.models or not self.models[\n                base_model_name\n            ].get(\"model\"):\n                logger.info(f\"Base model '{base_model_name}' not loaded, loading now.\")\n                base_config = ModelConfig.from_ui_selection(\n                    base_model_name, None, is_lora = False\n                )\n                if not self.load_model(\n                    base_config, max_seq_length, dtype, load_in_4bit, hf_token\n                ):\n                    return False, None, None\n\n            self.active_model_name = base_model_name\n\n            # 2. Determine the required adapter name from the user's selection\n            adapter_name = lora_path.split(\"/\")[-1].replace(\".\", \"_\")\n\n            # 3. Call our robust load_adapter function to ensure this specific adapter is loaded.\n            # It will only load from disk if the model doesn't already have it.\n            adapter_success = self.load_adapter(\n                base_model_name = base_model_name,\n                adapter_path = lora_path,\n                adapter_name = adapter_name,\n            )\n            if not adapter_success:\n                return False, base_model_name, None\n\n            # 4. Return the correct, verified adapter name for the UI logic to use.\n            return True, base_model_name, adapter_name\n\n        except Exception as e:\n            logger.error(f\"Error during load_for_eval: {e}\")\n            import traceback\n\n            logger.error(traceback.format_exc())\n            return False, None, None\n\n    def load_adapter(\n        self, base_model_name: str, adapter_path: str, adapter_name: str\n    ) -> bool:\n        \"\"\"\n        Loads an adapter onto the model ONLY if it's not already attached.\n        \"\"\"\n        model = self.models[base_model_name].get(\"model\")\n\n        # Check if this adapter name is already part of the model's config. This is the most reliable check.\n        if hasattr(model, \"peft_config\") and adapter_name in model.peft_config:\n            logger.info(\n                f\"Adapter '{adapter_name}' is already attached to the model. Skipping load.\"\n            )\n            return True\n\n        try:\n            logger.info(\n                f\"Loading new adapter '{adapter_name}' from '{adapter_path}' onto {base_model_name}\"\n            )\n            model.load_adapter(adapter_path, adapter_name = adapter_name)\n\n            # Update our internal registry ONLY after a successful load.\n            if \"loaded_adapters\" not in self.models[base_model_name]:\n                self.models[base_model_name][\"loaded_adapters\"] = {}\n            self.models[base_model_name][\"loaded_adapters\"][adapter_name] = adapter_path\n\n            total_adapters = len(getattr(model, \"peft_config\", {}))\n            logger.info(\n                f\"Adapter '{adapter_name}' loaded successfully. (Total unique adapters on model: {total_adapters})\"\n            )\n            return True\n        except Exception as e:\n            logger.error(f\"Failed to load adapter '{adapter_name}': {e}\")\n            return False\n\n    def set_active_adapter(self, base_model_name: str, adapter_name: str) -> bool:\n        \"\"\"\n        Sets the active adapter for generation. This replaces the flawed 'enable_adapter'.\n        \"\"\"\n        model = self.models[base_model_name].get(\"model\")\n        try:\n            logger.info(f\"Setting active adapter to: '{adapter_name}'\")\n            model.set_adapter(adapter_name)\n            self.models[base_model_name][\"active_adapter\"] = adapter_name\n            return True\n        except Exception as e:\n            # This will catch the \"adapter not found\" error if something goes wrong.\n            logger.error(f\"Failed to set active adapter to '{adapter_name}': {e}\")\n            return False\n\n    def _apply_adapter_state(self, use_adapter: Optional[Union[bool, str]]) -> None:\n        \"\"\"\n        Apply adapter state before generation. Must be called under _generation_lock.\n\n        Uses PEFT's disable_adapter_layers() / enable_adapter_layers() which toggle\n        a boolean flag on each LoRA layer. Unsloth's fast_linear_forward checks this\n        flag (proj.disable_adapters) and skips LoRA computation when True.\n        This is non-destructive — no model unloading/reloading needed.\n\n        Args:\n            use_adapter: None = no change, False = disable (base model),\n                         True = enable current adapter, str = enable specific adapter.\n        \"\"\"\n        if use_adapter is None:\n            return\n\n        base = self.active_model_name\n        if not base or base not in self.models:\n            return\n\n        model_info = self.models[base]\n        model = model_info.get(\"model\")\n        if model is None:\n            return\n\n        if use_adapter is False:\n            # Disable LoRA layers → base model output\n            if isinstance(model, (PeftModel, PeftModelForCausalLM)):\n                logger.info(\n                    f\"Compare mode: disabling adapters on '{base}' for base model generation\"\n                )\n                model.base_model.disable_adapter_layers()\n            else:\n                logger.info(\n                    f\"Compare mode: model '{base}' is not a PeftModel, already base\"\n                )\n\n        elif use_adapter is True:\n            # Re-enable LoRA layers → adapter output\n            if isinstance(model, (PeftModel, PeftModelForCausalLM)):\n                logger.info(\n                    f\"Compare mode: enabling adapters on '{base}' for LoRA generation\"\n                )\n                model.base_model.enable_adapter_layers()\n            else:\n                logger.warning(\"use_adapter=true but model is not a PeftModel\")\n\n        elif isinstance(use_adapter, str):\n            # Enable adapters and set the specific one active\n            if isinstance(model, (PeftModel, PeftModelForCausalLM)):\n                logger.info(\n                    f\"Compare mode: enabling adapter '{use_adapter}' on '{base}'\"\n                )\n                model.base_model.enable_adapter_layers()\n                self.set_active_adapter(base, use_adapter)\n            else:\n                logger.warning(\n                    f\"use_adapter='{use_adapter}' but model is not a PeftModel\"\n                )\n\n    def generate_with_adapter_control(\n        self,\n        use_adapter: Optional[Union[bool, str]] = None,\n        cancel_event = None,\n        **gen_kwargs,\n    ) -> Generator[str, None, None]:\n        \"\"\"\n        Thread-safe generation with optional adapter toggling.\n\n        The adapter toggle + model.generate() are serialized by _generation_lock\n        inside the background generation thread — NOT in the event-loop thread.\n        This prevents the RLock-reentrant race that occurs when two async SSE\n        handlers share the same event-loop thread.\n\n        Args:\n            use_adapter: Adapter control (None/False/True/str). See _apply_adapter_state.\n            **gen_kwargs: Forwarded to generate_chat_response.\n        \"\"\"\n        yield from self._generate_chat_response_inner(\n            cancel_event = cancel_event, _adapter_state = use_adapter, **gen_kwargs\n        )\n\n    def generate_chat_response(\n        self,\n        messages: list,\n        system_prompt: str,\n        image = None,\n        temperature: float = 0.7,\n        top_p: float = 0.9,\n        top_k: int = 40,\n        min_p: float = 0.0,\n        max_new_tokens: int = 256,\n        repetition_penalty: float = 1.0,\n        cancel_event = None,\n    ) -> Generator[str, None, None]:\n        \"\"\"\n        Generate response for text or vision models.\n        The generation lock is acquired by the background generation thread.\n        \"\"\"\n        yield from self._generate_chat_response_inner(\n            messages = messages,\n            system_prompt = system_prompt,\n            image = image,\n            temperature = temperature,\n            top_p = top_p,\n            top_k = top_k,\n            min_p = min_p,\n            max_new_tokens = max_new_tokens,\n            repetition_penalty = repetition_penalty,\n            cancel_event = cancel_event,\n        )\n\n    def _generate_chat_response_inner(\n        self,\n        messages: list,\n        system_prompt: str = \"\",\n        image = None,\n        temperature: float = 0.7,\n        top_p: float = 0.9,\n        top_k: int = 40,\n        min_p: float = 0.0,\n        max_new_tokens: int = 256,\n        repetition_penalty: float = 1.0,\n        cancel_event = None,\n        _adapter_state = None,\n    ) -> Generator[str, None, None]:\n        \"\"\"\n        Inner generation logic. Called by both generate_chat_response\n        and generate_with_adapter_control.\n\n        _adapter_state is passed to generate_stream/vision so the background\n        thread can toggle adapters under the generation lock.\n        \"\"\"\n        if not self.active_model_name:\n            yield \"Error: No active model\"\n            return\n\n        model_info = self.models[self.active_model_name]\n        is_vision = model_info.get(\"is_vision\", False)\n        tokenizer = model_info.get(\"tokenizer\") or model_info.get(\"processor\")\n        # Unwrap processor → raw tokenizer for VLMs on the text path\n        tokenizer = getattr(tokenizer, \"tokenizer\", tokenizer)\n        top_k = self._normalize_top_k(top_k)\n\n        if is_vision and image:\n            # Vision model generation (only when an image is actually provided)\n            # Check that the stored processor can actually handle images.\n            # FastVisionModel may return a raw tokenizer (e.g. GemmaTokenizerFast)\n            # instead of a proper ProcessorMixin for some models (e.g. Gemma-3).\n            from transformers import ProcessorMixin\n\n            processor = model_info.get(\"processor\")\n            has_image_processing = processor is not None and (\n                isinstance(processor, ProcessorMixin)\n                or hasattr(processor, \"image_processor\")\n            )\n            if has_image_processing:\n                yield from self._generate_vision_response(\n                    messages,\n                    system_prompt,\n                    image,\n                    temperature,\n                    top_p,\n                    top_k,\n                    min_p,\n                    max_new_tokens,\n                    repetition_penalty,\n                    cancel_event = cancel_event,\n                )\n                return\n            else:\n                logger.warning(\n                    f\"Model '{self.active_model_name}' is marked as vision but its processor \"\n                    f\"({type(processor).__name__}) has no image_processor — \"\n                    f\"falling back to text-only generation (image will be ignored).\"\n                )\n\n        # Text path: Use training pipeline approach\n        # Messages are already in ChatML format from eval.py\n\n        # Step 1: Apply get_chat_template if model is in mapper\n        try:\n            from utils.datasets import (\n                MODEL_TO_TEMPLATE_MAPPER,\n                get_tokenizer_chat_template,\n            )\n\n            model_name_lower = self.active_model_name.lower()\n\n            # Check if model has a registered template\n            if model_name_lower in MODEL_TO_TEMPLATE_MAPPER:\n                template_name = MODEL_TO_TEMPLATE_MAPPER[model_name_lower]\n                logger.info(\n                    f\"Applying chat template '{template_name}' for {self.active_model_name}\"\n                )\n\n                # This modifies the tokenizer with the correct template\n                tokenizer = get_chat_template(\n                    tokenizer,\n                    chat_template = template_name,\n                )\n            else:\n                logger.info(\n                    f\"No registered Unsloth template for {self.active_model_name}, using tokenizer default\"\n                )\n        except Exception as e:\n            logger.warning(f\"Could not apply get_chat_template: {e}\")\n\n        # Step 2: Format with tokenizer.apply_chat_template()\n        try:\n            if not (hasattr(tokenizer, \"chat_template\") and tokenizer.chat_template):\n                raise ValueError(\n                    f\"Model '{self.active_model_name}' has no chat_template set in its \"\n                    f\"tokenizer_config.json. This is usually a problem with the model's \"\n                    f\"HuggingFace repository — it is missing a 'chat_template' key. \"\n                    f\"Please use a model that includes a chat template, or manually set \"\n                    f\"one via tokenizer.chat_template before inference.\"\n                )\n            formatted_prompt = tokenizer.apply_chat_template(\n                messages, tokenize = False, add_generation_prompt = True\n            )\n            logger.debug(f\"Formatted prompt: {formatted_prompt[:200]}...\")\n        except Exception as e:\n            logger.error(f\"Error applying chat template: {e}\")\n            # Fallback to manual formatting\n            formatted_prompt = self.format_chat_prompt(messages, system_prompt)\n\n        # Step 3: Generate\n        yield from self.generate_stream(\n            formatted_prompt,\n            temperature,\n            top_p,\n            top_k,\n            min_p,\n            max_new_tokens,\n            repetition_penalty,\n            cancel_event = cancel_event,\n            _adapter_state = _adapter_state,\n        )\n\n    def _generate_vision_response(\n        self,\n        messages,\n        system_prompt,\n        image,\n        temperature,\n        top_p,\n        top_k,\n        min_p,\n        max_new_tokens,\n        repetition_penalty,\n        cancel_event = None,\n    ) -> Generator[str, None, None]:\n        \"\"\"Handle vision model generation with true token-by-token streaming.\"\"\"\n        model_info = self.models[self.active_model_name]\n        model = model_info[\"model\"]\n        processor = model_info[\"processor\"]\n        # FastVisionModel may return a raw tokenizer (e.g. GemmaTokenizerFast)\n        # instead of a Processor for some models. Safe unwrap for tokenize-only ops.\n        raw_tokenizer = getattr(processor, \"tokenizer\", processor)\n\n        # Extract user message\n        user_message = \"\"\n        if messages and messages[-1][\"role\"] == \"user\":\n            import re\n\n            user_message = messages[-1][\"content\"]\n            user_message = re.sub(r\"<img[^>]*>\", \"\", user_message).strip()\n\n        if not user_message:\n            user_message = \"Describe this image.\" if image else \"Hello\"\n\n        # Prepare vision messages\n        if image:\n            vision_messages = [\n                {\n                    \"role\": \"user\",\n                    \"content\": [\n                        {\"type\": \"image\"},\n                        {\"type\": \"text\", \"text\": user_message},\n                    ],\n                }\n            ]\n\n            input_text = processor.apply_chat_template(\n                vision_messages, add_generation_prompt = True, tokenize = False\n            )\n            inputs = processor(\n                image,\n                input_text,\n                add_special_tokens = False,\n                return_tensors = \"pt\",\n            ).to(self.device)\n        else:\n            # Text-only for vision model\n            formatted_prompt = self.format_chat_prompt(messages, system_prompt)\n            inputs = raw_tokenizer(formatted_prompt, return_tensors = \"pt\").to(\n                self.device\n            )\n\n        # Stream with TextIteratorStreamer + background thread\n        try:\n            from transformers import TextIteratorStreamer\n            import threading\n\n            streamer = TextIteratorStreamer(\n                raw_tokenizer,\n                skip_prompt = True,\n                skip_special_tokens = True,\n                timeout = 0.2,\n            )\n\n            generation_kwargs = dict(\n                **inputs,\n                streamer = streamer,\n                max_new_tokens = max_new_tokens,\n                use_cache = True,\n                do_sample = temperature > 0,\n                temperature = temperature,\n                top_p = top_p,\n                top_k = top_k,\n                min_p = min_p,\n            )\n\n            err: dict[str, str] = {}\n\n            def generate_fn():\n                with self._generation_lock:\n                    try:\n                        model.generate(**generation_kwargs)\n                    except Exception as e:\n                        err[\"msg\"] = str(e)\n                        logger.error(f\"Vision generation error in thread: {e}\")\n                    finally:\n                        try:\n                            streamer.end()\n                        except Exception:\n                            pass\n\n            thread = threading.Thread(target = generate_fn)\n            thread.start()\n\n            output = \"\"\n            from queue import Empty\n\n            generation_complete = False\n            try:\n                while True:\n                    if cancel_event is not None and cancel_event.is_set():\n                        break\n                    try:\n                        new_token = next(streamer)\n                    except StopIteration:\n                        generation_complete = True\n                        break\n                    except Empty:\n                        if not thread.is_alive():\n                            generation_complete = True\n                            break\n                        continue\n                    if new_token:\n                        output += new_token\n                        cleaned = self._clean_generated_text(output)\n                        yield cleaned\n            finally:\n                if cancel_event is not None and not generation_complete:\n                    cancel_event.set()\n                thread.join(timeout = 10)\n                if thread.is_alive():\n                    logger.warning(\n                        \"Vision generation thread did not exit after cancel/join timeout\"\n                    )\n\n            if err.get(\"msg\"):\n                yield f\"Error: {err['msg']}\"\n\n        except Exception as e:\n            logger.error(f\"Vision generation error: {e}\")\n            yield f\"Error: {str(e)}\"\n\n    def generate_audio_input_response(\n        self,\n        messages,\n        system_prompt,\n        audio_array,\n        temperature,\n        top_p,\n        top_k,\n        min_p,\n        max_new_tokens,\n        repetition_penalty,\n        cancel_event = None,\n    ) -> Generator[str, None, None]:\n        \"\"\"Handle audio input (ASR) generation — accepts audio numpy array, streams text output.\n\n        Uses processor.apply_chat_template with audio embedded in messages (Gemma 3n pattern).\n        \"\"\"\n        import threading\n        import numpy as np\n\n        model_info = self.models[self.active_model_name]\n        model = model_info[\"model\"]\n        processor = model_info.get(\"processor\") or model_info.get(\"tokenizer\")\n        raw_tokenizer = getattr(processor, \"tokenizer\", processor)\n\n        # Extract last user text — default matches notebook prompt\n        user_text = \"Please transcribe this audio.\"\n        if messages:\n            for msg in reversed(messages):\n                if msg[\"role\"] == \"user\" and msg.get(\"content\"):\n                    user_text = msg[\"content\"]\n                    break\n\n        # Use ASR-specific system prompt if user hasn't set a custom one\n        if not system_prompt:\n            system_prompt = \"You are an assistant that transcribes speech accurately.\"\n\n        # Build messages in Gemma 3n format — audio goes INTO apply_chat_template\n        audio_messages = [\n            {\"role\": \"system\", \"content\": [{\"type\": \"text\", \"text\": system_prompt}]},\n            {\n                \"role\": \"user\",\n                \"content\": [\n                    {\"type\": \"audio\", \"audio\": audio_array},\n                    {\"type\": \"text\", \"text\": user_text},\n                ],\n            },\n        ]\n\n        # apply_chat_template handles audio embedding + tokenization in one step\n        inputs = processor.apply_chat_template(\n            audio_messages,\n            add_generation_prompt = True,\n            tokenize = True,\n            return_dict = True,\n            return_tensors = \"pt\",\n            truncation = False,\n        ).to(self.device)\n\n        try:\n            from transformers import TextIteratorStreamer\n            from queue import Empty\n\n            streamer = TextIteratorStreamer(\n                raw_tokenizer,\n                skip_prompt = True,\n                skip_special_tokens = True,\n                timeout = 0.2,\n            )\n\n            # Notebook uses do_sample=False for ASR (greedy decoding for accuracy)\n            generation_kwargs = dict(\n                **inputs,\n                streamer = streamer,\n                max_new_tokens = max_new_tokens,\n                use_cache = True,\n                do_sample = False,\n            )\n\n            err: dict[str, str] = {}\n\n            def generate_fn():\n                with self._generation_lock:\n                    try:\n                        model.generate(**generation_kwargs)\n                    except Exception as e:\n                        err[\"msg\"] = str(e)\n                        logger.error(f\"Audio input generation error in thread: {e}\")\n                    finally:\n                        try:\n                            streamer.end()\n                        except Exception:\n                            pass\n\n            thread = threading.Thread(target = generate_fn)\n            thread.start()\n\n            output = \"\"\n            try:\n                while True:\n                    if cancel_event is not None and cancel_event.is_set():\n                        break\n                    try:\n                        new_token = next(streamer)\n                    except StopIteration:\n                        break\n                    except Empty:\n                        if not thread.is_alive():\n                            break\n                        continue\n                    if new_token:\n                        output += new_token\n                        yield new_token\n            finally:\n                if cancel_event is not None:\n                    cancel_event.set()\n                thread.join(timeout = 10)\n                if thread.is_alive():\n                    logger.warning(\n                        \"Audio input generation thread did not exit after cancel/join timeout\"\n                    )\n\n            if err.get(\"msg\"):\n                yield f\"Error: {err['msg']}\"\n\n        except Exception as e:\n            logger.error(f\"Audio input generation error: {e}\")\n            yield f\"Error: {str(e)}\"\n\n    def generate_whisper_response(\n        self, audio_array, cancel_event = None\n    ) -> Generator[str, None, None]:\n        \"\"\"Whisper ASR — takes audio numpy array, yields transcribed text.\n\n        Uses the pre-built transformers pipeline (created during model loading).\n        \"\"\"\n        model_info = self.models[self.active_model_name]\n        whisper_pipe = model_info.get(\"whisper_pipeline\")\n        if not whisper_pipe:\n            yield \"Error: Whisper pipeline not initialized\"\n            return\n\n        try:\n            with self._generation_lock:\n                result = whisper_pipe({\"raw\": audio_array, \"sampling_rate\": 16000})\n\n            text = result.get(\"text\", \"\") if isinstance(result, dict) else str(result)\n            if text:\n                yield text\n        except Exception as e:\n            logger.error(f\"Whisper ASR error: {e}\")\n            yield f\"Error: {str(e)}\"\n\n    def _is_gpt_oss_model(self, model_name: str = None) -> bool:\n        \"\"\"Check if the given (or active) model uses the gpt-oss harmony protocol.\"\"\"\n        name = (model_name or self.active_model_name or \"\").lower()\n        try:\n            from utils.datasets import MODEL_TO_TEMPLATE_MAPPER\n\n            # Exact match\n            if MODEL_TO_TEMPLATE_MAPPER.get(name) == \"gpt-oss\":\n                return True\n            # Partial match (e.g. name-bnb-4bit variants)\n            for key, tmpl in MODEL_TO_TEMPLATE_MAPPER.items():\n                if tmpl == \"gpt-oss\" and (key in name or name in key):\n                    return True\n        except Exception:\n            pass\n        return \"gpt-oss\" in name\n\n    def generate_stream(\n        self,\n        prompt: str,\n        temperature: float = 0.7,\n        top_p: float = 0.9,\n        top_k: int = 40,\n        min_p: float = 0.0,\n        max_new_tokens: int = 256,\n        repetition_penalty: float = 1.0,\n        cancel_event = None,\n        _adapter_state = None,\n    ) -> Generator[str, None, None]:\n        \"\"\"Generate streaming text response (text models only).\n\n        _adapter_state: if not None, the background thread toggles adapters\n        before model.generate(), all under _generation_lock.\n        \"\"\"\n        if not self.active_model_name:\n            yield \"Error: No active model\"\n            return\n\n        model_info = self.models[self.active_model_name]\n        model = model_info[\"model\"]\n        # For VLMs the stored \"tokenizer\" is actually the processor.\n        # Unwrap to get the real tokenizer so TextIteratorStreamer's\n        # skip_prompt / skip_special_tokens work correctly.\n        tokenizer = model_info[\"tokenizer\"]\n        tokenizer = getattr(tokenizer, \"tokenizer\", tokenizer)\n\n        try:\n            inputs = tokenizer(prompt, return_tensors = \"pt\").to(model.device)\n\n            from transformers import TextIteratorStreamer\n            import threading\n\n            # Use HarmonyTextStreamer for gpt-oss models to properly parse\n            # the multi-channel harmony protocol into <think> tags\n            if self._is_gpt_oss_model():\n                try:\n                    streamer = HarmonyTextStreamer(\n                        tokenizer,\n                        skip_prompt = True,\n                        timeout = 0.2,\n                    )\n                except Exception as e:\n                    logger.warning(\n                        f\"HarmonyTextStreamer init failed, falling back: {e}\"\n                    )\n                    streamer = TextIteratorStreamer(\n                        tokenizer,\n                        skip_prompt = True,\n                        skip_special_tokens = True,\n                        timeout = 0.2,\n                    )\n            else:\n                streamer = TextIteratorStreamer(\n                    tokenizer,\n                    skip_prompt = True,\n                    skip_special_tokens = True,\n                    timeout = 0.2,\n                )\n\n            generation_kwargs = dict(\n                **inputs,\n                streamer = streamer,\n                max_new_tokens = max_new_tokens,\n                temperature = temperature,\n                top_p = top_p,\n                top_k = top_k,\n                min_p = min_p,\n                repetition_penalty = repetition_penalty,\n                do_sample = temperature > 0,\n                eos_token_id = tokenizer.eos_token_id,\n                pad_token_id = tokenizer.eos_token_id\n                if tokenizer.pad_token_id is None\n                else tokenizer.pad_token_id,\n            )\n            if cancel_event is not None:\n                from transformers.generation.stopping_criteria import (\n                    StoppingCriteria,\n                    StoppingCriteriaList,\n                )\n\n                class _CancelCriteria(StoppingCriteria):\n                    def __init__(self, ev):\n                        self.ev = ev\n\n                    def __call__(self, input_ids, scores, **kwargs):\n                        return self.ev.is_set()\n\n                generation_kwargs[\"stopping_criteria\"] = StoppingCriteriaList(\n                    [_CancelCriteria(cancel_event)]\n                )\n\n            def generate_fn():\n                with self._generation_lock:\n                    try:\n                        if _adapter_state is not None:\n                            self._apply_adapter_state(_adapter_state)\n                        model.generate(**generation_kwargs)\n                    except Exception as e:\n                        err[\"msg\"] = str(e)\n                        logger.error(f\"Generation error: {e}\")\n                    finally:\n                        try:\n                            streamer.end()\n                        except Exception:\n                            pass\n\n            err: dict[str, str] = {}\n            thread = threading.Thread(target = generate_fn)\n            thread.start()\n\n            output = \"\"\n            from queue import Empty\n\n            generation_complete = False\n            try:\n                while True:\n                    if cancel_event is not None and cancel_event.is_set():\n                        break\n                    try:\n                        new_token = next(streamer)\n                    except StopIteration:\n                        generation_complete = True\n                        break\n                    except Empty:\n                        if not thread.is_alive():\n                            generation_complete = True\n                            break\n                        continue\n                    if new_token:\n                        output += new_token\n                        cleaned = self._clean_generated_text(output)\n                        yield cleaned\n            finally:\n                # Only set cancel_event when we exited early (user cancel),\n                # NOT on normal completion.  cancel_event is a shared mp.Event\n                # — setting it unconditionally would leave a stale cancel\n                # signal that could interfere with the next serialized\n                # generation request (e.g. in compare mode).\n                if cancel_event is not None and not generation_complete:\n                    cancel_event.set()\n                thread.join(timeout = 10)\n                if thread.is_alive():\n                    logger.warning(\n                        \"Generation thread did not exit after cancel/join timeout\"\n                    )\n\n            if err.get(\"msg\"):\n                yield f\"Error: {err['msg']}\"\n\n        except Exception as e:\n            logger.error(f\"Error during generation: {e}\")\n            yield f\"Error: {str(e)}\"\n\n    # ── Audio (TTS) Generation ────────────────────────────────────\n\n    def generate_audio_response(\n        self,\n        text: str,\n        temperature: float = 0.6,\n        top_p: float = 0.95,\n        top_k: int = 50,\n        min_p: float = 0.0,\n        max_new_tokens: int = 2048,\n        repetition_penalty: float = 1.0,\n        use_adapter: Optional[Union[bool, str]] = None,\n    ) -> Tuple[bytes, int]:\n        \"\"\"\n        Generate audio from text for TTS models.\n        Returns (wav_bytes, sample_rate).\n        Blocking — generates complete audio before returning.\n        \"\"\"\n        if not self.active_model_name:\n            raise RuntimeError(\"No active model\")\n\n        model_info = self.models[self.active_model_name]\n        audio_type = model_info.get(\"audio_type\")\n        model = model_info[\"model\"]\n        tokenizer = model_info.get(\"tokenizer\")\n\n        if not audio_type:\n            raise RuntimeError(f\"Model {self.active_model_name} is not an audio model\")\n\n        top_k = self._normalize_top_k(top_k)\n\n        with self._generation_lock:\n            if use_adapter is not None:\n                self._apply_adapter_state(use_adapter)\n\n            if audio_type == \"snac\":\n                return self._generate_snac(\n                    model,\n                    tokenizer,\n                    text,\n                    temperature,\n                    top_p,\n                    max_new_tokens,\n                    repetition_penalty,\n                )\n            elif audio_type == \"csm\":\n                processor = model_info.get(\"processor\", tokenizer)\n                return self._generate_csm(model, processor, text, max_new_tokens)\n            elif audio_type == \"bicodec\":\n                return self._generate_bicodec(\n                    model, tokenizer, text, temperature, top_k, max_new_tokens\n                )\n            elif audio_type == \"dac\":\n                return self._generate_dac(\n                    model,\n                    tokenizer,\n                    text,\n                    temperature,\n                    top_k,\n                    top_p,\n                    min_p,\n                    max_new_tokens,\n                    repetition_penalty,\n                )\n            else:\n                raise RuntimeError(f\"Unknown audio_type: {audio_type}\")\n\n    def _generate_snac(\n        self,\n        model,\n        tokenizer,\n        text,\n        temperature,\n        top_p,\n        max_new_tokens,\n        repetition_penalty,\n    ):\n        \"\"\"Generate audio using SNAC codec (Orpheus).\"\"\"\n        device = model.device\n        start_token = torch.tensor([[128259]], device = device)  # START_OF_HUMAN\n        end_tokens = torch.tensor(\n            [[128009, 128260]], device = device\n        )  # EOT, END_OF_HUMAN\n        text_ids = tokenizer(text, return_tensors = \"pt\").input_ids.to(device)\n        input_ids = torch.cat([start_token, text_ids, end_tokens], dim = 1)\n        attention_mask = torch.ones_like(input_ids)\n\n        generated = model.generate(\n            input_ids = input_ids,\n            attention_mask = attention_mask,\n            max_new_tokens = max_new_tokens,\n            do_sample = True,\n            temperature = temperature,\n            top_p = top_p,\n            repetition_penalty = repetition_penalty,\n            eos_token_id = 128258,  # END_OF_SPEECH\n            use_cache = True,\n        )\n        return self._audio_codec_manager.decode_snac(generated, str(device))\n\n    def _generate_csm(self, model, processor, text, max_new_tokens):\n        \"\"\"Generate audio using CSM (Sesame).\"\"\"\n        speaker_id = 0\n        inputs = processor(\n            f\"[{speaker_id}]{text}\", add_special_tokens = True, return_tensors = \"pt\"\n        ).to(model.device)\n        audio_values = model.generate(\n            **inputs, max_new_tokens = max_new_tokens, output_audio = True\n        )\n        return self._audio_codec_manager.decode_csm(audio_values)\n\n    def _generate_bicodec(\n        self, model, tokenizer, text, temperature, top_k, max_new_tokens\n    ):\n        \"\"\"Generate audio using BiCodec (Spark-TTS).\"\"\"\n        prompt = (\n            \"<|task_tts|><|start_content|>\"\n            + text\n            + \"<|end_content|><|start_global_token|>\"\n        )\n        inputs = tokenizer([prompt], return_tensors = \"pt\").to(model.device)\n        generated = model.generate(\n            **inputs,\n            max_new_tokens = max_new_tokens,\n            do_sample = True,\n            temperature = temperature,\n            top_k = top_k,\n            eos_token_id = tokenizer.eos_token_id,\n            pad_token_id = tokenizer.pad_token_id,\n        )\n        new_tokens = generated[:, inputs.input_ids.shape[1] :]\n        decoded_text = tokenizer.batch_decode(new_tokens, skip_special_tokens = False)[0]\n        return self._audio_codec_manager.decode_bicodec(decoded_text, str(model.device))\n\n    def _generate_dac(\n        self,\n        model,\n        tokenizer,\n        text,\n        temperature,\n        top_k,\n        top_p,\n        min_p,\n        max_new_tokens,\n        repetition_penalty,\n    ):\n        \"\"\"Generate audio using DAC (OuteTTS). Follows Oute_TTS_(1B).ipynb exactly.\"\"\"\n        # Monkey-patch RepetitionPenaltyLogitsProcessor with a 64-token penalty\n        # window (same as the OuteTTS notebook) to avoid degenerate repetition.\n        self._patch_repetition_penalty_processor()\n\n        prompt = (\n            \"<|im_start|>\\n<|text_start|>\"\n            + text\n            + \"<|text_end|>\\n<|audio_start|><|global_features_start|>\\n\"\n        )\n        with torch.inference_mode():\n            with torch.amp.autocast(\"cuda\", dtype = model.dtype):\n                inputs = tokenizer([prompt], return_tensors = \"pt\").to(model.device)\n                generated = model.generate(\n                    **inputs,\n                    temperature = temperature,\n                    top_k = top_k,\n                    top_p = top_p,\n                    min_p = min_p,\n                    repetition_penalty = repetition_penalty,\n                    max_new_tokens = max_new_tokens,\n                )\n        decoded_text = tokenizer.batch_decode(generated, skip_special_tokens = False)[0]\n        return self._audio_codec_manager.decode_dac(decoded_text, str(model.device))\n\n    _repetition_penalty_patched = False\n\n    @classmethod\n    def _patch_repetition_penalty_processor(cls):\n        \"\"\"\n        Monkey-patch transformers' RepetitionPenaltyLogitsProcessor with a\n        64-token sliding window variant (from the OuteTTS notebook).\n        Only applied once per process.\n        \"\"\"\n        if cls._repetition_penalty_patched:\n            return\n        cls._repetition_penalty_patched = True\n\n        from transformers import LogitsProcessor\n        import transformers.generation.utils as generation_utils\n\n        class RepetitionPenaltyLogitsProcessorPatch(LogitsProcessor):\n            def __init__(self, penalty: float):\n                self.penalty_last_n = 64\n                if not isinstance(penalty, float) or penalty <= 0:\n                    raise ValueError(\n                        f\"`penalty` has to be a positive float, but is {penalty}\"\n                    )\n                self.penalty = penalty\n\n            @torch.no_grad()\n            def __call__(\n                self, input_ids: torch.LongTensor, scores: torch.FloatTensor\n            ) -> torch.FloatTensor:\n                if self.penalty_last_n == 0 or self.penalty == 1.0:\n                    return scores\n                batch_size, seq_len = input_ids.shape\n                vocab_size = scores.shape[-1]\n                for b in range(batch_size):\n                    start_index = max(0, seq_len - self.penalty_last_n)\n                    window_indices = input_ids[b, start_index:]\n                    if window_indices.numel() == 0:\n                        continue\n                    for token_id in set(window_indices.tolist()):\n                        if token_id >= vocab_size:\n                            continue\n                        logit = scores[b, token_id]\n                        scores[b, token_id] = (\n                            logit * self.penalty if logit <= 0 else logit / self.penalty\n                        )\n                return scores\n\n        generation_utils.RepetitionPenaltyLogitsProcessor = (\n            RepetitionPenaltyLogitsProcessorPatch\n        )\n        logger.info(\n            \"Patched RepetitionPenaltyLogitsProcessor with 64-token window for OuteTTS\"\n        )\n\n    def format_chat_prompt(self, messages: list, system_prompt: str = None) -> str:\n        if not self.active_model_name or self.active_model_name not in self.models:\n            logger.error(\"No active model available\")\n            return \"\"\n\n        if self.models[self.active_model_name].get(\"tokenizer\") is None:\n            logger.error(\"Tokenizer not loaded for active model\")\n            return \"\"\n\n        chat_template_info = self.models[self.active_model_name].get(\n            \"chat_template_info\", {}\n        )\n        tokenizer = self.models[self.active_model_name][\"tokenizer\"]\n        tokenizer = getattr(tokenizer, \"tokenizer\", tokenizer)\n\n        chat_messages = []\n\n        if system_prompt:\n            chat_messages.append({\"role\": \"system\", \"content\": system_prompt})\n\n        last_role = \"system\" if system_prompt else None\n\n        for msg in messages:\n            role = msg.get(\"role\", \"\")\n            content = msg.get(\"content\", \"\")\n\n            if role in [\"system\", \"user\", \"assistant\"] and content.strip():\n                if role == last_role:\n                    logger.debug(\n                        f\"Skipping consecutive {role} message to maintain alternation\"\n                    )\n                    continue\n\n                if role == \"user\":\n                    import re\n\n                    clean_content = re.sub(r\"<[^>]+>\", \"\", content).strip()\n                    if clean_content:\n                        chat_messages.append({\"role\": role, \"content\": clean_content})\n                        last_role = role\n                elif role == \"assistant\" and content.strip():\n                    chat_messages.append({\"role\": role, \"content\": content})\n                    last_role = role\n                elif role == \"system\":\n                    continue\n\n        if chat_messages and chat_messages[-1][\"role\"] == \"assistant\":\n            logger.debug(\n                \"Removing final assistant message to ensure proper alternation\"\n            )\n            chat_messages.pop()\n\n        logger.info(f\"Sending {len(chat_messages)} messages to tokenizer:\")\n        for i, msg in enumerate(chat_messages):\n            logger.info(f\"  {i}: {msg['role']} - {msg['content'][:50]}...\")\n\n        try:\n            formatted_prompt = tokenizer.apply_chat_template(\n                chat_messages, tokenize = False, add_generation_prompt = True\n            )\n            logger.info(f\"Successfully applied tokenizer's native chat template\")\n            return formatted_prompt\n        except Exception as e:\n            error_msg = str(e).lower()\n            if (\n                \"chat_template is not set\" in error_msg\n                or \"no template argument\" in error_msg\n            ):\n                logger.info(\n                    f\"Base model detected - no built-in chat template available, using fallback formatting\"\n                )\n            else:\n                logger.warning(f\"Failed to apply tokenizer chat template: {e}\")\n            logger.debug(\n                f\"\"\"Failed with messages: {[f\"{m['role']}: {m['content'][:30]}...\" for m in chat_messages]}\"\"\"\n            )\n\n        if chat_template_info.get(\"has_template\", False):\n            logger.info(\n                \"Falling back to manual template formatting based on detected patterns\"\n            )\n            template_type = chat_template_info.get(\"format_type\", \"generic\")\n            manual_prompt = self._format_chat_manual(\n                chat_messages,\n                template_type,\n                chat_template_info.get(\"special_tokens\", {}),\n            )\n            logger.info(f\"Manual template result: {manual_prompt[:200]}...\")\n            return manual_prompt\n        else:\n            logger.info(\"Using generic chat formatting for base model\")\n            return self._format_generic_template(chat_messages, {})\n\n    def _format_chat_manual(\n        self, messages: list, template_type: str, special_tokens: dict\n    ) -> str:\n        \"\"\"\n        Manual chat formatting fallback for when tokenizer template fails\n\n        Args:\n            messages: List of message dictionaries\n            template_type: Detected template type\n            special_tokens: Dictionary of special tokens\n\n        Returns:\n            str: Manually formatted prompt\n        \"\"\"\n        if template_type == \"llama3\":\n            return self._format_llama3_template(messages, special_tokens)\n        elif template_type == \"mistral\":\n            return self._format_mistral_template(messages, special_tokens)\n        elif template_type == \"chatml\":\n            return self._format_chatml_template(messages, special_tokens)\n        elif template_type == \"alpaca\":\n            return self._format_alpaca_template(messages, special_tokens)\n        else:\n            return self._format_generic_template(messages, special_tokens)\n\n    def _format_llama3_template(self, messages: list, special_tokens: dict) -> str:\n        \"\"\"Format messages using Llama 3 template\"\"\"\n        bos_token = special_tokens.get(\"bos_token\", \"<|begin_of_text|>\")\n        formatted = bos_token\n\n        for msg in messages:\n            role = msg[\"role\"]\n            content = msg[\"content\"]\n            formatted += (\n                f\"<|start_header_id|>{role}<|end_header_id|>\\n\\n{content}<|eot_id|>\"\n            )\n\n        formatted += \"<|start_header_id|>assistant<|end_header_id|>\\n\\n\"\n        return formatted\n\n    def _format_mistral_template(self, messages: list, special_tokens: dict) -> str:\n        \"\"\"Format messages using Mistral template\"\"\"\n        bos_token = special_tokens.get(\"bos_token\", \"<s>\")\n        formatted = bos_token\n\n        system_msg = None\n        conversation = []\n\n        for msg in messages:\n            if msg[\"role\"] == \"system\":\n                system_msg = msg[\"content\"]\n            else:\n                conversation.append(msg)\n\n        i = 0\n        while i < len(conversation):\n            if conversation[i][\"role\"] == \"user\":\n                user_content = conversation[i][\"content\"]\n\n                if system_msg and i == 0:\n                    user_content = f\"{system_msg}\\n\\n{user_content}\"\n\n                formatted += f\"[INST] {user_content} [/INST]\"\n\n                if (\n                    i + 1 < len(conversation)\n                    and conversation[i + 1][\"role\"] == \"assistant\"\n                ):\n                    formatted += f\" {conversation[i + 1]['content']}</s>\"\n                    i += 2\n                else:\n                    formatted += \" \"\n                    break\n            else:\n                i += 1\n\n        return formatted\n\n    def _format_chatml_template(self, messages: list, special_tokens: dict) -> str:\n        \"\"\"Format messages using ChatML template\"\"\"\n        formatted = \"\"\n\n        for msg in messages:\n            role = msg[\"role\"]\n            content = msg[\"content\"]\n            formatted += f\"<|im_start|>{role}\\n{content}<|im_end|>\\n\"\n\n        formatted += \"<|im_start|>assistant\\n\"\n        return formatted\n\n    def _format_alpaca_template(self, messages: list, special_tokens: dict) -> str:\n        \"\"\"Format messages using Alpaca template\"\"\"\n        formatted = \"\"\n        system_msg = None\n\n        for msg in messages:\n            if msg[\"role\"] == \"system\":\n                system_msg = msg[\"content\"]\n            elif msg[\"role\"] == \"user\":\n                if system_msg:\n                    formatted += f\"### Instruction:\\n{system_msg}\\n\\n### Input:\\n{msg['content']}\\n\\n### Response:\\n\"\n                    system_msg = None\n                else:\n                    formatted += f\"### Human:\\n{msg['content']}\\n\\n### Assistant:\\n\"\n            elif msg[\"role\"] == \"assistant\":\n                formatted += f\"{msg['content']}\\n\\n\"\n\n        return formatted\n\n    def _format_generic_template(self, messages: list, special_tokens: dict) -> str:\n        \"\"\"Generic fallback formatting\"\"\"\n        formatted = \"\"\n\n        for msg in messages:\n            role = msg[\"role\"].title()\n            content = msg[\"content\"]\n            formatted += f\"{role}: {content}\\n\"\n\n        formatted += \"Assistant: \"\n        return formatted\n\n    def check_vision_model_compatibility(self) -> bool:\n        \"\"\"\n        Check if current model supports vision.\n\n        Returns:\n            bool: True if current model supports vision, False otherwise\n        \"\"\"\n        current_model = self.get_current_model()\n        if current_model and current_model in self.models:\n            return self.models[current_model].get(\"is_vision\", False)\n        return False\n\n    def _reset_model_generation_state(self, model_name: str):\n        \"\"\"Reset generation state for a specific model to prevent contamination.\"\"\"\n        if model_name not in self.models:\n            return\n\n        model = self.models[model_name].get(\"model\")\n        if not model:\n            return\n\n        try:\n            # This is a common pattern for Unsloth/Hugging Face models\n            if hasattr(model, \"past_key_values\"):\n                model.past_key_values = None\n            if hasattr(model, \"generation_config\"):\n                if hasattr(model.generation_config, \"past_key_values\"):\n                    model.generation_config.past_key_values = None\n\n            logger.debug(f\"Reset generation state for model: {model_name}\")\n        except Exception as e:\n            logger.warning(f\"Could not fully reset model state for {model_name}: {e}\")\n\n    def reset_generation_state(self):\n        \"\"\"Reset any cached generation state to prevent hanging after errors\"\"\"\n        try:\n            # Clear cached states for ALL loaded models\n            for model_name in self.models.keys():\n                self._reset_model_generation_state(model_name)\n\n            clear_gpu_cache()\n            logger.debug(\"Cleared GPU cache\")\n\n            import gc\n\n            gc.collect()\n            logger.info(\"Performed comprehensive generation state reset\")\n\n        except Exception as e:\n            logger.warning(f\"Could not fully reset generation state: {e}\")\n\n    def resize_image(self, img, max_size: int = 800):\n        \"\"\"Resize image while maintaining aspect ratio if either dimension exceeds max_size\"\"\"\n        if img is None:\n            return None\n        if img.size[0] > max_size or img.size[1] > max_size:\n            from PIL import Image\n\n            ratio = min(max_size / img.size[0], max_size / img.size[1])\n            new_size = (int(img.size[0] * ratio), int(img.size[1] * ratio))\n            return img.resize(new_size, Image.Resampling.LANCZOS)\n        return img\n\n    def _clean_generated_text(self, text: str) -> str:\n        \"\"\"Strip leaked special tokens using the tokenizer's own token list.\"\"\"\n        if self._is_gpt_oss_model():\n            # HarmonyTextStreamer produces clean <think>...</think> output.\n            # Strip harmony protocol tokens and other gpt-oss added tokens\n            # (e.g. <|return|>) that may leak past the streamer.\n            import re\n\n            text = re.sub(r\"<\\|[a-z_]+\\|>\", \"\", text)\n            return text.strip()\n\n        tokenizer = self.models.get(self.active_model_name, {}).get(\"tokenizer\")\n        if tokenizer:\n            for token in getattr(tokenizer, \"all_special_tokens\", []):\n                if token in text:\n                    text = text.replace(token, \"\")\n        return text.strip()\n\n    def _load_chat_template_info(self, model_name: str):\n        if model_name not in self.models or not self.models[model_name].get(\n            \"tokenizer\"\n        ):\n            return\n\n        tokenizer = self.models[model_name][\"tokenizer\"]\n        chat_template_info = {\n            \"has_template\": False,\n            \"template\": None,\n            \"format_type\": \"generic\",\n            \"special_tokens\": {},\n            \"template_name\": None,\n        }\n\n        try:\n            from utils.datasets import MODEL_TO_TEMPLATE_MAPPER\n\n            # Try exact match first\n            model_name_lower = model_name.lower()\n            if model_name_lower in MODEL_TO_TEMPLATE_MAPPER:\n                chat_template_info[\"template_name\"] = MODEL_TO_TEMPLATE_MAPPER[\n                    model_name_lower\n                ]\n                logger.info(\n                    f\"Detected template '{chat_template_info['template_name']}' for {model_name} from mapper\"\n                )\n            else:\n                # Try partial match (for variants like model_name-bnb-4bit)\n                for key in MODEL_TO_TEMPLATE_MAPPER:\n                    if key in model_name_lower or model_name_lower in key:\n                        chat_template_info[\"template_name\"] = MODEL_TO_TEMPLATE_MAPPER[\n                            key\n                        ]\n                        logger.info(\n                            f\"Detected template '{chat_template_info['template_name']}' for {model_name} (partial match)\"\n                        )\n                        break\n        except Exception as e:\n            logger.warning(\n                f\"Could not detect template from mapper for {model_name}: {e}\"\n            )\n\n        try:\n            if hasattr(tokenizer, \"chat_template\") and tokenizer.chat_template:\n                chat_template_info[\"has_template\"] = True\n                chat_template_info[\"template\"] = tokenizer.chat_template\n\n                template_str = tokenizer.chat_template.lower()\n\n                if (\n                    \"start_header_id\" in template_str\n                    and \"end_header_id\" in template_str\n                ):\n                    chat_template_info[\"format_type\"] = \"llama3\"\n                elif \"[inst]\" in template_str and \"[/inst]\" in template_str:\n                    chat_template_info[\"format_type\"] = \"mistral\"\n                elif \"<|im_start|>\" in template_str and \"<|im_end|>\" in template_str:\n                    chat_template_info[\"format_type\"] = \"chatml\"\n                elif \"### instruction:\" in template_str or \"### human:\" in template_str:\n                    chat_template_info[\"format_type\"] = \"alpaca\"\n                else:\n                    chat_template_info[\"format_type\"] = \"custom\"\n\n                logger.info(\n                    f\"Loaded chat template for {model_name} (detected as {chat_template_info['format_type']} format)\"\n                )\n                logger.debug(f\"Template preview: {tokenizer.chat_template[:200]}...\")\n\n                special_tokens = {}\n                if hasattr(tokenizer, \"bos_token\") and tokenizer.bos_token:\n                    special_tokens[\"bos_token\"] = tokenizer.bos_token\n                if hasattr(tokenizer, \"eos_token\") and tokenizer.eos_token:\n                    special_tokens[\"eos_token\"] = tokenizer.eos_token\n                if hasattr(tokenizer, \"pad_token\") and tokenizer.pad_token:\n                    special_tokens[\"pad_token\"] = tokenizer.pad_token\n\n                chat_template_info[\"special_tokens\"] = special_tokens\n\n            else:\n                logger.info(\n                    f\"No chat template found for {model_name}, will use generic formatting\"\n                )\n\n        except Exception as e:\n            logger.error(f\"Error loading chat template info for {model_name}: {e}\")\n\n        self.models[model_name][\"chat_template_info\"] = chat_template_info\n\n        if chat_template_info[\"has_template\"]:\n            logger.info(\n                f\"Chat template loaded for {model_name}: {chat_template_info['format_type']} format\"\n            )\n        else:\n            logger.info(\n                f\"No built-in chat template for {model_name}, will use generic formatting\"\n            )\n\n    def get_current_model(self) -> Optional[str]:\n        \"\"\"Get currently active model name\"\"\"\n        return self.active_model_name\n\n    def is_model_loading(self) -> bool:\n        \"\"\"Check if any model is currently loading\"\"\"\n        return len(self.loading_models) > 0\n\n    def get_loading_model(self) -> Optional[str]:\n        \"\"\"Get name of currently loading model\"\"\"\n        return next(iter(self.loading_models)) if self.loading_models else None\n\n    def load_model_simple(\n        self,\n        model_path: str,\n        hf_token: Optional[str] = None,\n        max_seq_length: int = 2048,\n        load_in_4bit: bool = True,\n    ) -> bool:\n        \"\"\"\n        Simple model loading wrapper for chat interface.\n        Accepts model path as string and handles ModelConfig creation internally.\n\n        Args:\n            model_path: Model name or path (e.g., \"unsloth/llama-3-8b\")\n            hf_token: HuggingFace token for gated models\n            max_seq_length: Maximum sequence length\n            load_in_4bit: Whether to use 4-bit quantization\n\n        Returns:\n            bool: True if successful, False otherwise\n        \"\"\"\n        try:\n            # Create config from string path\n            config = ModelConfig.from_ui_selection(\n                model_path,\n                lora_path = None,  # No LoRA for chat\n                is_lora = False,\n            )\n\n            # Call existing load_model with config\n            return self.load_model(\n                config = config,\n                max_seq_length = max_seq_length,\n                dtype = None,  # Auto-detect\n                load_in_4bit = load_in_4bit,\n                hf_token = hf_token,\n            )\n\n        except Exception as e:\n            logger.error(f\"Error in load_model_simple: {e}\")\n            return False\n\n\n# Global inference backend instance\ninference_backend = InferenceBackend()\n\n\ndef get_inference_backend() -> InferenceBackend:\n    return inference_backend\n"
  },
  {
    "path": "studio/backend/core/inference/llama_cpp.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nllama-server inference backend for GGUF models.\n\nManages a llama-server subprocess and proxies chat completions\nthrough its OpenAI-compatible /v1/chat/completions endpoint.\n\"\"\"\n\nimport atexit\nimport contextlib\nimport json\nimport struct\nimport structlog\nfrom loggers import get_logger\nimport shutil\nimport signal\nimport socket\nimport subprocess\nimport threading\nimport time\nfrom pathlib import Path\nfrom typing import Generator, Optional\n\nimport httpx\n\nlogger = get_logger(__name__)\n\n\nclass LlamaCppBackend:\n    \"\"\"\n    Manages a llama-server subprocess for GGUF model inference.\n\n    Lifecycle:\n        1. load_model()  — starts llama-server with the GGUF file\n        2. generate_chat_completion() — proxies to /v1/chat/completions, streams back\n        3. unload_model() — terminates llama-server subprocess\n    \"\"\"\n\n    def __init__(self):\n        self._process: Optional[subprocess.Popen] = None\n        self._port: Optional[int] = None\n        self._model_identifier: Optional[str] = None\n        self._gguf_path: Optional[str] = None\n        self._hf_repo: Optional[str] = None\n        self._hf_variant: Optional[str] = None\n        self._is_vision: bool = False\n        self._healthy = False\n        self._context_length: Optional[int] = None\n        self._chat_template: Optional[str] = None\n        self._supports_reasoning: bool = False\n        self._supports_tools: bool = False\n        self._cache_type_kv: Optional[str] = None\n        self._reasoning_default: bool = True\n        self._lock = threading.Lock()\n        self._stdout_lines: list[str] = []\n        self._stdout_thread: Optional[threading.Thread] = None\n        self._cancel_event = threading.Event()\n\n        self._kill_orphaned_servers()\n        atexit.register(self._cleanup)\n\n    # ── Properties ────────────────────────────────────────────────\n\n    @property\n    def is_loaded(self) -> bool:\n        return self._process is not None and self._healthy\n\n    @property\n    def is_active(self) -> bool:\n        \"\"\"True if a llama-server process exists (loading or loaded).\"\"\"\n        return self._process is not None\n\n    @property\n    def base_url(self) -> str:\n        return f\"http://127.0.0.1:{self._port}\"\n\n    @property\n    def model_identifier(self) -> Optional[str]:\n        return self._model_identifier\n\n    @property\n    def is_vision(self) -> bool:\n        return self._is_vision\n\n    @property\n    def hf_variant(self) -> Optional[str]:\n        return self._hf_variant\n\n    @property\n    def context_length(self) -> Optional[int]:\n        return self._context_length\n\n    @property\n    def chat_template(self) -> Optional[str]:\n        return self._chat_template\n\n    @property\n    def supports_reasoning(self) -> bool:\n        return self._supports_reasoning\n\n    @property\n    def reasoning_default(self) -> bool:\n        return self._reasoning_default\n\n    @property\n    def supports_tools(self) -> bool:\n        return self._supports_tools\n\n    @property\n    def cache_type_kv(self) -> Optional[str]:\n        return self._cache_type_kv\n\n    # ── Binary discovery ──────────────────────────────────────────\n\n    @staticmethod\n    def _find_llama_server_binary() -> Optional[str]:\n        \"\"\"\n        Locate the llama-server binary.\n\n        Search order:\n        1.  LLAMA_SERVER_PATH environment variable (direct path to binary)\n        1b. UNSLOTH_LLAMA_CPP_PATH env var (custom llama.cpp install dir)\n        2.  ~/.unsloth/llama.cpp/llama-server        (make build, root dir)\n        3.  ~/.unsloth/llama.cpp/build/bin/llama-server  (cmake build, Linux)\n        4.  ~/.unsloth/llama.cpp/build/bin/Release/llama-server.exe  (cmake build, Windows)\n        5.  ./llama.cpp/llama-server                 (legacy: make build, root dir)\n        6.  ./llama.cpp/build/bin/llama-server        (legacy: cmake in-tree build)\n        7.  llama-server on PATH                     (system install)\n        8.  ./bin/llama-server                       (legacy: extracted binary)\n        \"\"\"\n        import os\n        import sys\n\n        binary_name = \"llama-server.exe\" if sys.platform == \"win32\" else \"llama-server\"\n\n        # 1. Env var — direct path to binary\n        env_path = os.environ.get(\"LLAMA_SERVER_PATH\")\n        if env_path and Path(env_path).is_file():\n            return env_path\n\n        # 1b. UNSLOTH_LLAMA_CPP_PATH — custom llama.cpp install directory\n        custom_llama_cpp = os.environ.get(\"UNSLOTH_LLAMA_CPP_PATH\")\n        if custom_llama_cpp:\n            custom_dir = Path(custom_llama_cpp)\n            # Root dir (make builds)\n            root_bin = custom_dir / binary_name\n            if root_bin.is_file():\n                return str(root_bin)\n            # build/bin/ (cmake builds on Linux)\n            cmake_bin = custom_dir / \"build\" / \"bin\" / binary_name\n            if cmake_bin.is_file():\n                return str(cmake_bin)\n            # build/bin/Release/ (cmake builds on Windows)\n            if sys.platform == \"win32\":\n                win_bin = custom_dir / \"build\" / \"bin\" / \"Release\" / binary_name\n                if win_bin.is_file():\n                    return str(win_bin)\n\n        # 2–4. ~/.unsloth/llama.cpp (primary — setup.sh / setup.ps1 build here)\n        unsloth_home = Path.home() / \".unsloth\" / \"llama.cpp\"\n        # Root dir (make builds copy binaries here)\n        home_root = unsloth_home / binary_name\n        if home_root.is_file():\n            return str(home_root)\n        # build/bin/ (cmake builds on Linux)\n        home_linux = unsloth_home / \"build\" / \"bin\" / binary_name\n        if home_linux.is_file():\n            return str(home_linux)\n\n        # 3. Windows MSVC build has Release subdir\n        if sys.platform == \"win32\":\n            home_win = unsloth_home / \"build\" / \"bin\" / \"Release\" / binary_name\n            if home_win.is_file():\n                return str(home_win)\n\n        # 5–6. Legacy: in-tree build (older setup.sh / setup.ps1 versions)\n        project_root = Path(__file__).resolve().parents[4]\n        # Root dir (make builds)\n        root_path = project_root / \"llama.cpp\" / binary_name\n        if root_path.is_file():\n            return str(root_path)\n        # build/bin/ (cmake builds)\n        build_path = project_root / \"llama.cpp\" / \"build\" / \"bin\" / binary_name\n        if build_path.is_file():\n            return str(build_path)\n        if sys.platform == \"win32\":\n            win_path = (\n                project_root / \"llama.cpp\" / \"build\" / \"bin\" / \"Release\" / binary_name\n            )\n            if win_path.is_file():\n                return str(win_path)\n\n        # 7. System PATH\n        system_path = shutil.which(\"llama-server\")\n        if system_path:\n            return system_path\n\n        # 8. Legacy: extracted to bin/\n        bin_path = project_root / \"bin\" / binary_name\n        if bin_path.is_file():\n            return str(bin_path)\n\n        return None\n\n    # ── GPU allocation ────────────────────────────────────────────\n\n    @staticmethod\n    def _get_gguf_size_bytes(model_path: str) -> int:\n        \"\"\"Get total GGUF size in bytes, including split shards.\"\"\"\n        import re\n\n        main = Path(model_path)\n        total = main.stat().st_size\n\n        # Check for split shards (e.g., model-00001-of-00003.gguf)\n        shard_pat = re.compile(r\"^(.*)-(\\d{5})-of-(\\d{5})\\.gguf$\")\n        m = shard_pat.match(main.name)\n        if m:\n            prefix, _, num_total = m.group(1), m.group(2), m.group(3)\n            sibling_pat = re.compile(\n                r\"^\"\n                + re.escape(prefix)\n                + r\"-\\d{5}-of-\"\n                + re.escape(num_total)\n                + r\"\\.gguf$\"\n            )\n            for sibling in main.parent.iterdir():\n                if sibling != main and sibling_pat.match(sibling.name):\n                    total += sibling.stat().st_size\n\n        return total\n\n    @staticmethod\n    def _get_gpu_free_memory() -> list[tuple[int, int]]:\n        \"\"\"Query free memory per GPU via nvidia-smi.\n\n        Returns list of (gpu_index, free_mib) sorted by index.\n        Respects CUDA_VISIBLE_DEVICES if set.\n        Returns empty list if nvidia-smi is not available.\n        \"\"\"\n        import os\n\n        try:\n            result = subprocess.run(\n                [\n                    \"nvidia-smi\",\n                    \"--query-gpu=index,memory.free\",\n                    \"--format=csv,noheader,nounits\",\n                ],\n                capture_output = True,\n                text = True,\n                timeout = 10,\n            )\n            if result.returncode != 0:\n                return []\n\n            # Parse which GPUs are allowed by existing CUDA_VISIBLE_DEVICES\n            allowed = None\n            cvd = os.environ.get(\"CUDA_VISIBLE_DEVICES\")\n            if cvd is not None and cvd.strip():\n                try:\n                    allowed = set(int(x.strip()) for x in cvd.split(\",\"))\n                except ValueError:\n                    pass  # Non-numeric (e.g., \"GPU-uuid\"), ignore filter\n\n            gpus = []\n            for line in result.stdout.strip().splitlines():\n                parts = line.split(\",\")\n                if len(parts) == 2:\n                    idx = int(parts[0].strip())\n                    free_mib = int(parts[1].strip())\n                    if allowed is not None and idx not in allowed:\n                        continue\n                    gpus.append((idx, free_mib))\n            return gpus\n        except Exception:\n            return []\n\n    @staticmethod\n    def _select_gpus(\n        model_size_bytes: int,\n        gpus: list[tuple[int, int]],\n    ) -> tuple[Optional[list[int]], bool]:\n        \"\"\"Pick GPU(s) for a model based on file size and free memory.\n\n        Uses GGUF file size as a rough proxy for VRAM usage (actual usage\n        is higher due to KV cache and compute buffers, but 70% threshold\n        accounts for that).\n\n        Returns (gpu_indices, use_fit):\n          - ([1], False)       model fits on 1 GPU at 70% of free\n          - ([1, 2], False)    model needs 2 GPUs\n          - (None, True)       model too large, let --fit handle it\n        \"\"\"\n        if not gpus:\n            return None, True\n\n        model_size_mib = model_size_bytes / (1024 * 1024)\n\n        # Sort GPUs by free memory descending\n        ranked = sorted(gpus, key = lambda g: g[1], reverse = True)\n\n        # Try fitting on 1 GPU (70% of free memory threshold)\n        if ranked[0][1] * 0.70 >= model_size_mib:\n            return [ranked[0][0]], False\n\n        # Try fitting on N GPUs (accumulate free memory from most-free)\n        cumulative = 0\n        selected = []\n        for idx, free_mib in ranked:\n            selected.append(idx)\n            cumulative += free_mib * 0.70\n            if cumulative >= model_size_mib:\n                return sorted(selected), False\n\n        # Model is too large even for all GPUs, let --fit handle it\n        return None, True\n\n    # ── Variant fallback ────────────────────────────────────────────\n\n    @staticmethod\n    def _find_smallest_fitting_variant(\n        hf_repo: str,\n        free_bytes: int,\n        hf_token: Optional[str] = None,\n    ) -> Optional[tuple[str, int]]:\n        \"\"\"Find the smallest GGUF variant (including all shards) that fits.\n\n        Groups split shards by variant prefix and sums their sizes.\n        For example, UD-Q4_K_XL with 9 shards of 50 GB each = 450 GB total.\n\n        Returns (first_shard_filename, total_size_bytes) or None if nothing fits.\n        \"\"\"\n        import re\n\n        try:\n            from huggingface_hub import get_paths_info, list_repo_files\n\n            files = list_repo_files(hf_repo, token = hf_token)\n            gguf_files = [\n                f for f in files if f.endswith(\".gguf\") and \"mmproj\" not in f.lower()\n            ]\n            if not gguf_files:\n                return None\n\n            # Get sizes for all GGUF files\n            path_infos = list(get_paths_info(hf_repo, gguf_files, token = hf_token))\n            size_map = {p.path: (p.size or 0) for p in path_infos}\n\n            # Group files by variant: shards share a prefix before -NNNNN-of-NNNNN\n            shard_pat = re.compile(r\"^(.*)-\\d{5}-of-\\d{5}\\.gguf$\")\n            variants: dict[str, list[str]] = {}\n            for f in gguf_files:\n                m = shard_pat.match(f)\n                key = m.group(1) if m else f\n                variants.setdefault(key, []).append(f)\n\n            # Sum shard sizes per variant, track the first shard (for download)\n            variant_sizes: list[tuple[str, int, list[str]]] = []\n            for key, shard_files in variants.items():\n                total = sum(size_map.get(f, 0) for f in shard_files)\n                first = sorted(shard_files)[0]\n                variant_sizes.append((first, total, shard_files))\n\n            # Sort by total size ascending and pick the smallest that fits\n            variant_sizes.sort(key = lambda x: x[1])\n            for first_file, total_size, _ in variant_sizes:\n                if total_size > 0 and total_size <= free_bytes:\n                    return first_file, total_size\n\n            return None\n        except Exception:\n            return None\n\n    # ── Port allocation ───────────────────────────────────────────\n\n    @staticmethod\n    def _find_free_port() -> int:\n        \"\"\"Find an available TCP port.\"\"\"\n        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:\n            s.bind((\"127.0.0.1\", 0))\n            return s.getsockname()[1]\n\n    # ── Stdout drain (prevents pipe deadlock on Windows) ─────────\n\n    def _drain_stdout(self):\n        \"\"\"\n        Read lines from the subprocess stdout in a background thread.\n\n        This prevents a pipe-buffer deadlock on Windows where the default\n        pipe buffer is only ~4 KB.  Without draining, llama-server blocks\n        on writes and never becomes healthy.\n        \"\"\"\n        try:\n            for line in self._process.stdout:\n                line = line.rstrip()\n                if line:\n                    self._stdout_lines.append(line)\n                    logger.debug(f\"[llama-server] {line}\")\n        except (ValueError, OSError):\n            # Pipe closed — process is terminating\n            pass\n\n    # GGUF KV type sizes for fast skipping\n    _GGUF_TYPE_SIZE = {\n        0: 1,\n        1: 1,\n        2: 2,\n        3: 2,\n        4: 4,\n        5: 4,\n        6: 4,\n        7: 1,\n        10: 8,\n        11: 8,\n        12: 8,\n    }\n\n    @staticmethod\n    def _gguf_skip_value(f, vtype: int) -> None:\n        \"\"\"Skip a GGUF KV value without reading it.\"\"\"\n        sz = LlamaCppBackend._GGUF_TYPE_SIZE.get(vtype)\n        if sz is not None:\n            f.seek(sz, 1)\n        elif vtype == 8:  # STRING\n            slen = struct.unpack(\"<Q\", f.read(8))[0]\n            f.seek(slen, 1)\n        elif vtype == 9:  # ARRAY\n            atype = struct.unpack(\"<I\", f.read(4))[0]\n            alen = struct.unpack(\"<Q\", f.read(8))[0]\n            elem_sz = LlamaCppBackend._GGUF_TYPE_SIZE.get(atype)\n            if elem_sz is not None:\n                f.seek(elem_sz * alen, 1)\n            elif atype == 8:\n                for _ in range(alen):\n                    slen = struct.unpack(\"<Q\", f.read(8))[0]\n                    f.seek(slen, 1)\n            else:\n                for _ in range(alen):\n                    LlamaCppBackend._gguf_skip_value(f, atype)\n\n    def _read_gguf_metadata(self, gguf_path: str) -> None:\n        \"\"\"Read context_length and chat_template from a GGUF file's KV header.\n\n        Parses only the KV pairs we need (~30ms even for multi-GB files).\n        For split GGUFs, metadata is always in shard 1.\n        \"\"\"\n        # Reset metadata from any previously loaded model so stale flags\n        # (eg _supports_reasoning) do not carry over when switching models.\n        self._context_length = None\n        self._chat_template = None\n        self._supports_reasoning = False\n        self._supports_tools = False\n\n        try:\n            WANTED = {\"general.architecture\", \"tokenizer.chat_template\"}\n            arch = None\n            ctx_key = None\n\n            with open(gguf_path, \"rb\") as f:\n                magic = struct.unpack(\"<I\", f.read(4))[0]\n                if magic != 0x46554747:  # b\"GGUF\" as little-endian u32\n                    return\n                _version = struct.unpack(\"<I\", f.read(4))[0]\n                _tensor_count, kv_count = struct.unpack(\"<QQ\", f.read(16))\n\n                for _ in range(kv_count):\n                    key_len = struct.unpack(\"<Q\", f.read(8))[0]\n                    key = f.read(key_len).decode(\"utf-8\")\n                    vtype = struct.unpack(\"<I\", f.read(4))[0]\n\n                    if key in WANTED or (ctx_key and key == ctx_key):\n                        # Read this value\n                        if vtype == 8:  # STRING\n                            slen = struct.unpack(\"<Q\", f.read(8))[0]\n                            val_s = f.read(slen).decode(\"utf-8\")\n                            if key == \"general.architecture\":\n                                arch = val_s\n                                ctx_key = f\"{arch}.context_length\"\n                            elif key == \"tokenizer.chat_template\":\n                                self._chat_template = val_s\n                        elif vtype == 4:  # UINT32\n                            val_i = struct.unpack(\"<I\", f.read(4))[0]\n                            if ctx_key and key == ctx_key:\n                                self._context_length = val_i\n                        elif vtype == 10:  # UINT64\n                            val_i = struct.unpack(\"<Q\", f.read(8))[0]\n                            if ctx_key and key == ctx_key:\n                                self._context_length = val_i\n                        else:\n                            self._gguf_skip_value(f, vtype)\n                    else:\n                        self._gguf_skip_value(f, vtype)\n\n            if self._context_length:\n                logger.info(f\"GGUF metadata: context_length={self._context_length}\")\n            if self._chat_template:\n                logger.info(\n                    f\"GGUF metadata: chat_template={len(self._chat_template)} chars\"\n                )\n                # Detect thinking/reasoning support from chat template\n                tpl = self._chat_template\n                if \"enable_thinking\" in tpl:\n                    self._supports_reasoning = True\n                    logger.info(\n                        \"GGUF metadata: model supports reasoning (enable_thinking)\"\n                    )\n                elif \"thinking\" in tpl:\n                    # DeepSeek uses 'thinking' instead of 'enable_thinking'\n                    normalized_id = (self._model_identifier or \"\").lower()\n                    if \"deepseek\" in normalized_id:\n                        self._supports_reasoning = True\n                        logger.info(\n                            \"GGUF metadata: model supports reasoning (DeepSeek thinking)\"\n                        )\n                # Detect tool calling support from chat template\n                tool_markers = [\n                    \"{%- if tools %}\",\n                    \"{% if tools %}\",\n                    '\"role\" == \"tool\"',\n                    \"'role' == 'tool'\",\n                    'message.role == \"tool\"',\n                    \"message.role == 'tool'\",\n                ]\n                if any(marker in tpl for marker in tool_markers):\n                    self._supports_tools = True\n                    logger.info(\"GGUF metadata: model supports tool calling\")\n        except Exception as e:\n            logger.warning(f\"Failed to read GGUF metadata: {e}\")\n\n    # ── HF download (no lock held) ───────────────────────────────\n\n    def _download_gguf(\n        self,\n        *,\n        hf_repo: str,\n        hf_variant: Optional[str] = None,\n        hf_token: Optional[str] = None,\n    ) -> str:\n        \"\"\"Download GGUF file(s) from HuggingFace. Returns local path.\n\n        Runs WITHOUT self._lock so that unload_model() can set\n        _cancel_event at any time. Checks _cancel_event between\n        each shard download.\n        \"\"\"\n        try:\n            from huggingface_hub import hf_hub_download\n        except ImportError:\n            raise RuntimeError(\n                \"huggingface_hub is required for HF model loading. \"\n                \"Install it with: pip install huggingface_hub\"\n            )\n\n        # Determine the filename from the variant\n        gguf_filename = None\n        gguf_extra_shards: list[str] = []\n        if hf_variant:\n            try:\n                import re\n                from huggingface_hub import list_repo_files\n\n                files = list_repo_files(hf_repo, token = hf_token)\n                variant_lower = hf_variant.lower()\n                boundary = re.compile(\n                    r\"(?<![a-zA-Z0-9])\" + re.escape(variant_lower) + r\"(?![a-zA-Z0-9])\"\n                )\n                gguf_files = sorted(\n                    f\n                    for f in files\n                    if f.endswith(\".gguf\") and boundary.search(f.lower())\n                )\n                if gguf_files:\n                    gguf_filename = gguf_files[0]\n                    shard_pat = re.compile(r\"^(.*)-\\d{5}-of-(\\d{5})\\.gguf$\")\n                    m = shard_pat.match(gguf_filename)\n                    if m:\n                        prefix = m.group(1)\n                        total = m.group(2)\n                        sibling_pat = re.compile(\n                            r\"^\"\n                            + re.escape(prefix)\n                            + r\"-\\d{5}-of-\"\n                            + re.escape(total)\n                            + r\"\\.gguf$\"\n                        )\n                        gguf_extra_shards = [\n                            f for f in gguf_files[1:] if sibling_pat.match(f)\n                        ]\n            except Exception as e:\n                logger.warning(f\"Could not list repo files: {e}\")\n\n            if not gguf_filename:\n                repo_name = hf_repo.split(\"/\")[-1].replace(\"-GGUF\", \"\")\n                gguf_filename = f\"{repo_name}-{hf_variant}.gguf\"\n\n        # Check disk space and fall back to a smaller variant if needed\n        all_gguf_files = [gguf_filename] + gguf_extra_shards\n        try:\n            import os\n\n            from huggingface_hub import get_paths_info\n\n            path_infos = list(get_paths_info(hf_repo, all_gguf_files, token = hf_token))\n            total_download_bytes = sum((p.size or 0) for p in path_infos)\n\n            if total_download_bytes > 0:\n                cache_dir = os.environ.get(\n                    \"HF_HUB_CACHE\",\n                    str(Path.home() / \".cache\" / \"huggingface\" / \"hub\"),\n                )\n                Path(cache_dir).mkdir(parents = True, exist_ok = True)\n                free_bytes = shutil.disk_usage(cache_dir).free\n\n                total_gb = total_download_bytes / (1024**3)\n                free_gb = free_bytes / (1024**3)\n\n                logger.info(\n                    f\"GGUF download: {total_gb:.1f} GB needed, \"\n                    f\"{free_gb:.1f} GB free on disk\"\n                )\n\n                if total_download_bytes > free_bytes:\n                    smaller = self._find_smallest_fitting_variant(\n                        hf_repo,\n                        free_bytes,\n                        hf_token,\n                    )\n                    if smaller:\n                        fallback_file, fallback_size = smaller\n                        logger.info(\n                            f\"Selected variant too large ({total_gb:.1f} GB), \"\n                            f\"falling back to {fallback_file} ({fallback_size / (1024**3):.1f} GB)\"\n                        )\n                        gguf_filename = fallback_file\n                        import re as _re\n\n                        _shard_pat = _re.compile(r\"^(.*)-\\d{5}-of-\\d{5}\\.gguf$\")\n                        _m = _shard_pat.match(gguf_filename)\n                        _prefix = _m.group(1) if _m else None\n                        if _prefix:\n                            gguf_extra_shards = sorted(\n                                f\n                                for f in all_gguf_files\n                                if f.startswith(_prefix)\n                                and f != gguf_filename\n                                and \"mmproj\" not in f.lower()\n                            )\n                        else:\n                            gguf_extra_shards = []\n                    else:\n                        raise RuntimeError(\n                            f\"Not enough disk space to download any variant. \"\n                            f\"Only {free_gb:.1f} GB free in {cache_dir}\"\n                        )\n        except RuntimeError:\n            raise\n        except Exception as e:\n            logger.warning(f\"Could not check disk space: {e}\")\n\n        gguf_label = f\"{hf_repo}/{gguf_filename}\" + (\n            f\" (+{len(gguf_extra_shards)} shards)\" if gguf_extra_shards else \"\"\n        )\n        logger.info(f\"Resolving GGUF: {gguf_label}\")\n        try:\n            if self._cancel_event.is_set():\n                raise RuntimeError(\"Cancelled\")\n            dl_start = time.monotonic()\n            local_path = hf_hub_download(\n                repo_id = hf_repo,\n                filename = gguf_filename,\n                token = hf_token,\n            )\n            for shard in gguf_extra_shards:\n                if self._cancel_event.is_set():\n                    raise RuntimeError(\"Cancelled\")\n                logger.info(f\"Resolving GGUF shard: {shard}\")\n                hf_hub_download(\n                    repo_id = hf_repo,\n                    filename = shard,\n                    token = hf_token,\n                )\n        except RuntimeError as e:\n            if \"Cancelled\" in str(e):\n                raise\n            raise RuntimeError(\n                f\"Failed to download GGUF file '{gguf_filename}' from {hf_repo}: {e}\"\n            )\n        except Exception as e:\n            raise RuntimeError(\n                f\"Failed to download GGUF file '{gguf_filename}' from {hf_repo}: {e}\"\n            )\n\n        dl_elapsed = time.monotonic() - dl_start\n        if dl_elapsed < 2.0:\n            logger.info(f\"GGUF resolved from cache: {local_path}\")\n        else:\n            logger.info(f\"GGUF downloaded in {dl_elapsed:.1f}s: {local_path}\")\n        return local_path\n\n    def _download_mmproj(\n        self,\n        *,\n        hf_repo: str,\n        hf_token: Optional[str] = None,\n    ) -> Optional[str]:\n        \"\"\"Download the mmproj (vision projection) file from a GGUF repo.\n\n        Prefers mmproj-F16.gguf, falls back to any mmproj*.gguf file.\n        Returns the local path, or None if no mmproj file exists.\n        \"\"\"\n        try:\n            from huggingface_hub import hf_hub_download, list_repo_files\n\n            files = list_repo_files(hf_repo, token = hf_token)\n            mmproj_files = sorted(\n                f for f in files if f.endswith(\".gguf\") and \"mmproj\" in f.lower()\n            )\n            if not mmproj_files:\n                return None\n\n            # Prefer F16 variant\n            target = None\n            for f in mmproj_files:\n                if \"f16\" in f.lower():\n                    target = f\n                    break\n            if target is None:\n                target = mmproj_files[0]\n\n            logger.info(f\"Downloading mmproj: {hf_repo}/{target}\")\n            local_path = hf_hub_download(\n                repo_id = hf_repo,\n                filename = target,\n                token = hf_token,\n            )\n            return local_path\n        except Exception as e:\n            logger.warning(f\"Could not download mmproj: {e}\")\n            return None\n\n    # ── Lifecycle ─────────────────────────────────────────────────\n\n    def load_model(\n        self,\n        *,\n        # Local mode: pass a path to a .gguf file\n        gguf_path: Optional[str] = None,\n        # Vision projection (mmproj) for local vision models\n        mmproj_path: Optional[str] = None,\n        # HF mode: let llama-server download via -hf \"repo:quant\"\n        hf_repo: Optional[str] = None,\n        hf_variant: Optional[str] = None,\n        hf_token: Optional[str] = None,\n        # Common\n        model_identifier: str,\n        is_vision: bool = False,\n        n_ctx: int = 4096,\n        chat_template_override: Optional[str] = None,\n        cache_type_kv: Optional[str] = None,\n        n_threads: Optional[int] = None,\n        n_gpu_layers: Optional[int] = None,  # Accepted for caller compat, unused\n    ) -> bool:\n        \"\"\"\n        Start llama-server with a GGUF model.\n\n        Two modes:\n        - Local: ``gguf_path=\"/path/to/model.gguf\"`` → uses ``-m``\n        - HF:    ``hf_repo=\"unsloth/gemma-3-4b-it-GGUF\", hf_variant=\"Q4_K_M\"`` → uses ``-hf``\n\n        In HF mode, llama-server handles downloading, caching, and\n        auto-loading mmproj files for vision models.\n\n        Returns True if server started and health check passed.\n        \"\"\"\n        self._cancel_event.clear()\n\n        # ── Phase 1: kill old process (under lock, fast) ──────────\n        with self._lock:\n            self._kill_process()\n\n        binary = self._find_llama_server_binary()\n        if not binary:\n            raise RuntimeError(\n                \"llama-server binary not found. \"\n                \"Run setup.sh to build it, install llama.cpp, \"\n                \"or set LLAMA_SERVER_PATH environment variable.\"\n            )\n\n        # ── Phase 2: download (NO lock held, so cancel can proceed) ──\n        if hf_repo:\n            model_path = self._download_gguf(\n                hf_repo = hf_repo,\n                hf_variant = hf_variant,\n                hf_token = hf_token,\n            )\n            # Auto-download mmproj for vision models\n            if is_vision and not mmproj_path:\n                mmproj_path = self._download_mmproj(\n                    hf_repo = hf_repo,\n                    hf_token = hf_token,\n                )\n        elif gguf_path:\n            if not Path(gguf_path).is_file():\n                raise FileNotFoundError(f\"GGUF file not found: {gguf_path}\")\n            model_path = gguf_path\n        else:\n            raise ValueError(\"Either gguf_path or hf_repo must be provided\")\n\n        # Set identifier early so _read_gguf_metadata can use it for DeepSeek detection\n        self._model_identifier = model_identifier\n\n        # Read GGUF metadata (context_length, chat_template) -- fast, header only\n        self._read_gguf_metadata(model_path)\n\n        # Check cancel after download\n        if self._cancel_event.is_set():\n            logger.info(\"Load cancelled after download phase\")\n            return False\n\n        # ── Phase 3: start llama-server (under lock) ──────────────\n        with self._lock:\n            # Re-check cancel inside lock\n            if self._cancel_event.is_set():\n                logger.info(\"Load cancelled before server start\")\n                return False\n\n            self._port = self._find_free_port()\n\n            # Select GPU(s) based on model size and free memory\n            try:\n                model_size = self._get_gguf_size_bytes(model_path)\n                gpus = self._get_gpu_free_memory()\n                gpu_indices, use_fit = self._select_gpus(model_size, gpus)\n                logger.info(\n                    f\"GGUF size: {model_size / (1024**3):.1f} GB, \"\n                    f\"GPUs free: {gpus}, selected: {gpu_indices}, fit: {use_fit}\"\n                )\n            except Exception as e:\n                logger.warning(f\"GPU selection failed ({e}), using --fit on\")\n                gpu_indices, use_fit = None, True\n\n            cmd = [\n                binary,\n                \"-m\",\n                model_path,\n                \"--port\",\n                str(self._port),\n                \"-c\",\n                \"0\",  # 0 = use model's native context size\n                \"--parallel\",\n                \"1\",  # Single-user studio, saves VRAM\n                \"--flash-attn\",\n                \"on\",  # Force flash attention for speed\n            ]\n\n            if use_fit:\n                cmd.extend([\"--fit\", \"on\"])\n\n            if n_threads is not None:\n                cmd.extend([\"--threads\", str(n_threads)])\n\n            # Always enable Jinja chat template rendering for proper template support\n            cmd.extend([\"--jinja\"])\n\n            # KV cache data type\n            _valid_cache_types = {\n                \"f16\",\n                \"bf16\",\n                \"q8_0\",\n                \"q4_0\",\n                \"q4_1\",\n                \"q5_0\",\n                \"q5_1\",\n                \"iq4_nl\",\n                \"f32\",\n            }\n            if cache_type_kv and cache_type_kv in _valid_cache_types:\n                cmd.extend(\n                    [\"--cache-type-k\", cache_type_kv, \"--cache-type-v\", cache_type_kv]\n                )\n                self._cache_type_kv = cache_type_kv\n                logger.info(f\"KV cache type: {cache_type_kv}\")\n            else:\n                self._cache_type_kv = None\n\n            # Apply custom chat template override if provided\n            if chat_template_override:\n                import tempfile\n\n                self._chat_template_file = tempfile.NamedTemporaryFile(\n                    mode = \"w\",\n                    suffix = \".jinja\",\n                    delete = False,\n                    prefix = \"unsloth_chat_template_\",\n                )\n                self._chat_template_file.write(chat_template_override)\n                self._chat_template_file.close()\n                cmd.extend([\"--chat-template-file\", self._chat_template_file.name])\n                logger.info(\n                    f\"Using custom chat template file: {self._chat_template_file.name}\"\n                )\n\n            # For reasoning models, set default thinking mode.\n            # Qwen3.5 models below 9B (0.8B, 2B, 4B) disable thinking by default.\n            # Only 9B and larger enable thinking.\n            if self._supports_reasoning:\n                import re\n\n                thinking_default = True\n                mid = (model_identifier or \"\").lower()\n                if \"qwen3.5\" in mid:\n                    # Extract size like \"0.8b\", \"4b\", \"35b\" etc.\n                    size_match = re.search(r\"(\\d+\\.?\\d*)\\s*b\", mid)\n                    if size_match:\n                        size_val = float(size_match.group(1))\n                        if size_val < 9:\n                            thinking_default = False\n                self._reasoning_default = thinking_default\n                cmd.extend(\n                    [\n                        \"--chat-template-kwargs\",\n                        json.dumps({\"enable_thinking\": thinking_default}),\n                    ]\n                )\n                logger.info(\n                    f\"Reasoning model: enable_thinking={thinking_default} by default\"\n                )\n\n            if mmproj_path:\n                if not Path(mmproj_path).is_file():\n                    logger.warning(f\"mmproj file not found: {mmproj_path}\")\n                else:\n                    cmd.extend([\"--mmproj\", mmproj_path])\n                    logger.info(f\"Using mmproj for vision: {mmproj_path}\")\n\n            logger.info(f\"Starting llama-server: {' '.join(cmd)}\")\n\n            # Set library paths so llama-server can find its shared libs and CUDA DLLs\n            import os\n            import sys\n\n            env = os.environ.copy()\n            binary_dir = str(Path(binary).parent)\n\n            if sys.platform == \"win32\":\n                # On Windows, CUDA DLLs (cublas64_12.dll, cudart64_12.dll, etc.)\n                # must be on PATH. Add CUDA_PATH\\bin if available.\n                path_dirs = [binary_dir]\n                cuda_path = os.environ.get(\"CUDA_PATH\", \"\")\n                if cuda_path:\n                    cuda_bin = os.path.join(cuda_path, \"bin\")\n                    if os.path.isdir(cuda_bin):\n                        path_dirs.append(cuda_bin)\n                    # Some CUDA installs put DLLs in bin\\x64\n                    cuda_bin_x64 = os.path.join(cuda_path, \"bin\", \"x64\")\n                    if os.path.isdir(cuda_bin_x64):\n                        path_dirs.append(cuda_bin_x64)\n                existing_path = env.get(\"PATH\", \"\")\n                env[\"PATH\"] = \";\".join(path_dirs) + \";\" + existing_path\n            else:\n                # Linux: set LD_LIBRARY_PATH for shared libs next to the binary\n                # and CUDA runtime libs (libcudart, libcublas, etc.)\n                import platform\n\n                lib_dirs = [binary_dir]\n                _arch = platform.machine()  # x86_64, aarch64, etc.\n                for cuda_lib in [\n                    \"/usr/local/cuda/lib64\",\n                    f\"/usr/local/cuda/targets/{_arch}-linux/lib\",\n                    # Fallback CUDA compat paths (e.g. binary built with\n                    # CUDA 12 on a system where default /usr/local/cuda\n                    # points to CUDA 13+).\n                    \"/usr/local/cuda-12/lib64\",\n                    \"/usr/local/cuda-12.8/lib64\",\n                    f\"/usr/local/cuda-12/targets/{_arch}-linux/lib\",\n                    f\"/usr/local/cuda-12.8/targets/{_arch}-linux/lib\",\n                ]:\n                    if os.path.isdir(cuda_lib):\n                        lib_dirs.append(cuda_lib)\n                existing_ld = env.get(\"LD_LIBRARY_PATH\", \"\")\n                new_ld = \":\".join(lib_dirs)\n                env[\"LD_LIBRARY_PATH\"] = (\n                    f\"{new_ld}:{existing_ld}\" if existing_ld else new_ld\n                )\n\n            # Pin to selected GPU(s) via CUDA_VISIBLE_DEVICES\n            if gpu_indices is not None:\n                env[\"CUDA_VISIBLE_DEVICES\"] = \",\".join(str(i) for i in gpu_indices)\n\n            self._stdout_lines = []\n            self._process = subprocess.Popen(\n                cmd,\n                stdout = subprocess.PIPE,\n                stderr = subprocess.STDOUT,\n                text = True,\n                env = env,\n            )\n\n            # Start background thread to drain stdout and prevent pipe deadlock\n            self._stdout_thread = threading.Thread(\n                target = self._drain_stdout, daemon = True, name = \"llama-stdout\"\n            )\n            self._stdout_thread.start()\n\n            self._gguf_path = gguf_path\n            self._hf_repo = hf_repo\n            self._hf_variant = hf_variant\n            self._is_vision = is_vision\n            self._model_identifier = model_identifier\n\n            # Wait for llama-server to become healthy\n            if not self._wait_for_health(timeout = 120.0):\n                self._kill_process()\n                raise RuntimeError(\n                    \"llama-server failed to start. \"\n                    \"Check that the GGUF file is valid and you have enough memory.\"\n                )\n\n            self._healthy = True\n\n            logger.info(\n                f\"llama-server ready on port {self._port} \"\n                f\"for model '{model_identifier}'\"\n            )\n            return True\n\n    def unload_model(self) -> bool:\n        \"\"\"Terminate the llama-server subprocess and cancel any in-flight download.\"\"\"\n        self._cancel_event.set()\n        with self._lock:\n            self._kill_process()\n            logger.info(f\"Unloaded GGUF model: {self._model_identifier}\")\n            self._model_identifier = None\n            self._gguf_path = None\n            self._hf_repo = None\n            self._hf_variant = None\n            self._is_vision = False\n            self._is_audio = False\n            self._audio_type = None\n            self._port = None\n            self._healthy = False\n            self._context_length = None\n            self._chat_template = None\n            self._supports_reasoning = False\n            self._supports_tools = False\n            self._cache_type_kv = None\n            # Clean up temp chat template file\n            if hasattr(self, \"_chat_template_file\") and self._chat_template_file:\n                try:\n                    import os\n\n                    os.unlink(self._chat_template_file.name)\n                except Exception:\n                    pass\n                self._chat_template_file = None\n            # Free audio codec GPU memory\n            if LlamaCppBackend._codec_mgr is not None:\n                LlamaCppBackend._codec_mgr.unload()\n                LlamaCppBackend._codec_mgr = None\n                import torch\n\n                if torch.cuda.is_available():\n                    torch.cuda.empty_cache()\n            return True\n\n    def _kill_process(self):\n        \"\"\"Terminate the subprocess if running.\"\"\"\n        if self._process is None:\n            return\n        try:\n            self._process.terminate()\n            self._process.wait(timeout = 5)\n        except subprocess.TimeoutExpired:\n            logger.warning(\"llama-server did not exit on SIGTERM, sending SIGKILL\")\n            self._process.kill()\n            self._process.wait(timeout = 5)\n        except Exception as e:\n            logger.warning(f\"Error killing llama-server process: {e}\")\n        finally:\n            self._process = None\n            if self._stdout_thread is not None:\n                self._stdout_thread.join(timeout = 2)\n                self._stdout_thread = None\n\n    @staticmethod\n    def _kill_orphaned_servers():\n        \"\"\"Kill orphaned llama-server processes started by studio.\n\n        Only kills processes whose binary lives under ~/.unsloth/llama.cpp/\n        to avoid terminating unrelated llama-server instances on the machine.\n        \"\"\"\n        import os\n        import signal\n\n        try:\n            # Use pgrep with full command match to identify studio-managed servers\n            result = subprocess.run(\n                [\"pgrep\", \"-a\", \"-f\", \"llama-server\"],\n                capture_output = True,\n                text = True,\n                timeout = 5,\n            )\n            if result.returncode != 0:\n                return\n            for line in result.stdout.strip().splitlines():\n                parts = line.strip().split(None, 1)\n                if len(parts) < 2:\n                    continue\n                pid = int(parts[0])\n                cmdline = parts[1]\n                if pid == os.getpid():\n                    continue\n                # Only kill if it's a studio-managed server (lives under .unsloth/)\n                if \".unsloth/\" not in cmdline and \"unsloth\" not in cmdline.lower():\n                    continue\n                try:\n                    os.kill(pid, signal.SIGKILL)\n                    logger.info(f\"Killed orphaned llama-server process (pid={pid})\")\n                except ProcessLookupError:\n                    pass\n                except PermissionError:\n                    pass\n        except Exception:\n            pass\n\n    def _cleanup(self):\n        \"\"\"atexit handler to ensure llama-server is terminated.\"\"\"\n        self._kill_process()\n\n    def _wait_for_health(self, timeout: float = 120.0, interval: float = 0.5) -> bool:\n        \"\"\"\n        Poll llama-server's /health endpoint until it responds 200.\n\n        Also monitors subprocess for early exit/crash.\n        \"\"\"\n        deadline = time.monotonic() + timeout\n        url = f\"http://127.0.0.1:{self._port}/health\"\n\n        while time.monotonic() < deadline:\n            # Check if process crashed\n            if self._process.poll() is not None:\n                # Give the drain thread a moment to collect final output\n                if self._stdout_thread is not None:\n                    self._stdout_thread.join(timeout = 2)\n                output = \"\\n\".join(self._stdout_lines[-50:])\n                logger.error(\n                    f\"llama-server exited with code {self._process.returncode}. \"\n                    f\"Output: {output[:2000]}\"\n                )\n                return False\n\n            try:\n                resp = httpx.get(url, timeout = 2.0)\n                if resp.status_code == 200:\n                    return True\n            except (httpx.ConnectError, httpx.TimeoutException):\n                pass\n\n            time.sleep(interval)\n\n        logger.error(f\"llama-server health check timed out after {timeout}s\")\n        return False\n\n    # ── Message building (OpenAI format) ──────────────────────────\n\n    @staticmethod\n    def _parse_tool_calls_from_text(content: str) -> list[dict]:\n        \"\"\"\n        Parse tool calls from XML markup in content text.\n\n        Handles formats like:\n          <tool_call>{\"name\":\"web_search\",\"arguments\":{\"query\":\"...\"}}</tool_call>\n          <tool_call><function=web_search><parameter=query>...</parameter></function></tool_call>\n        Closing tags (</tool_call>, </function>, </parameter>) are all optional\n        since models frequently omit them.\n        \"\"\"\n        import re\n\n        tool_calls = []\n\n        # Pattern 1: JSON inside <tool_call> tags.\n        # Use balanced-brace extraction that skips braces inside JSON strings.\n        for m in re.finditer(r\"<tool_call>\\s*\\{\", content):\n            brace_start = m.end() - 1  # position of the opening {\n            depth, i = 0, brace_start\n            in_string = False\n            while i < len(content):\n                ch = content[i]\n                if in_string:\n                    if ch == \"\\\\\" and i + 1 < len(content):\n                        i += 2  # skip escaped character\n                        continue\n                    if ch == '\"':\n                        in_string = False\n                elif ch == '\"':\n                    in_string = True\n                elif ch == \"{\":\n                    depth += 1\n                elif ch == \"}\":\n                    depth -= 1\n                    if depth == 0:\n                        break\n                i += 1\n            if depth == 0:\n                json_str = content[brace_start : i + 1]\n                try:\n                    obj = json.loads(json_str)\n                    tc = {\n                        \"id\": f\"call_{len(tool_calls)}\",\n                        \"type\": \"function\",\n                        \"function\": {\n                            \"name\": obj.get(\"name\", \"\"),\n                            \"arguments\": obj.get(\"arguments\", {}),\n                        },\n                    }\n                    if isinstance(tc[\"function\"][\"arguments\"], dict):\n                        tc[\"function\"][\"arguments\"] = json.dumps(\n                            tc[\"function\"][\"arguments\"]\n                        )\n                    tool_calls.append(tc)\n                except (json.JSONDecodeError, ValueError):\n                    pass\n\n        # Pattern 2: XML-style <function=name><parameter=key>value</parameter></function>\n        # All closing tags optional -- models frequently omit </parameter>,\n        # </function>, and/or </tool_call>.\n        if not tool_calls:\n            # Step 1: Find all <function=name> positions and extract their bodies.\n            # Body boundary: use only </tool_call> or next <function= as hard\n            # boundaries.  We avoid using </function> as a boundary because\n            # code parameter values can contain that literal string.\n            # After extracting, we trim a trailing </function> if present.\n            func_starts = list(re.finditer(r\"<function=(\\w+)>\\s*\", content))\n            for idx, fm in enumerate(func_starts):\n                func_name = fm.group(1)\n                body_start = fm.end()\n                # Hard boundaries: next <function= tag or </tool_call>\n                next_func = (\n                    func_starts[idx + 1].start()\n                    if idx + 1 < len(func_starts)\n                    else len(content)\n                )\n                end_tag = re.search(r\"</tool_call>\", content[body_start:])\n                if end_tag:\n                    body_end = body_start + end_tag.start()\n                else:\n                    body_end = len(content)\n                body_end = min(body_end, next_func)\n                body = content[body_start:body_end]\n                # Trim trailing </function> if present (it's the real closing tag)\n                body = re.sub(r\"\\s*</function>\\s*$\", \"\", body)\n\n                # Step 2: Extract parameters from body.\n                # For single-parameter functions (the common case: code, command,\n                # query), use body end as the only boundary to avoid false matches\n                # on </parameter> inside code strings.\n                arguments = {}\n                param_starts = list(re.finditer(r\"<parameter=(\\w+)>\\s*\", body))\n                if len(param_starts) == 1:\n                    # Single parameter: value is everything from after the tag\n                    # to end of body, trimming any trailing </parameter>.\n                    pm = param_starts[0]\n                    val = body[pm.end() :]\n                    val = re.sub(r\"\\s*</parameter>\\s*$\", \"\", val)\n                    arguments[pm.group(1)] = val.strip()\n                else:\n                    for pidx, pm in enumerate(param_starts):\n                        param_name = pm.group(1)\n                        val_start = pm.end()\n                        # Value ends at next <parameter= or end of body\n                        next_param = (\n                            param_starts[pidx + 1].start()\n                            if pidx + 1 < len(param_starts)\n                            else len(body)\n                        )\n                        val = body[val_start:next_param]\n                        # Trim trailing </parameter> if present\n                        val = re.sub(r\"\\s*</parameter>\\s*$\", \"\", val)\n                        arguments[param_name] = val.strip()\n\n                tc = {\n                    \"id\": f\"call_{len(tool_calls)}\",\n                    \"type\": \"function\",\n                    \"function\": {\n                        \"name\": func_name,\n                        \"arguments\": json.dumps(arguments),\n                    },\n                }\n                tool_calls.append(tc)\n\n        return tool_calls\n\n    @staticmethod\n    def _build_openai_messages(\n        messages: list[dict],\n        image_b64: Optional[str] = None,\n    ) -> list[dict]:\n        \"\"\"\n        Build OpenAI-format messages, optionally injecting an image_url\n        content part into the last user message for vision models.\n\n        If no image is provided, returns messages as-is.\n        \"\"\"\n        if not image_b64:\n            return messages\n\n        # Find the last user message and convert to multimodal content parts\n        result = [msg.copy() for msg in messages]\n        last_user_idx = None\n        for i, msg in enumerate(result):\n            if msg[\"role\"] == \"user\":\n                last_user_idx = i\n\n        if last_user_idx is not None:\n            text_content = result[last_user_idx].get(\"content\", \"\")\n            result[last_user_idx][\"content\"] = [\n                {\"type\": \"text\", \"text\": text_content},\n                {\n                    \"type\": \"image_url\",\n                    \"image_url\": {\n                        \"url\": f\"data:image/png;base64,{image_b64}\",\n                    },\n                },\n            ]\n\n        return result\n\n    # ── Generation (proxy to llama-server) ────────────────────────\n\n    @staticmethod\n    def _iter_text_cancellable(\n        response: \"httpx.Response\",\n        cancel_event: Optional[threading.Event] = None,\n    ) -> Generator[str, None, None]:\n        \"\"\"Iterate over an httpx streaming response with cancel support.\n\n        Checks cancel_event between chunks and on ReadTimeout.  The\n        cancel watcher in _stream_with_retry also calls response.close()\n        on cancel, which unblocks iter_text() once the response exists.\n        During normal streaming llama-server sends tokens frequently,\n        so the cancel check between chunks is the primary mechanism.\n        \"\"\"\n        text_iter = response.iter_text()\n        while True:\n            if cancel_event is not None and cancel_event.is_set():\n                response.close()\n                return\n            try:\n                chunk = next(text_iter)\n                yield chunk\n            except StopIteration:\n                return\n            except httpx.ReadTimeout:\n                # No data within the timeout window -- just loop back\n                # and re-check cancel_event.\n                continue\n\n    @staticmethod\n    @contextlib.contextmanager\n    def _stream_with_retry(\n        client: \"httpx.Client\",\n        url: str,\n        payload: dict,\n        cancel_event: Optional[threading.Event] = None,\n    ):\n        \"\"\"Open an httpx streaming POST with cancel support.\n\n        Sends the request once with a long read timeout (120 s) so\n        prompt processing (prefill) can finish without triggering a\n        retry storm.  The previous 0.5 s timeout caused duplicate POST\n        requests every half second, forcing llama-server to restart\n        processing each time.\n\n        A background watcher thread provides cancel by closing the\n        response when cancel_event is set.  Limitation: httpx does not\n        allow interrupting a blocked read from another thread before\n        the response object exists, so cancel during the initial\n        header wait (prefill phase) only takes effect once headers\n        arrive.  After that, response.close() unblocks reads promptly.\n        In practice llama-server prefill is 1-5 s for typical prompts,\n        during which cancel is deferred -- still much better than the\n        old retry storm which made prefill slower.\n        \"\"\"\n        if cancel_event is not None and cancel_event.is_set():\n            raise GeneratorExit\n\n        # Background watcher: close the response if cancel is requested.\n        # Only effective after response headers arrive (httpx limitation).\n        _cancel_closed = threading.Event()\n        _response_ref: list = [None]\n\n        def _cancel_watcher():\n            while not _cancel_closed.is_set():\n                if cancel_event.wait(timeout = 0.3):\n                    # Cancel requested. Keep polling until the response object\n                    # exists so we can close it, or until the main thread\n                    # finishes on its own (_cancel_closed is set in finally).\n                    while not _cancel_closed.is_set():\n                        r = _response_ref[0]\n                        if r is not None:\n                            try:\n                                r.close()\n                                return\n                            except Exception as e:\n                                logger.debug(\n                                    f\"Error closing response in cancel watcher: {e}\"\n                                )\n                        # Response not created yet -- wait briefly and retry\n                        _cancel_closed.wait(timeout = 0.1)\n                    return\n\n        watcher = None\n        if cancel_event is not None:\n            watcher = threading.Thread(\n                target = _cancel_watcher, daemon = True, name = \"prefill-cancel\"\n            )\n            watcher.start()\n\n        try:\n            # Long read timeout so prefill (prompt processing) can finish\n            # without triggering a retry storm.  Cancel during both\n            # prefill and streaming is handled by the watcher thread\n            # which closes the response, unblocking any httpx read.\n            prefill_timeout = httpx.Timeout(\n                connect = 30,\n                read = 120.0,\n                write = 10,\n                pool = 10,\n            )\n            with client.stream(\n                \"POST\", url, json = payload, timeout = prefill_timeout\n            ) as response:\n                _response_ref[0] = response\n                if cancel_event is not None and cancel_event.is_set():\n                    raise GeneratorExit\n                yield response\n                return\n        except (httpx.ReadError, httpx.RemoteProtocolError, httpx.CloseError):\n            # Response was closed by the cancel watcher\n            if cancel_event is not None and cancel_event.is_set():\n                raise GeneratorExit\n            raise\n        finally:\n            _cancel_closed.set()\n\n    def generate_chat_completion(\n        self,\n        messages: list[dict],\n        image_b64: Optional[str] = None,\n        temperature: float = 0.6,\n        top_p: float = 0.95,\n        top_k: int = 20,\n        min_p: float = 0.01,\n        max_tokens: Optional[int] = None,\n        repetition_penalty: float = 1.0,\n        presence_penalty: float = 0.0,\n        stop: Optional[list[str]] = None,\n        cancel_event: Optional[threading.Event] = None,\n        enable_thinking: Optional[bool] = None,\n    ) -> Generator[str, None, None]:\n        \"\"\"\n        Send a chat completion request to llama-server and stream tokens back.\n\n        Uses /v1/chat/completions — llama-server handles chat template\n        application and vision (multimodal image_url parts) natively.\n\n        Yields cumulative text (matching InferenceBackend's convention).\n        \"\"\"\n        if not self.is_loaded:\n            raise RuntimeError(\"llama-server is not loaded\")\n\n        openai_messages = self._build_openai_messages(messages, image_b64)\n\n        payload = {\n            \"messages\": openai_messages,\n            \"stream\": True,\n            \"temperature\": temperature,\n            \"top_p\": top_p,\n            \"top_k\": top_k if top_k >= 0 else 0,\n            \"min_p\": min_p,\n            \"repeat_penalty\": repetition_penalty,\n            \"presence_penalty\": presence_penalty,\n        }\n        # Pass enable_thinking per-request for reasoning models\n        if self._supports_reasoning and enable_thinking is not None:\n            payload[\"chat_template_kwargs\"] = {\"enable_thinking\": enable_thinking}\n        if max_tokens is not None:\n            payload[\"max_tokens\"] = max_tokens\n        if stop:\n            payload[\"stop\"] = stop\n\n        url = f\"{self.base_url}/v1/chat/completions\"\n        cumulative = \"\"\n        in_thinking = False\n\n        try:\n            # _stream_with_retry uses a 120 s read timeout so prefill\n            # can finish.  Cancel during streaming is handled by the\n            # watcher thread (closes the response on cancel_event).\n            stream_timeout = httpx.Timeout(connect = 10, read = 0.5, write = 10, pool = 10)\n            with httpx.Client(timeout = stream_timeout) as client:\n                with self._stream_with_retry(\n                    client, url, payload, cancel_event\n                ) as response:\n                    if response.status_code != 200:\n                        error_body = response.read().decode()\n                        raise RuntimeError(\n                            f\"llama-server returned {response.status_code}: {error_body}\"\n                        )\n\n                    buffer = \"\"\n                    has_content_tokens = False\n                    reasoning_text = \"\"\n                    for raw_chunk in self._iter_text_cancellable(\n                        response, cancel_event\n                    ):\n                        buffer += raw_chunk\n                        while \"\\n\" in buffer:\n                            line, buffer = buffer.split(\"\\n\", 1)\n                            line = line.strip()\n\n                            if not line:\n                                continue\n                            if line == \"data: [DONE]\":\n                                if in_thinking:\n                                    if has_content_tokens:\n                                        # Real thinking + content: close the tag\n                                        cumulative += \"</think>\"\n                                        yield cumulative\n                                    else:\n                                        # Only reasoning_content, no content tokens:\n                                        # the model put its entire reply in reasoning\n                                        # (e.g. Qwen3 always-think mode). Show it\n                                        # as the main response, not as a thinking block.\n                                        cumulative = reasoning_text\n                                        yield cumulative\n                                return\n                            if not line.startswith(\"data: \"):\n                                continue\n\n                            try:\n                                data = json.loads(line[6:])\n                                choices = data.get(\"choices\", [])\n                                if choices:\n                                    delta = choices[0].get(\"delta\", {})\n\n                                    # Handle reasoning/thinking tokens\n                                    # llama-server sends these as \"reasoning_content\"\n                                    # Wrap in <think> tags for the frontend parser\n                                    reasoning = delta.get(\"reasoning_content\", \"\")\n                                    if reasoning:\n                                        reasoning_text += reasoning\n                                        if not in_thinking:\n                                            cumulative += \"<think>\"\n                                            in_thinking = True\n                                        cumulative += reasoning\n                                        yield cumulative\n\n                                    token = delta.get(\"content\", \"\")\n                                    if token:\n                                        has_content_tokens = True\n                                        if in_thinking:\n                                            cumulative += \"</think>\"\n                                            in_thinking = False\n                                        cumulative += token\n                                        yield cumulative\n                            except json.JSONDecodeError:\n                                logger.debug(\n                                    f\"Skipping malformed SSE line: {line[:100]}\"\n                                )\n\n        except httpx.ConnectError:\n            raise RuntimeError(\"Lost connection to llama-server\")\n        except Exception as e:\n            if cancel_event is not None and cancel_event.is_set():\n                return\n            raise\n\n    # ── Tool-calling agentic loop ──────────────────────────────\n\n    def generate_chat_completion_with_tools(\n        self,\n        messages: list[dict],\n        tools: list[dict],\n        temperature: float = 0.6,\n        top_p: float = 0.95,\n        top_k: int = 20,\n        min_p: float = 0.01,\n        max_tokens: Optional[int] = None,\n        repetition_penalty: float = 1.0,\n        presence_penalty: float = 0.0,\n        stop: Optional[list[str]] = None,\n        cancel_event: Optional[threading.Event] = None,\n        enable_thinking: Optional[bool] = None,\n        max_tool_iterations: int = 10,\n        auto_heal_tool_calls: bool = True,\n        tool_call_timeout: int = 300,\n        session_id: Optional[str] = None,\n    ) -> Generator[dict, None, None]:\n        \"\"\"\n        Agentic loop: let the model call tools, execute them, and continue.\n\n        Yields dicts with:\n          {\"type\": \"status\", \"text\": \"Searching: ...\"}   -- tool status updates\n          {\"type\": \"content\", \"text\": \"token\"}            -- streamed content tokens (cumulative)\n          {\"type\": \"reasoning\", \"text\": \"token\"}          -- streamed reasoning tokens (cumulative)\n        \"\"\"\n        from core.inference.tools import execute_tool\n\n        if not self.is_loaded:\n            raise RuntimeError(\"llama-server is not loaded\")\n\n        conversation = list(messages)\n        url = f\"{self.base_url}/v1/chat/completions\"\n\n        for iteration in range(max_tool_iterations):\n            if cancel_event is not None and cancel_event.is_set():\n                return\n\n            # Build payload for non-streaming tool detection pass\n            payload = {\n                \"messages\": conversation,\n                \"stream\": False,\n                \"temperature\": temperature,\n                \"top_p\": top_p,\n                \"top_k\": top_k if top_k >= 0 else 0,\n                \"min_p\": min_p,\n                \"repeat_penalty\": repetition_penalty,\n                \"presence_penalty\": presence_penalty,\n                \"tools\": tools,\n                \"tool_choice\": \"auto\",\n            }\n            if self._supports_reasoning and enable_thinking is not None:\n                payload[\"chat_template_kwargs\"] = {\"enable_thinking\": enable_thinking}\n            if max_tokens is not None:\n                payload[\"max_tokens\"] = max_tokens\n            if stop:\n                payload[\"stop\"] = stop\n\n            try:\n                with httpx.Client(timeout = None) as client:\n                    resp = client.post(url, json = payload)\n                    if resp.status_code != 200:\n                        raise RuntimeError(\n                            f\"llama-server returned {resp.status_code}: {resp.text}\"\n                        )\n                    data = resp.json()\n            except httpx.ConnectError:\n                raise RuntimeError(\"Lost connection to llama-server\")\n\n            choices = data.get(\"choices\", [])\n            if not choices:\n                return\n\n            choice = choices[0]\n            finish_reason = choice.get(\"finish_reason\", \"\")\n            message = choice.get(\"message\", {})\n\n            # If model wants to call tools\n            tool_calls = message.get(\"tool_calls\")\n\n            # Fallback: detect tool calls embedded as XML/text in content\n            # Some models output <tool_call> XML instead of structured tool_calls,\n            # or bare <function=...> tags without <tool_call> wrapper.\n            content_text = message.get(\"content\", \"\") or \"\"\n            if (\n                auto_heal_tool_calls\n                and not tool_calls\n                and (\"<tool_call>\" in content_text or \"<function=\" in content_text)\n            ):\n                tool_calls = self._parse_tool_calls_from_text(content_text)\n                if tool_calls:\n                    # Strip the tool call markup from content.\n                    # Use greedy match within <tool_call> blocks since they\n                    # can contain arbitrary content including code.\n                    import re\n\n                    # Strip <tool_call>...</tool_call> blocks (greedy inside)\n                    content_text = re.sub(\n                        r\"<tool_call>.*?</tool_call>\",\n                        \"\",\n                        content_text,\n                        flags = re.DOTALL,\n                    )\n                    # Strip unterminated <tool_call>... to end\n                    content_text = re.sub(\n                        r\"<tool_call>.*$\",\n                        \"\",\n                        content_text,\n                        flags = re.DOTALL,\n                    )\n                    # Strip bare <function=...>...</function> blocks\n                    content_text = re.sub(\n                        r\"<function=\\w+>.*?</function>\",\n                        \"\",\n                        content_text,\n                        flags = re.DOTALL,\n                    )\n                    # Strip unterminated bare <function=...> to end\n                    content_text = re.sub(\n                        r\"<function=\\w+>.*$\",\n                        \"\",\n                        content_text,\n                        flags = re.DOTALL,\n                    ).strip()\n                    logger.info(\n                        f\"Parsed {len(tool_calls)} tool call(s) from content text\"\n                    )\n\n            if finish_reason == \"tool_calls\" or (tool_calls and len(tool_calls) > 0):\n                # Append the assistant message with tool_calls to conversation\n                assistant_msg = {\"role\": \"assistant\", \"content\": content_text}\n                if tool_calls:\n                    assistant_msg[\"tool_calls\"] = tool_calls\n                conversation.append(assistant_msg)\n\n                # Execute each tool call\n                for tc in tool_calls or []:\n                    func = tc.get(\"function\", {})\n                    tool_name = func.get(\"name\", \"\")\n                    raw_args = func.get(\"arguments\", {})\n\n                    # Handle arguments as either string or dict\n                    if isinstance(raw_args, str):\n                        try:\n                            arguments = json.loads(raw_args)\n                        except (json.JSONDecodeError, ValueError):\n                            if auto_heal_tool_calls:\n                                arguments = {\"query\": raw_args}\n                            else:\n                                arguments = {\"raw\": raw_args}\n                    else:\n                        arguments = raw_args\n\n                    # Yield status update\n                    if tool_name == \"web_search\":\n                        status_text = f\"Searching: {arguments.get('query', '')}\"\n                    elif tool_name == \"python\":\n                        preview = (\n                            (arguments.get(\"code\") or \"\").strip().split(\"\\n\")[0][:60]\n                        )\n                        status_text = (\n                            f\"Running Python: {preview}\"\n                            if preview\n                            else \"Running Python...\"\n                        )\n                    elif tool_name == \"terminal\":\n                        cmd_preview = (arguments.get(\"command\") or \"\")[:60]\n                        status_text = (\n                            f\"Running: {cmd_preview}\"\n                            if cmd_preview\n                            else \"Running command...\"\n                        )\n                    else:\n                        status_text = f\"Calling: {tool_name}\"\n                    yield {\"type\": \"status\", \"text\": status_text}\n\n                    # Emit tool_start so the frontend can record inputs\n                    yield {\n                        \"type\": \"tool_start\",\n                        \"tool_name\": tool_name,\n                        \"tool_call_id\": tc.get(\"id\", \"\"),\n                        \"arguments\": arguments,\n                    }\n\n                    # Execute the tool\n                    _effective_timeout = (\n                        None if tool_call_timeout >= 9999 else tool_call_timeout\n                    )\n                    result = execute_tool(\n                        tool_name,\n                        arguments,\n                        cancel_event = cancel_event,\n                        timeout = _effective_timeout,\n                        session_id = session_id,\n                    )\n\n                    # Emit tool_end so the frontend can record outputs\n                    yield {\n                        \"type\": \"tool_end\",\n                        \"tool_name\": tool_name,\n                        \"tool_call_id\": tc.get(\"id\", \"\"),\n                        \"result\": result,\n                    }\n\n                    # Append tool result to conversation\n                    tool_msg = {\n                        \"role\": \"tool\",\n                        \"name\": tool_name,\n                        \"content\": result,\n                    }\n                    tool_call_id = tc.get(\"id\")\n                    if tool_call_id:\n                        tool_msg[\"tool_call_id\"] = tool_call_id\n                    conversation.append(tool_msg)\n\n                # Continue the loop to let model respond with context\n                continue\n\n            # No tool calls -- model answered directly.\n            # If no tools were executed at all, just yield the content\n            # from this response instead of making a redundant second request.\n            if iteration == 0 and content_text:\n                yield {\"type\": \"status\", \"text\": \"\"}\n                yield {\"type\": \"content\", \"text\": content_text}\n                return\n\n            # Tools were called in previous iterations; do a final\n            # streaming pass so the model can synthesize a response\n            # incorporating the tool results.\n            break\n\n        # Clear status\n        yield {\"type\": \"status\", \"text\": \"\"}\n\n        # Final streaming pass with the full conversation context\n        stream_payload = {\n            \"messages\": conversation,\n            \"stream\": True,\n            \"temperature\": temperature,\n            \"top_p\": top_p,\n            \"top_k\": top_k if top_k >= 0 else 0,\n            \"min_p\": min_p,\n            \"repeat_penalty\": repetition_penalty,\n            \"presence_penalty\": presence_penalty,\n        }\n        if self._supports_reasoning and enable_thinking is not None:\n            stream_payload[\"chat_template_kwargs\"] = {\n                \"enable_thinking\": enable_thinking\n            }\n        if max_tokens is not None:\n            stream_payload[\"max_tokens\"] = max_tokens\n        if stop:\n            stream_payload[\"stop\"] = stop\n\n        import re as _re_final\n\n        # Closed blocks only -- safe to strip mid-stream without shrinking later.\n        _TOOL_CLOSED_PATTERNS = [\n            _re_final.compile(r\"<tool_call>.*?</tool_call>\", _re_final.DOTALL),\n            _re_final.compile(r\"<function=\\w+>.*?</function>\", _re_final.DOTALL),\n        ]\n        # Open-ended patterns strip from an opening tag to end-of-string.\n        # Only applied on the final flush to avoid non-monotonic shrinking.\n        _TOOL_ALL_PATTERNS = _TOOL_CLOSED_PATTERNS + [\n            _re_final.compile(r\"<tool_call>.*$\", _re_final.DOTALL),\n            _re_final.compile(r\"<function=\\w+>.*$\", _re_final.DOTALL),\n        ]\n\n        def _strip_tool_markup(text: str, *, final: bool = False) -> str:\n            if not auto_heal_tool_calls:\n                return text\n            patterns = _TOOL_ALL_PATTERNS if final else _TOOL_CLOSED_PATTERNS\n            for pat in patterns:\n                text = pat.sub(\"\", text)\n            return text.strip() if final else text\n\n        cumulative = \"\"\n        _last_emitted = \"\"\n        in_thinking = False\n        has_content_tokens = False\n        reasoning_text = \"\"\n\n        try:\n            stream_timeout = httpx.Timeout(connect = 10, read = 0.5, write = 10, pool = 10)\n            with httpx.Client(timeout = stream_timeout) as client:\n                with self._stream_with_retry(\n                    client, url, stream_payload, cancel_event\n                ) as response:\n                    if response.status_code != 200:\n                        error_body = response.read().decode()\n                        raise RuntimeError(\n                            f\"llama-server returned {response.status_code}: {error_body}\"\n                        )\n\n                    buffer = \"\"\n                    for raw_chunk in self._iter_text_cancellable(\n                        response, cancel_event\n                    ):\n                        buffer += raw_chunk\n                        while \"\\n\" in buffer:\n                            line, buffer = buffer.split(\"\\n\", 1)\n                            line = line.strip()\n\n                            if not line:\n                                continue\n                            if line == \"data: [DONE]\":\n                                if in_thinking:\n                                    if has_content_tokens:\n                                        cumulative += \"</think>\"\n                                        yield {\n                                            \"type\": \"content\",\n                                            \"text\": _strip_tool_markup(\n                                                cumulative, final = True\n                                            ),\n                                        }\n                                    else:\n                                        cumulative = reasoning_text\n                                        yield {\"type\": \"content\", \"text\": cumulative}\n                                return\n                            if not line.startswith(\"data: \"):\n                                continue\n\n                            try:\n                                chunk_data = json.loads(line[6:])\n                                choices = chunk_data.get(\"choices\", [])\n                                if choices:\n                                    delta = choices[0].get(\"delta\", {})\n\n                                    reasoning = delta.get(\"reasoning_content\", \"\")\n                                    if reasoning:\n                                        reasoning_text += reasoning\n                                        if not in_thinking:\n                                            cumulative += \"<think>\"\n                                            in_thinking = True\n                                        cumulative += reasoning\n                                        yield {\"type\": \"content\", \"text\": cumulative}\n\n                                    token = delta.get(\"content\", \"\")\n                                    if token:\n                                        has_content_tokens = True\n                                        if in_thinking:\n                                            cumulative += \"</think>\"\n                                            in_thinking = False\n                                        cumulative += token\n                                        cleaned = _strip_tool_markup(cumulative)\n                                        # Only emit when cleaned text grows (monotonic).\n                                        if len(cleaned) > len(_last_emitted):\n                                            _last_emitted = cleaned\n                                            yield {\"type\": \"content\", \"text\": cleaned}\n                            except json.JSONDecodeError:\n                                logger.debug(\n                                    f\"Skipping malformed SSE line: {line[:100]}\"\n                                )\n\n        except httpx.ConnectError:\n            raise RuntimeError(\"Lost connection to llama-server\")\n        except Exception as e:\n            if cancel_event is not None and cancel_event.is_set():\n                return\n            raise\n\n    # ── TTS support ────────────────────────────────────────────\n\n    def detect_audio_type(self) -> Optional[str]:\n        \"\"\"Detect audio/TTS codec by probing the loaded model's vocabulary.\"\"\"\n        if not self.is_loaded:\n            return None\n        try:\n            with httpx.Client(timeout = 10) as client:\n\n                def _detok(tid: int) -> str:\n                    r = client.post(\n                        f\"{self.base_url}/detokenize\", json = {\"tokens\": [tid]}\n                    )\n                    return r.json().get(\"content\", \"\") if r.status_code == 200 else \"\"\n\n                def _tok(text: str) -> list[int]:\n                    r = client.post(\n                        f\"{self.base_url}/tokenize\",\n                        json = {\"content\": text, \"add_special\": False},\n                    )\n                    return r.json().get(\"tokens\", []) if r.status_code == 200 else []\n\n                # Check codec-specific tokens (not generic ones that may exist in non-audio models)\n                if \"<custom_token_\" in _detok(128258) and \"<custom_token_\" in _detok(\n                    128259\n                ):\n                    return \"snac\"\n                if len(_tok(\"<|AUDIO|>\")) == 1 and len(_tok(\"<|audio_eos|>\")) == 1:\n                    return \"csm\"\n                if len(_tok(\"<|startoftranscript|>\")) == 1:\n                    return \"whisper\"\n                if (\n                    len(_tok(\"<|bicodec_semantic_0|>\")) == 1\n                    and len(_tok(\"<|bicodec_global_0|>\")) == 1\n                ):\n                    return \"bicodec\"\n                if len(_tok(\"<|c1_0|>\")) == 1 and len(_tok(\"<|c2_0|>\")) == 1:\n                    return \"dac\"\n        except Exception as e:\n            logger.debug(f\"Audio type detection failed: {e}\")\n        return None\n\n    # Prompt format per codec: (template, stop_tokens, needs_token_ids)\n    # Matches prompts in InferenceBackend._generate_snac/bicodec/dac\n    _TTS_PROMPTS = {\n        \"snac\": (\n            \"<custom_token_3>{text}<|eot_id|><custom_token_4>\",\n            [\"<custom_token_2>\"],\n            True,\n        ),\n        \"bicodec\": (\n            \"<|task_tts|><|start_content|>{text}<|end_content|><|start_global_token|>\",\n            [\"<|im_end|>\", \"</s>\"],\n            False,\n        ),\n        \"dac\": (\n            \"<|im_start|>\\n<|text_start|>{text}<|text_end|>\\n<|audio_start|><|global_features_start|>\\n\",\n            [\"<|im_end|>\", \"<|audio_end|>\"],\n            False,\n        ),\n    }\n\n    _codec_mgr = None  # Shared AudioCodecManager instance\n\n    def init_audio_codec(self, audio_type: str) -> None:\n        \"\"\"Load the audio codec at model load time (mirrors non-GGUF path).\"\"\"\n        import torch\n        from core.inference.audio_codecs import AudioCodecManager\n\n        if LlamaCppBackend._codec_mgr is None:\n            LlamaCppBackend._codec_mgr = AudioCodecManager()\n\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n        model_repo_path = None\n\n        # BiCodec needs a repo with BiCodec/ weights — download canonical SparkTTS\n        if audio_type == \"bicodec\":\n            from huggingface_hub import snapshot_download\n            import os\n\n            repo_path = snapshot_download(\n                \"unsloth/Spark-TTS-0.5B\", local_dir = \"Spark-TTS-0.5B\"\n            )\n            model_repo_path = os.path.abspath(repo_path)\n\n        LlamaCppBackend._codec_mgr.load_codec(\n            audio_type, device, model_repo_path = model_repo_path\n        )\n        logger.info(f\"Loaded audio codec for GGUF TTS: {audio_type}\")\n\n    def generate_audio_response(\n        self,\n        text: str,\n        audio_type: str,\n        temperature: float = 0.6,\n        top_p: float = 0.95,\n        top_k: int = 50,\n        min_p: float = 0.0,\n        max_new_tokens: int = 2048,\n        repetition_penalty: float = 1.1,\n    ) -> tuple:\n        \"\"\"\n        Generate TTS audio via llama-server /completion + codec decoding.\n        Returns (wav_bytes, sample_rate).\n        \"\"\"\n        if audio_type not in self._TTS_PROMPTS:\n            raise RuntimeError(f\"GGUF TTS does not support '{audio_type}' codec.\")\n\n        tpl, stop, need_ids = self._TTS_PROMPTS[audio_type]\n\n        payload: dict = {\n            \"prompt\": tpl.format(text = text),\n            \"stream\": False,\n            \"n_predict\": max_new_tokens,\n            \"temperature\": temperature,\n            \"top_p\": top_p,\n            \"top_k\": top_k if top_k >= 0 else 0,\n            \"min_p\": min_p,\n            \"repeat_penalty\": repetition_penalty,\n        }\n        if stop:\n            payload[\"stop\"] = stop\n        if need_ids:\n            payload[\"n_probs\"] = 1\n\n        with httpx.Client(timeout = httpx.Timeout(300, connect = 10)) as client:\n            resp = client.post(f\"{self.base_url}/completion\", json = payload)\n            if resp.status_code != 200:\n                raise RuntimeError(\n                    f\"llama-server returned {resp.status_code}: {resp.text}\"\n                )\n\n        data = resp.json()\n        token_ids = (\n            [p[\"id\"] for p in data.get(\"completion_probabilities\", []) if \"id\" in p]\n            if need_ids\n            else None\n        )\n\n        import torch\n\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n        return LlamaCppBackend._codec_mgr.decode(\n            audio_type, device, token_ids = token_ids, text = data.get(\"content\", \"\")\n        )\n"
  },
  {
    "path": "studio/backend/core/inference/orchestrator.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nInference orchestrator — subprocess-based.\n\nProvides the same API as InferenceBackend, but delegates all ML work\nto a persistent subprocess. The subprocess is spawned on first model load\nand stays alive for subsequent requests.\n\nWhen switching between models that need different transformers versions\n(e.g. GLM-4.7-Flash needs 5.x, Qwen needs 4.57.x), the old subprocess\nis killed and a new one is spawned with the correct version.\n\nPattern follows core/training/training.py.\n\"\"\"\n\nimport atexit\nimport base64\nimport structlog\nfrom loggers import get_logger\nimport multiprocessing as mp\nimport queue\nimport threading\nimport time\nimport uuid\nfrom io import BytesIO\nfrom pathlib import Path\nfrom typing import Any, Generator, Optional, Tuple, Union\n\nlogger = get_logger(__name__)\n\n_CTX = mp.get_context(\"spawn\")\n\n# Dispatcher timeout constants (seconds)\n_DISPATCH_READ_TIMEOUT = 30.0\n_DISPATCH_POLL_INTERVAL = 0.5\n_DISPATCH_STOP_TIMEOUT = 5.0\n_DISPATCH_IDLE_TIMEOUT = 30.0\n_DISPATCH_DRAIN_TIMEOUT = 5.0\n\n\nclass InferenceOrchestrator:\n    \"\"\"\n    Inference backend orchestrator — subprocess-based.\n\n    Exposes the same API surface as InferenceBackend so routes/inference.py\n    needs minimal changes. Internally, all heavy ML operations happen in\n    a persistent subprocess.\n    \"\"\"\n\n    def __init__(self):\n        # Subprocess state\n        self._proc: Optional[mp.Process] = None\n        self._cmd_queue: Any = None\n        self._resp_queue: Any = None\n        self._cancel_event: Any = None  # mp.Event — set to cancel generation instantly\n        self._lock = threading.Lock()\n        self._gen_lock = (\n            threading.Lock()\n        )  # Serializes generation — one request at a time\n\n        # Dispatcher state — for compare mode (adapter-controlled requests).\n        # Instead of serializing via _gen_lock, adapter-controlled requests\n        # send commands directly to the subprocess and read from per-request\n        # mailboxes. A dispatcher thread routes resp_queue events by request_id.\n        self._mailboxes: dict[str, queue.Queue] = {}\n        self._mailbox_lock = threading.Lock()  # Protects _mailboxes dict\n        self._dispatcher_thread: Optional[threading.Thread] = None\n        self._dispatcher_stop = threading.Event()\n\n        # Local state mirrors (updated from subprocess responses)\n        self.active_model_name: Optional[str] = None\n        self.models: dict = {}\n        self.loading_models: set = set()\n        self.loaded_local_models: list = []\n        from core.inference.defaults import get_default_models\n\n        self._static_models = get_default_models()\n        self._top_gguf_cache: Optional[list[str]] = None\n        self._top_hub_cache: Optional[list[str]] = None\n        self._top_models_ready = threading.Event()\n\n        # Version tracking for subprocess reuse\n        self._current_transformers_major: Optional[str] = None  # \"4\" or \"5\"\n\n        atexit.register(self._cleanup)\n        logger.info(\"InferenceOrchestrator initialized (subprocess mode)\")\n\n        # Kick off background fetch of top models from HF\n        threading.Thread(\n            target = self._fetch_top_models, daemon = True, name = \"top-models\"\n        ).start()\n\n    # ------------------------------------------------------------------\n    # Default models (top GGUFs fetched dynamically from HF)\n    # ------------------------------------------------------------------\n\n    @property\n    def default_models(self) -> list[str]:\n        # Wait up to 5s for background HF fetch to finish\n        self._top_models_ready.wait(timeout = 5)\n        top_gguf = self._top_gguf_cache or []\n        top_hub = self._top_hub_cache or []\n        # GGUFs first, then hub models, then static fallbacks.\n        # Send extras so the frontend still has 4 per category\n        # after removing already-downloaded models.\n        result: list[str] = []\n        seen: set[str] = set()\n        for m in top_gguf + top_hub + self._static_models:\n            if m not in seen:\n                result.append(m)\n                seen.add(m)\n        return result\n\n    def _fetch_top_models(self) -> None:\n        \"\"\"Fetch top GGUF and non-GGUF repos from unsloth by downloads.\"\"\"\n        try:\n            import httpx\n\n            resp = httpx.get(\n                \"https://huggingface.co/api/models\",\n                params = {\n                    \"author\": \"unsloth\",\n                    \"sort\": \"downloads\",\n                    \"direction\": \"-1\",\n                    \"limit\": \"80\",\n                },\n                timeout = 15,\n            )\n            if resp.status_code == 200:\n                models = resp.json()\n                # Top 40 GGUFs - frontend pages through them on-demand via\n                # infinite scroll, so we send a deep pool.\n                gguf_ids = [\n                    m[\"id\"] for m in models if m.get(\"id\", \"\").upper().endswith(\"-GGUF\")\n                ][:40]\n                # Top 40 non-GGUF hub models\n                hub_ids = [\n                    m[\"id\"]\n                    for m in models\n                    if not m.get(\"id\", \"\").upper().endswith(\"-GGUF\")\n                ][:40]\n                if gguf_ids:\n                    self._top_gguf_cache = gguf_ids\n                    logger.info(\"Top GGUF models: %s\", gguf_ids)\n                if hub_ids:\n                    self._top_hub_cache = hub_ids\n                    logger.info(\"Top hub models: %s\", hub_ids)\n        except Exception as e:\n            logger.warning(\"Failed to fetch top models: %s\", e)\n        finally:\n            self._top_models_ready.set()\n\n    # ------------------------------------------------------------------\n    # Subprocess lifecycle\n    # ------------------------------------------------------------------\n\n    def _spawn_subprocess(self, config: dict) -> None:\n        \"\"\"Spawn a new inference subprocess.\"\"\"\n        from .worker import run_inference_process\n\n        self._cmd_queue = _CTX.Queue()\n        self._resp_queue = _CTX.Queue()\n        self._cancel_event = _CTX.Event()\n\n        self._proc = _CTX.Process(\n            target = run_inference_process,\n            kwargs = {\n                \"cmd_queue\": self._cmd_queue,\n                \"resp_queue\": self._resp_queue,\n                \"cancel_event\": self._cancel_event,\n                \"config\": config,\n            },\n            daemon = True,\n        )\n        self._proc.start()\n        logger.info(\"Inference subprocess started (pid=%s)\", self._proc.pid)\n\n    def _cancel_generation(self) -> None:\n        \"\"\"Cancel any ongoing generation in the subprocess (instant).\"\"\"\n        if self._cancel_event is not None:\n            self._cancel_event.set()\n\n    def _shutdown_subprocess(self, timeout: float = 10.0) -> None:\n        \"\"\"Gracefully shut down the inference subprocess.\"\"\"\n        self._stop_dispatcher()  # Stop dispatcher before killing subprocess\n        if self._proc is None or not self._proc.is_alive():\n            self._proc = None\n            return\n\n        # 1. Cancel any ongoing generation first (instant via mp.Event)\n        self._cancel_generation()\n        time.sleep(0.5)  # Brief wait for generation to stop\n\n        # 2. Drain stale responses from queue\n        self._drain_queue()\n\n        # 3. Send shutdown command\n        try:\n            self._cmd_queue.put({\"type\": \"shutdown\"})\n        except (OSError, ValueError):\n            pass\n\n        # 4. Wait for graceful shutdown\n        try:\n            self._proc.join(timeout = timeout)\n        except Exception:\n            pass\n\n        # 5. Force kill if still alive\n        if self._proc is not None and self._proc.is_alive():\n            logger.warning(\"Inference subprocess did not exit gracefully, terminating\")\n            try:\n                self._proc.terminate()\n                self._proc.join(timeout = 5)\n            except Exception:\n                pass\n            if self._proc is not None and self._proc.is_alive():\n                logger.warning(\"Subprocess still alive after terminate, killing\")\n                try:\n                    self._proc.kill()\n                    self._proc.join(timeout = 3)\n                except Exception:\n                    pass\n\n        self._proc = None\n        self._cmd_queue = None\n        self._resp_queue = None\n        self._cancel_event = None\n        logger.info(\"Inference subprocess shut down\")\n\n    def _cleanup(self):\n        \"\"\"atexit handler.\"\"\"\n        self._shutdown_subprocess(timeout = 5.0)\n\n    def _ensure_subprocess_alive(self) -> bool:\n        \"\"\"Check if subprocess is alive.\"\"\"\n        return self._proc is not None and self._proc.is_alive()\n\n    # ------------------------------------------------------------------\n    # Queue helpers\n    # ------------------------------------------------------------------\n\n    def _send_cmd(self, cmd: dict) -> None:\n        \"\"\"Send a command to the subprocess.\"\"\"\n        if self._cmd_queue is None:\n            raise RuntimeError(\"No inference subprocess running\")\n        try:\n            self._cmd_queue.put(cmd)\n        except (OSError, ValueError) as exc:\n            raise RuntimeError(f\"Failed to send command to subprocess: {exc}\")\n\n    def _read_resp(self, timeout: float = 1.0) -> Optional[dict]:\n        \"\"\"Read a response from the subprocess (non-blocking with timeout).\"\"\"\n        if self._resp_queue is None:\n            return None\n        try:\n            return self._resp_queue.get(timeout = timeout)\n        except queue.Empty:\n            return None\n        except (EOFError, OSError, ValueError):\n            return None\n\n    def _wait_response(self, expected_type: str, timeout: float = 120.0) -> dict:\n        \"\"\"Block until a response of the expected type arrives.\n\n        Also handles 'status' and 'error' events during the wait.\n        Returns the matching response dict.\n        Raises RuntimeError on timeout or subprocess crash.\n        \"\"\"\n        deadline = time.monotonic() + timeout\n\n        while time.monotonic() < deadline:\n            remaining = max(0.1, deadline - time.monotonic())\n            resp = self._read_resp(timeout = min(remaining, 1.0))\n\n            if resp is None:\n                # Check subprocess health\n                if not self._ensure_subprocess_alive():\n                    raise RuntimeError(\"Inference subprocess crashed during wait\")\n                continue\n\n            rtype = resp.get(\"type\", \"\")\n\n            if rtype == expected_type:\n                return resp\n\n            if rtype == \"error\":\n                error_msg = resp.get(\"error\", \"Unknown error\")\n                raise RuntimeError(f\"Subprocess error: {error_msg}\")\n\n            if rtype == \"status\":\n                logger.info(\"Subprocess status: %s\", resp.get(\"message\", \"\"))\n                continue\n\n            # Other response types during wait — skip\n            logger.debug(\n                \"Skipping response type '%s' while waiting for '%s'\",\n                rtype,\n                expected_type,\n            )\n\n        raise RuntimeError(\n            f\"Timeout waiting for '{expected_type}' response after {timeout}s\"\n        )\n\n    def _drain_queue(self) -> list:\n        \"\"\"Drain all pending responses.\"\"\"\n        events = []\n        if self._resp_queue is None:\n            return events\n        while True:\n            try:\n                events.append(self._resp_queue.get_nowait())\n            except queue.Empty:\n                return events\n            except (EOFError, OSError, ValueError):\n                return events\n\n    def _drain_until_gen_done(self, timeout: float = 5.0) -> None:\n        \"\"\"Consume resp_queue events until gen_done/gen_error, discarding them.\n\n        Called after cancel to ensure stale tokens from the cancelled\n        generation don't leak into the next request.\n        \"\"\"\n        deadline = time.monotonic() + timeout\n        while time.monotonic() < deadline:\n            resp = self._read_resp(timeout = min(0.5, deadline - time.monotonic()))\n            if resp is None:\n                if not self._ensure_subprocess_alive():\n                    return\n                continue\n            rtype = resp.get(\"type\", \"\")\n            if rtype in (\"gen_done\", \"gen_error\"):\n                return\n        logger.warning(\"Timed out waiting for gen_done after cancel\")\n\n    # ------------------------------------------------------------------\n    # Dispatcher — per-request mailbox routing for compare mode\n    # ------------------------------------------------------------------\n\n    def _start_dispatcher(self) -> None:\n        \"\"\"Start the dispatcher thread if not already running.\n\n        The dispatcher reads from the shared resp_queue and routes\n        responses to per-request mailbox queues. This allows multiple\n        adapter-controlled (compare) requests to be in-flight without\n        holding _gen_lock.\n        \"\"\"\n        if self._dispatcher_thread is not None and self._dispatcher_thread.is_alive():\n            return\n\n        self._dispatcher_stop.clear()\n        self._dispatcher_thread = threading.Thread(\n            target = self._dispatcher_loop,\n            daemon = True,\n            name = \"inference-dispatcher\",\n        )\n        self._dispatcher_thread.start()\n        logger.debug(\"Dispatcher thread started\")\n\n    def _stop_dispatcher(self) -> None:\n        \"\"\"Signal the dispatcher to stop and wait for it.\"\"\"\n        if self._dispatcher_thread is None:\n            return\n        self._dispatcher_stop.set()\n        self._dispatcher_thread.join(timeout = _DISPATCH_STOP_TIMEOUT)\n        self._dispatcher_thread = None\n        logger.debug(\"Dispatcher thread stopped\")\n\n    def _dispatcher_loop(self) -> None:\n        \"\"\"Background loop: read resp_queue → route to mailboxes by request_id.\"\"\"\n        while not self._dispatcher_stop.is_set():\n            if self._resp_queue is None:\n                break\n\n            try:\n                resp = self._resp_queue.get(timeout = _DISPATCH_POLL_INTERVAL)\n            except queue.Empty:\n                continue\n            except (EOFError, OSError, ValueError):\n                break\n\n            rid = resp.get(\"request_id\")\n            rtype = resp.get(\"type\", \"\")\n\n            # Status messages — log and skip\n            if rtype == \"status\":\n                logger.info(\"Subprocess status: %s\", resp.get(\"message\", \"\"))\n                continue\n\n            # Route to mailbox if a matching request_id exists\n            if rid:\n                with self._mailbox_lock:\n                    mbox = self._mailboxes.get(rid)\n                if mbox is not None:\n                    mbox.put(resp)\n                    continue\n\n            # No matching mailbox — might be for a _gen_lock reader or orphaned\n            # Push it back so _read_resp can pick it up. But we can't un-get\n            # from mp.Queue, so log a warning.\n            if rtype not in (\"status\",):\n                logger.debug(\n                    \"Dispatcher: no mailbox for request_id=%s type=%s, dropping\",\n                    rid,\n                    rtype,\n                )\n\n    def _generate_dispatched(\n        self,\n        messages: list = None,\n        system_prompt: str = \"\",\n        image = None,\n        temperature: float = 0.7,\n        top_p: float = 0.9,\n        top_k: int = 40,\n        min_p: float = 0.0,\n        max_new_tokens: int = 256,\n        repetition_penalty: float = 1.0,\n        cancel_event = None,\n        use_adapter = None,\n    ) -> Generator[str, None, None]:\n        \"\"\"Dispatched generation — sends command without holding _gen_lock.\n\n        Uses a per-request mailbox to receive tokens. This allows two\n        compare-mode requests to be queued in the subprocess simultaneously,\n        eliminating the inter-generation round-trip overhead.\n\n        The subprocess processes commands sequentially from its cmd_queue,\n        so generation is still serialized at the GPU level — we just avoid\n        the orchestrator-level lock contention.\n        \"\"\"\n        if not self._ensure_subprocess_alive():\n            yield \"Error: Inference subprocess is not running\"\n            return\n\n        if not self.active_model_name:\n            yield \"Error: No active model\"\n            return\n\n        # Ensure dispatcher is running\n        self._start_dispatcher()\n\n        request_id = str(uuid.uuid4())\n\n        # Convert PIL Image to base64 if needed\n        image_b64 = None\n        if image is not None:\n            image_b64 = self._pil_to_base64(image)\n\n        cmd = {\n            \"type\": \"generate\",\n            \"request_id\": request_id,\n            \"messages\": messages or [],\n            \"system_prompt\": system_prompt,\n            \"image_base64\": image_b64,\n            \"temperature\": temperature,\n            \"top_p\": top_p,\n            \"top_k\": top_k,\n            \"min_p\": min_p,\n            \"max_new_tokens\": max_new_tokens,\n            \"repetition_penalty\": repetition_penalty,\n        }\n\n        if use_adapter is not None:\n            cmd[\"use_adapter\"] = use_adapter\n\n        # Create mailbox BEFORE sending command\n        mailbox: queue.Queue = queue.Queue()\n        with self._mailbox_lock:\n            self._mailboxes[request_id] = mailbox\n\n        try:\n            self._send_cmd(cmd)\n        except RuntimeError as exc:\n            with self._mailbox_lock:\n                self._mailboxes.pop(request_id, None)\n            yield f\"Error: {exc}\"\n            return\n\n        # Read tokens from our private mailbox\n        try:\n            while True:\n                try:\n                    resp = mailbox.get(timeout = _DISPATCH_READ_TIMEOUT)\n                except queue.Empty:\n                    # Timeout — check subprocess health\n                    if not self._ensure_subprocess_alive():\n                        yield \"Error: Inference subprocess crashed during generation\"\n                        return\n                    continue\n\n                rtype = resp.get(\"type\", \"\")\n\n                if rtype == \"token\":\n                    # Check cancel from route (e.g. SSE connection closed)\n                    if cancel_event is not None and cancel_event.is_set():\n                        self._cancel_generation()\n                        # Drain remaining events for this request\n                        self._drain_mailbox(mailbox, timeout = 5.0)\n                        return\n                    yield resp.get(\"text\", \"\")\n\n                elif rtype == \"gen_done\":\n                    return\n\n                elif rtype == \"gen_error\":\n                    yield f\"Error: {resp.get('error', 'Unknown error')}\"\n                    return\n        finally:\n            with self._mailbox_lock:\n                self._mailboxes.pop(request_id, None)\n\n    def _drain_mailbox(self, mailbox: queue.Queue, timeout: float = 5.0) -> None:\n        \"\"\"Drain a mailbox until gen_done/gen_error, discarding tokens.\"\"\"\n        deadline = time.monotonic() + timeout\n        while time.monotonic() < deadline:\n            try:\n                resp = mailbox.get(\n                    timeout = min(_DISPATCH_POLL_INTERVAL, deadline - time.monotonic())\n                )\n            except queue.Empty:\n                continue\n            rtype = resp.get(\"type\", \"\")\n            if rtype in (\"gen_done\", \"gen_error\"):\n                return\n        logger.warning(\"Timed out draining mailbox after cancel\")\n\n    def _wait_dispatcher_idle(self) -> None:\n        \"\"\"Wait for all dispatched requests to complete, then stop dispatcher.\n\n        Called by _generate_inner before using the _gen_lock path, to ensure\n        the dispatcher thread isn't competing for resp_queue reads.\n        \"\"\"\n        if self._dispatcher_thread is None or not self._dispatcher_thread.is_alive():\n            return\n\n        # Wait for all mailboxes to be emptied (dispatched requests complete)\n        deadline = time.monotonic() + _DISPATCH_IDLE_TIMEOUT\n        while time.monotonic() < deadline:\n            with self._mailbox_lock:\n                if not self._mailboxes:\n                    break\n            time.sleep(0.1)\n\n        # Only stop dispatcher if all mailboxes drained.  If compare\n        # requests are still active, leave the dispatcher running so\n        # their token routing isn't killed mid-stream.\n        with self._mailbox_lock:\n            still_active = bool(self._mailboxes)\n        if still_active:\n            logger.warning(\n                \"Dispatcher still has %d active mailbox(es); \"\n                \"leaving dispatcher running for compare requests\",\n                len(self._mailboxes),\n            )\n        else:\n            self._stop_dispatcher()\n\n    # ------------------------------------------------------------------\n    # Public API — same interface as InferenceBackend\n    # ------------------------------------------------------------------\n\n    def load_model(\n        self,\n        config,  # ModelConfig\n        max_seq_length: int = 2048,\n        dtype = None,\n        load_in_4bit: bool = True,\n        hf_token: Optional[str] = None,\n        trust_remote_code: bool = False,\n    ) -> bool:\n        \"\"\"Load a model for inference.\n\n        Always spawns a fresh subprocess for each model load. This ensures\n        a clean Python interpreter — no stale unsloth patches, torch.compile\n        caches, or inspect.getsource() failures from a previous model.\n        \"\"\"\n        from utils.transformers_version import needs_transformers_5\n\n        model_name = config.identifier\n        self.loading_models.add(model_name)\n\n        try:\n            needed_major = \"5\" if needs_transformers_5(model_name) else \"4\"\n\n            # Build config dict for subprocess\n            sub_config = {\n                \"model_name\": model_name,\n                \"max_seq_length\": max_seq_length,\n                \"load_in_4bit\": load_in_4bit,\n                \"hf_token\": hf_token or \"\",\n                \"gguf_variant\": getattr(config, \"gguf_variant\", None),\n                \"trust_remote_code\": trust_remote_code,\n            }\n\n            # Always kill existing subprocess and spawn fresh.\n            # Reusing a subprocess after unsloth patches torch internals\n            # causes inspect.getsource() failures on the next model load.\n            if self._ensure_subprocess_alive():\n                self._cancel_generation()\n                time.sleep(0.3)\n                self._shutdown_subprocess()\n\n            elif self._proc is not None:\n                # Dead subprocess — clean up\n                self._shutdown_subprocess(timeout = 2)\n\n            logger.info(\n                \"Spawning fresh inference subprocess for '%s' (transformers %s.x)\",\n                model_name,\n                needed_major,\n            )\n            self._spawn_subprocess(sub_config)\n            resp = self._wait_response(\"loaded\", timeout = 180)\n\n            # Update local state from response\n            if resp.get(\"success\"):\n                self._current_transformers_major = needed_major\n                model_info = resp.get(\"model_info\", {})\n                self.active_model_name = model_info.get(\"identifier\", model_name)\n                self.models[self.active_model_name] = {\n                    \"is_vision\": model_info.get(\"is_vision\", False),\n                    \"is_lora\": model_info.get(\"is_lora\", False),\n                    \"display_name\": model_info.get(\"display_name\", model_name),\n                    \"is_audio\": model_info.get(\"is_audio\", False),\n                    \"audio_type\": model_info.get(\"audio_type\"),\n                    \"has_audio_input\": model_info.get(\"has_audio_input\", False),\n                }\n                self.loading_models.discard(model_name)\n                logger.info(\"Model '%s' loaded successfully in subprocess\", model_name)\n                return True\n            else:\n                error = resp.get(\"error\", \"Failed to load model\")\n                self.loading_models.discard(model_name)\n                self.active_model_name = None\n                self.models.clear()\n                raise Exception(error)\n\n        except Exception:\n            self.loading_models.discard(model_name)\n            self.active_model_name = None\n            self.models.clear()\n            raise\n\n    def unload_model(self, model_name: str) -> bool:\n        \"\"\"Unload a model from the subprocess.\"\"\"\n        if not self._ensure_subprocess_alive():\n            # No subprocess — just clear local state\n            self.models.pop(model_name, None)\n            if self.active_model_name == model_name:\n                self.active_model_name = None\n            return True\n\n        try:\n            self._send_cmd(\n                {\n                    \"type\": \"unload\",\n                    \"model_name\": model_name,\n                }\n            )\n            resp = self._wait_response(\"unloaded\", timeout = 30)\n\n            # Update local state\n            self.models.pop(model_name, None)\n            if self.active_model_name == model_name:\n                self.active_model_name = None\n\n            logger.info(\"Model '%s' unloaded from subprocess\", model_name)\n            return True\n\n        except Exception as exc:\n            logger.error(\"Error unloading model '%s': %s\", model_name, exc)\n            # Clear local state anyway\n            self.models.pop(model_name, None)\n            if self.active_model_name == model_name:\n                self.active_model_name = None\n            return False\n\n    def generate_chat_response(\n        self,\n        messages: list,\n        system_prompt: str = \"\",\n        image = None,\n        temperature: float = 0.7,\n        top_p: float = 0.9,\n        top_k: int = 40,\n        min_p: float = 0.0,\n        max_new_tokens: int = 256,\n        repetition_penalty: float = 1.0,\n        cancel_event = None,\n    ) -> Generator[str, None, None]:\n        \"\"\"Generate response, streaming tokens from subprocess.\"\"\"\n        yield from self._generate_inner(\n            messages = messages,\n            system_prompt = system_prompt,\n            image = image,\n            temperature = temperature,\n            top_p = top_p,\n            top_k = top_k,\n            min_p = min_p,\n            max_new_tokens = max_new_tokens,\n            repetition_penalty = repetition_penalty,\n            cancel_event = cancel_event,\n            use_adapter = None,\n        )\n\n    def generate_with_adapter_control(\n        self,\n        use_adapter: Optional[Union[bool, str]] = None,\n        cancel_event = None,\n        **gen_kwargs,\n    ) -> Generator[str, None, None]:\n        \"\"\"Generate with adapter control, streaming tokens from subprocess.\n\n        Uses the dispatcher path (no _gen_lock) so that compare-mode\n        requests don't block each other. The subprocess naturally\n        serializes them via its sequential command loop.\n        \"\"\"\n        yield from self._generate_dispatched(\n            use_adapter = use_adapter,\n            cancel_event = cancel_event,\n            **gen_kwargs,\n        )\n\n    def _generate_inner(\n        self,\n        messages: list = None,\n        system_prompt: str = \"\",\n        image = None,\n        temperature: float = 0.7,\n        top_p: float = 0.9,\n        top_k: int = 40,\n        min_p: float = 0.0,\n        max_new_tokens: int = 256,\n        repetition_penalty: float = 1.0,\n        cancel_event = None,\n        use_adapter = None,\n    ) -> Generator[str, None, None]:\n        \"\"\"Inner generation logic — sends command to subprocess, yields tokens.\n\n        Serialized by _gen_lock: only one generation runs at a time.\n        This prevents concurrent readers from consuming each other's\n        tokens off the shared resp_queue.\n        \"\"\"\n        if not self._ensure_subprocess_alive():\n            yield \"Error: Inference subprocess is not running\"\n            return\n\n        if not self.active_model_name:\n            yield \"Error: No active model\"\n            return\n\n        # If the dispatcher is running (from a previous compare-mode request),\n        # wait for all dispatched requests to finish, then stop the dispatcher\n        # so we can safely read from resp_queue directly.\n        self._wait_dispatcher_idle()\n\n        # Serialize generation — single GPU, one generation at a time.\n        # Without this lock, two concurrent readers on the same resp_queue\n        # can consume and drop each other's token events.\n        with self._gen_lock:\n            yield from self._generate_locked(\n                messages = messages,\n                system_prompt = system_prompt,\n                image = image,\n                temperature = temperature,\n                top_p = top_p,\n                top_k = top_k,\n                min_p = min_p,\n                max_new_tokens = max_new_tokens,\n                repetition_penalty = repetition_penalty,\n                cancel_event = cancel_event,\n                use_adapter = use_adapter,\n            )\n\n    def _generate_locked(\n        self,\n        messages: list = None,\n        system_prompt: str = \"\",\n        image = None,\n        temperature: float = 0.7,\n        top_p: float = 0.9,\n        top_k: int = 40,\n        min_p: float = 0.0,\n        max_new_tokens: int = 256,\n        repetition_penalty: float = 1.0,\n        cancel_event = None,\n        use_adapter = None,\n    ) -> Generator[str, None, None]:\n        \"\"\"Actual generation logic — must be called under _gen_lock.\"\"\"\n        request_id = str(uuid.uuid4())\n\n        # Convert PIL Image to base64 if needed\n        image_b64 = None\n        if image is not None:\n            image_b64 = self._pil_to_base64(image)\n\n        cmd = {\n            \"type\": \"generate\",\n            \"request_id\": request_id,\n            \"messages\": messages or [],\n            \"system_prompt\": system_prompt,\n            \"image_base64\": image_b64,\n            \"temperature\": temperature,\n            \"top_p\": top_p,\n            \"top_k\": top_k,\n            \"min_p\": min_p,\n            \"max_new_tokens\": max_new_tokens,\n            \"repetition_penalty\": repetition_penalty,\n        }\n\n        if use_adapter is not None:\n            cmd[\"use_adapter\"] = use_adapter\n\n        try:\n            self._send_cmd(cmd)\n        except RuntimeError as exc:\n            yield f\"Error: {exc}\"\n            return\n\n        # Yield tokens from response queue — we are the only reader\n        # because _gen_lock is held.\n        while True:\n            resp = self._read_resp(timeout = 30.0)\n\n            if resp is None:\n                # Check subprocess health\n                if not self._ensure_subprocess_alive():\n                    yield \"Error: Inference subprocess crashed during generation\"\n                    return\n                continue\n\n            rtype = resp.get(\"type\", \"\")\n\n            # Status messages — skip\n            if rtype == \"status\":\n                continue\n\n            # Error without request_id = subprocess-level error\n            resp_rid = resp.get(\"request_id\")\n            if rtype == \"error\" and not resp_rid:\n                yield f\"Error: {resp.get('error', 'Unknown error')}\"\n                return\n\n            if rtype == \"token\":\n                # Check cancel from route (e.g. SSE connection closed)\n                if cancel_event is not None and cancel_event.is_set():\n                    self._cancel_generation()\n                    # Wait for the subprocess to acknowledge cancellation\n                    # (gen_done/gen_error) so stale events don't leak into\n                    # the next generation request.\n                    self._drain_until_gen_done(timeout = 5.0)\n                    return\n                yield resp.get(\"text\", \"\")\n\n            elif rtype == \"gen_done\":\n                return\n\n            elif rtype == \"gen_error\":\n                yield f\"Error: {resp.get('error', 'Unknown error')}\"\n                return\n\n    def reset_generation_state(self):\n        \"\"\"Cancel any ongoing generation and reset state.\"\"\"\n        self._cancel_generation()\n        if not self._ensure_subprocess_alive():\n            return\n        try:\n            self._send_cmd({\"type\": \"reset\"})\n        except RuntimeError:\n            pass\n\n    # ------------------------------------------------------------------\n    # Audio generation — TTS, ASR, audio input\n    # ------------------------------------------------------------------\n\n    def generate_audio_response(\n        self,\n        text: str,\n        temperature: float = 0.6,\n        top_p: float = 0.95,\n        top_k: int = 50,\n        min_p: float = 0.0,\n        max_new_tokens: int = 2048,\n        repetition_penalty: float = 1.0,\n        use_adapter: Optional[Union[bool, str]] = None,\n    ) -> Tuple[bytes, int]:\n        \"\"\"Generate TTS audio. Returns (wav_bytes, sample_rate).\n\n        Blocking — sends command and waits for the complete audio response.\n        \"\"\"\n        if not self._ensure_subprocess_alive():\n            raise RuntimeError(\"Inference subprocess is not running\")\n        if not self.active_model_name:\n            raise RuntimeError(\"No active model\")\n\n        import uuid\n\n        request_id = str(uuid.uuid4())\n\n        cmd = {\n            \"type\": \"generate_audio\",\n            \"request_id\": request_id,\n            \"text\": text,\n            \"temperature\": temperature,\n            \"top_p\": top_p,\n            \"top_k\": top_k,\n            \"min_p\": min_p,\n            \"max_new_tokens\": max_new_tokens,\n            \"repetition_penalty\": repetition_penalty,\n        }\n        if use_adapter is not None:\n            cmd[\"use_adapter\"] = use_adapter\n\n        self._send_cmd(cmd)\n\n        # Wait for audio_done or audio_error\n        deadline = time.monotonic() + 120.0\n        while time.monotonic() < deadline:\n            remaining = max(0.1, deadline - time.monotonic())\n            resp = self._read_resp(timeout = min(remaining, 1.0))\n\n            if resp is None:\n                if not self._ensure_subprocess_alive():\n                    raise RuntimeError(\n                        \"Inference subprocess crashed during audio generation\"\n                    )\n                continue\n\n            rtype = resp.get(\"type\", \"\")\n\n            if rtype == \"audio_done\":\n                wav_bytes = base64.b64decode(resp[\"wav_base64\"])\n                sample_rate = resp[\"sample_rate\"]\n                return wav_bytes, sample_rate\n\n            if rtype == \"audio_error\":\n                raise RuntimeError(resp.get(\"error\", \"Audio generation failed\"))\n\n            if rtype == \"error\":\n                raise RuntimeError(resp.get(\"error\", \"Unknown error\"))\n\n            if rtype == \"status\":\n                continue\n\n        raise RuntimeError(\"Timeout waiting for audio generation (120s)\")\n\n    def generate_whisper_response(\n        self,\n        audio_array,\n        cancel_event = None,\n    ) -> Generator[str, None, None]:\n        \"\"\"Whisper ASR — sends audio to subprocess, yields text.\"\"\"\n        yield from self._generate_audio_input_inner(\n            audio_array = audio_array,\n            audio_type = \"whisper\",\n            messages = [],\n            system_prompt = \"\",\n            cancel_event = cancel_event,\n        )\n\n    def generate_audio_input_response(\n        self,\n        messages,\n        system_prompt,\n        audio_array,\n        temperature: float = 0.7,\n        top_p: float = 0.9,\n        top_k: int = 40,\n        min_p: float = 0.0,\n        max_new_tokens: int = 512,\n        repetition_penalty: float = 1.0,\n        cancel_event = None,\n    ) -> Generator[str, None, None]:\n        \"\"\"Audio input generation (e.g. Gemma 3n) — streams text tokens.\"\"\"\n        yield from self._generate_audio_input_inner(\n            audio_array = audio_array,\n            audio_type = None,  # worker will use generate_audio_input_response\n            messages = messages,\n            system_prompt = system_prompt,\n            temperature = temperature,\n            top_p = top_p,\n            top_k = top_k,\n            min_p = min_p,\n            max_new_tokens = max_new_tokens,\n            repetition_penalty = repetition_penalty,\n            cancel_event = cancel_event,\n        )\n\n    def _generate_audio_input_inner(\n        self,\n        audio_array,\n        audio_type: Optional[str] = None,\n        messages: list = None,\n        system_prompt: str = \"\",\n        temperature: float = 0.7,\n        top_p: float = 0.9,\n        top_k: int = 40,\n        min_p: float = 0.0,\n        max_new_tokens: int = 512,\n        repetition_penalty: float = 1.0,\n        cancel_event = None,\n    ) -> Generator[str, None, None]:\n        \"\"\"Shared inner logic for audio input generation (Whisper + ASR).\"\"\"\n        if not self._ensure_subprocess_alive():\n            yield \"Error: Inference subprocess is not running\"\n            return\n        if not self.active_model_name:\n            yield \"Error: No active model\"\n            return\n\n        with self._gen_lock:\n            import uuid\n\n            request_id = str(uuid.uuid4())\n\n            # Convert numpy array to list for mp.Queue serialization\n            audio_data = (\n                audio_array.tolist()\n                if hasattr(audio_array, \"tolist\")\n                else list(audio_array)\n            )\n\n            cmd = {\n                \"type\": \"generate_audio_input\",\n                \"request_id\": request_id,\n                \"audio_data\": audio_data,\n                \"audio_type\": audio_type,\n                \"messages\": messages or [],\n                \"system_prompt\": system_prompt,\n                \"temperature\": temperature,\n                \"top_p\": top_p,\n                \"top_k\": top_k,\n                \"min_p\": min_p,\n                \"max_new_tokens\": max_new_tokens,\n                \"repetition_penalty\": repetition_penalty,\n            }\n\n            try:\n                self._send_cmd(cmd)\n            except RuntimeError as exc:\n                yield f\"Error: {exc}\"\n                return\n\n            # Yield tokens — same pattern as _generate_locked\n            while True:\n                resp = self._read_resp(timeout = 30.0)\n\n                if resp is None:\n                    if not self._ensure_subprocess_alive():\n                        yield \"Error: Inference subprocess crashed during audio input generation\"\n                        return\n                    continue\n\n                rtype = resp.get(\"type\", \"\")\n\n                if rtype == \"status\":\n                    continue\n\n                if rtype == \"error\" and not resp.get(\"request_id\"):\n                    yield f\"Error: {resp.get('error', 'Unknown error')}\"\n                    return\n\n                if rtype == \"token\":\n                    if cancel_event is not None and cancel_event.is_set():\n                        self._cancel_generation()\n                        self._drain_until_gen_done(timeout = 5.0)\n                        return\n                    yield resp.get(\"text\", \"\")\n\n                elif rtype == \"gen_done\":\n                    return\n\n                elif rtype == \"gen_error\":\n                    yield f\"Error: {resp.get('error', 'Unknown error')}\"\n                    return\n\n    # ------------------------------------------------------------------\n    # Local helpers (no subprocess needed)\n    # ------------------------------------------------------------------\n\n    def resize_image(self, img, max_size: int = 800):\n        \"\"\"Resize image while maintaining aspect ratio.\n        No ML imports needed — runs locally in parent process.\n        \"\"\"\n        if img is None:\n            return None\n        if img.size[0] > max_size or img.size[1] > max_size:\n            from PIL import Image\n\n            ratio = min(max_size / img.size[0], max_size / img.size[1])\n            new_size = (int(img.size[0] * ratio), int(img.size[1] * ratio))\n            return img.resize(new_size, Image.Resampling.LANCZOS)\n        return img\n\n    @staticmethod\n    def _pil_to_base64(img) -> str:\n        \"\"\"Convert a PIL Image to base64 string for IPC.\"\"\"\n        buf = BytesIO()\n        img.save(buf, format = \"PNG\")\n        return base64.b64encode(buf.getvalue()).decode(\"ascii\")\n\n    def get_current_model(self) -> Optional[str]:\n        \"\"\"Get currently active model name.\"\"\"\n        return self.active_model_name\n\n    def is_model_loading(self) -> bool:\n        \"\"\"Check if any model is currently loading.\"\"\"\n        return len(self.loading_models) > 0\n\n    def get_loading_model(self) -> Optional[str]:\n        \"\"\"Get name of currently loading model.\"\"\"\n        return next(iter(self.loading_models)) if self.loading_models else None\n\n    def check_vision_model_compatibility(self) -> bool:\n        \"\"\"Check if current model supports vision.\"\"\"\n        if self.active_model_name and self.active_model_name in self.models:\n            return self.models[self.active_model_name].get(\"is_vision\", False)\n        return False\n\n\n# ========== GLOBAL INSTANCE ==========\n_inference_backend = None\n\n\ndef get_inference_backend() -> InferenceOrchestrator:\n    \"\"\"Get global inference backend instance (orchestrator).\"\"\"\n    global _inference_backend\n    if _inference_backend is None:\n        _inference_backend = InferenceOrchestrator()\n    return _inference_backend\n"
  },
  {
    "path": "studio/backend/core/inference/tools.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nTool definitions and executors for LLM tool calling.\n\nSupports web search (DuckDuckGo), Python code execution, and terminal commands.\n\"\"\"\n\nimport ast\nimport os\n\nos.environ[\"UNSLOTH_IS_PRESENT\"] = \"1\"\n\nimport subprocess\nimport sys\nimport tempfile\nimport threading\n\nfrom loggers import get_logger\n\nlogger = get_logger(__name__)\n\n_EXEC_TIMEOUT = 300  # 5 minutes\n_MAX_OUTPUT_CHARS = 8000  # truncate long output\n_BASH_BLOCKED_WORDS = {\"rm\", \"sudo\", \"dd\", \"chmod\", \"mkfs\", \"shutdown\", \"reboot\"}\n\n# Per-session working directories so each chat thread gets its own sandbox.\n# Falls back to a shared ~/studio_sandbox/ for API callers without a session_id.\n_workdirs: dict[str, str] = {}\n\n\ndef _get_workdir(session_id: str | None = None) -> str:\n    \"\"\"Return (and lazily create) a persistent working directory for tool execution.\"\"\"\n    global _workdirs\n    key = session_id or \"_default\"\n    if key not in _workdirs or not os.path.isdir(_workdirs[key]):\n        home = os.path.expanduser(\"~\")\n        sandbox_root = os.path.join(home, \"studio_sandbox\")\n        if session_id:\n            # Sanitize: strip path separators and parent-dir references\n            safe_id = os.path.basename(session_id.replace(\"..\", \"\"))\n            if not safe_id:\n                safe_id = \"_invalid\"\n            workdir = os.path.join(sandbox_root, safe_id)\n            # Verify resolved path stays under sandbox root\n            if not os.path.realpath(workdir).startswith(os.path.realpath(sandbox_root)):\n                workdir = os.path.join(sandbox_root, \"_invalid\")\n        else:\n            workdir = sandbox_root\n        os.makedirs(workdir, exist_ok = True)\n        _workdirs[key] = workdir\n    return _workdirs[key]\n\n\nWEB_SEARCH_TOOL = {\n    \"type\": \"function\",\n    \"function\": {\n        \"name\": \"web_search\",\n        \"description\": \"Search the web for current information, recent events, or facts you are uncertain about.\",\n        \"parameters\": {\n            \"type\": \"object\",\n            \"properties\": {\n                \"query\": {\n                    \"type\": \"string\",\n                    \"description\": \"The search query\",\n                }\n            },\n            \"required\": [\"query\"],\n        },\n    },\n}\n\nPYTHON_TOOL = {\n    \"type\": \"function\",\n    \"function\": {\n        \"name\": \"python\",\n        \"description\": \"Execute Python code in a sandbox and return stdout/stderr.\",\n        \"parameters\": {\n            \"type\": \"object\",\n            \"properties\": {\n                \"code\": {\n                    \"type\": \"string\",\n                    \"description\": \"The Python code to run\",\n                }\n            },\n            \"required\": [\"code\"],\n        },\n    },\n}\n\nTERMINAL_TOOL = {\n    \"type\": \"function\",\n    \"function\": {\n        \"name\": \"terminal\",\n        \"description\": \"Execute a terminal command and return stdout/stderr.\",\n        \"parameters\": {\n            \"type\": \"object\",\n            \"properties\": {\n                \"command\": {\n                    \"type\": \"string\",\n                    \"description\": \"The command to run\",\n                }\n            },\n            \"required\": [\"command\"],\n        },\n    },\n}\n\nALL_TOOLS = [WEB_SEARCH_TOOL, PYTHON_TOOL, TERMINAL_TOOL]\n\n\n_TIMEOUT_UNSET = object()\n\n\ndef execute_tool(\n    name: str,\n    arguments: dict,\n    cancel_event = None,\n    timeout: int | None = _TIMEOUT_UNSET,\n    session_id: str | None = None,\n) -> str:\n    \"\"\"Execute a tool by name with the given arguments. Returns result as a string.\n\n    ``timeout``: int sets per-call limit in seconds, ``None`` means no limit,\n    unset (default) uses ``_EXEC_TIMEOUT`` (300 s).\n    ``session_id``: optional thread/session ID for per-conversation sandbox isolation.\n    \"\"\"\n    logger.info(\n        f\"execute_tool: name={name}, session_id={session_id}, timeout={timeout}\"\n    )\n    effective_timeout = _EXEC_TIMEOUT if timeout is _TIMEOUT_UNSET else timeout\n    if name == \"web_search\":\n        return _web_search(arguments.get(\"query\", \"\"), timeout = effective_timeout)\n    if name == \"python\":\n        return _python_exec(\n            arguments.get(\"code\", \"\"), cancel_event, effective_timeout, session_id\n        )\n    if name == \"terminal\":\n        return _bash_exec(\n            arguments.get(\"command\", \"\"), cancel_event, effective_timeout, session_id\n        )\n    return f\"Unknown tool: {name}\"\n\n\ndef _web_search(query: str, max_results: int = 5, timeout: int = _EXEC_TIMEOUT) -> str:\n    \"\"\"Search the web using DuckDuckGo and return formatted results.\"\"\"\n    if not query.strip():\n        return \"No query provided.\"\n    try:\n        from ddgs import DDGS\n\n        results = DDGS(timeout = timeout).text(query, max_results = max_results)\n        if not results:\n            return \"No results found.\"\n        parts = []\n        for r in results:\n            parts.append(\n                f\"Title: {r.get('title', '')}\\n\"\n                f\"URL: {r.get('href', '')}\\n\"\n                f\"Snippet: {r.get('body', '')}\"\n            )\n        return \"\\n\\n---\\n\\n\".join(parts)\n    except Exception as e:\n        return f\"Search failed: {e}\"\n\n\ndef _check_signal_escape_patterns(code: str):\n    \"\"\"\n    Check if code contains patterns that could escape signal-based timeouts.\n\n    Vendored from unsloth_zoo.rl_environments to avoid importing unsloth_zoo\n    (which requires GPU drivers and fails on Mac/Apple Silicon).\n\n    Returns (safe: bool, details: dict)\n    \"\"\"\n    try:\n        tree = ast.parse(code)\n    except SyntaxError as e:\n        return False, {\n            \"error\": f\"SyntaxError: {e}\",\n            \"signal_tampering\": [],\n            \"exception_catching\": [],\n            \"warnings\": [],\n        }\n\n    signal_tampering = []\n    exception_catching = []\n    warnings = []\n\n    def _ast_name_matches(node, names):\n        if isinstance(node, ast.Name):\n            return node.id in names\n        elif isinstance(node, ast.Attribute):\n            full_name = []\n            current = node\n            while isinstance(current, ast.Attribute):\n                full_name.append(current.attr)\n                current = current.value\n            if isinstance(current, ast.Name):\n                full_name.append(current.id)\n            full_name = \".\".join(reversed(full_name))\n            return full_name in names\n        return False\n\n    class SignalEscapeVisitor(ast.NodeVisitor):\n        def __init__(self):\n            self.imports_signal = False\n            self.signal_aliases = {\"signal\"}\n            self.loop_depth = 0\n\n        def visit_Import(self, node):\n            for alias in node.names:\n                if alias.name == \"signal\":\n                    self.imports_signal = True\n                    if alias.asname:\n                        self.signal_aliases.add(alias.asname)\n            self.generic_visit(node)\n\n        def visit_ImportFrom(self, node):\n            if node.module == \"signal\":\n                self.imports_signal = True\n                for alias in node.names:\n                    if alias.name in (\n                        \"signal\",\n                        \"SIGALRM\",\n                        \"SIG_IGN\",\n                        \"setitimer\",\n                        \"ITIMER_REAL\",\n                        \"pthread_sigmask\",\n                        \"SIG_BLOCK\",\n                        \"alarm\",\n                    ):\n                        self.signal_aliases.add(alias.asname or alias.name)\n            self.generic_visit(node)\n\n        def visit_While(self, node):\n            self.loop_depth += 1\n            self.generic_visit(node)\n            self.loop_depth -= 1\n\n        def visit_For(self, node):\n            self.loop_depth += 1\n            self.generic_visit(node)\n            self.loop_depth -= 1\n\n        def visit_Call(self, node):\n            func = node.func\n            func_name = None\n            if isinstance(func, ast.Attribute):\n                if isinstance(func.value, ast.Name):\n                    if func.value.id in self.signal_aliases:\n                        func_name = f\"signal.{func.attr}\"\n            elif isinstance(func, ast.Name):\n                if func.id in (\"signal\", \"setitimer\", \"alarm\", \"pthread_sigmask\"):\n                    func_name = func.id\n\n            if func_name:\n                if func_name in (\"signal.signal\", \"signal\"):\n                    if len(node.args) >= 1:\n                        if _ast_name_matches(\n                            node.args[0], (\"SIGALRM\", \"signal.SIGALRM\")\n                        ):\n                            signal_tampering.append(\n                                {\n                                    \"type\": \"signal_handler_override\",\n                                    \"line\": node.lineno,\n                                    \"description\": \"Overrides SIGALRM handler\",\n                                }\n                            )\n                elif func_name in (\"signal.setitimer\", \"setitimer\"):\n                    if len(node.args) >= 1:\n                        if _ast_name_matches(\n                            node.args[0], (\"ITIMER_REAL\", \"signal.ITIMER_REAL\")\n                        ):\n                            signal_tampering.append(\n                                {\n                                    \"type\": \"timer_manipulation\",\n                                    \"line\": node.lineno,\n                                    \"description\": \"Manipulates ITIMER_REAL timer\",\n                                }\n                            )\n                elif func_name in (\"signal.alarm\", \"alarm\"):\n                    signal_tampering.append(\n                        {\n                            \"type\": \"alarm_manipulation\",\n                            \"line\": node.lineno,\n                            \"description\": \"Manipulates alarm timer\",\n                        }\n                    )\n                elif func_name in (\"signal.pthread_sigmask\", \"pthread_sigmask\"):\n                    signal_tampering.append(\n                        {\n                            \"type\": \"signal_mask\",\n                            \"line\": node.lineno,\n                            \"description\": \"Modifies signal mask (may block SIGALRM)\",\n                        }\n                    )\n            self.generic_visit(node)\n\n        def visit_ExceptHandler(self, node):\n            if self.loop_depth == 0:\n                self.generic_visit(node)\n                return\n            if node.type is None:\n                exception_catching.append(\n                    {\n                        \"type\": \"bare_except_in_loop\",\n                        \"line\": node.lineno,\n                        \"description\": \"Bare except in loop catches TimeoutError and continues looping\",\n                    }\n                )\n            elif isinstance(node.type, ast.Name):\n                if node.type.id in (\"TimeoutError\", \"BaseException\", \"Exception\"):\n                    exception_catching.append(\n                        {\n                            \"type\": f\"catches_{node.type.id}_in_loop\",\n                            \"line\": node.lineno,\n                            \"description\": f\"Catches {node.type.id} in loop - may suppress timeout and continue\",\n                        }\n                    )\n            elif isinstance(node.type, ast.Tuple):\n                for elt in node.type.elts:\n                    if isinstance(elt, ast.Name):\n                        if elt.id in (\"TimeoutError\", \"BaseException\", \"Exception\"):\n                            exception_catching.append(\n                                {\n                                    \"type\": f\"catches_{elt.id}_in_loop\",\n                                    \"line\": node.lineno,\n                                    \"description\": f\"Catches {elt.id} in loop - may suppress timeout and continue\",\n                                }\n                            )\n            self.generic_visit(node)\n\n    visitor = SignalEscapeVisitor()\n    visitor.visit(tree)\n\n    if visitor.imports_signal and not signal_tampering:\n        warnings.append(\"Code imports 'signal' module - review manually for safety\")\n\n    is_safe = len(signal_tampering) == 0 and len(exception_catching) == 0\n    return is_safe, {\n        \"signal_tampering\": signal_tampering,\n        \"exception_catching\": exception_catching,\n        \"warnings\": warnings,\n    }\n\n\ndef _check_code_safety(code: str) -> str | None:\n    \"\"\"Validate code safety via static analysis.\n\n    Returns an error message string if the code is unsafe, or None if OK.\n    \"\"\"\n    safe, info = _check_signal_escape_patterns(code)\n    if not safe:\n        reasons = [\n            item.get(\"description\", \"\") for item in info.get(\"signal_tampering\", [])\n        ]\n        return (\n            f\"Error: unsafe code detected ({'; '.join(reasons)}). \"\n            f\"Please remove signal manipulation from your code.\"\n        )\n\n    return None\n\n\ndef _cancel_watcher(proc, cancel_event, poll_interval = 0.2):\n    \"\"\"Daemon thread that kills a process when cancel_event is set.\"\"\"\n    while proc.poll() is None:\n        if cancel_event is not None and cancel_event.is_set():\n            proc.kill()\n            return\n        cancel_event.wait(poll_interval) if cancel_event else None\n\n\ndef _truncate(text: str, limit: int = _MAX_OUTPUT_CHARS) -> str:\n    if len(text) > limit:\n        return text[:limit] + f\"\\n\\n... (truncated, {len(text)} chars total)\"\n    return text\n\n\ndef _python_exec(\n    code: str,\n    cancel_event = None,\n    timeout: int = _EXEC_TIMEOUT,\n    session_id: str | None = None,\n) -> str:\n    \"\"\"Execute Python code in a subprocess sandbox.\"\"\"\n    if not code or not code.strip():\n        return \"No code provided.\"\n\n    # Validate imports and code safety\n    error = _check_code_safety(code)\n    if error:\n        return error\n\n    tmp_path = None\n    workdir = _get_workdir(session_id)\n    try:\n        fd, tmp_path = tempfile.mkstemp(\n            suffix = \".py\", prefix = \"studio_exec_\", dir = workdir\n        )\n        with os.fdopen(fd, \"w\") as f:\n            f.write(code)\n\n        proc = subprocess.Popen(\n            [sys.executable, tmp_path],\n            stdout = subprocess.PIPE,\n            stderr = subprocess.STDOUT,\n            text = True,\n            cwd = workdir,\n        )\n\n        # Spawn cancel watcher if we have a cancel event\n        if cancel_event is not None:\n            watcher = threading.Thread(\n                target = _cancel_watcher, args = (proc, cancel_event), daemon = True\n            )\n            watcher.start()\n\n        try:\n            output, _ = proc.communicate(timeout = timeout)\n        except subprocess.TimeoutExpired:\n            proc.kill()\n            proc.communicate()\n            return _truncate(f\"Execution timed out after {timeout} seconds.\")\n\n        if cancel_event is not None and cancel_event.is_set():\n            return \"Execution cancelled.\"\n\n        result = output or \"\"\n        if proc.returncode != 0:\n            result = f\"Exit code {proc.returncode}:\\n{result}\"\n        return _truncate(result) if result.strip() else \"(no output)\"\n\n    except Exception as e:\n        return f\"Execution error: {e}\"\n    finally:\n        if tmp_path and os.path.exists(tmp_path):\n            try:\n                os.unlink(tmp_path)\n            except OSError:\n                pass\n\n\ndef _bash_exec(\n    command: str,\n    cancel_event = None,\n    timeout: int = _EXEC_TIMEOUT,\n    session_id: str | None = None,\n) -> str:\n    \"\"\"Execute a bash command in a subprocess sandbox.\"\"\"\n    if not command or not command.strip():\n        return \"No command provided.\"\n\n    # Block dangerous commands\n    tokens = set(command.lower().split())\n    blocked = tokens & _BASH_BLOCKED_WORDS\n    if blocked:\n        return f\"Blocked command(s) for safety: {', '.join(sorted(blocked))}\"\n\n    try:\n        workdir = _get_workdir(session_id)\n        proc = subprocess.Popen(\n            [\"bash\", \"-c\", command],\n            stdout = subprocess.PIPE,\n            stderr = subprocess.STDOUT,\n            text = True,\n            cwd = workdir,\n        )\n\n        if cancel_event is not None:\n            watcher = threading.Thread(\n                target = _cancel_watcher, args = (proc, cancel_event), daemon = True\n            )\n            watcher.start()\n\n        try:\n            output, _ = proc.communicate(timeout = timeout)\n        except subprocess.TimeoutExpired:\n            proc.kill()\n            proc.communicate()\n            return _truncate(f\"Execution timed out after {timeout} seconds.\")\n\n        if cancel_event is not None and cancel_event.is_set():\n            return \"Execution cancelled.\"\n\n        result = output or \"\"\n        if proc.returncode != 0:\n            result = f\"Exit code {proc.returncode}:\\n{result}\"\n        return _truncate(result) if result.strip() else \"(no output)\"\n\n    except Exception as e:\n        return f\"Execution error: {e}\"\n"
  },
  {
    "path": "studio/backend/core/inference/worker.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nInference subprocess entry point.\n\nEach inference session runs in a persistent subprocess (mp.get_context(\"spawn\")).\nThis gives us a clean Python interpreter with no stale module state —\nsolving the transformers version-switching problem completely.\n\nThe subprocess stays alive while a model is loaded, accepting commands\n(generate, load, unload) via mp.Queue. It exits on shutdown or unload.\n\nPattern follows core/training/worker.py.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport base64\nimport structlog\nfrom loggers import get_logger\nimport os\nimport queue as _queue\nimport sys\nimport time\nimport traceback\nfrom io import BytesIO\nfrom pathlib import Path\nfrom typing import Any\n\nlogger = get_logger(__name__)\n\n\ndef _activate_transformers_version(model_name: str) -> None:\n    \"\"\"Activate the correct transformers version BEFORE any ML imports.\n\n    If the model needs transformers 5.x, prepend the pre-installed .venv_t5/\n    directory to sys.path. Otherwise do nothing (default 4.57.x in .venv/).\n    \"\"\"\n    # Ensure backend is on path for utils imports\n    backend_path = str(Path(__file__).resolve().parent.parent.parent)\n    if backend_path not in sys.path:\n        sys.path.insert(0, backend_path)\n\n    from utils.transformers_version import (\n        needs_transformers_5,\n        _resolve_base_model,\n        _ensure_venv_t5_exists,\n        _VENV_T5_DIR,\n    )\n\n    resolved = _resolve_base_model(model_name)\n    if needs_transformers_5(resolved):\n        if not _ensure_venv_t5_exists():\n            raise RuntimeError(\n                f\"Cannot activate transformers 5.x: .venv_t5 missing at {_VENV_T5_DIR}\"\n            )\n        if _VENV_T5_DIR not in sys.path:\n            sys.path.insert(0, _VENV_T5_DIR)\n        logger.info(\"Activated transformers 5.x from %s\", _VENV_T5_DIR)\n        # Propagate to child subprocesses (e.g. GGUF converter)\n        _pp = os.environ.get(\"PYTHONPATH\", \"\")\n        os.environ[\"PYTHONPATH\"] = _VENV_T5_DIR + (os.pathsep + _pp if _pp else \"\")\n    else:\n        logger.info(\"Using default transformers (4.57.x) for %s\", model_name)\n\n\ndef _decode_image(image_base64: str):\n    \"\"\"Decode base64 string to PIL.Image.\"\"\"\n    from PIL import Image\n\n    image_data = base64.b64decode(image_base64)\n    return Image.open(BytesIO(image_data))\n\n\ndef _resize_image(img, max_size: int = 800):\n    \"\"\"Resize image while maintaining aspect ratio.\"\"\"\n    if img is None:\n        return None\n    if img.size[0] > max_size or img.size[1] > max_size:\n        from PIL import Image\n\n        ratio = min(max_size / img.size[0], max_size / img.size[1])\n        new_size = (int(img.size[0] * ratio), int(img.size[1] * ratio))\n        return img.resize(new_size, Image.Resampling.LANCZOS)\n    return img\n\n\ndef _send_response(resp_queue: Any, response: dict) -> None:\n    \"\"\"Send a response to the parent process.\"\"\"\n    try:\n        resp_queue.put(response)\n    except (OSError, ValueError) as exc:\n        logger.error(\"Failed to send response: %s\", exc)\n\n\ndef _build_model_config(config: dict):\n    \"\"\"Build a ModelConfig from the config dict.\"\"\"\n    from utils.models import ModelConfig\n\n    model_name = config[\"model_name\"]\n    hf_token = config.get(\"hf_token\")\n    hf_token = hf_token if hf_token and hf_token.strip() else None\n    gguf_variant = config.get(\"gguf_variant\")\n\n    mc = ModelConfig.from_identifier(\n        model_id = model_name,\n        hf_token = hf_token,\n        gguf_variant = gguf_variant,\n    )\n    if not mc:\n        raise ValueError(f\"Invalid model identifier: {model_name}\")\n    return mc\n\n\ndef _handle_load(backend, config: dict, resp_queue: Any) -> None:\n    \"\"\"Handle a load command: load a model into the backend.\"\"\"\n    try:\n        mc = _build_model_config(config)\n\n        hf_token = config.get(\"hf_token\")\n        hf_token = hf_token if hf_token and hf_token.strip() else None\n\n        # Auto-detect quantization for LoRA adapters\n        load_in_4bit = config.get(\"load_in_4bit\", True)\n        if mc.is_lora and mc.path:\n            import json\n            from pathlib import Path\n\n            adapter_cfg_path = Path(mc.path) / \"adapter_config.json\"\n            if adapter_cfg_path.exists():\n                try:\n                    with open(adapter_cfg_path) as f:\n                        adapter_cfg = json.load(f)\n                    training_method = adapter_cfg.get(\"unsloth_training_method\")\n                    if training_method == \"lora\" and load_in_4bit:\n                        logger.info(\n                            \"adapter_config.json says lora — setting load_in_4bit=False\"\n                        )\n                        load_in_4bit = False\n                    elif training_method == \"qlora\" and not load_in_4bit:\n                        logger.info(\n                            \"adapter_config.json says qlora — setting load_in_4bit=True\"\n                        )\n                        load_in_4bit = True\n                    elif not training_method:\n                        if (\n                            mc.base_model\n                            and \"-bnb-4bit\" not in mc.base_model.lower()\n                            and load_in_4bit\n                        ):\n                            logger.info(\n                                \"No training method, base model has no -bnb-4bit — setting load_in_4bit=False\"\n                            )\n                            load_in_4bit = False\n                except Exception as e:\n                    logger.warning(\"Could not read adapter_config.json: %s\", e)\n\n        success = backend.load_model(\n            config = mc,\n            max_seq_length = config.get(\"max_seq_length\", 2048),\n            load_in_4bit = load_in_4bit,\n            hf_token = hf_token,\n            trust_remote_code = config.get(\"trust_remote_code\", False),\n        )\n\n        if success:\n            # Build model_info for the parent to mirror\n            model_info = {\n                \"identifier\": mc.identifier,\n                \"display_name\": mc.display_name,\n                \"is_vision\": mc.is_vision,\n                \"is_lora\": mc.is_lora,\n                \"is_gguf\": False,\n                \"is_audio\": getattr(mc, \"is_audio\", False),\n                \"audio_type\": getattr(mc, \"audio_type\", None),\n                \"has_audio_input\": getattr(mc, \"has_audio_input\", False),\n            }\n            _send_response(\n                resp_queue,\n                {\n                    \"type\": \"loaded\",\n                    \"success\": True,\n                    \"model_info\": model_info,\n                    \"ts\": time.time(),\n                },\n            )\n        else:\n            _send_response(\n                resp_queue,\n                {\n                    \"type\": \"loaded\",\n                    \"success\": False,\n                    \"error\": \"Failed to load model\",\n                    \"ts\": time.time(),\n                },\n            )\n\n    except Exception as exc:\n        _send_response(\n            resp_queue,\n            {\n                \"type\": \"loaded\",\n                \"success\": False,\n                \"error\": str(exc),\n                \"stack\": traceback.format_exc(limit = 20),\n                \"ts\": time.time(),\n            },\n        )\n\n\ndef _handle_generate(\n    backend,\n    cmd: dict,\n    resp_queue: Any,\n    cancel_event,\n) -> None:\n    \"\"\"Handle a generate command: stream tokens back via resp_queue.\n\n    cancel_event is an mp.Event shared with the parent process.\n    The parent can set it at any time (e.g. user stops generation,\n    or user loads a new model while generating) and generation\n    stops within 1-2 tokens.\n    \"\"\"\n    request_id = cmd.get(\"request_id\", \"\")\n\n    try:\n        # Decode image if provided\n        image = None\n        image_b64 = cmd.get(\"image_base64\")\n        if image_b64:\n            image = _decode_image(image_b64)\n            image = _resize_image(image)\n\n        # Build generation kwargs\n        gen_kwargs = {\n            \"messages\": cmd[\"messages\"],\n            \"system_prompt\": cmd.get(\"system_prompt\", \"\"),\n            \"image\": image,\n            \"temperature\": cmd.get(\"temperature\", 0.7),\n            \"top_p\": cmd.get(\"top_p\", 0.9),\n            \"top_k\": cmd.get(\"top_k\", 40),\n            \"min_p\": cmd.get(\"min_p\", 0.0),\n            \"max_new_tokens\": cmd.get(\"max_new_tokens\", 256),\n            \"repetition_penalty\": cmd.get(\"repetition_penalty\", 1.0),\n            \"cancel_event\": cancel_event,\n        }\n\n        # Choose generation path\n        use_adapter = cmd.get(\"use_adapter\")\n        if use_adapter is not None:\n            generator = backend.generate_with_adapter_control(\n                use_adapter = use_adapter,\n                **gen_kwargs,\n            )\n        else:\n            generator = backend.generate_chat_response(**gen_kwargs)\n\n        logger.info(\"Starting text generation for request_id=%s\", request_id)\n\n        for cumulative_text in generator:\n            # cancel_event is an mp.Event — checked instantly, no queue polling\n            if cancel_event.is_set():\n                logger.info(\"Generation cancelled for request %s\", request_id)\n                break\n\n            _send_response(\n                resp_queue,\n                {\n                    \"type\": \"token\",\n                    \"request_id\": request_id,\n                    \"text\": cumulative_text,\n                    \"ts\": time.time(),\n                },\n            )\n\n        _send_response(\n            resp_queue,\n            {\n                \"type\": \"gen_done\",\n                \"request_id\": request_id,\n                \"ts\": time.time(),\n            },\n        )\n        logger.info(\"Finished text generation for request_id=%s\", request_id)\n\n    except Exception as exc:\n        logger.error(\"Generation error: %s\", exc, exc_info = True)\n        _send_response(\n            resp_queue,\n            {\n                \"type\": \"gen_error\",\n                \"request_id\": request_id,\n                \"error\": str(exc),\n                \"stack\": traceback.format_exc(limit = 20),\n                \"ts\": time.time(),\n            },\n        )\n\n\ndef _handle_generate_audio(\n    backend,\n    cmd: dict,\n    resp_queue: Any,\n) -> None:\n    \"\"\"Handle TTS audio generation — returns WAV bytes + sample_rate.\"\"\"\n    request_id = cmd.get(\"request_id\", \"\")\n    try:\n        logger.info(\"Starting audio generation for request_id=%s\", request_id)\n        wav_bytes, sample_rate = backend.generate_audio_response(\n            text = cmd[\"text\"],\n            temperature = cmd.get(\"temperature\", 0.6),\n            top_p = cmd.get(\"top_p\", 0.95),\n            top_k = cmd.get(\"top_k\", 50),\n            min_p = cmd.get(\"min_p\", 0.0),\n            max_new_tokens = cmd.get(\"max_new_tokens\", 2048),\n            repetition_penalty = cmd.get(\"repetition_penalty\", 1.0),\n            use_adapter = cmd.get(\"use_adapter\"),\n        )\n\n        # Send WAV bytes as base64 (bytes can't go through mp.Queue directly)\n        _send_response(\n            resp_queue,\n            {\n                \"type\": \"audio_done\",\n                \"request_id\": request_id,\n                \"wav_base64\": base64.b64encode(wav_bytes).decode(\"ascii\"),\n                \"sample_rate\": sample_rate,\n                \"ts\": time.time(),\n            },\n        )\n        logger.info(\"Finished audio generation for request_id=%s\", request_id)\n\n    except Exception as exc:\n        logger.error(\"Audio generation error: %s\", exc, exc_info = True)\n        _send_response(\n            resp_queue,\n            {\n                \"type\": \"audio_error\",\n                \"request_id\": request_id,\n                \"error\": str(exc),\n                \"stack\": traceback.format_exc(limit = 20),\n                \"ts\": time.time(),\n            },\n        )\n\n\ndef _handle_generate_audio_input(\n    backend,\n    cmd: dict,\n    resp_queue: Any,\n    cancel_event,\n) -> None:\n    \"\"\"Handle audio input generation (ASR/Whisper) — streams text tokens back.\"\"\"\n    request_id = cmd.get(\"request_id\", \"\")\n\n    try:\n        import numpy as np\n\n        # Decode audio array from list (numpy arrays can't go through mp.Queue)\n        audio_array = np.array(cmd[\"audio_data\"], dtype = np.float32)\n\n        audio_type = cmd.get(\"audio_type\")\n\n        if audio_type == \"whisper\":\n            generator = backend.generate_whisper_response(\n                audio_array = audio_array,\n                cancel_event = cancel_event,\n            )\n        else:\n            generator = backend.generate_audio_input_response(\n                messages = cmd.get(\"messages\", []),\n                system_prompt = cmd.get(\"system_prompt\", \"\"),\n                audio_array = audio_array,\n                temperature = cmd.get(\"temperature\", 0.7),\n                top_p = cmd.get(\"top_p\", 0.9),\n                top_k = cmd.get(\"top_k\", 40),\n                min_p = cmd.get(\"min_p\", 0.0),\n                max_new_tokens = cmd.get(\"max_new_tokens\", 512),\n                repetition_penalty = cmd.get(\"repetition_penalty\", 1.0),\n                cancel_event = cancel_event,\n            )\n\n        logger.info(\"Starting audio input generation for request_id=%s\", request_id)\n\n        for text_chunk in generator:\n            if cancel_event.is_set():\n                logger.info(\n                    \"Audio input generation cancelled for request %s\", request_id\n                )\n                break\n\n            _send_response(\n                resp_queue,\n                {\n                    \"type\": \"token\",\n                    \"request_id\": request_id,\n                    \"text\": text_chunk,\n                    \"ts\": time.time(),\n                },\n            )\n\n        _send_response(\n            resp_queue,\n            {\n                \"type\": \"gen_done\",\n                \"request_id\": request_id,\n                \"ts\": time.time(),\n            },\n        )\n        logger.info(\"Finished audio input generation for request_id=%s\", request_id)\n\n    except Exception as exc:\n        logger.error(\"Audio input generation error: %s\", exc, exc_info = True)\n        _send_response(\n            resp_queue,\n            {\n                \"type\": \"gen_error\",\n                \"request_id\": request_id,\n                \"error\": str(exc),\n                \"stack\": traceback.format_exc(limit = 20),\n                \"ts\": time.time(),\n            },\n        )\n\n\ndef _handle_unload(backend, cmd: dict, resp_queue: Any) -> None:\n    \"\"\"Handle an unload command.\"\"\"\n    model_name = cmd.get(\"model_name\", \"\")\n    try:\n        if model_name and model_name in backend.models:\n            backend.unload_model(model_name)\n        elif backend.active_model_name:\n            backend.unload_model(backend.active_model_name)\n\n        _send_response(\n            resp_queue,\n            {\n                \"type\": \"unloaded\",\n                \"model_name\": model_name,\n                \"ts\": time.time(),\n            },\n        )\n    except Exception as exc:\n        logger.error(\"Unload error: %s\", exc)\n        _send_response(\n            resp_queue,\n            {\n                \"type\": \"unloaded\",\n                \"model_name\": model_name,\n                \"error\": str(exc),\n                \"ts\": time.time(),\n            },\n        )\n\n\ndef run_inference_process(\n    *,\n    cmd_queue: Any,\n    resp_queue: Any,\n    cancel_event,\n    config: dict,\n) -> None:\n    \"\"\"Subprocess entrypoint. Persistent — runs command loop until shutdown.\n\n    Args:\n        cmd_queue: mp.Queue for receiving commands from parent.\n        resp_queue: mp.Queue for sending responses to parent.\n        cancel_event: mp.Event shared with parent — set by parent to cancel generation.\n        config: Initial configuration dict with model info.\n    \"\"\"\n    os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n    os.environ[\"PYTHONWARNINGS\"] = (\n        \"ignore\"  # Suppress warnings at C-level before imports\n    )\n\n    import warnings\n    from loggers.config import LogConfig\n\n    if os.getenv(\"ENVIRONMENT_TYPE\", \"production\") == \"production\":\n        warnings.filterwarnings(\"ignore\")\n\n    LogConfig.setup_logging(\n        service_name = \"unsloth-studio-inference-worker\",\n        env = os.getenv(\"ENVIRONMENT_TYPE\", \"production\"),\n    )\n\n    model_name = config[\"model_name\"]\n\n    # ── 1. Activate correct transformers version BEFORE any ML imports ──\n    try:\n        _activate_transformers_version(model_name)\n    except Exception as exc:\n        _send_response(\n            resp_queue,\n            {\n                \"type\": \"error\",\n                \"error\": f\"Failed to activate transformers version: {exc}\",\n                \"stack\": traceback.format_exc(limit = 20),\n                \"ts\": time.time(),\n            },\n        )\n        return\n\n    # ── 1b. On Windows, check Triton availability (must be before import torch) ──\n    if sys.platform == \"win32\":\n        try:\n            import triton  # noqa: F401\n\n            logger.info(\"Triton available — torch.compile enabled\")\n        except ImportError:\n            os.environ[\"TORCHDYNAMO_DISABLE\"] = \"1\"\n            logger.warning(\n                \"Triton not found on Windows — torch.compile disabled. \"\n                'Install for better performance: pip install \"triton-windows<3.7\"'\n            )\n\n    # ── 2. Import ML libraries (fresh in this clean process) ──\n    try:\n        _send_response(\n            resp_queue,\n            {\n                \"type\": \"status\",\n                \"message\": \"Importing Unsloth...\",\n                \"ts\": time.time(),\n            },\n        )\n\n        backend_path = str(Path(__file__).resolve().parent.parent.parent)\n        if backend_path not in sys.path:\n            sys.path.insert(0, backend_path)\n\n        from core.inference.inference import InferenceBackend\n\n        import transformers\n\n        logger.info(\"Subprocess loaded transformers %s\", transformers.__version__)\n\n    except Exception as exc:\n        _send_response(\n            resp_queue,\n            {\n                \"type\": \"error\",\n                \"error\": f\"Failed to import ML libraries: {exc}\",\n                \"stack\": traceback.format_exc(limit = 20),\n                \"ts\": time.time(),\n            },\n        )\n        return\n\n    # ── 3. Create inference backend and load initial model ──\n    try:\n        backend = InferenceBackend()\n\n        _send_response(\n            resp_queue,\n            {\n                \"type\": \"status\",\n                \"message\": \"Loading model...\",\n                \"ts\": time.time(),\n            },\n        )\n\n        _handle_load(backend, config, resp_queue)\n\n    except Exception as exc:\n        _send_response(\n            resp_queue,\n            {\n                \"type\": \"error\",\n                \"error\": f\"Failed to initialize inference backend: {exc}\",\n                \"stack\": traceback.format_exc(limit = 20),\n                \"ts\": time.time(),\n            },\n        )\n        return\n\n    # ── 4. Command loop — process commands until shutdown ──\n    # cancel_event is an mp.Event shared with parent — parent can set it\n    # at any time to cancel generation instantly (no queue polling needed).\n    logger.info(\"Inference subprocess ready, entering command loop\")\n\n    while True:\n        try:\n            cmd = cmd_queue.get(timeout = 1.0)\n        except _queue.Empty:\n            continue\n        except (EOFError, OSError):\n            logger.info(\"Command queue closed, shutting down\")\n            return\n\n        if cmd is None:\n            continue\n\n        cmd_type = cmd.get(\"type\", \"\")\n        logger.info(\"Received command: %s\", cmd_type)\n\n        try:\n            if cmd_type == \"generate\":\n                cancel_event.clear()\n                _handle_generate(backend, cmd, resp_queue, cancel_event)\n\n            elif cmd_type == \"load\":\n                # Load a new model (reusing this subprocess)\n                # First unload current model\n                if backend.active_model_name:\n                    backend.unload_model(backend.active_model_name)\n                _handle_load(backend, cmd, resp_queue)\n\n            elif cmd_type == \"generate_audio\":\n                cancel_event.clear()\n                _handle_generate_audio(backend, cmd, resp_queue)\n\n            elif cmd_type == \"generate_audio_input\":\n                cancel_event.clear()\n                _handle_generate_audio_input(backend, cmd, resp_queue, cancel_event)\n\n            elif cmd_type == \"unload\":\n                _handle_unload(backend, cmd, resp_queue)\n\n            elif cmd_type == \"cancel\":\n                # Redundant with mp.Event but handle gracefully\n                cancel_event.set()\n                logger.info(\"Cancel command received\")\n\n            elif cmd_type == \"reset\":\n                cancel_event.set()\n                backend.reset_generation_state()\n                _send_response(\n                    resp_queue,\n                    {\n                        \"type\": \"reset_ack\",\n                        \"ts\": time.time(),\n                    },\n                )\n\n            elif cmd_type == \"status\":\n                # Return current status\n                _send_response(\n                    resp_queue,\n                    {\n                        \"type\": \"status_response\",\n                        \"active_model\": backend.active_model_name,\n                        \"models\": {\n                            name: {\n                                \"is_vision\": info.get(\"is_vision\", False),\n                                \"is_lora\": info.get(\"is_lora\", False),\n                            }\n                            for name, info in backend.models.items()\n                        },\n                        \"loading\": list(backend.loading_models),\n                        \"ts\": time.time(),\n                    },\n                )\n\n            elif cmd_type == \"shutdown\":\n                logger.info(\"Shutdown command received, exiting\")\n                # Unload all models\n                for model_name in list(backend.models.keys()):\n                    try:\n                        backend.unload_model(model_name)\n                    except Exception:\n                        pass\n                _send_response(\n                    resp_queue,\n                    {\n                        \"type\": \"shutdown_ack\",\n                        \"ts\": time.time(),\n                    },\n                )\n                return\n\n            else:\n                logger.warning(\"Unknown command type: %s\", cmd_type)\n                _send_response(\n                    resp_queue,\n                    {\n                        \"type\": \"error\",\n                        \"error\": f\"Unknown command type: {cmd_type}\",\n                        \"ts\": time.time(),\n                    },\n                )\n\n        except Exception as exc:\n            logger.error(\n                \"Error handling command '%s': %s\", cmd_type, exc, exc_info = True\n            )\n            _send_response(\n                resp_queue,\n                {\n                    \"type\": \"error\",\n                    \"error\": f\"Command '{cmd_type}' failed: {exc}\",\n                    \"stack\": traceback.format_exc(limit = 20),\n                    \"ts\": time.time(),\n                },\n            )\n"
  },
  {
    "path": "studio/backend/core/training/__init__.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nTraining submodule - Training backends and trainer classes\n\"\"\"\n\nfrom .training import TrainingBackend, TrainingProgress, get_training_backend\n\n__all__ = [\n    \"TrainingProgress\",\n    \"TrainingBackend\",\n    \"get_training_backend\",\n]\n"
  },
  {
    "path": "studio/backend/core/training/trainer.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nUnsloth Training Backend\nIntegrates Unsloth training capabilities with the FastAPI backend\n\"\"\"\n\nimport os\nimport sys\n\n# Prevent tokenizer parallelism deadlocks when datasets uses multiprocessing fork\nos.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n\nimport torch\nfrom utils.hardware import clear_gpu_cache, safe_num_proc\n\ntorch._dynamo.config.recompile_limit = 64\nfrom unsloth import FastLanguageModel, FastVisionModel, is_bfloat16_supported\nfrom unsloth.chat_templates import get_chat_template\n\nimport json\nimport threading\nimport math\nimport structlog\nfrom loggers import get_logger\nimport time\nfrom pathlib import Path\nfrom typing import Optional, Callable\nfrom dataclasses import dataclass\nimport pandas as pd\nfrom datasets import Dataset, load_dataset\n\nfrom utils.models import is_vision_model, detect_audio_type\nfrom utils.datasets import format_and_template_dataset\nfrom utils.datasets import MODEL_TO_TEMPLATE_MAPPER, TEMPLATE_TO_RESPONSES_MAPPER\nfrom utils.paths import (\n    ensure_dir,\n    resolve_dataset_path,\n    resolve_output_dir,\n    resolve_tensorboard_dir,\n)\nfrom trl import SFTTrainer, SFTConfig\n\nlogger = get_logger(__name__)\n\n\ndef _build_report_targets(training_args) -> list[str] | str:\n    report_to: list[str] = []\n    if training_args.get(\"enable_wandb\", False):\n        report_to.append(\"wandb\")\n    if training_args.get(\"enable_tensorboard\", False):\n        report_to.append(\"tensorboard\")\n    return report_to or \"none\"\n\n\n@dataclass\nclass TrainingProgress:\n    \"\"\"Training progress tracking\"\"\"\n\n    epoch: float = 0\n    step: int = 0\n    total_steps: int = 0\n    loss: float = 0.0\n    learning_rate: float = 0.0\n    is_training: bool = False\n    is_completed: bool = False\n    error: Optional[str] = None\n    status_message: str = \"Ready to train\"  # Current stage message\n    elapsed_seconds: Optional[float] = None\n    eta_seconds: Optional[float] = None\n    grad_norm: Optional[float] = None\n    num_tokens: Optional[int] = None\n    eval_loss: Optional[float] = None\n\n\nclass UnslothTrainer:\n    \"\"\"\n    Unsloth Training Backend\n    \"\"\"\n\n    def __init__(self):\n        self.model = None\n        self.tokenizer = None\n        self.trainer = None\n        self.training_thread = None\n        self.training_progress = TrainingProgress()\n        self.progress_callbacks = []\n        self.is_training = False\n        self.should_stop = False\n        self.save_on_stop = True\n        self.load_in_4bit = True  # Track quantization mode for metadata\n\n        # Model state tracking\n        self.is_vlm = False\n        self.is_audio = False\n        self.is_audio_vlm = (\n            False  # Multimodal model (e.g. Gemma 3N) trained on audio data\n        )\n        self._audio_type = None  # 'csm', 'whisper', 'snac', 'bicodec', 'dac'\n        self._cuda_audio_used = (\n            False  # Set once after audio CUDA preprocessing; never cleared\n        )\n        self._spark_tts_repo_dir = (\n            None  # Path to downloaded Spark-TTS repo (for BiCodecTokenizer)\n        )\n        self.model_name = None\n\n        # Training metrics tracking\n        self.training_start_time: Optional[float] = None\n        self.batch_size: Optional[int] = None\n        self.max_seq_length: Optional[int] = None\n        self.gradient_accumulation_steps: Optional[int] = None\n\n        # Thread safety\n        self._lock = threading.Lock()\n\n        # Store training context for later transfer\n        self.training_context = {\n            \"base_model_name\": None,\n            \"output_dir\": None,\n            \"is_lora\": True,  # Default to LoRA\n        }\n\n    def pre_detect_and_load_tokenizer(\n        self,\n        model_name: str,\n        max_seq_length: int = 2048,\n        hf_token: Optional[str] = None,\n        is_dataset_image: bool = False,\n        is_dataset_audio: bool = False,\n        trust_remote_code: bool = False,\n    ) -> None:\n        \"\"\"Lightweight detection and tokenizer load — no model weights, no VRAM.\n\n        Sets is_vlm, _audio_type, is_audio_vlm, model_name and loads a\n        lightweight tokenizer for dataset formatting.  Call this before\n        load_and_format_dataset() when you want to process the dataset\n        BEFORE loading the training model (avoids VRAM contention with\n        the LLM-assisted detection helper).\n\n        load_model() may be called afterwards — it will re-detect and load\n        the full model + tokenizer, overwriting the lightweight one set here.\n        \"\"\"\n        self.model_name = model_name\n        self.max_seq_length = max_seq_length\n        self.trust_remote_code = trust_remote_code\n\n        if hf_token:\n            os.environ[\"HF_TOKEN\"] = hf_token\n\n        # --- Detect audio type (reads config.json only, no VRAM) ---\n        self._audio_type = detect_audio_type(model_name, hf_token)\n        if self._audio_type == \"audio_vlm\":\n            self.is_audio = False\n            self.is_audio_vlm = is_dataset_audio\n            self._audio_type = None\n        else:\n            self.is_audio = self._audio_type is not None\n            self.is_audio_vlm = False\n\n        if not self.is_audio and not self.is_audio_vlm:\n            self._cuda_audio_used = False\n\n        # --- Detect VLM ---\n        vision = is_vision_model(model_name) if not self.is_audio else False\n        self.is_vlm = not self.is_audio_vlm and vision and is_dataset_image\n\n        logger.info(\n            \"pre_detect: audio_type=%s, is_audio=%s, is_audio_vlm=%s, is_vlm=%s\",\n            self._audio_type,\n            self.is_audio,\n            self.is_audio_vlm,\n            self.is_vlm,\n        )\n\n        # --- Load lightweight tokenizer/processor (CPU only, no VRAM) ---\n        # Whisper needs AutoProcessor (has feature_extractor + tokenizer).\n        # All others work with AutoTokenizer (CSM loads its own processor inline).\n        if self._audio_type == \"whisper\":\n            from transformers import AutoProcessor\n\n            self.tokenizer = AutoProcessor.from_pretrained(\n                model_name,\n                trust_remote_code = trust_remote_code,\n                token = hf_token,\n            )\n        else:\n            from transformers import AutoTokenizer\n\n            self.tokenizer = AutoTokenizer.from_pretrained(\n                model_name,\n                trust_remote_code = trust_remote_code,\n                token = hf_token,\n            )\n\n        logger.info(\"Pre-loaded tokenizer for %s\", model_name)\n\n    def add_progress_callback(self, callback: Callable[[TrainingProgress], None]):\n        \"\"\"Add callback for training progress updates\"\"\"\n        self.progress_callbacks.append(callback)\n\n    def _update_progress(self, **kwargs):\n        \"\"\"Update training progress and notify callbacks\"\"\"\n        with self._lock:\n            for key, value in kwargs.items():\n                if hasattr(self.training_progress, key):\n                    setattr(self.training_progress, key, value)\n\n            # Notify all callbacks\n            for callback in self.progress_callbacks:\n                try:\n                    callback(self.training_progress)\n                except Exception as e:\n                    logger.error(f\"Error in progress callback: {e}\")\n\n    def _create_progress_callback(self):\n        \"\"\"Create a TrainerCallback for progress tracking. Reused by all training branches.\"\"\"\n        from transformers import TrainerCallback\n\n        trainer_ref = self\n\n        class _ProgressCallback(TrainerCallback):\n            def on_log(self, args, state, control, logs = None, **kwargs):\n                if not logs:\n                    return\n                loss_value = logs.get(\"loss\", logs.get(\"train_loss\", 0.0))\n                current_step = state.global_step\n                grad_norm = logs.get(\"grad_norm\", None)\n\n                elapsed_seconds = None\n                if trainer_ref.training_start_time is not None:\n                    elapsed_seconds = time.time() - trainer_ref.training_start_time\n\n                eta_seconds = None\n                if elapsed_seconds is not None and current_step > 0:\n                    total_steps = trainer_ref.training_progress.total_steps\n                    if total_steps > 0:\n                        steps_remaining = total_steps - current_step\n                        if steps_remaining > 0:\n                            eta_seconds = (\n                                elapsed_seconds / current_step\n                            ) * steps_remaining\n\n                num_tokens = getattr(state, \"num_input_tokens_seen\", None)\n\n                trainer_ref._update_progress(\n                    step = current_step,\n                    epoch = round(state.epoch, 2) if state.epoch else 0,\n                    loss = loss_value,\n                    learning_rate = logs.get(\"learning_rate\", 0.0),\n                    elapsed_seconds = elapsed_seconds,\n                    eta_seconds = eta_seconds,\n                    grad_norm = grad_norm,\n                    num_tokens = num_tokens,\n                    eval_loss = logs.get(\"eval_loss\", None),\n                    status_message = \"\",\n                )\n\n            def on_epoch_end(self, args, state, control, **kwargs):\n                trainer_ref._update_progress(epoch = state.epoch, step = state.global_step)\n\n            def on_step_end(self, args, state, control, **kwargs):\n                if trainer_ref.should_stop:\n                    logger.info(f\"Stop detected at step {state.global_step}\\n\")\n                    control.should_training_stop = True\n                    return control\n\n        return _ProgressCallback()\n\n    def _calculate_total_steps(\n        self, num_samples, batch_size, grad_accum, num_epochs, max_steps\n    ):\n        \"\"\"Calculate total training steps from dataset size and training params.\"\"\"\n        if max_steps and max_steps > 0:\n            return max_steps\n        len_dataloader = math.ceil(num_samples / batch_size)\n        steps_per_epoch = max(\n            len_dataloader // grad_accum + int(len_dataloader % grad_accum > 0), 1\n        )\n        return steps_per_epoch * num_epochs\n\n    def _build_audio_training_args(self, training_args, output_dir, *, extra_args = None):\n        \"\"\"Build training args dict for audio branches.\n\n        Constructs the common config (batch size, lr, warmup, fp16/bf16, etc.)\n        and applies per-branch overrides via extra_args.\n        \"\"\"\n        batch_size = training_args.get(\"batch_size\", 2)\n        gradient_accumulation_steps = training_args.get(\n            \"gradient_accumulation_steps\", 4\n        )\n        warmup_steps_val = training_args.get(\"warmup_steps\", 5)\n        max_steps_val = training_args.get(\"max_steps\", 0)\n        learning_rate = training_args.get(\"learning_rate\", 2e-4)\n        weight_decay = training_args.get(\"weight_decay\", 0.001)\n        lr_scheduler_type = training_args.get(\"lr_scheduler_type\", \"linear\")\n        random_seed = training_args.get(\"random_seed\", 3407)\n        optim_value = training_args.get(\"optim\", \"adamw_8bit\")\n\n        config = {\n            \"per_device_train_batch_size\": batch_size,\n            \"gradient_accumulation_steps\": gradient_accumulation_steps,\n            \"warmup_steps\": warmup_steps_val if warmup_steps_val is not None else 5,\n            \"learning_rate\": learning_rate,\n            \"fp16\": not is_bfloat16_supported(),\n            \"bf16\": is_bfloat16_supported(),\n            \"logging_steps\": 1,\n            \"optim\": optim_value,\n            \"weight_decay\": weight_decay,\n            \"lr_scheduler_type\": lr_scheduler_type,\n            \"seed\": random_seed,\n            \"output_dir\": output_dir,\n            \"report_to\": _build_report_targets(training_args),\n        }\n\n        if training_args.get(\"enable_tensorboard\", False):\n            config[\"logging_dir\"] = str(\n                resolve_tensorboard_dir(training_args.get(\"tensorboard_dir\"))\n            )\n\n        # max_steps vs epochs\n        if max_steps_val and max_steps_val > 0:\n            config[\"max_steps\"] = max_steps_val\n        else:\n            config[\"num_train_epochs\"] = training_args.get(\"num_epochs\", 3)\n\n        # save_steps\n        save_steps_val = training_args.get(\"save_steps\", 0)\n        if save_steps_val and save_steps_val > 0:\n            config[\"save_steps\"] = save_steps_val\n            config[\"save_strategy\"] = \"steps\"\n\n        # Apply per-branch overrides\n        if extra_args:\n            config.update(extra_args)\n\n        return config\n\n    def _finalize_training(self, output_dir, label = \"\"):\n        \"\"\"Save model after training and update progress. Used by all training branches.\"\"\"\n        if self.should_stop and self.save_on_stop:\n            self.trainer.save_model()\n            self.tokenizer.save_pretrained(output_dir)\n            self._patch_adapter_config(output_dir)\n            msg = f\"{label} training stopped\" if label else \"Training stopped\"\n            logger.info(f\"\\n{msg}. Model saved to {output_dir}\\n\")\n            self._update_progress(\n                is_training = False,\n                status_message = f\"Training stopped. Model saved to {output_dir}\",\n            )\n        elif self.should_stop:\n            msg = f\"{label} training cancelled\" if label else \"Training cancelled\"\n            logger.info(f\"\\n{msg}.\\n\")\n            self._update_progress(\n                is_training = False, status_message = \"Training cancelled.\"\n            )\n        else:\n            self.trainer.save_model()\n            self.tokenizer.save_pretrained(output_dir)\n            self._patch_adapter_config(output_dir)\n            msg = f\"{label} training completed\" if label else \"Training completed\"\n            logger.info(f\"\\n{msg}! Model saved to {output_dir}\\n\")\n            self._update_progress(\n                is_training = False,\n                is_completed = True,\n                status_message = f\"Training completed! Model saved to {output_dir}\",\n            )\n\n    def _cleanup_audio_artifacts(self):\n        \"\"\"Remove sys.path entries and sys.modules from previous audio preprocessing.\n\n        After audio training, cloned repo dirs (OuteTTS, Spark-TTS) remain on\n        sys.path and heavy audio modules (snac, whisper, sparktts, outetts) stay\n        in sys.modules. When the next training run calls dataset.map(num_proc=N),\n        forked child processes inherit this stale state and deadlock.\n        \"\"\"\n        import sys as _sys\n\n        # Remove cloned audio repo paths from sys.path\n        base_dir = os.path.dirname(os.path.abspath(__file__))\n        audio_paths = [\n            os.path.join(base_dir, \"inference\", \"OuteTTS\"),  # DAC/OuteTTS\n        ]\n        # Spark-TTS path is relative to the downloaded repo\n        if self._spark_tts_repo_dir:\n            spark_code_dir = os.path.join(\n                os.path.dirname(self._spark_tts_repo_dir), \"Spark-TTS\"\n            )\n            audio_paths.append(spark_code_dir)\n\n        removed_paths = []\n        for path in audio_paths:\n            if path in _sys.path:\n                _sys.path.remove(path)\n                removed_paths.append(path)\n\n        # Remove stale audio modules from sys.modules\n        prefixes = (\"snac\", \"whisper\", \"sparktts\", \"outetts\")\n        removed_modules = [key for key in _sys.modules if key.startswith(prefixes)]\n        for key in removed_modules:\n            del _sys.modules[key]\n\n        if removed_paths or removed_modules:\n            logger.info(\n                f\"Cleaned up audio artifacts: {len(removed_paths)} paths, \"\n                f\"{len(removed_modules)} modules\\n\"\n            )\n\n    def _resolve_audio_columns(self, dataset, custom_format_mapping: dict = None):\n        \"\"\"Resolve audio, text, and speaker columns from user mapping or hardcoded fallback.\n\n        Returns:\n            dict with keys: audio_col, text_col, speaker_col (speaker_col may be None)\n        \"\"\"\n        cols = dataset.column_names\n\n        if custom_format_mapping:\n            audio_col = None\n            text_col = None\n            speaker_col = None\n            for col, role in custom_format_mapping.items():\n                if role == \"audio\":\n                    audio_col = col\n                elif role == \"text\":\n                    text_col = col\n                elif role == \"speaker_id\":\n                    speaker_col = col\n            # Use mapping if both required columns exist in the dataset\n            if audio_col and audio_col in cols and text_col and text_col in cols:\n                return {\n                    \"audio_col\": audio_col,\n                    \"text_col\": text_col,\n                    \"speaker_col\": speaker_col,\n                }\n\n        # Hardcoded fallback (existing behavior)\n        audio_col = next((c for c in cols if c.lower() in (\"audio\", \"speech\")), None)\n        text_col = next(\n            (\n                c\n                for c in cols\n                if c.lower() in (\"text\", \"sentence\", \"transcript\", \"transcription\")\n            ),\n            None,\n        )\n\n        speaker_col = None\n        if \"source\" in cols:\n            speaker_col = \"source\"\n        elif \"speaker_id\" in cols:\n            speaker_col = \"speaker_id\"\n\n        return {\n            \"audio_col\": audio_col,\n            \"text_col\": text_col,\n            \"speaker_col\": speaker_col,\n        }\n\n    def load_model(\n        self,\n        model_name: str,\n        max_seq_length: int = 2048,\n        load_in_4bit: bool = True,\n        hf_token: Optional[str] = None,\n        is_dataset_image: bool = False,\n        is_dataset_audio: bool = False,\n        trust_remote_code: bool = False,\n        full_finetuning: bool = False,\n    ) -> bool:\n        \"\"\"Load model for training (supports both text and vision models)\"\"\"\n        self.load_in_4bit = load_in_4bit  # Store for training_meta.json\n        self.trust_remote_code = (\n            trust_remote_code  # For AutoProcessor etc. used during training\n        )\n        try:\n            if self.model is not None:\n                del self.model\n            if self.tokenizer is not None:\n                del self.tokenizer\n\n            if self.trainer is not None:\n                del self.trainer\n\n            logger.info(\"\\nClearing GPU memory before training...\")\n            clear_gpu_cache()\n\n            # Clean up sys.path and sys.modules from previous audio preprocessing\n            # to prevent deadlocks when forking worker processes in dataset.map()\n            self._cleanup_audio_artifacts()\n\n            # Reload Unsloth-patched transformers modeling modules before clearing\n            # the compiled cache. unsloth_compile_transformers() sets __UNSLOTH_PATCHED__\n            # on each modeling module and replaces methods with exec'd code.\n            # clear_unsloth_compiled_cache() deletes the disk cache, but the flag\n            # prevents re-compilation — leaving missing cache files. Reloading\n            # restores original class definitions so Unsloth can re-compile cleanly.\n            import sys as _sys\n            import importlib\n\n            for _key, _mod in list(_sys.modules.items()):\n                if \"transformers.models.\" in _key and \".modeling_\" in _key:\n                    if hasattr(_mod, \"__UNSLOTH_PATCHED__\"):\n                        try:\n                            importlib.reload(_mod)\n                        except Exception:\n                            pass  # Non-critical — Unsloth will handle stale modules\n\n            # Remove stale compiled cache so the new model gets a fresh one\n            from utils.cache_cleanup import clear_unsloth_compiled_cache\n\n            clear_unsloth_compiled_cache()\n            # Detect audio model type dynamically (config.json + tokenizer)\n            self._audio_type = detect_audio_type(model_name, hf_token)\n            # audio_vlm is detected as an audio_type now, handle it separately\n            if self._audio_type == \"audio_vlm\":\n                self.is_audio = False\n                self.is_audio_vlm = (\n                    is_dataset_audio  # Only use audio VLM path if dataset has audio\n                )\n                self._audio_type = None\n            else:\n                self.is_audio = self._audio_type is not None\n                self.is_audio_vlm = False\n\n            if not self.is_audio and not self.is_audio_vlm:\n                self._cuda_audio_used = False\n\n            # VLM: vision model with image dataset (mutually exclusive with audio paths)\n            vision = is_vision_model(model_name) if not self.is_audio else False\n            self.is_vlm = not self.is_audio_vlm and vision and is_dataset_image\n            self.model_name = model_name\n            self.max_seq_length = max_seq_length\n\n            logger.info(\n                f\"Audio type: {self._audio_type}, is_audio: {self.is_audio}, is_audio_vlm: {self.is_audio_vlm}\"\n            )\n            logger.info(\n                f\"Dataset has images: {is_dataset_image}, audio: {is_dataset_audio}\"\n            )\n            logger.info(f\"Using VLM path: {self.is_vlm}\")\n\n            # Reset training state for new run\n            self._update_progress(\n                is_training = True,\n                is_completed = False,\n                error = None,\n                step = 0,\n                loss = 0.0,\n                epoch = 0,\n            )\n\n            # Update UI immediately with loading message\n            model_display = (\n                model_name.split(\"/\")[-1] if \"/\" in model_name else model_name\n            )\n            model_type_label = (\n                \"audio\" if self.is_audio else (\"vision\" if self.is_vlm else \"text\")\n            )\n            self._update_progress(\n                status_message = f\"Loading {model_type_label} model... {model_display}\"\n            )\n\n            logger.info(f\"\\nLoading {model_type_label} model: {model_name}\")\n\n            # Set HF token if provided\n            if hf_token:\n                os.environ[\"HF_TOKEN\"] = hf_token\n\n            # Proactive gated-model check: verify access BEFORE from_pretrained.\n            # Catches ALL gated/private models (text, vision, audio) globally.\n            if \"/\" in model_name:  # Only check HF repo IDs, not local paths\n                try:\n                    from huggingface_hub import model_info as hf_model_info\n\n                    info = hf_model_info(model_name, token = hf_token or None)\n                    # model_info succeeds even for gated repos (metadata is public),\n                    # but info.gated tells us if files require acceptance/token.\n                    if info.gated and not hf_token:\n                        friendly = (\n                            f\"Access denied for '{model_name}'. This model is gated. \"\n                            f\"Please add a Hugging Face token with access and try again.\"\n                        )\n                        logger.error(\n                            f\"Model '{model_name}' is gated (gated={info.gated}) and no HF token provided\"\n                        )\n                        self._update_progress(error = friendly, is_training = False)\n                        return False\n                except Exception as gate_err:\n                    from huggingface_hub.utils import (\n                        GatedRepoError,\n                        RepositoryNotFoundError,\n                    )\n\n                    if isinstance(gate_err, (GatedRepoError, RepositoryNotFoundError)):\n                        friendly = (\n                            f\"Access denied for '{model_name}'. This model is gated or private. \"\n                            f\"Please add a Hugging Face token with access and try again.\"\n                        )\n                        logger.error(f\"Gated model check failed: {gate_err}\")\n                        self._update_progress(error = friendly, is_training = False)\n                        return False\n\n            # Branch based on model type\n            if self._audio_type == \"csm\":\n                # CSM: FastModel + auto_model=CsmForConditionalGeneration + load_in_4bit=False\n                from unsloth import FastModel\n                from transformers import CsmForConditionalGeneration\n\n                self.model, self.tokenizer = FastModel.from_pretrained(\n                    model_name = model_name,\n                    max_seq_length = max_seq_length,\n                    dtype = None,\n                    auto_model = CsmForConditionalGeneration,\n                    load_in_4bit = False,\n                    full_finetuning = full_finetuning,\n                    token = hf_token,\n                    trust_remote_code = trust_remote_code,\n                )\n                logger.info(\"Loaded CSM audio model\")\n\n            elif self._audio_type == \"whisper\":\n                # Whisper: FastModel + auto_model=WhisperForConditionalGeneration + load_in_4bit=False\n                from unsloth import FastModel\n                from transformers import WhisperForConditionalGeneration\n\n                self.model, self.tokenizer = FastModel.from_pretrained(\n                    model_name = model_name,\n                    dtype = None,\n                    load_in_4bit = False,\n                    full_finetuning = full_finetuning,\n                    auto_model = WhisperForConditionalGeneration,\n                    whisper_language = \"English\",\n                    whisper_task = \"transcribe\",\n                    token = hf_token,\n                    trust_remote_code = trust_remote_code,\n                )\n                # Configure generation settings (notebook lines 100-105)\n                self.model.generation_config.language = \"<|en|>\"\n                self.model.generation_config.task = \"transcribe\"\n                self.model.config.suppress_tokens = []\n                self.model.generation_config.forced_decoder_ids = None\n                logger.info(\"Loaded Whisper audio model (FastModel)\")\n\n            elif self._audio_type == \"snac\":\n                # Orpheus: language model with audio codec tokens\n                self.model, self.tokenizer = FastLanguageModel.from_pretrained(\n                    model_name = model_name,\n                    max_seq_length = max_seq_length,\n                    dtype = None,\n                    load_in_4bit = load_in_4bit,\n                    full_finetuning = full_finetuning,\n                    token = hf_token,\n                    trust_remote_code = trust_remote_code,\n                )\n                logger.info(\n                    f\"Loaded {self._audio_type} audio model (FastLanguageModel)\"\n                )\n\n            elif self._audio_type == \"bicodec\":\n                # Spark-TTS: download full repo (contains sparktts package + BiCodec weights),\n                # then load only the LLM subfolder with FastModel.\n                # model_name may be:\n                #   \"Spark-TTS-0.5B/LLM\"       (local-style, from YAML mapping)\n                #   \"unsloth/Spark-TTS-0.5B\"    (HF repo ID)\n                from unsloth import FastModel\n                from huggingface_hub import snapshot_download\n\n                if model_name.endswith(\"/LLM\"):\n                    # \"Spark-TTS-0.5B/LLM\" → parent=\"Spark-TTS-0.5B\"\n                    local_dir = model_name.rsplit(\"/\", 1)[0]\n                    hf_repo = f\"unsloth/{local_dir}\"\n                    llm_path = model_name\n                else:\n                    # \"unsloth/Spark-TTS-0.5B\" → local_dir=\"Spark-TTS-0.5B\"\n                    hf_repo = model_name\n                    local_dir = model_name.split(\"/\")[-1]\n                    llm_path = f\"{local_dir}/LLM\"\n\n                repo_path = snapshot_download(hf_repo, local_dir = local_dir)\n                self._spark_tts_repo_dir = os.path.abspath(\n                    repo_path\n                )  # Absolute path for sys.path\n                llm_path = os.path.join(self._spark_tts_repo_dir, \"LLM\")\n\n                self.model, self.tokenizer = FastModel.from_pretrained(\n                    model_name = llm_path,\n                    max_seq_length = max_seq_length,\n                    dtype = torch.float32,  # Spark-TTS requires float32\n                    load_in_4bit = False,\n                    full_finetuning = full_finetuning,\n                    token = hf_token,\n                    trust_remote_code = trust_remote_code,\n                )\n                logger.info(\"Loaded Spark-TTS (bicodec) model\")\n\n            elif self._audio_type == \"dac\":\n                # OuteTTS: uses FastModel (not FastLanguageModel) with load_in_4bit=False\n                from unsloth import FastModel\n\n                self.model, self.tokenizer = FastModel.from_pretrained(\n                    model_name,\n                    max_seq_length = max_seq_length,\n                    load_in_4bit = False,\n                    full_finetuning = full_finetuning,\n                    token = hf_token,\n                    trust_remote_code = trust_remote_code,\n                )\n                logger.info(\"Loaded OuteTTS (dac) model (FastModel)\")\n\n            elif self.is_audio_vlm:\n                # Audio VLM: multimodal model trained on audio (e.g. Gemma 3N)\n                # Uses FastModel (general loader) — returns (model, processor)\n                from unsloth import FastModel\n\n                self.model, self.tokenizer = FastModel.from_pretrained(\n                    model_name = model_name,\n                    max_seq_length = max_seq_length,\n                    dtype = None,\n                    load_in_4bit = load_in_4bit,\n                    full_finetuning = full_finetuning,\n                    token = hf_token,\n                    trust_remote_code = trust_remote_code,\n                )\n                logger.info(\"Loaded audio VLM model (FastModel)\")\n\n            elif self.is_vlm:\n                # Load vision model - returns (model, tokenizer)\n                self.model, self.tokenizer = FastVisionModel.from_pretrained(\n                    model_name = model_name,\n                    max_seq_length = max_seq_length,\n                    dtype = None,  # Auto-detect\n                    load_in_4bit = load_in_4bit,\n                    full_finetuning = full_finetuning,\n                    token = hf_token,\n                    trust_remote_code = trust_remote_code,\n                )\n                logger.info(\"Loaded vision model\")\n\n                # Diagnostic: check if FastVisionModel returned a real Processor or a raw tokenizer\n                from transformers import ProcessorMixin\n\n                tok = self.tokenizer\n                has_image_proc = isinstance(tok, ProcessorMixin) or hasattr(\n                    tok, \"image_processor\"\n                )\n                logger.info(\n                    f\"\\n[VLM Diagnostic] FastVisionModel returned: {type(tok).__name__}\"\n                )\n                logger.info(\n                    f\"[VLM Diagnostic] Is ProcessorMixin: {isinstance(tok, ProcessorMixin)}\"\n                )\n                logger.info(\n                    f\"[VLM Diagnostic] Has image_processor: {hasattr(tok, 'image_processor')}\"\n                )\n                logger.info(\n                    f\"[VLM Diagnostic] Usable as vision processor: {has_image_proc}\\n\"\n                )\n            else:\n                # Load text model - returns (model, tokenizer)\n                self.model, self.tokenizer = FastLanguageModel.from_pretrained(\n                    model_name = model_name,\n                    max_seq_length = max_seq_length,\n                    dtype = None,  # Auto-detect\n                    load_in_4bit = load_in_4bit,\n                    full_finetuning = full_finetuning,\n                    token = hf_token,\n                    trust_remote_code = trust_remote_code,\n                )\n                logger.info(\"Loaded text model\")\n\n            if self.should_stop:\n                return False\n\n            if full_finetuning:\n                # Enable training mode for full fine-tuning\n                # This ensures all model parameters are trainable; otherwise, they might be frozen.\n                self.model.for_training()\n\n            self._update_progress(status_message = \"Model loaded successfully\")\n            logger.info(\"Model loaded successfully\")\n            return True\n\n        except OSError as e:\n            if \"could not get source code\" in str(e) and not getattr(\n                self, \"_source_code_retried\", False\n            ):\n                # Unsloth's patching can leave stale state that makes\n                # inspect.getsource() fail when switching model families\n                # (e.g. gemma3 → gemma3n). The load always succeeds on a\n                # second attempt because the failed first call's partial\n                # imports clean up the stale state as a side effect.\n                self._source_code_retried = True\n                logger.info(f\"\\n'could not get source code' — retrying once...\\n\")\n                return self.load_model(\n                    model_name = model_name,\n                    max_seq_length = max_seq_length,\n                    load_in_4bit = load_in_4bit,\n                    hf_token = hf_token,\n                    is_dataset_image = is_dataset_image,\n                    is_dataset_audio = is_dataset_audio,\n                    trust_remote_code = trust_remote_code,\n                    full_finetuning = full_finetuning,\n                )\n            error_msg = str(e)\n            error_lower = error_msg.lower()\n            if any(\n                k in error_lower\n                for k in (\n                    \"gated repo\",\n                    \"access to it at\",\n                    \"401\",\n                    \"403\",\n                    \"unauthorized\",\n                    \"forbidden\",\n                )\n            ):\n                error_msg = (\n                    f\"Access denied for '{model_name}'. This model is gated or private. \"\n                    f\"Please add a Hugging Face token with access and try again.\"\n                )\n            logger.error(f\"Error loading model: {e}\")\n            self._update_progress(error = error_msg, is_training = False)\n            return False\n        except Exception as e:\n            error_msg = str(e)\n            # Catch gated/auth errors and surface a friendly message\n            error_lower = error_msg.lower()\n            if any(\n                k in error_lower\n                for k in (\n                    \"gated repo\",\n                    \"access to it at\",\n                    \"401\",\n                    \"403\",\n                    \"unauthorized\",\n                    \"forbidden\",\n                )\n            ):\n                error_msg = (\n                    f\"Access denied for '{model_name}'. This model is gated or private. \"\n                    f\"Please add a Hugging Face token with access and try again.\"\n                )\n            logger.error(f\"Error loading model: {e}\")\n            self._update_progress(error = error_msg, is_training = False)\n            return False\n        finally:\n            self._source_code_retried = False\n\n    def prepare_model_for_training(\n        self,\n        use_lora: bool = True,\n        # Vision-specific LoRA parameters (only used if is_vlm=True)\n        finetune_vision_layers: bool = True,\n        finetune_language_layers: bool = True,\n        finetune_attention_modules: bool = True,\n        finetune_mlp_modules: bool = True,\n        # Standard LoRA parameters\n        target_modules: list = None,\n        lora_r: int = 16,\n        lora_alpha: int = 16,\n        lora_dropout: float = 0.0,\n        use_gradient_checkpointing: str = \"unsloth\",\n        use_rslora: bool = False,\n        use_loftq: bool = False,\n    ) -> bool:\n        \"\"\"\n        Prepare model for training (with optional LoRA).\n        \"\"\"\n        try:\n            if self.model is None:\n                raise ValueError(\"Model not loaded. Call load_model() first.\")\n\n            # Full finetuning mode - skip PEFT entirely\n            if not use_lora:\n                self._update_progress(\n                    status_message = \"Full finetuning mode - no LoRA adapters\"\n                )\n                logger.info(\"Full finetuning mode - training all parameters\\n\")\n                return True\n\n            # LoRA/QLoRA mode - apply PEFT\n            # \"all-linear\" is a PEFT keyword that targets every linear layer\n            if isinstance(target_modules, list) and \"all-linear\" in target_modules:\n                if len(target_modules) == 1:\n                    target_modules = \"all-linear\"\n                else:\n                    target_modules = [m for m in target_modules if m != \"all-linear\"]\n            elif target_modules is None or (\n                isinstance(target_modules, list) and len(target_modules) == 0\n            ):\n                target_modules = [\n                    \"q_proj\",\n                    \"k_proj\",\n                    \"v_proj\",\n                    \"o_proj\",\n                    \"gate_proj\",\n                    \"up_proj\",\n                    \"down_proj\",\n                ]\n\n            # Validate and normalize gradient_checkpointing\n            # Must be one of: True, False, or \"unsloth\"\n            if isinstance(use_gradient_checkpointing, str):\n                use_gradient_checkpointing = use_gradient_checkpointing.strip().lower()\n                if (\n                    use_gradient_checkpointing == \"\"\n                    or use_gradient_checkpointing == \"unsloth\"\n                ):\n                    use_gradient_checkpointing = \"unsloth\"\n                elif use_gradient_checkpointing in (\"true\", \"1\", \"yes\"):\n                    use_gradient_checkpointing = True\n                elif use_gradient_checkpointing in (\"false\", \"0\", \"no\"):\n                    use_gradient_checkpointing = False\n                else:\n                    # Invalid value, default to \"unsloth\"\n                    logger.warning(\n                        f\"Invalid gradient_checkpointing value: {use_gradient_checkpointing}, defaulting to 'unsloth'\"\n                    )\n                    use_gradient_checkpointing = \"unsloth\"\n            elif use_gradient_checkpointing not in (True, False, \"unsloth\"):\n                # Invalid type or value, default to \"unsloth\"\n                logger.warning(\n                    f\"Invalid gradient_checkpointing type/value: {use_gradient_checkpointing}, defaulting to 'unsloth'\"\n                )\n                use_gradient_checkpointing = \"unsloth\"\n\n            # Verify model is loaded\n            if self.model is None:\n                error_msg = \"Model is None - model was not loaded properly\"\n                logger.error(error_msg)\n                self._update_progress(error = error_msg)\n                return False\n\n            # Check if model has the expected attributes\n            if not hasattr(self.model, \"config\"):\n                error_msg = \"Model does not have config attribute - model may not be loaded correctly\"\n                logger.error(error_msg)\n                self._update_progress(error = error_msg)\n                return False\n\n            logger.info(\n                f\"Configuring LoRA adapters (r={lora_r}, alpha={lora_alpha})...\\n\"\n            )\n            logger.info(\n                f\"Gradient checkpointing: {use_gradient_checkpointing} (type: {type(use_gradient_checkpointing).__name__})\\n\"\n            )\n\n            # Branch based on model type: audio, audio_vlm, vision, or text\n            if self._audio_type in (\"csm\", \"bicodec\", \"dac\") or self.is_audio_vlm:\n                # Models using FastModel.get_peft_model (codec audio + audio VLM)\n                from unsloth import FastModel\n\n                label = self._audio_type or \"audio_vlm\"\n                logger.info(f\"{label} LoRA configuration:\")\n                logger.info(f\"  - Target modules: {target_modules}\")\n                if self.is_audio_vlm:\n                    logger.info(f\"  - Finetune vision layers: {finetune_vision_layers}\")\n                    logger.info(\n                        f\"  - Finetune language layers: {finetune_language_layers}\"\n                    )\n                    logger.info(\n                        f\"  - Finetune attention modules: {finetune_attention_modules}\"\n                    )\n                    logger.info(f\"  - Finetune MLP modules: {finetune_mlp_modules}\")\n                logger.info()\n\n                peft_kwargs = dict(\n                    r = lora_r,\n                    target_modules = target_modules,\n                    lora_alpha = lora_alpha,\n                    lora_dropout = lora_dropout,\n                    bias = \"none\",\n                    use_gradient_checkpointing = use_gradient_checkpointing,\n                    random_state = 3407,\n                    use_rslora = use_rslora,\n                    loftq_config = {\"loftq_bits\": 4, \"loftq_iter\": 1}\n                    if use_loftq\n                    else None,\n                )\n                # Audio VLM models support VLM-style layer selection\n                if self.is_audio_vlm:\n                    peft_kwargs.update(\n                        finetune_vision_layers = finetune_vision_layers,\n                        finetune_language_layers = finetune_language_layers,\n                        finetune_attention_modules = finetune_attention_modules,\n                        finetune_mlp_modules = finetune_mlp_modules,\n                    )\n\n                self.model = FastModel.get_peft_model(self.model, **peft_kwargs)\n\n            elif self._audio_type == \"whisper\":\n                # Phase 2: Whisper uses FastModel.get_peft_model with task_type=None\n                from unsloth import FastModel\n\n                logger.info(f\"Audio model (whisper) LoRA configuration:\")\n                logger.info(f\"  - Target modules: {target_modules}\\n\")\n\n                self.model = FastModel.get_peft_model(\n                    self.model,\n                    r = lora_r,\n                    target_modules = target_modules,\n                    lora_alpha = lora_alpha,\n                    lora_dropout = lora_dropout,\n                    bias = \"none\",\n                    use_gradient_checkpointing = use_gradient_checkpointing,\n                    random_state = 3407,\n                    use_rslora = use_rslora,\n                    loftq_config = {\"loftq_bits\": 4, \"loftq_iter\": 1}\n                    if use_loftq\n                    else None,\n                    task_type = None,\n                )\n\n            elif self._audio_type == \"snac\":\n                # Orpheus uses FastLanguageModel.get_peft_model\n                logger.info(f\"Audio model ({self._audio_type}) LoRA configuration:\")\n                logger.info(f\"  - Target modules: {target_modules}\\n\")\n\n                self.model = FastLanguageModel.get_peft_model(\n                    self.model,\n                    r = lora_r,\n                    target_modules = target_modules,\n                    lora_alpha = lora_alpha,\n                    lora_dropout = lora_dropout,\n                    bias = \"none\",\n                    use_gradient_checkpointing = use_gradient_checkpointing,\n                    random_state = 3407,\n                    use_rslora = use_rslora,\n                    loftq_config = {\"loftq_bits\": 4, \"loftq_iter\": 1}\n                    if use_loftq\n                    else None,\n                )\n\n            elif self.is_vlm:\n                # Vision model LoRA\n                logger.info(f\"Vision model LoRA configuration:\")\n                logger.info(f\"  - Finetune vision layers: {finetune_vision_layers}\")\n                logger.info(f\"  - Finetune language layers: {finetune_language_layers}\")\n                logger.info(\n                    f\"  - Finetune attention modules: {finetune_attention_modules}\"\n                )\n                logger.info(f\"  - Finetune MLP modules: {finetune_mlp_modules}\\n\")\n\n                self.model = FastVisionModel.get_peft_model(\n                    self.model,\n                    finetune_vision_layers = finetune_vision_layers,\n                    finetune_language_layers = finetune_language_layers,\n                    finetune_attention_modules = finetune_attention_modules,\n                    finetune_mlp_modules = finetune_mlp_modules,\n                    r = lora_r,\n                    target_modules = target_modules,\n                    lora_alpha = lora_alpha,\n                    lora_dropout = lora_dropout,\n                    bias = \"none\",\n                    use_gradient_checkpointing = use_gradient_checkpointing,\n                    random_state = 3407,\n                    use_rslora = use_rslora,\n                    loftq_config = {\"loftq_bits\": 4, \"loftq_iter\": 1}\n                    if use_loftq\n                    else None,\n                )\n            else:\n                # Text model LoRA\n                logger.info(f\"Text model LoRA configuration:\")\n                logger.info(f\"  - Target modules: {target_modules}\\n\")\n\n                self.model = FastLanguageModel.get_peft_model(\n                    self.model,\n                    r = lora_r,\n                    target_modules = target_modules,\n                    lora_alpha = lora_alpha,\n                    lora_dropout = lora_dropout,\n                    bias = \"none\",\n                    use_gradient_checkpointing = use_gradient_checkpointing,\n                    random_state = 3407,\n                    use_rslora = use_rslora,\n                    loftq_config = {\"loftq_bits\": 4, \"loftq_iter\": 1}\n                    if use_loftq\n                    else None,\n                )\n\n            # Check if stopped during LoRA preparation\n            if self.should_stop:\n                logger.info(\"Stopped during LoRA configuration\\n\")\n                return False\n\n            self._update_progress(status_message = \"LoRA adapters configured\")\n            logger.info(\"LoRA adapters configured successfully\\n\")\n            return True\n\n        except Exception as e:\n            import traceback\n            import sys\n\n            error_details = (\n                f\"{type(e).__name__}: {str(e)}\"\n                if str(e)\n                else f\"{type(e).__name__} (no message)\"\n            )\n            full_traceback = traceback.format_exc()\n            logger.error(f\"Error preparing model: {error_details}\")\n            logger.error(f\"Full traceback:\\n{full_traceback}\")\n            logger.info(f\"\\n[ERROR] Error preparing model: {error_details}\")\n            logger.info(f\"[ERROR] Full traceback:\\n{full_traceback}\")\n            self._update_progress(error = error_details)\n            return False\n\n    def _apply_csm_forward_fix(self):\n        \"\"\"Monkey-patch CsmForConditionalGeneration.forward to fix depth decoder kwargs.\n\n        The original transformers forward passes raw **kwargs (num_items_in_batch,\n        causal_mask, etc.) from the Trainer/PEFT through to the depth decoder,\n        causing depth_decoder_loss=None and 'Tensor + NoneType' crash.\n\n        We patch at both instance AND class level for maximum reliability,\n        and strip non-TransformersKwargs params that Unsloth/PEFT inject.\n        \"\"\"\n        import types\n        import torch\n        import torch.nn as nn\n        from transformers.models.csm.modeling_csm import (\n            CsmForConditionalGeneration,\n            CsmOutputWithPast,\n        )\n\n        base_csm = self.model.base_model.model  # CsmForConditionalGeneration\n\n        # Save original forward (the @can_return_tuple wrapped version)\n        _original_forward = CsmForConditionalGeneration.forward\n\n        # Keys that the depth decoder and its sub-layers actually understand\n        _TRANSFORMERS_KWARGS = {\n            \"num_items_in_batch\",\n            \"output_hidden_states\",\n            \"output_attentions\",\n            \"output_router_logits\",\n            \"cu_seq_lens_q\",\n            \"cu_seq_lens_k\",\n            \"max_length_q\",\n            \"max_length_k\",\n        }\n\n        def _fixed_csm_forward(\n            self,\n            input_ids = None,\n            input_values = None,\n            attention_mask = None,\n            input_values_cutoffs = None,\n            position_ids = None,\n            past_key_values = None,\n            inputs_embeds = None,\n            labels = None,\n            use_cache = None,\n            cache_position = None,\n            logits_to_keep = 0,\n            **kwargs,\n        ):\n            # Strip non-standard kwargs injected by Unsloth/PEFT (causal_mask,\n            # num_logits_to_keep, task_ids, return_dict, etc.)\n            output_attentions = kwargs.pop(\"output_attentions\", None)\n            output_hidden_states = kwargs.pop(\"output_hidden_states\", None)\n            kwargs.pop(\"return_dict\", None)\n            kwargs.pop(\"causal_mask\", None)\n            kwargs.pop(\"num_logits_to_keep\", None)\n            kwargs.pop(\"task_ids\", None)\n\n            # Only keep recognized TransformersKwargs\n            clean_kwargs = {\n                k: v for k, v in kwargs.items() if k in _TRANSFORMERS_KWARGS\n            }\n\n            if input_ids is not None and input_ids.ndim == 2:\n                merged = self._merge_input_ids_with_input_values(\n                    input_ids, input_values, input_values_cutoffs, labels\n                )\n                inputs_embeds = merged[\"inputs_embeds\"]\n                labels = merged[\"labels\"]\n                input_ids = None\n\n            backbone_outputs = self.backbone_model(\n                input_ids = input_ids,\n                attention_mask = attention_mask,\n                position_ids = position_ids,\n                past_key_values = past_key_values,\n                inputs_embeds = inputs_embeds,\n                use_cache = use_cache,\n                cache_position = cache_position,\n                output_attentions = output_attentions,\n                output_hidden_states = output_hidden_states,\n                **clean_kwargs,\n            )\n\n            backbone_hidden_states = backbone_outputs[0]\n            slice_indices = (\n                slice(-logits_to_keep, None)\n                if isinstance(logits_to_keep, int)\n                else logits_to_keep\n            )\n            backbone_logits = self.lm_head(backbone_hidden_states[:, slice_indices, :])\n\n            loss = None\n            backbone_loss = None\n            depth_decoder_loss = None\n            depth_decoder_outputs = None\n            if labels is not None:\n                backbone_labels = labels[:, :, 0]\n                backbone_loss = self.loss_function(\n                    logits = backbone_logits,\n                    labels = backbone_labels,\n                    vocab_size = self.config.vocab_size,\n                    **clean_kwargs,\n                )\n\n                train_mask = ~(labels[:, :, 1:] == -100).all(dim = -1)\n                depth_decoder_input_ids = labels[train_mask][\n                    ..., : self.config.num_codebooks - 1\n                ]\n                depth_decoder_input_ids = nn.functional.pad(\n                    depth_decoder_input_ids, (1, 0), value = 0\n                )\n\n                train_idxs = train_mask.nonzero(as_tuple = True)\n                backbone_last_hidden_states = backbone_hidden_states[\n                    train_idxs[0], train_idxs[1] - 1, :\n                ]\n                depth_decoder_labels = labels[train_mask]\n\n                # Build clean kwargs for depth decoder\n                dd_kwargs = clean_kwargs.copy()\n                # Scale num_items_in_batch for depth decoder (31 codebooks)\n                if \"num_items_in_batch\" in dd_kwargs:\n                    dd_kwargs[\"num_items_in_batch\"] = dd_kwargs[\n                        \"num_items_in_batch\"\n                    ] * (self.config.num_codebooks - 1)\n\n                depth_decoder_outputs = self.depth_decoder(\n                    input_ids = depth_decoder_input_ids,\n                    backbone_last_hidden_state = backbone_last_hidden_states,\n                    use_cache = False,\n                    return_dict = True,\n                    labels = depth_decoder_labels,\n                    output_attentions = output_attentions,\n                    output_hidden_states = output_hidden_states,\n                    **dd_kwargs,\n                )\n\n                depth_decoder_loss = depth_decoder_outputs.loss\n                if depth_decoder_loss is None:\n                    logger.warning(\n                        \"CSM depth_decoder_loss is None! \"\n                        f\"labels shape={depth_decoder_labels.shape}, \"\n                        f\"train_mask sum={train_mask.sum().item()}\"\n                    )\n                    # Fallback: use only backbone loss instead of crashing\n                    loss = backbone_loss\n                else:\n                    loss = backbone_loss + depth_decoder_loss\n\n            return CsmOutputWithPast(\n                loss = loss,\n                backbone_loss = backbone_loss,\n                depth_decoder_loss = depth_decoder_loss,\n                logits = backbone_logits,\n                past_key_values = backbone_outputs.past_key_values,\n                hidden_states = backbone_outputs.hidden_states,\n                attentions = backbone_outputs.attentions,\n                depth_decoder_logits = (\n                    depth_decoder_outputs.logits if depth_decoder_outputs else None\n                ),\n                depth_decoder_past_key_values = (\n                    depth_decoder_outputs.past_key_values\n                    if depth_decoder_outputs\n                    else None\n                ),\n                depth_decoder_hidden_states = (\n                    depth_decoder_outputs.hidden_states\n                    if depth_decoder_outputs\n                    else None\n                ),\n                depth_decoder_attentions = (\n                    depth_decoder_outputs.attentions if depth_decoder_outputs else None\n                ),\n            )\n\n        # Patch at BOTH instance and class level for maximum reliability.\n        # Instance-level: catches calls via BaseTuner.forward -> self.model.forward()\n        base_csm.forward = types.MethodType(_fixed_csm_forward, base_csm)\n        # Class-level: catches any path that resolves through the class dict\n        CsmForConditionalGeneration.forward = _fixed_csm_forward\n        logger.info(\"Applied CSM forward fix (class + instance level)\\n\")\n\n    def _preprocess_csm_dataset(self, dataset, custom_format_mapping = None):\n        \"\"\"Preprocess dataset for CSM TTS training (exact notebook copy).\"\"\"\n        from transformers import AutoProcessor\n        from datasets import Audio\n        import torch\n\n        processor = AutoProcessor.from_pretrained(\n            self.model_name,\n            trust_remote_code = getattr(self, \"trust_remote_code\", False),\n        )\n\n        # Strip pad_to_multiple_of from tokenizer init_kwargs — fine-tuned models\n        # (e.g. keanteng/sesame-csm-elise) save it in tokenizer_config.json, and\n        # _merge_kwargs leaks it into audio_kwargs where EncodecFeatureExtractor rejects it.\n        processor.tokenizer.init_kwargs.pop(\"pad_to_multiple_of\", None)\n\n        # Resolve columns from user mapping or hardcoded fallback\n        resolved = self._resolve_audio_columns(dataset, custom_format_mapping)\n        audio_col = resolved[\"audio_col\"]\n        text_col = resolved[\"text_col\"]\n        speaker_key = resolved[\"speaker_col\"]\n\n        if audio_col is None:\n            raise ValueError(\n                f\"No audio column found in dataset. Columns: {dataset.column_names}\"\n            )\n        if text_col is None:\n            raise ValueError(\n                f\"No text column found in dataset. Columns: {dataset.column_names}\"\n            )\n        if speaker_key is None:\n            logger.info(\n                \"No speaker found, adding default 'source' of 0 for all examples\\n\"\n            )\n            dataset = dataset.add_column(\"source\", [\"0\"] * len(dataset))\n            speaker_key = \"source\"\n\n        logger.info(\n            f\"CSM preprocessing: audio_col='{audio_col}', text_col='{text_col}', speaker_key='{speaker_key}'\\n\"\n        )\n\n        dataset = dataset.cast_column(audio_col, Audio(sampling_rate = 24000))\n\n        required_keys = [\n            \"input_ids\",\n            \"attention_mask\",\n            \"labels\",\n            \"input_values\",\n            \"input_values_cutoffs\",\n        ]\n\n        self._update_progress(status_message = \"Preprocessing CSM dataset...\")\n        processed_examples = []\n        skipped = 0\n        for idx in range(len(dataset)):\n            if self.should_stop:\n                logger.info(\"Stopped during CSM preprocessing\\n\")\n                break\n\n            example = dataset[idx]\n            try:\n                conversation = [\n                    {\n                        \"role\": str(example[speaker_key]),\n                        \"content\": [\n                            {\"type\": \"text\", \"text\": example.get(text_col, \"\")},\n                            {\"type\": \"audio\", \"path\": example[audio_col][\"array\"]},\n                        ],\n                    }\n                ]\n                # NOTE: pad_to_multiple_of intentionally omitted from text_kwargs —\n                # CsmProcessor._merge_kwargs leaks it to EncodecFeatureExtractor which rejects it.\n                model_inputs = processor.apply_chat_template(\n                    conversation,\n                    tokenize = True,\n                    return_dict = True,\n                    output_labels = True,\n                    text_kwargs = {\n                        \"padding\": \"max_length\",\n                        \"max_length\": 256,\n                        \"padding_side\": \"right\",\n                    },\n                    audio_kwargs = {\n                        \"sampling_rate\": 24_000,\n                        \"max_length\": 240001,\n                        \"padding\": \"max_length\",\n                    },\n                    common_kwargs = {\"return_tensors\": \"pt\"},\n                )\n\n                out = {}\n                for k in required_keys:\n                    if k not in model_inputs:\n                        raise KeyError(f\"Missing required key '{k}' in model outputs\")\n                    out[k] = model_inputs[k][0]\n\n                if not all(isinstance(out[k], torch.Tensor) for k in out):\n                    skipped += 1\n                    continue\n\n                processed_examples.append(out)\n\n            except Exception as e:\n                logger.warning(f\"Error processing CSM example {idx}: {e}\")\n                skipped += 1\n                continue\n\n            if (idx + 1) % 100 == 0:\n                self._update_progress(\n                    status_message = f\"Preprocessing CSM... {idx + 1}/{len(dataset)}\"\n                )\n\n        if not processed_examples:\n            raise ValueError(\n                f\"No valid examples after CSM preprocessing (skipped {skipped})\"\n            )\n\n        result_dataset = Dataset.from_list(processed_examples)\n        logger.info(\n            f\"CSM preprocessing complete: {len(result_dataset)} examples \"\n            f\"({skipped} skipped)\\n\"\n        )\n        return result_dataset\n\n    def _format_audio_vlm_dataset(self, dataset, custom_format_mapping = None):\n        \"\"\"Format dataset as audio chat messages for multimodal models (e.g. Gemma 3N).\n\n        Expects columns: audio (Audio), text (str).\n        Produces: messages column with system/user/assistant chat format.\n        \"\"\"\n        from datasets import Audio\n\n        resolved = self._resolve_audio_columns(dataset, custom_format_mapping)\n        audio_col = resolved[\"audio_col\"]\n        text_col = resolved[\"text_col\"]\n        if not audio_col or not text_col:\n            raise ValueError(\n                f\"Audio VLM dataset needs 'audio' and 'text' columns, got: {dataset.column_names}\"\n            )\n\n        # Store resolved audio column name for the collator closure\n        self._audio_vlm_audio_col = audio_col\n\n        # Cast audio to 16kHz (standard for speech models)\n        dataset = dataset.cast_column(audio_col, Audio(sampling_rate = 16000))\n\n        def format_messages(samples):\n            formatted = {\"messages\": []}\n            for idx in range(len(samples[audio_col])):\n                audio = samples[audio_col][idx][\"array\"]\n                label = str(samples[text_col][idx])\n                message = [\n                    {\n                        \"role\": \"system\",\n                        \"content\": [\n                            {\n                                \"type\": \"text\",\n                                \"text\": \"You are an assistant that transcribes speech accurately.\",\n                            }\n                        ],\n                    },\n                    {\n                        \"role\": \"user\",\n                        \"content\": [\n                            {\"type\": \"audio\", \"audio\": audio},\n                            {\"type\": \"text\", \"text\": \"Please transcribe this audio.\"},\n                        ],\n                    },\n                    {\"role\": \"assistant\", \"content\": [{\"type\": \"text\", \"text\": label}]},\n                ]\n                formatted[\"messages\"].append(message)\n            return formatted\n\n        self._update_progress(status_message = \"Formatting audio VLM dataset...\")\n        dataset = dataset.map(\n            format_messages, batched = True, batch_size = 4, num_proc = safe_num_proc(4)\n        )\n        logger.info(f\"Audio VLM dataset formatted: {len(dataset)} examples\\n\")\n        return dataset\n\n    def _preprocess_snac_dataset(self, dataset, custom_format_mapping = None):\n        \"\"\"Preprocess dataset for Orpheus TTS training with SNAC codec.\n\n        Mirrors Orpheus_(3B)-TTS.ipynb: encode audio with SNAC (24kHz, 3 hierarchical\n        layers), interleave 7 codes per frame, wrap with Orpheus special tokens,\n        train on full sequence (no label masking).\n        \"\"\"\n        import torch\n        import torchaudio.transforms as T\n\n        SNAC_MODEL_NAME = \"hubertsiuzdak/snac_24khz\"\n        SNAC_SAMPLE_RATE = 24000\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n        max_length = self.max_seq_length or 2048\n        tokenizer = self.tokenizer\n\n        # Orpheus special token IDs (hardcoded in tokenizer vocabulary)\n        START_OF_HUMAN = 128259\n        END_OF_HUMAN = 128260\n        START_OF_AI = 128261\n        END_OF_AI = 128262\n        START_OF_SPEECH = 128257\n        END_OF_SPEECH = 128258\n        END_OF_TEXT = 128009\n        AUDIO_OFFSET = 128266\n\n        resolved = self._resolve_audio_columns(dataset, custom_format_mapping)\n        audio_col = resolved[\"audio_col\"]\n        text_col = resolved[\"text_col\"]\n        speaker_col = resolved[\"speaker_col\"]\n        has_source = speaker_col is not None\n        if not audio_col or not text_col:\n            raise ValueError(\n                f\"SNAC dataset needs 'audio' and 'text' columns, got: {dataset.column_names}\"\n            )\n\n        # Cast audio column so datasets 4.x AudioDecoder objects are decoded to dicts\n        from datasets import Audio\n\n        dataset = dataset.cast_column(audio_col, Audio(sampling_rate = SNAC_SAMPLE_RATE))\n\n        # Get dataset sample rate from first example (after cast, always SNAC_SAMPLE_RATE)\n        first_audio = dataset[0][audio_col]\n        ds_sample_rate = (\n            first_audio.get(\"sampling_rate\", SNAC_SAMPLE_RATE)\n            if isinstance(first_audio, dict)\n            else SNAC_SAMPLE_RATE\n        )\n\n        # Load SNAC codec model\n        self._update_progress(status_message = \"Loading SNAC codec model...\")\n        logger.info(\"Loading SNAC codec model...\\n\")\n        from snac import SNAC\n\n        snac_model = SNAC.from_pretrained(SNAC_MODEL_NAME)\n        snac_model = snac_model.to(device).eval()\n\n        # Resample transform (created once)\n        resample_transform = (\n            T.Resample(orig_freq = ds_sample_rate, new_freq = SNAC_SAMPLE_RATE)\n            if ds_sample_rate != SNAC_SAMPLE_RATE\n            else None\n        )\n\n        self._update_progress(status_message = \"Encoding audio with SNAC...\")\n        logger.info(\n            f\"SNAC preprocessing: audio_col='{audio_col}', text_col='{text_col}', \"\n            f\"has_source={has_source}, ds_sample_rate={ds_sample_rate}\\n\"\n        )\n\n        processed_examples = []\n        skipped = 0\n        for idx in range(len(dataset)):\n            if self.should_stop:\n                logger.info(\"Stopped during SNAC preprocessing\\n\")\n                break\n\n            example = dataset[idx]\n            try:\n                text = example.get(text_col)\n                if not text:\n                    skipped += 1\n                    continue\n\n                audio_data = example.get(audio_col)\n                if audio_data is None or audio_data.get(\"array\") is None:\n                    skipped += 1\n                    continue\n\n                # --- Encode audio with SNAC (notebook lines 122-142) ---\n                waveform = (\n                    torch.from_numpy(audio_data[\"array\"])\n                    .unsqueeze(0)\n                    .to(dtype = torch.float32)\n                )\n                if resample_transform is not None:\n                    waveform = resample_transform(waveform)\n\n                waveform = waveform.unsqueeze(0).to(device)\n                with torch.inference_mode():\n                    codes = snac_model.encode(waveform)\n\n                # Interleave 7 codes per frame with layer offsets (notebook lines 134-142)\n                all_codes = []\n                for i in range(codes[0].shape[1]):\n                    all_codes.append(codes[0][0][i].item() + AUDIO_OFFSET)\n                    all_codes.append(codes[1][0][2 * i].item() + AUDIO_OFFSET + 4096)\n                    all_codes.append(\n                        codes[2][0][4 * i].item() + AUDIO_OFFSET + (2 * 4096)\n                    )\n                    all_codes.append(\n                        codes[2][0][(4 * i) + 1].item() + AUDIO_OFFSET + (3 * 4096)\n                    )\n                    all_codes.append(\n                        codes[1][0][(2 * i) + 1].item() + AUDIO_OFFSET + (4 * 4096)\n                    )\n                    all_codes.append(\n                        codes[2][0][(4 * i) + 2].item() + AUDIO_OFFSET + (5 * 4096)\n                    )\n                    all_codes.append(\n                        codes[2][0][(4 * i) + 3].item() + AUDIO_OFFSET + (6 * 4096)\n                    )\n\n                if len(all_codes) == 0:\n                    skipped += 1\n                    continue\n\n                # Deduplicate consecutive frames with same first code (notebook lines 185-207)\n                deduped = all_codes[:7]\n                for i in range(7, len(all_codes), 7):\n                    if all_codes[i] != deduped[-7]:\n                        deduped.extend(all_codes[i : i + 7])\n                all_codes = deduped\n\n                # --- Build text tokens (notebook lines 217-224) ---\n                text_prompt = (\n                    f\"{example[speaker_col]}: {text}\"\n                    if has_source and example.get(speaker_col)\n                    else text\n                )\n                text_ids = tokenizer.encode(text_prompt, add_special_tokens = True)\n                text_ids.append(END_OF_TEXT)\n\n                # --- Build full input_ids (notebook lines 225-234) ---\n                input_ids = (\n                    [START_OF_HUMAN]\n                    + text_ids\n                    + [END_OF_HUMAN]\n                    + [START_OF_AI]\n                    + [START_OF_SPEECH]\n                    + all_codes\n                    + [END_OF_SPEECH]\n                    + [END_OF_AI]\n                )\n\n                # Truncate to max_length\n                input_ids = input_ids[:max_length]\n\n                # Labels = input_ids (no masking — Orpheus trains on full sequence)\n                labels = list(input_ids)\n                attention_mask = [1] * len(input_ids)\n\n                processed_examples.append(\n                    {\n                        \"input_ids\": input_ids,\n                        \"labels\": labels,\n                        \"attention_mask\": attention_mask,\n                    }\n                )\n\n            except Exception as e:\n                logger.warning(f\"Error processing SNAC example {idx}: {e}\")\n                skipped += 1\n                continue\n\n            # Progress update every 100 examples\n            if (idx + 1) % 100 == 0:\n                self._update_progress(\n                    status_message = f\"Encoding audio... {idx + 1}/{len(dataset)}\"\n                )\n\n        # Free SNAC model from GPU\n        logger.info(\"Freeing SNAC codec model from GPU...\\n\")\n        snac_model.to(\"cpu\")\n        del snac_model\n        import gc\n\n        gc.collect()\n        torch.cuda.empty_cache()\n        self._cuda_audio_used = True\n\n        if not processed_examples:\n            raise ValueError(\n                f\"No valid examples after SNAC preprocessing (skipped {skipped})\"\n            )\n\n        result_dataset = Dataset.from_list(processed_examples)\n        logger.info(\n            f\"SNAC preprocessing complete: {len(result_dataset)} examples \"\n            f\"({skipped} skipped)\\n\"\n        )\n        return result_dataset\n\n    def _preprocess_bicodec_dataset(self, dataset, custom_format_mapping = None):\n        \"\"\"Preprocess dataset for Spark-TTS training with BiCodec tokenizer.\n\n        Mirrors Spark_TTS_(0_5B).ipynb: encode audio with BiCodec (semantic + global tokens),\n        format as special-token text strings for SFTTrainer with dataset_text_field=\"text\".\n        \"\"\"\n        import sys\n        import torch\n        import numpy as np\n        import torchaudio.transforms as T\n\n        import subprocess\n\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n        # The sparktts Python package lives in the SparkAudio/Spark-TTS GitHub repo,\n        # NOT in the unsloth/Spark-TTS-0.5B HF model repo. Clone it if needed.\n        spark_code_dir = os.path.join(\n            os.path.dirname(self._spark_tts_repo_dir), \"Spark-TTS\"\n        )\n        sparktts_pkg = os.path.join(spark_code_dir, \"sparktts\")\n        if not os.path.isdir(sparktts_pkg):\n            self._update_progress(status_message = \"Cloning Spark-TTS code repo...\")\n            logger.info(f\"Cloning SparkAudio/Spark-TTS to {spark_code_dir}...\\n\")\n            subprocess.run(\n                [\n                    \"git\",\n                    \"clone\",\n                    \"--depth\",\n                    \"1\",\n                    \"https://github.com/SparkAudio/Spark-TTS\",\n                    spark_code_dir,\n                ],\n                check = True,\n            )\n\n        if spark_code_dir not in sys.path:\n            sys.path.insert(0, spark_code_dir)\n\n        from sparktts.models.audio_tokenizer import BiCodecTokenizer\n        from sparktts.utils.audio import audio_volume_normalize\n\n        # Resolve audio and text columns\n        resolved = self._resolve_audio_columns(dataset, custom_format_mapping)\n        audio_col = resolved[\"audio_col\"]\n        text_col = resolved[\"text_col\"]\n        speaker_col = resolved[\"speaker_col\"]\n        has_source = speaker_col is not None\n        if not audio_col or not text_col:\n            raise ValueError(\n                f\"BiCodec dataset needs 'audio' and 'text' columns, got: {dataset.column_names}\"\n            )\n\n        # Cast audio column so datasets 4.x AudioDecoder objects are decoded to dicts.\n        # Don't resample here — BiCodec's target_sr may differ; the loop handles resampling.\n        from datasets import Audio\n\n        dataset = dataset.cast_column(audio_col, Audio())\n\n        # Load BiCodec tokenizer\n        self._update_progress(status_message = \"Loading BiCodec tokenizer...\")\n        logger.info(\"Loading BiCodec tokenizer...\\n\")\n        audio_tokenizer = BiCodecTokenizer(self._spark_tts_repo_dir, device)\n\n        target_sr = audio_tokenizer.config[\"sample_rate\"]\n\n        self._update_progress(status_message = \"Encoding audio with BiCodec...\")\n        logger.info(\n            f\"BiCodec preprocessing: audio_col='{audio_col}', text_col='{text_col}', \"\n            f\"has_source={has_source}, target_sr={target_sr}\\n\"\n        )\n\n        def extract_wav2vec2_features(wavs: torch.Tensor) -> torch.Tensor:\n            \"\"\"Extract wav2vec2 features (average of layers 11, 14, 16).\"\"\"\n            if wavs.shape[0] != 1:\n                raise ValueError(f\"Expected batch size 1, but got shape {wavs.shape}\")\n            wav_np = wavs.squeeze(0).cpu().numpy()\n\n            processed = audio_tokenizer.processor(\n                wav_np,\n                sampling_rate = 16000,\n                return_tensors = \"pt\",\n                padding = True,\n            )\n            input_values = processed.input_values.to(\n                audio_tokenizer.feature_extractor.device\n            )\n            model_output = audio_tokenizer.feature_extractor(input_values)\n\n            if model_output.hidden_states is None:\n                raise ValueError(\"Wav2Vec2Model did not return hidden states.\")\n\n            feats_mix = (\n                model_output.hidden_states[11]\n                + model_output.hidden_states[14]\n                + model_output.hidden_states[16]\n            ) / 3\n            return feats_mix\n\n        processed_examples = []\n        skipped = 0\n        for idx in range(len(dataset)):\n            if self.should_stop:\n                logger.info(\"Stopped during BiCodec preprocessing\\n\")\n                break\n\n            example = dataset[idx]\n            try:\n                text = example.get(text_col)\n                if not text:\n                    skipped += 1\n                    continue\n\n                audio_data = example.get(audio_col)\n                if audio_data is None or audio_data.get(\"array\") is None:\n                    skipped += 1\n                    continue\n\n                audio_array = audio_data[\"array\"]\n                sampling_rate = audio_data.get(\"sampling_rate\", target_sr)\n\n                # Resample if needed\n                if sampling_rate != target_sr:\n                    resampler = T.Resample(orig_freq = sampling_rate, new_freq = target_sr)\n                    audio_tensor_temp = torch.from_numpy(audio_array).float()\n                    audio_array = resampler(audio_tensor_temp).numpy()\n\n                # Volume normalize if configured\n                if audio_tokenizer.config.get(\"volume_normalize\", False):\n                    audio_array = audio_volume_normalize(audio_array)\n\n                # Get reference clip\n                ref_wav_np = audio_tokenizer.get_ref_clip(audio_array)\n\n                # Prepare tensors\n                audio_tensor = (\n                    torch.from_numpy(audio_array).unsqueeze(0).float().to(device)\n                )\n                ref_wav_tensor = (\n                    torch.from_numpy(ref_wav_np).unsqueeze(0).float().to(device)\n                )\n\n                # Extract wav2vec2 features\n                feat = extract_wav2vec2_features(audio_tensor)\n\n                batch = {\n                    \"wav\": audio_tensor,\n                    \"ref_wav\": ref_wav_tensor,\n                    \"feat\": feat.to(device),\n                }\n\n                # BiCodec tokenize\n                semantic_token_ids, global_token_ids = audio_tokenizer.model.tokenize(\n                    batch\n                )\n\n                global_tokens = \"\".join(\n                    [\n                        f\"<|bicodec_global_{i}|>\"\n                        for i in global_token_ids.squeeze().cpu().numpy()\n                    ]\n                )\n                semantic_tokens = \"\".join(\n                    [\n                        f\"<|bicodec_semantic_{i}|>\"\n                        for i in semantic_token_ids.squeeze().cpu().numpy()\n                    ]\n                )\n\n                # Format text with source prefix if available\n                text_content = (\n                    f\"{example[speaker_col]}: {text}\"\n                    if has_source and example.get(speaker_col)\n                    else text\n                )\n\n                formatted = \"\".join(\n                    [\n                        \"<|task_tts|>\",\n                        \"<|start_content|>\",\n                        text_content,\n                        \"<|end_content|>\",\n                        \"<|start_global_token|>\",\n                        global_tokens,\n                        \"<|end_global_token|>\",\n                        \"<|start_semantic_token|>\",\n                        semantic_tokens,\n                        \"<|end_semantic_token|>\",\n                        \"<|im_end|>\",\n                    ]\n                )\n\n                processed_examples.append({\"text\": formatted})\n\n            except Exception as e:\n                logger.warning(f\"Error processing BiCodec example {idx}: {e}\")\n                skipped += 1\n                continue\n\n            # Progress update every 100 examples\n            if (idx + 1) % 100 == 0:\n                self._update_progress(\n                    status_message = f\"Encoding audio with BiCodec... {idx + 1}/{len(dataset)}\"\n                )\n\n        # Free BiCodec model from GPU\n        logger.info(\"Freeing BiCodec tokenizer from GPU...\\n\")\n        audio_tokenizer.model.cpu()\n        audio_tokenizer.feature_extractor.cpu()\n        del audio_tokenizer\n        import gc\n\n        gc.collect()\n        torch.cuda.empty_cache()\n        self._cuda_audio_used = True\n\n        if not processed_examples:\n            raise ValueError(\n                f\"No valid examples after BiCodec preprocessing (skipped {skipped})\"\n            )\n\n        result_dataset = Dataset.from_list(processed_examples)\n        logger.info(\n            f\"BiCodec preprocessing complete: {len(result_dataset)} examples \"\n            f\"({skipped} skipped)\\n\"\n        )\n        # Debug: show first example text (truncated)\n        sample = result_dataset[0][\"text\"]\n        logger.info(f\"Sample text (first 200 chars): {sample[:200]}...\\n\")\n        logger.info(f\"Sample text length: {len(sample)} chars\\n\")\n        return result_dataset\n\n    def _preprocess_dac_dataset(self, dataset, custom_format_mapping = None):\n        \"\"\"Preprocess dataset for OuteTTS training with DAC codec.\n\n        Mirrors Oute_TTS_(1B).ipynb DataCreationV3: uses Whisper for word timings,\n        OuteTTS AudioProcessor for speaker representations, PromptProcessor for\n        training prompts. Outputs text strings for SFTTrainer with dataset_text_field=\"text\".\n        \"\"\"\n        import sys\n        import io\n        import tempfile\n        import torch\n        import numpy as np\n        import soundfile as sf\n        from datasets import Dataset as HFDataset\n        from utils.paths import ensure_dir, tmp_root\n\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n        # Clone OuteTTS repo (same as audio_codecs._load_dac)\n        import subprocess\n\n        base_dir = os.path.dirname(os.path.abspath(__file__))\n        outetts_code_dir = os.path.join(base_dir, \"inference\", \"OuteTTS\")\n        outetts_pkg = os.path.join(outetts_code_dir, \"outetts\")\n        if not os.path.isdir(outetts_pkg):\n            self._update_progress(status_message = \"Cloning OuteTTS code repo...\")\n            logger.info(f\"Cloning edwko/OuteTTS to {outetts_code_dir}...\\n\")\n            subprocess.run(\n                [\n                    \"git\",\n                    \"clone\",\n                    \"--depth\",\n                    \"1\",\n                    \"https://github.com/edwko/OuteTTS\",\n                    outetts_code_dir,\n                ],\n                check = True,\n            )\n            for fpath in [\n                os.path.join(outetts_pkg, \"models\", \"gguf_model.py\"),\n                os.path.join(outetts_pkg, \"interface.py\"),\n                os.path.join(outetts_pkg, \"__init__.py\"),\n            ]:\n                if os.path.exists(fpath):\n                    os.remove(fpath)\n                    logger.info(f\"Removed {fpath}\\n\")\n\n        if outetts_code_dir not in sys.path:\n            sys.path.insert(0, outetts_code_dir)\n\n        from outetts.version.v3.audio_processor import AudioProcessor\n        from outetts.version.v3.prompt_processor import PromptProcessor\n        from outetts.models.config import ModelConfig as OuteTTSModelConfig\n        from outetts.utils.preprocessing import text_normalizations\n\n        # Resolve audio and text columns\n        resolved = self._resolve_audio_columns(dataset, custom_format_mapping)\n        audio_col = resolved[\"audio_col\"]\n        text_col = resolved[\"text_col\"]\n        if not audio_col or not text_col:\n            raise ValueError(\n                f\"DAC dataset needs 'audio' and 'text' columns, got: {dataset.column_names}\"\n            )\n\n        # Cast audio to 24kHz (notebook: dataset.cast_column(\"audio\", Audio(sampling_rate=24000)))\n        from datasets import Audio\n\n        dataset = dataset.cast_column(audio_col, Audio(sampling_rate = 24000))\n        logger.info(\"Cast audio column to 24kHz\\n\")\n\n        # Load Whisper for word timings\n        self._update_progress(\n            status_message = \"Loading Whisper model for word timings...\"\n        )\n        logger.info(\"Loading Whisper model for word timings...\\n\")\n        import whisper\n\n        whisper_model = whisper.load_model(\"turbo\", device = device)\n\n        # Load OuteTTS AudioProcessor + PromptProcessor\n        self._update_progress(status_message = \"Loading OuteTTS AudioProcessor...\")\n        logger.info(\"Loading OuteTTS AudioProcessor...\\n\")\n        model_tokenizer_path = \"OuteAI/Llama-OuteTTS-1.0-1B\"\n        dummy_config = OuteTTSModelConfig(\n            tokenizer_path = model_tokenizer_path,\n            device = device,\n            audio_codec_path = None,\n        )\n        audio_processor = AudioProcessor(config = dummy_config)\n        prompt_processor = PromptProcessor(model_tokenizer_path)\n\n        self._update_progress(status_message = \"Preprocessing audio with OuteTTS...\")\n        logger.info(\n            f\"DAC preprocessing: audio_col='{audio_col}', text_col='{text_col}'\\n\"\n        )\n\n        processed_examples = []\n        skipped = 0\n        for idx in range(len(dataset)):\n            if self.should_stop:\n                logger.info(\"Stopped during DAC preprocessing\\n\")\n                break\n\n            example = dataset[idx]\n            try:\n                text = example.get(text_col)\n                if not text or not isinstance(text, str):\n                    skipped += 1\n                    continue\n\n                audio_data = example.get(audio_col)\n                if audio_data is None or audio_data.get(\"array\") is None:\n                    skipped += 1\n                    continue\n\n                audio_array = np.array(audio_data[\"array\"], dtype = np.float32)\n                sampling_rate = audio_data.get(\"sampling_rate\", 24000)\n\n                # Convert to WAV bytes (Whisper needs a file path)\n                buf = io.BytesIO()\n                sf.write(buf, audio_array, sampling_rate, format = \"WAV\", subtype = \"FLOAT\")\n                buf.seek(0)\n                audio_bytes = buf.getvalue()\n\n                # 1. Get word timings from Whisper\n                with tempfile.NamedTemporaryFile(\n                    suffix = \".wav\",\n                    delete = False,\n                    dir = str(ensure_dir(tmp_root())),\n                ) as tmp:\n                    tmp.write(audio_bytes)\n                    tmp.flush()\n                    tmp_path = tmp.name\n                try:\n                    whisper_result = whisper_model.transcribe(\n                        tmp_path, word_timestamps = True\n                    )\n                finally:\n                    Path(tmp_path).unlink(missing_ok = True)\n\n                normalized_transcript = text_normalizations(text)\n                words_with_timings = []\n                if whisper_result and \"segments\" in whisper_result:\n                    for segment in whisper_result[\"segments\"]:\n                        for word_info in segment.get(\"words\", []):\n                            cleaned = word_info[\"word\"].strip()\n                            if cleaned:\n                                words_with_timings.append(\n                                    {\n                                        \"word\": cleaned,\n                                        \"start\": float(word_info[\"start\"]),\n                                        \"end\": float(word_info[\"end\"]),\n                                    }\n                                )\n\n                if not words_with_timings:\n                    skipped += 1\n                    continue\n\n                # 2. Create speaker representation with AudioProcessor\n                speaker_data_dict = {\n                    \"audio\": {\"bytes\": audio_bytes},\n                    \"text\": normalized_transcript,\n                    \"words\": words_with_timings,\n                }\n                speaker = audio_processor.create_speaker_from_dict(speaker_data_dict)\n                if speaker is None:\n                    skipped += 1\n                    continue\n\n                # 3. Get training prompt from PromptProcessor\n                prompt = prompt_processor.get_training_prompt(speaker)\n                if prompt:\n                    processed_examples.append({\"text\": prompt})\n\n            except Exception as e:\n                logger.warning(f\"Error processing DAC example {idx}: {e}\")\n                skipped += 1\n                continue\n\n            if (idx + 1) % 100 == 0:\n                self._update_progress(\n                    status_message = f\"Preprocessing audio with OuteTTS... {idx + 1}/{len(dataset)}\"\n                )\n\n        # Free Whisper from GPU (notebook: data_processor.whisper_model.to('cpu'))\n        logger.info(\"Moving Whisper model to CPU...\\n\")\n        whisper_model.to(\"cpu\")\n        del whisper_model\n        del audio_processor\n        del prompt_processor\n        import gc\n\n        gc.collect()\n        torch.cuda.empty_cache()\n        self._cuda_audio_used = True\n\n        if not processed_examples:\n            raise ValueError(\n                f\"No valid examples after DAC preprocessing (skipped {skipped})\"\n            )\n\n        result_dataset = HFDataset.from_list(processed_examples)\n        logger.info(\n            f\"DAC preprocessing complete: {len(result_dataset)} examples \"\n            f\"({skipped} skipped)\\n\"\n        )\n        sample = result_dataset[0][\"text\"]\n        logger.info(f\"Sample text (first 200 chars): {sample[:200]}...\\n\")\n        return result_dataset\n\n    def _preprocess_whisper_dataset(\n        self, dataset, eval_split = None, custom_format_mapping = None\n    ):\n        \"\"\"Preprocess dataset for Whisper speech-to-text training.\n\n        Mirrors Whisper.ipynb: extract audio features with Whisper's feature\n        extractor, tokenize text labels. Returns (train_data, eval_data) where\n        each is a list of dicts with 'input_features' and 'labels'.\n        \"\"\"\n        from datasets import Audio\n\n        WHISPER_SAMPLE_RATE = 16000\n\n        resolved = self._resolve_audio_columns(dataset, custom_format_mapping)\n        audio_col = resolved[\"audio_col\"]\n        text_col = resolved[\"text_col\"]\n        if not audio_col or not text_col:\n            raise ValueError(\n                f\"Whisper dataset needs 'audio' and 'text' columns, got: {dataset.column_names}\"\n            )\n\n        # Cast audio to 16kHz (Whisper's expected sample rate)\n        dataset = dataset.cast_column(\n            audio_col, Audio(sampling_rate = WHISPER_SAMPLE_RATE)\n        )\n\n        # Train/eval split (notebook does dataset.train_test_split)\n        eval_dataset_raw = None\n        if eval_split:\n            splits = dataset.train_test_split(test_size = 0.06, seed = 42)\n            dataset = splits[\"train\"]\n            eval_dataset_raw = splits[\"test\"]\n\n        self._update_progress(status_message = \"Processing audio for Whisper...\")\n        logger.info(\n            f\"Whisper preprocessing: audio_col='{audio_col}', text_col='{text_col}', \"\n            f\"samples={len(dataset)}\\n\"\n        )\n\n        def process_split(ds, split_name = \"train\"):\n            processed = []\n            skipped = 0\n            for idx in range(len(ds)):\n                if self.should_stop:\n                    logger.info(f\"Stopped during Whisper {split_name} preprocessing\\n\")\n                    break\n\n                example = ds[idx]\n                try:\n                    audio_data = example.get(audio_col)\n                    text = example.get(text_col)\n                    if (\n                        audio_data is None\n                        or audio_data.get(\"array\") is None\n                        or not text\n                    ):\n                        skipped += 1\n                        continue\n\n                    # Extract audio features (notebook line 112-115)\n                    features = self.tokenizer.feature_extractor(\n                        audio_data[\"array\"], sampling_rate = audio_data[\"sampling_rate\"]\n                    )\n                    # Tokenize text (notebook line 116)\n                    tokenized_text = self.tokenizer.tokenizer(text)\n\n                    processed.append(\n                        {\n                            \"input_features\": features.input_features[0],\n                            \"labels\": tokenized_text.input_ids,\n                        }\n                    )\n                except Exception as e:\n                    logger.warning(\n                        f\"Error processing Whisper {split_name} example {idx}: {e}\"\n                    )\n                    skipped += 1\n                    continue\n\n                if (idx + 1) % 100 == 0:\n                    self._update_progress(\n                        status_message = f\"Processing {split_name} audio... {idx + 1}/{len(ds)}\"\n                    )\n\n            logger.info(\n                f\"Whisper {split_name} preprocessing: {len(processed)} examples ({skipped} skipped)\\n\"\n            )\n            return processed\n\n        train_data = process_split(dataset, \"train\")\n        eval_data = (\n            process_split(eval_dataset_raw, \"eval\") if eval_dataset_raw else None\n        )\n\n        if not train_data:\n            raise ValueError(\"No valid examples after Whisper preprocessing\")\n\n        return (train_data, eval_data)\n\n    @staticmethod\n    def _resolve_local_files(file_paths: list) -> list[str]:\n        \"\"\"Resolve a list of local dataset paths to concrete file paths.\"\"\"\n        all_files: list[str] = []\n        for dataset_file in file_paths:\n            if os.path.isabs(dataset_file):\n                file_path = dataset_file\n            else:\n                file_path = str(resolve_dataset_path(dataset_file))\n\n            file_path_obj = Path(file_path)\n\n            if file_path_obj.is_dir():\n                parquet_dir = (\n                    file_path_obj / \"parquet-files\"\n                    if (file_path_obj / \"parquet-files\").exists()\n                    else file_path_obj\n                )\n                parquet_files = sorted(parquet_dir.glob(\"*.parquet\"))\n                if parquet_files:\n                    all_files.extend(str(p) for p in parquet_files)\n                    continue\n                candidates: list[Path] = []\n                for ext in (\".json\", \".jsonl\", \".csv\", \".parquet\"):\n                    candidates.extend(sorted(file_path_obj.glob(f\"*{ext}\")))\n                if candidates:\n                    all_files.extend(str(c) for c in candidates)\n                    continue\n                raise ValueError(\n                    f\"No supported data files in directory: {file_path_obj}\"\n                )\n            else:\n                all_files.append(str(file_path_obj))\n        return all_files\n\n    @staticmethod\n    def _loader_for_files(files: list[str]) -> str:\n        \"\"\"Determine the HF datasets loader type from file extensions.\"\"\"\n        first_ext = Path(files[0]).suffix.lower()\n        if first_ext in (\".json\", \".jsonl\"):\n            return \"json\"\n        elif first_ext == \".csv\":\n            return \"csv\"\n        elif first_ext == \".parquet\":\n            return \"parquet\"\n        raise ValueError(f\"Unsupported dataset format: {files[0]}\")\n\n    def load_and_format_dataset(\n        self,\n        dataset_source: str,\n        format_type: str = \"auto\",\n        local_datasets: list = None,\n        local_eval_datasets: list = None,\n        custom_format_mapping: dict = None,\n        subset: str = None,\n        train_split: str = \"train\",\n        eval_split: str = None,\n        eval_steps: float = 0.00,\n        dataset_slice_start: int = None,\n        dataset_slice_end: int = None,\n    ) -> Optional[tuple]:\n        \"\"\"\n        Load and prepare dataset for training.\n\n        Strategy: format first, then split — ensures both train and eval\n        portions are properly formatted and templated.\n\n        Returns:\n            Tuple of (dataset_info, eval_dataset) or None on error.\n            eval_dataset may be None if no eval split is available.\n        \"\"\"\n        try:\n            dataset = None\n            eval_dataset = None\n            has_separate_eval_source = (\n                False  # True if eval comes from a separate HF split\n            )\n            eval_enabled = eval_steps is not None and eval_steps > 0\n\n            if local_datasets:\n                # Load local datasets using load_dataset() so the result is\n                # Arrow-backed (has cache files).  Dataset.from_list() creates\n                # an in-memory dataset with no cache, which forces num_proc=1\n                # during tokenization/map because sharding requires Arrow files.\n                all_files = self._resolve_local_files(local_datasets)\n\n                if all_files:\n                    loader = self._loader_for_files(all_files)\n                    dataset = load_dataset(loader, data_files = all_files, split = \"train\")\n\n                    # Check if stopped during dataset loading\n                    if self.should_stop:\n                        logger.info(\"Stopped during dataset loading\\n\")\n                        return None\n\n                    self._update_progress(\n                        status_message = f\"Loaded {len(dataset)} samples from local files\"\n                    )\n                    logger.info(f\"Loaded {len(dataset)} samples from local files\\n\")\n                    logger.info(f\"[DEBUG] Dataset cache_files: {dataset.cache_files}\\n\")\n\n                # Load local eval datasets if provided\n                if local_eval_datasets and eval_enabled:\n                    eval_all_files = self._resolve_local_files(local_eval_datasets)\n                    if eval_all_files:\n                        eval_loader = self._loader_for_files(eval_all_files)\n                        eval_dataset = load_dataset(\n                            eval_loader, data_files = eval_all_files, split = \"train\"\n                        )\n                        has_separate_eval_source = True\n                        logger.info(\n                            f\"Loaded {len(eval_dataset)} eval samples from local eval files\\n\"\n                        )\n\n            elif dataset_source:\n                # Load from Hugging Face\n                split_name = train_split or \"train\"\n                load_kwargs = {\"path\": dataset_source, \"split\": split_name}\n                if subset:\n                    load_kwargs[\"name\"] = subset\n\n                _slice_start = dataset_slice_start or 0\n                if (\n                    dataset_slice_end is not None\n                    and dataset_slice_end >= 0\n                    and dataset_slice_end >= _slice_start\n                ):\n                    # Manual slice — stream only the rows we need instead of\n                    # downloading the entire dataset.\n                    rows_to_stream = dataset_slice_end + 1\n                    logger.info(\n                        f\"[dataset-slice] Manual slice specified \"\n                        f\"(start={dataset_slice_start}, end={dataset_slice_end}), \"\n                        f\"streaming {rows_to_stream} rows\\n\"\n                    )\n                    stream = load_dataset(**load_kwargs, streaming = True)\n                    dataset = Dataset.from_list(list(stream.take(rows_to_stream)))\n                    logger.info(\n                        f\"[dataset-slice] Downloaded {len(dataset)} rows \"\n                        f\"(requested {rows_to_stream})\\n\"\n                    )\n                    self._update_progress(\n                        status_message = f\"Streamed {len(dataset)} rows from HuggingFace\"\n                    )\n                else:\n                    self._update_progress(\n                        status_message = f\"Downloading dataset: {dataset_source}...\"\n                    )\n                    dataset = load_dataset(**load_kwargs)\n\n                # Check if stopped during dataset loading\n                if self.should_stop:\n                    logger.info(\"Stopped during dataset loading\\n\")\n                    return None\n\n                n_rows = len(dataset) if hasattr(dataset, \"__len__\") else 0\n                self._update_progress(\n                    status_message = f\"Downloaded {dataset_source} ({n_rows:,} rows)\"\n                )\n                logger.info(\n                    f\"Loaded dataset from Hugging Face: {dataset_source} ({n_rows:,} rows)\\n\"\n                )\n\n                # Resolve eval split from a separate HF split (explicit or auto-detected)\n                if eval_enabled:\n                    effective_train = train_split or \"train\"\n                    if eval_split and eval_split != effective_train:\n                        # Explicit eval split provided - load it directly\n                        logger.info(f\"Loading explicit eval split: '{eval_split}'\\n\")\n                        eval_load_kwargs = {\"path\": dataset_source, \"split\": eval_split}\n                        if subset:\n                            eval_load_kwargs[\"name\"] = subset\n                        eval_dataset = load_dataset(**eval_load_kwargs)\n                        has_separate_eval_source = True\n                        logger.info(\n                            f\"Loaded eval split '{eval_split}' with {len(eval_dataset)} rows\\n\"\n                        )\n                    elif eval_split and eval_split == effective_train:\n                        # Same split as training — will do 80/20 split after formatting\n                        logger.info(\n                            f\"Eval split '{eval_split}' is the same as train split — will split 80/20\\n\"\n                        )\n                    else:\n                        # Auto-detect eval split from HF (returns a separate dataset, or None)\n                        eval_dataset = self._auto_detect_eval_split_from_hf(\n                            dataset_source = dataset_source,\n                            subset = subset,\n                        )\n                        if eval_dataset is not None:\n                            has_separate_eval_source = True\n                else:\n                    logger.info(\n                        \"Eval disabled (eval_steps <= 0), skipping eval split detection\\n\"\n                    )\n\n            if dataset is None:\n                raise ValueError(\"No dataset provided\")\n\n            # Apply index range slicing if requested (inclusive on both ends)\n            if dataset_slice_start is not None or dataset_slice_end is not None:\n                total_rows = len(dataset)\n                start = dataset_slice_start if dataset_slice_start is not None else 0\n                end = (\n                    dataset_slice_end\n                    if dataset_slice_end is not None\n                    else total_rows - 1\n                )\n                # Clamp to valid range\n                start = max(0, min(start, total_rows - 1))\n                end = max(start, min(end, total_rows - 1))\n                dataset = dataset.select(range(start, end + 1))\n                logger.info(\n                    f\"Sliced dataset to rows [{start}, {end}]: {len(dataset)} of {total_rows} rows\\n\"\n                )\n                self._update_progress(\n                    status_message = f\"Sliced dataset to {len(dataset)} rows (indices {start}-{end})\"\n                )\n\n            # Check if stopped before applying template\n            if self.should_stop:\n                logger.info(\"Stopped before applying chat template\\n\")\n                return None\n\n            # ========== AUDIO MODELS: custom preprocessing ==========\n            if self._audio_type == \"csm\":\n                processed = self._preprocess_csm_dataset(dataset, custom_format_mapping)\n                return (processed, None)\n\n            elif self._audio_type == \"whisper\":\n                train_data, eval_data = self._preprocess_whisper_dataset(\n                    dataset,\n                    eval_split = eval_split,\n                    custom_format_mapping = custom_format_mapping,\n                )\n                return (train_data, eval_data)\n\n            elif self._audio_type == \"snac\":\n                processed = self._preprocess_snac_dataset(\n                    dataset, custom_format_mapping\n                )\n                return (processed, None)\n\n            elif self._audio_type == \"bicodec\":\n                processed = self._preprocess_bicodec_dataset(\n                    dataset, custom_format_mapping\n                )\n                return ({\"dataset\": processed, \"final_format\": \"audio_bicodec\"}, None)\n\n            elif self._audio_type == \"dac\":\n                processed = self._preprocess_dac_dataset(dataset, custom_format_mapping)\n                return ({\"dataset\": processed, \"final_format\": \"audio_dac\"}, None)\n\n            elif self.is_audio_vlm:\n                formatted = self._format_audio_vlm_dataset(\n                    dataset, custom_format_mapping\n                )\n                return (formatted, None)\n\n            # ========== FORMAT FIRST ==========\n            logger.info(f\"Formatting dataset with format_type='{format_type}'...\\n\")\n\n            dataset_info = format_and_template_dataset(\n                dataset,\n                model_name = self.model_name,\n                tokenizer = self.tokenizer,\n                is_vlm = self.is_vlm,\n                format_type = format_type,\n                dataset_name = dataset_source,\n                custom_format_mapping = custom_format_mapping,\n                progress_callback = self._update_progress,\n            )\n\n            # Check if stopped during formatting\n            if self.should_stop:\n                logger.info(\"Stopped during dataset formatting\\n\")\n                return None\n\n            # Abort if dataset formatting/conversion failed\n            if not dataset_info.get(\"success\", True):\n                errors = dataset_info.get(\"errors\", [])\n                error_msg = \"; \".join(errors) if errors else \"Dataset formatting failed\"\n                logger.error(f\"Dataset conversion failed: {error_msg}\")\n                self._update_progress(error = error_msg)\n                return None\n\n            detected = dataset_info.get(\"detected_format\", \"unknown\")\n            final_ds = dataset_info.get(\"dataset\")\n            final_n = len(final_ds) if hasattr(final_ds, \"__len__\") else \"?\"\n            self._update_progress(\n                status_message = f\"Dataset ready ({final_n:,} samples, {detected} format)\"\n            )\n            logger.info(\n                f\"Dataset formatted successfully ({final_n} samples, {detected})\\n\"\n            )\n\n            # ========== THEN SPLIT ==========\n            if has_separate_eval_source and eval_dataset is not None:\n                # Eval came from a separate HF split — format it too\n                logger.info(f\"Formatting eval dataset ({len(eval_dataset)} rows)...\\n\")\n                eval_info = format_and_template_dataset(\n                    eval_dataset,\n                    model_name = self.model_name,\n                    tokenizer = self.tokenizer,\n                    is_vlm = self.is_vlm,\n                    format_type = format_type,\n                    dataset_name = dataset_source,\n                    custom_format_mapping = custom_format_mapping,\n                )\n                eval_dataset = eval_info[\"dataset\"]\n                logger.info(f\"Eval dataset formatted successfully\\n\")\n            elif eval_enabled and not has_separate_eval_source:\n                # No separate eval source — split the already-formatted dataset\n                formatted_dataset = dataset_info[\"dataset\"]\n                split_result = self._resolve_eval_split_from_dataset(formatted_dataset)\n                if split_result is not None:\n                    train_portion, eval_dataset = split_result\n                    dataset_info[\"dataset\"] = train_portion\n\n            return (dataset_info, eval_dataset)\n\n        except Exception as e:\n            logger.error(f\"Error loading dataset: {e}\")\n            self._update_progress(error = str(e))\n            return None\n\n    def _auto_detect_eval_split_from_hf(\n        self, dataset_source: str, subset: str\n    ) -> Optional[Dataset]:\n        \"\"\"Auto-detect an eval split from HF dataset (separate named split only).\"\"\"\n        try:\n            from datasets import get_dataset_split_names\n\n            load_kwargs = {\"path\": dataset_source}\n            if subset:\n                load_kwargs[\"config_name\"] = subset\n            available_splits = get_dataset_split_names(**load_kwargs)\n            logger.info(f\"Available splits: {available_splits}\\n\")\n\n            # Check for common eval split names\n            for candidate in [\"eval\", \"validation\", \"valid\", \"val\", \"test\"]:\n                if candidate in available_splits:\n                    eval_load_kwargs = {\"path\": dataset_source, \"split\": candidate}\n                    if subset:\n                        eval_load_kwargs[\"name\"] = subset\n                    candidate_ds = load_dataset(**eval_load_kwargs)\n                    if len(candidate_ds) >= 16:\n                        logger.info(\n                            f\"Auto-detected eval split '{candidate}' with {len(candidate_ds)} rows\\n\"\n                        )\n                        return candidate_ds\n                    else:\n                        logger.info(\n                            f\"Found eval split '{candidate}' but only {len(candidate_ds)} rows (< 16), skipping\\n\"\n                        )\n\n        except Exception as e:\n            logger.warning(f\"Could not check dataset splits: {e}\")\n\n        # No separate HF eval split found — caller will handle programmatic splitting\n        return None\n\n    def _resolve_eval_split_from_dataset(self, dataset) -> Optional[tuple]:\n        \"\"\"Split a dataset into train and eval portions.\n\n        Returns:\n            Tuple of (train_dataset, eval_dataset), or None if dataset too small.\n        \"\"\"\n        MIN_EVAL_ROWS = 16\n        MIN_TOTAL_ROWS = 32  # Need at least 16 train + 16 eval\n\n        n = len(dataset)\n        if n < MIN_TOTAL_ROWS:\n            logger.info(f\"Dataset too small ({n} rows) for eval split, skipping eval\\n\")\n            return None\n\n        eval_size = max(MIN_EVAL_ROWS, min(128, int(0.05 * n)))\n        # Ensure we don't take more than half the dataset\n        eval_size = min(eval_size, n // 2)\n\n        logger.info(f\"Auto-splitting: {eval_size} rows for eval from {n} total\\n\")\n        split_result = dataset.train_test_split(test_size = eval_size, seed = 3407)\n        logger.info(\n            f\"Split complete: {len(split_result['train'])} train, {len(split_result['test'])} eval\\n\"\n        )\n        return (split_result[\"train\"], split_result[\"test\"])\n\n    def start_training(\n        self,\n        dataset: Dataset,\n        eval_dataset: Dataset = None,\n        eval_steps: float = 0.00,\n        output_dir: str | None = None,\n        num_epochs: int = 3,\n        learning_rate: float = 5e-5,\n        batch_size: int = 2,\n        gradient_accumulation_steps: int = 4,\n        warmup_steps: int = None,\n        warmup_ratio: float = None,\n        max_steps: int = 0,\n        save_steps: int = 0,\n        weight_decay: float = 0.01,\n        random_seed: int = 3407,\n        packing: bool = False,\n        train_on_completions: bool = False,\n        enable_wandb: bool = False,\n        wandb_project: str = \"unsloth-training\",\n        wandb_token: str = None,\n        enable_tensorboard: bool = False,\n        tensorboard_dir: str | None = None,\n        **kwargs,\n    ) -> bool:\n        \"\"\"Start training in a separate thread\"\"\"\n\n        if self.is_training:\n            logger.warning(\"Training already in progress\")\n            return False\n\n        if self.model is None or self.tokenizer is None:\n            self._update_progress(error = \"Model not loaded\")\n            return False\n\n        # Pre-import heavy transformers modules on the main thread.\n        # Unsloth's patched_import hook (deepseek_v3_moe.py) is not thread-safe\n        # with Python's importlib cache, causing KeyError: 'size' if these are\n        # first imported inside the worker thread.\n        import transformers  # noqa: F401 – ensures submodules are cached\n        from transformers import (  # noqa: F401\n            Trainer as _HFTrainer,\n            TrainingArguments as _TrainingArguments,\n            TrainerCallback as _TrainerCallback,\n        )\n\n        if self._audio_type == \"whisper\":\n            from transformers import (  # noqa: F401\n                Seq2SeqTrainer as _Seq2SeqTrainer,\n                Seq2SeqTrainingArguments as _Seq2SeqTrainingArguments,\n            )\n\n        # Start training in separate thread\n        self.training_thread = threading.Thread(\n            target = self._train_worker,\n            args = (dataset,),\n            kwargs = {\n                \"output_dir\": output_dir,\n                \"num_epochs\": num_epochs,\n                \"learning_rate\": learning_rate,\n                \"batch_size\": batch_size,\n                \"gradient_accumulation_steps\": gradient_accumulation_steps,\n                \"warmup_steps\": warmup_steps,\n                \"warmup_ratio\": warmup_ratio,\n                \"max_steps\": max_steps,\n                \"save_steps\": save_steps,\n                \"weight_decay\": weight_decay,\n                \"random_seed\": random_seed,\n                \"packing\": packing,\n                \"train_on_completions\": train_on_completions,\n                \"enable_wandb\": enable_wandb,\n                \"wandb_project\": wandb_project,\n                \"wandb_token\": wandb_token,\n                \"enable_tensorboard\": enable_tensorboard,\n                \"tensorboard_dir\": tensorboard_dir,\n                \"eval_dataset\": eval_dataset,\n                \"eval_steps\": eval_steps,\n                **kwargs,\n            },\n        )\n\n        self.should_stop = False\n        self.is_training = True\n        try:\n            self.training_thread.start()\n            return True\n        except Exception as e:\n            self.is_training = False\n            logger.error(f\"Failed to start training thread: {e}\")\n            return False\n\n    def _train_worker(self, dataset: Dataset, **training_args):\n        \"\"\"Worker function for training (runs in separate thread)\"\"\"\n        try:\n            # Store training parameters for metrics calculation\n            self.batch_size = training_args.get(\"batch_size\", 2)\n            self.max_seq_length = training_args.get(\"max_seq_length\", 2048)\n            self.gradient_accumulation_steps = training_args.get(\n                \"gradient_accumulation_steps\", 4\n            )\n\n            # Set training start time\n            self.training_start_time = time.time()\n\n            self._update_progress(is_training = True, error = None)\n\n            # Setup logging\n            if training_args.get(\"enable_wandb\", False) and training_args.get(\n                \"wandb_token\"\n            ):\n                os.environ[\"WANDB_API_KEY\"] = training_args[\"wandb_token\"]\n                import wandb\n\n                wandb.init(\n                    project = training_args.get(\"wandb_project\", \"unsloth-training\")\n                )\n\n            # Create output directory\n            output_dir = str(resolve_output_dir(training_args.get(\"output_dir\")))\n            ensure_dir(Path(output_dir))\n\n            # ========== AUDIO TRAINER BRANCH ==========\n            if self._audio_type == \"csm\":\n                # CSM uses plain HF Trainer (NOT SFTTrainer)\n                # Needs remove_unused_columns=False for depth decoder (input_values + cutoffs)\n                from transformers import Trainer as HFTrainer, TrainingArguments\n\n                self._apply_csm_forward_fix()\n\n                config = self._build_audio_training_args(\n                    training_args,\n                    output_dir,\n                    extra_args = {\n                        \"remove_unused_columns\": False,\n                    },\n                )\n                self.trainer = HFTrainer(\n                    model = self.model,\n                    train_dataset = dataset,\n                    args = TrainingArguments(**config),\n                )\n                self.trainer.add_callback(self._create_progress_callback())\n\n                batch_size = training_args.get(\"batch_size\", 2)\n                total = self._calculate_total_steps(\n                    len(dataset),\n                    batch_size,\n                    training_args.get(\"gradient_accumulation_steps\", 4),\n                    training_args.get(\"num_epochs\", 3),\n                    training_args.get(\"max_steps\", 0),\n                )\n                self._update_progress(\n                    total_steps = total, status_message = \"Starting CSM training...\"\n                )\n                logger.info(f\"CSM training config: {config}\\n\")\n                self.trainer.train()\n                self._finalize_training(output_dir, \"CSM\")\n                return\n\n            elif self._audio_type == \"snac\":\n                # Orpheus: language model with SNAC codec tokens — plain HF Trainer\n                # DataCollatorForSeq2Seq dynamically pads variable-length sequences per batch\n                # (text + audio codes vary in length) and pads labels with -100.\n                from transformers import (\n                    Trainer as HFTrainer,\n                    TrainingArguments,\n                    DataCollatorForSeq2Seq,\n                )\n\n                config = self._build_audio_training_args(training_args, output_dir)\n                self.trainer = HFTrainer(\n                    model = self.model,\n                    train_dataset = dataset,\n                    args = TrainingArguments(**config),\n                    data_collator = DataCollatorForSeq2Seq(\n                        tokenizer = self.tokenizer,\n                        padding = True,\n                        pad_to_multiple_of = 8,\n                    ),\n                )\n                self.trainer.add_callback(self._create_progress_callback())\n\n                batch_size = training_args.get(\"batch_size\", 2)\n                total = self._calculate_total_steps(\n                    len(dataset),\n                    batch_size,\n                    training_args.get(\"gradient_accumulation_steps\", 4),\n                    training_args.get(\"num_epochs\", 3),\n                    training_args.get(\"max_steps\", 0),\n                )\n                self._update_progress(\n                    total_steps = total, status_message = \"Starting SNAC training...\"\n                )\n                logger.info(f\"SNAC training config: {config}\\n\")\n                self.trainer.train()\n                self._finalize_training(output_dir, \"SNAC\")\n                return\n\n            elif self._audio_type == \"whisper\":\n                # Whisper: Seq2SeqTrainer with custom speech collator\n                from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments\n                from utils.datasets import DataCollatorSpeechSeq2SeqWithPadding\n\n                eval_dataset = training_args.get(\"eval_dataset\", None)\n                extra = {\"remove_unused_columns\": False, \"label_names\": [\"labels\"]}\n                if eval_dataset:\n                    extra[\"eval_strategy\"] = \"steps\"\n                    extra[\"eval_steps\"] = training_args.get(\"eval_steps\", 5)\n\n                config = self._build_audio_training_args(\n                    training_args, output_dir, extra_args = extra\n                )\n\n                trainer_kwargs = {\n                    \"model\": self.model,\n                    \"train_dataset\": dataset,\n                    \"data_collator\": DataCollatorSpeechSeq2SeqWithPadding(\n                        processor = self.tokenizer\n                    ),\n                    \"processing_class\": self.tokenizer.feature_extractor,\n                    \"args\": Seq2SeqTrainingArguments(**config),\n                }\n                if eval_dataset:\n                    trainer_kwargs[\"eval_dataset\"] = eval_dataset\n\n                self.trainer = Seq2SeqTrainer(**trainer_kwargs)\n                self.trainer.add_callback(self._create_progress_callback())\n\n                batch_size = training_args.get(\"batch_size\", 2)\n                total = self._calculate_total_steps(\n                    len(dataset),\n                    batch_size,\n                    training_args.get(\"gradient_accumulation_steps\", 4),\n                    training_args.get(\"num_epochs\", 3),\n                    training_args.get(\"max_steps\", 0),\n                )\n                self._update_progress(\n                    total_steps = total, status_message = \"Starting Whisper training...\"\n                )\n                logger.info(f\"Whisper training config: {config}\\n\")\n                self.trainer.train()\n                self._finalize_training(output_dir, \"Whisper\")\n                return\n\n            elif self._audio_type is not None and self._audio_type not in (\n                \"bicodec\",\n                \"dac\",\n            ):\n                # bicodec/dac use the standard SFTTrainer text path below\n                raise NotImplementedError(\n                    f\"Audio training for '{self._audio_type}' not yet implemented\"\n                )\n\n            # ========== DATA COLLATOR SELECTION ==========\n            # Detect special model types\n            model_name_lower = self.model_name.lower()\n            is_deepseek_ocr = (\n                \"deepseek\" in model_name_lower and \"ocr\" in model_name_lower\n            )\n\n            logger.info(\"Configuring data collator...\\n\")\n\n            data_collator = None  # Default to built-in data collator\n            if is_deepseek_ocr:\n                # Special DeepSeek OCR collator - auto-install if needed\n                logger.info(\"Detected DeepSeek OCR model\\n\")\n                # Ensure DeepSeek OCR module is installed\n                if not _ensure_deepseek_ocr_installed():\n                    error_msg = (\n                        \"Failed to install DeepSeek OCR module. \"\n                        \"Please install manually: \"\n                        \"from huggingface_hub import snapshot_download; \"\n                        \"snapshot_download('unsloth/DeepSeek-OCR', local_dir='deepseek_ocr')\"\n                    )\n                    logger.error(error_msg)\n                    self._update_progress(error = error_msg, is_training = False)\n                    return\n\n                try:\n                    from backend.data_utils import DeepSeekOCRDataCollator\n\n                    logger.info(\"Configuring DeepSeek OCR data collator...\\n\")\n                    FastVisionModel.for_training(self.model)\n                    data_collator = DeepSeekOCRDataCollator(\n                        tokenizer = self.tokenizer,\n                        model = self.model,\n                        image_size = 640,\n                        base_size = 1024,\n                        crop_mode = True,\n                        train_on_responses_only = training_args.get(\n                            \"train_on_completions\", False\n                        ),\n                    )\n                    logger.info(\"DeepSeek OCR data collator configured successfully\\n\")\n\n                except Exception as e:\n                    logger.error(f\"Failed to configure DeepSeek OCR collator: {e}\")\n                    error_msg = f\"Error configuring DeepSeek OCR: {str(e)}\"\n                    self._update_progress(error = error_msg, is_training = False)\n                    return\n\n            elif self.is_audio_vlm:\n                # Audio VLM collator (e.g. Gemma 3N with audio data)\n                # Mirrors the collate_fn from Gemma3N_(4B)-Audio notebook\n                logger.info(\"Configuring audio VLM data collator...\\n\")\n                processor = self.tokenizer  # FastModel returns processor as tokenizer\n\n                audio_col_name = getattr(self, \"_audio_vlm_audio_col\", \"audio\")\n\n                def audio_vlm_collate_fn(examples):\n                    texts = []\n                    audios = []\n                    for example in examples:\n                        text = processor.apply_chat_template(\n                            example[\"messages\"],\n                            tokenize = False,\n                            add_generation_prompt = False,\n                        ).strip()\n                        texts.append(text)\n                        audios.append(example[audio_col_name][\"array\"])\n\n                    batch = processor(\n                        text = texts, audio = audios, return_tensors = \"pt\", padding = True\n                    )\n\n                    # Labels = input_ids with special tokens masked\n                    labels = batch[\"input_ids\"].clone()\n                    labels[labels == processor.tokenizer.pad_token_id] = -100\n                    for attr in (\n                        \"audio_token_id\",\n                        \"image_token_id\",\n                        \"boi_token_id\",\n                        \"eoi_token_id\",\n                    ):\n                        token_id = getattr(processor.tokenizer, attr, None)\n                        if token_id is not None:\n                            labels[labels == token_id] = -100\n                    batch[\"labels\"] = labels\n                    return batch\n\n                data_collator = audio_vlm_collate_fn\n                logger.info(\"Audio VLM data collator configured\\n\")\n\n            elif self.is_vlm:\n                # Standard VLM collator (images)\n                logger.info(\"Using UnslothVisionDataCollator for vision model\\n\")\n                from unsloth.trainer import UnslothVisionDataCollator\n\n                FastVisionModel.for_training(self.model)\n                data_collator = UnslothVisionDataCollator(self.model, self.tokenizer)\n                logger.info(\"Vision data collator configured\\n\")\n\n            # ========== TRAINING CONFIGURATION ==========\n            # Handle warmup_steps vs warmup_ratio\n            warmup_steps_val = training_args.get(\"warmup_steps\", None)\n            warmup_ratio_val = training_args.get(\"warmup_ratio\", None)\n\n            lr_value = training_args.get(\"learning_rate\", 2e-4)\n            logger.info(\n                f\"[DEBUG] learning_rate from training_args: {lr_value} (type: {type(lr_value).__name__})\\n\"\n            )\n\n            config_args = {\n                \"per_device_train_batch_size\": training_args.get(\"batch_size\", 2),\n                \"gradient_accumulation_steps\": training_args.get(\n                    \"gradient_accumulation_steps\", 4\n                ),\n                \"num_train_epochs\": training_args.get(\n                    \"num_epochs\", 3\n                ),  # Default to epochs\n                \"learning_rate\": lr_value,\n                \"fp16\": not is_bfloat16_supported(),\n                \"bf16\": is_bfloat16_supported(),\n                \"logging_steps\": 1,\n                \"weight_decay\": training_args.get(\"weight_decay\", 0.01),\n                \"seed\": training_args.get(\"random_seed\", 3407),\n                \"output_dir\": output_dir,\n                \"report_to\": _build_report_targets(training_args),\n                \"include_num_input_tokens_seen\": True,  # Enable token counting\n                \"dataset_num_proc\": 1\n                if (self.is_audio or self.is_audio_vlm or self._cuda_audio_used)\n                else safe_num_proc(max(1, os.cpu_count() // 4)),\n                \"max_seq_length\": training_args.get(\"max_seq_length\", 2048),\n            }\n            if training_args.get(\"enable_tensorboard\", False):\n                config_args[\"logging_dir\"] = str(\n                    resolve_tensorboard_dir(training_args.get(\"tensorboard_dir\"))\n                )\n            logger.info(\n                f\"[DEBUG] dataset_num_proc={config_args['dataset_num_proc']} (is_audio={self.is_audio}, is_audio_vlm={self.is_audio_vlm}, _cuda_audio_used={self._cuda_audio_used})\"\n            )\n\n            # On Windows with transformers 5.x, disable DataLoader multiprocessing\n            # to avoid issues with modified sys.path (.venv_t5) in spawned workers.\n            if sys.platform == \"win32\":\n                import transformers as _tf\n\n                if _tf.__version__.startswith(\"5.\"):\n                    config_args[\"dataloader_num_workers\"] = 0\n\n            # Add warmup parameter - use warmup_ratio if provided, otherwise warmup_steps\n            if warmup_ratio_val is not None:\n                config_args[\"warmup_ratio\"] = warmup_ratio_val\n                logger.info(f\"Using warmup_ratio: {warmup_ratio_val}\\n\")\n            elif warmup_steps_val is not None:\n                config_args[\"warmup_steps\"] = warmup_steps_val\n                logger.info(f\"Using warmup_steps: {warmup_steps_val}\\n\")\n            else:\n                # Default to warmup_steps if neither provided\n                config_args[\"warmup_steps\"] = 5\n                logger.info(f\"Using default warmup_steps: 5\\n\")\n\n            # Add save_steps if specified\n            save_steps_val = training_args.get(\"save_steps\", 0)\n            if save_steps_val and save_steps_val > 0:\n                config_args[\"save_steps\"] = save_steps_val\n                config_args[\"save_strategy\"] = \"steps\"\n\n            #  If max_steps is specified, use it instead of epochs\n            max_steps_val = training_args.get(\"max_steps\", 0)\n            if max_steps_val and max_steps_val > 0:\n                del config_args[\"num_train_epochs\"]  # Remove epochs\n                config_args[\"max_steps\"] = max_steps_val  # Use steps instead\n                logger.info(f\"Training for {max_steps_val} steps\\n\")\n            else:\n                logger.info(f\"Training for {config_args['num_train_epochs']} epochs\\n\")\n\n            # ========== EVAL CONFIGURATION ==========\n            eval_dataset = training_args.get(\"eval_dataset\", None)\n            eval_steps_val = training_args.get(\"eval_steps\", 0.00)\n            if eval_dataset is not None:\n                if eval_steps_val > 0:\n                    config_args[\"eval_strategy\"] = \"steps\"\n                    config_args[\"eval_steps\"] = eval_steps_val\n                    logger.info(\n                        f\"✅ Evaluation enabled: eval_steps={eval_steps_val} (fraction of total steps)\\n\"\n                    )\n                    logger.info(f\"Eval dataset: {len(eval_dataset)} rows\\n\")\n                else:\n                    logger.info(\n                        f\"⚠️  Eval dataset provided but eval_steps={eval_steps_val} (disabled)\\n\"\n                    )\n                    logger.info(\"To enable evaluation, set eval_steps > 0.0\\n\")\n            else:\n                logger.info(\"No eval dataset — evaluation disabled\\n\")\n\n            # Add model-specific parameters\n            # Use optim and lr_scheduler_type from training_args if provided, otherwise use defaults\n            optim_value = training_args.get(\"optim\", \"adamw_8bit\")\n            lr_scheduler_type_value = training_args.get(\"lr_scheduler_type\", \"linear\")\n\n            if self.is_vlm or self.is_audio_vlm:\n                # Vision / audio VLM config (both need skip_prepare_dataset + remove_unused_columns)\n                label = \"audio VLM\" if self.is_audio_vlm else \"vision\"\n                logger.info(f\"Configuring {label} model training parameters\\n\")\n                # Use provided values or defaults for vision models\n                optim_value = training_args.get(\"optim\", \"adamw_torch_fused\")\n                lr_scheduler_type_value = training_args.get(\n                    \"lr_scheduler_type\", \"cosine\"\n                )\n                config_args.update(\n                    {\n                        \"optim\": optim_value,\n                        \"lr_scheduler_type\": lr_scheduler_type_value,\n                        \"gradient_checkpointing\": True,\n                        \"gradient_checkpointing_kwargs\": {\"use_reentrant\": False},\n                        \"max_grad_norm\": 0.3,\n                        \"remove_unused_columns\": False,\n                        \"dataset_text_field\": \"\",\n                        \"dataset_kwargs\": {\"skip_prepare_dataset\": True},\n                        \"max_length\": training_args.get(\"max_seq_length\", 2048),\n                    }\n                )\n            else:\n                logger.info(\"Configuring text model training parameters\\n\")\n                config_args.update(\n                    {\n                        \"optim\": optim_value,\n                        \"lr_scheduler_type\": lr_scheduler_type_value,\n                        \"dataset_text_field\": \"text\",\n                    }\n                )\n\n                # Only add packing for text models (not DeepSeek OCR which is VLM)\n                if not is_deepseek_ocr:\n                    packing_enabled = training_args.get(\"packing\", False)\n                    config_args[\"packing\"] = packing_enabled\n                    logger.info(\n                        f\"Sequence packing: {'enabled' if packing_enabled else 'disabled'}\\n\"\n                    )\n\n            # Audio codec overrides — BiCodec/DAC use the text SFTTrainer path\n            if self._audio_type == \"bicodec\":\n                config_args[\"packing\"] = False\n                logger.info(\"Applied BiCodec overrides: packing=False\\n\")\n            elif self._audio_type == \"dac\":\n                config_args[\"packing\"] = False\n                logger.info(\"Applied DAC overrides: packing=False\\n\")\n\n            logger.info(f\"The configuration is: {config_args}\")\n\n            logger.info(\"Training configuration prepared\\n\")\n            # ========== TRAINER INITIALIZATION ==========\n            if self.is_audio_vlm:\n                # Audio VLM (e.g. Gemma 3N + audio): raw Dataset from _format_audio_vlm_dataset\n                # Notebook uses processing_class=processor.tokenizer (text tokenizer only)\n                train_dataset = (\n                    dataset if isinstance(dataset, Dataset) else dataset[\"dataset\"]\n                )\n                processing_class = (\n                    self.tokenizer.tokenizer\n                    if hasattr(self.tokenizer, \"tokenizer\")\n                    else self.tokenizer\n                )\n                trainer_kwargs = {\n                    \"model\": self.model,\n                    \"train_dataset\": train_dataset,\n                    \"processing_class\": processing_class,\n                    \"data_collator\": data_collator,\n                    \"args\": SFTConfig(**config_args),\n                }\n                if eval_dataset is not None:\n                    trainer_kwargs[\"eval_dataset\"] = eval_dataset\n                self.trainer = SFTTrainer(**trainer_kwargs)\n            elif self.is_vlm:\n                # Image VLM: dataset is dict wrapper from format_and_template_dataset\n                train_dataset = (\n                    dataset[\"dataset\"] if isinstance(dataset, dict) else dataset\n                )\n                trainer_kwargs = {\n                    \"model\": self.model,\n                    \"train_dataset\": train_dataset,\n                    \"processing_class\": self.tokenizer,\n                    \"data_collator\": data_collator,\n                    \"args\": SFTConfig(**config_args),\n                }\n                if eval_dataset is not None:\n                    trainer_kwargs[\"eval_dataset\"] = eval_dataset\n                self.trainer = SFTTrainer(**trainer_kwargs)\n            else:\n                # For text-only training, if the tokenizer is actually a Processor\n                # (e.g., Gemma-3 returns a ProcessorMixin even for text), we must\n                # unwrap to the raw tokenizer. Otherwise Unsloth's SFTTrainer detects\n                # ProcessorMixin → sets _is_vlm=True → skips _prepare_dataset entirely,\n                # and the 'text' column never gets tokenized to 'input_ids'.\n                from transformers import ProcessorMixin\n\n                sft_tokenizer = self.tokenizer\n                if isinstance(self.tokenizer, ProcessorMixin) and hasattr(\n                    self.tokenizer, \"tokenizer\"\n                ):\n                    logger.info(\n                        f\"  ⚠️ Unwrapping Processor → raw tokenizer for text-only SFTTrainer\"\n                    )\n                    sft_tokenizer = self.tokenizer.tokenizer\n\n                trainer_kwargs = {\n                    \"model\": self.model,\n                    \"tokenizer\": sft_tokenizer,\n                    \"train_dataset\": dataset[\"dataset\"],\n                    \"data_collator\": data_collator,\n                    \"args\": SFTConfig(**config_args),\n                }\n                if eval_dataset is not None:\n                    trainer_kwargs[\"eval_dataset\"] = eval_dataset\n                self.trainer = SFTTrainer(**trainer_kwargs)\n                # Restore the full processor as processing_class so checkpoint\n                # saves include preprocessor_config.json (needed for GGUF export).\n                if sft_tokenizer is not self.tokenizer:\n                    self.trainer.processing_class = self.tokenizer\n            logger.info(\"Trainer initialized\\n\")\n\n            # ========== TRAIN ON RESPONSES ONLY ==========\n            # Determine if we should train on responses only\n            instruction_part = None\n            response_part = None\n            train_on_responses_enabled = training_args.get(\n                \"train_on_completions\", False\n            )\n\n            # DeepSeek OCR handles this internally in its collator, so skip\n            # Audio VLM handles label masking in its collator, so skip\n            if (\n                train_on_responses_enabled\n                and not self.is_audio_vlm\n                and not self.is_audio\n                and not (is_deepseek_ocr or dataset[\"final_format\"].lower() == \"alpaca\")\n            ):\n                try:\n                    logger.info(\"Configuring train on responses only...\\n\")\n\n                    # Get the template mapping for this model\n                    model_name_lower = self.model_name.lower()\n\n                    if model_name_lower in MODEL_TO_TEMPLATE_MAPPER:\n                        template_name = MODEL_TO_TEMPLATE_MAPPER[model_name_lower]\n                        logger.info(f\"Detected template: {template_name}\\n\")\n\n                        if template_name in TEMPLATE_TO_RESPONSES_MAPPER:\n                            instruction_part = TEMPLATE_TO_RESPONSES_MAPPER[\n                                template_name\n                            ][\"instruction\"]\n                            response_part = TEMPLATE_TO_RESPONSES_MAPPER[template_name][\n                                \"response\"\n                            ]\n\n                            logger.info(\n                                f\"Instruction marker: {instruction_part[:50]}...\\n\"\n                            )\n                            logger.info(f\"Response marker: {response_part[:50]}...\\n\")\n                        else:\n                            logger.info(\n                                f\"No response mapping found for template: {template_name}\\n\"\n                            )\n                            train_on_responses_enabled = False\n                    else:\n                        logger.info(\n                            f\"No template mapping found for model: {self.model_name}\\n\"\n                        )\n                        train_on_responses_enabled = False\n\n                except Exception as e:\n                    logger.warning(f\"Could not configure train on responses: {e}\")\n                    train_on_responses_enabled = False\n\n            # Apply train on responses only if we have valid parts\n            if (\n                train_on_responses_enabled\n                and instruction_part\n                and response_part\n                and not self.is_audio_vlm\n                and not self.is_audio\n                and not (is_deepseek_ocr or dataset[\"final_format\"].lower() == \"alpaca\")\n            ):\n                try:\n                    from unsloth.chat_templates import train_on_responses_only\n\n                    self.trainer = train_on_responses_only(\n                        self.trainer,\n                        instruction_part = instruction_part,\n                        response_part = response_part,\n                        num_proc = config_args[\"dataset_num_proc\"],\n                    )\n                    logger.info(\"Train on responses only configured successfully\\n\")\n\n                    # ── Safety net: check if all samples were filtered out ──\n                    # Unsloth's train_on_responses_only masks non-response\n                    # tokens with -100. If max_seq_length is too short and the\n                    # response portion gets truncated away, EVERY sample ends\n                    # up with all labels == -100 and Unsloth removes them,\n                    # leaving 0 usable training samples.\n                    filtered_len = len(self.trainer.train_dataset)\n                    original_len = len(dataset[\"dataset\"])\n                    dropped = original_len - filtered_len\n                    drop_pct = (\n                        round(100 * dropped / original_len, 1)\n                        if original_len > 0\n                        else 0\n                    )\n\n                    if filtered_len == 0 or drop_pct > 30:\n                        max_seq = training_args.get(\"max_seq_length\", 2048)\n                        error_msg = (\n                            f\"{dropped}/{original_len} samples ({drop_pct}%) \"\n                            f\"were dropped after applying 'train on responses \"\n                            f\"only' — only {filtered_len} remain. This usually \"\n                            f\"means max_seq_length ({max_seq}) is too short \"\n                            f\"and the response portion is being truncated \"\n                            f\"away. Try increasing max_seq_length (e.g. 8192) \"\n                            f\"or disabling 'Train on completions'.\"\n                        )\n                        logger.error(error_msg)\n                        self._update_progress(error = error_msg, is_training = False)\n                        return\n\n                    if dropped > 0:\n                        logger.info(\n                            f\"⚠️ {dropped}/{original_len} samples \"\n                            f\"({drop_pct}%) were dropped (all labels \"\n                            f\"masked). {filtered_len} samples remain.\\n\"\n                        )\n                    logger.info(f\"Post-filter dataset size: {filtered_len} samples\\n\")\n\n                    # [DEBUG] Decode first sample AFTER train_on_completions applied\n                    # try:\n                    #     _row = self.trainer.train_dataset[0]\n                    #     _space = self.tokenizer(\n                    #         \" \", add_special_tokens = False\n                    #     ).input_ids[0]\n                    #     print(\"[DEBUG] === After train_on_completions ===\", flush = True)\n                    #     print(\n                    #         f\"[DEBUG] input_ids decoded:\\n{self.tokenizer.decode(_row['input_ids'])}\\n\",\n                    #         flush = True,\n                    #     )\n                    #     print(\n                    #         f\"[DEBUG] labels decoded (-100 → space):\\n{self.tokenizer.decode([_space if x == -100 else x for x in _row['labels']])}\\n\",\n                    #         flush = True,\n                    #     )\n                    # except Exception as _dbg_e:\n                    #     print(\n                    #         f\"[DEBUG] Could not decode post-completions sample: {_dbg_e}\",\n                    #         flush = True,\n                    #     )\n\n                except Exception as e:\n                    logger.warning(f\"Failed to apply train on responses only: {e}\")\n                    train_on_responses_enabled = False\n            else:\n                if train_on_responses_enabled and is_deepseek_ocr:\n                    logger.info(\"Train on responses handled by DeepSeek OCR collator\\n\")\n                else:\n                    logger.info(\"Training on full sequences (including prompts)\\n\")\n\n            # ========== PROGRESS TRACKING ==========\n            self.trainer.add_callback(self._create_progress_callback())\n\n            num_samples = len(\n                dataset[\"dataset\"] if isinstance(dataset, dict) else dataset\n            )\n            batch_size = training_args.get(\"batch_size\", 2)\n            total_steps = self._calculate_total_steps(\n                num_samples,\n                batch_size,\n                training_args.get(\"gradient_accumulation_steps\", 4),\n                training_args.get(\"num_epochs\", 3),\n                training_args.get(\"max_steps\", 0),\n            )\n            self._update_progress(total_steps = total_steps)\n\n            # ========== START TRAINING ==========\n            self._update_progress(status_message = \"Starting training...\")\n            logger.info(\"Starting training...\\n\")\n            self.trainer.train()\n\n            # ========== SAVE MODEL ==========\n            self._finalize_training(output_dir)\n\n        except Exception as e:\n            import traceback\n\n            logger.error(f\"Training error: {e}\")\n            logger.error(f\"Full traceback:\\n{traceback.format_exc()}\")\n            self._update_progress(is_training = False, error = str(e))\n\n        finally:\n            self.is_training = False\n\n    def _patch_adapter_config(self, output_dir: str) -> None:\n        \"\"\"Patch adapter_config.json with unsloth_training_method.\n\n        Values: 'qlora', 'lora', 'FT', 'CPT', 'DPO', 'GRPO', etc.\n        For LoRA/QLoRA, the distinction comes from load_in_4bit.\n        \"\"\"\n        config_path = os.path.join(output_dir, \"adapter_config.json\")\n        if not os.path.exists(config_path):\n            logger.info(\"No adapter_config.json found — skipping training method patch\")\n            return\n\n        try:\n            with open(config_path, \"r\") as f:\n                config = json.load(f)\n\n            # Determine the training method\n            if self.load_in_4bit:\n                method = \"qlora\"\n            else:\n                method = \"lora\"\n\n            config[\"unsloth_training_method\"] = method\n            logger.info(\n                f\"Patching adapter_config.json with unsloth_training_method='{method}'\"\n            )\n\n            with open(config_path, \"w\") as f:\n                json.dump(config, f, indent = 2)\n\n        except Exception as e:\n            logger.warning(f\"Failed to patch adapter_config.json: {e}\")\n\n    def stop_training(self, save: bool = True):\n        \"\"\"Stop ongoing training\"\"\"\n        logger.info(f\"\\nStopping training (save={save})...\")\n        self.should_stop = True\n        self.save_on_stop = save\n        stop_msg = (\n            \"Stopping training and saving checkpoint...\"\n            if save\n            else \"Cancelling training...\"\n        )\n        self._update_progress(status_message = stop_msg)\n\n        # If trainer exists, try to stop it gracefully\n        if self.trainer:\n            try:\n                # The callback will catch should_stop flag and stop the training loop\n                logger.info(\"Training will stop at next step...\\n\")\n            except Exception as e:\n                logger.error(f\"Error stopping trainer: {e}\")\n\n    def get_training_progress(self) -> TrainingProgress:\n        \"\"\"Get current training progress\"\"\"\n        with self._lock:\n            return self.training_progress\n\n    def cleanup(self):\n        \"\"\"Cleanup resources\"\"\"\n        if self.trainer:\n            self.trainer = None\n        if self.model:\n            self.model = None\n        if self.tokenizer:\n            self.tokenizer = None\n\n        # Clear GPU memory\n        clear_gpu_cache()\n\n\ndef _ensure_deepseek_ocr_installed():\n    \"\"\"\n    Auto-install DeepSeek OCR module if not available.\n    Downloads from HuggingFace hub as a local module.\n\n    Returns:\n        bool: True if available (either already installed or just installed)\n    \"\"\"\n    try:\n        # Try importing to see if already available\n        from deepseek_ocr.modeling_deepseekocr import format_messages\n\n        logger.info(\"DeepSeek OCR module already available\")\n        return True\n    except ImportError:\n        pass\n\n    try:\n        logger.info(\n            \"DeepSeek OCR module not found. Auto-installing from HuggingFace...\"\n        )\n        logger.info(\"\\n Downloading DeepSeek OCR module from HuggingFace...\\n\")\n\n        from huggingface_hub import snapshot_download\n        import sys\n        import os\n\n        # Get the script directory to install locally\n        script_dir = os.path.dirname(os.path.abspath(__file__))\n        parent_dir = os.path.dirname(script_dir)  # Go up to project root\n\n        # Download to project root as 'deepseek_ocr' folder\n        local_dir = os.path.join(parent_dir, \"deepseek_ocr\")\n\n        snapshot_download(\n            \"unsloth/DeepSeek-OCR\", local_dir = local_dir, local_dir_use_symlinks = False\n        )\n\n        # Add to sys.path if not already there\n        if parent_dir not in sys.path:\n            sys.path.insert(0, parent_dir)\n\n        # Try importing again\n        from deepseek_ocr.modeling_deepseekocr import format_messages\n\n        logger.info(\"DeepSeek OCR module installed successfully\")\n        logger.info(\"DeepSeek OCR module installed successfully!\\n\")\n        return True\n\n    except Exception as e:\n        logger.error(f\"Failed to install DeepSeek OCR module: {e}\")\n        logger.info(f\"\\n❌ Failed to install DeepSeek OCR module: {e}\\n\")\n        return False\n\n\n# Global trainer instance\n_trainer_instance = None\n\n\ndef get_trainer() -> UnslothTrainer:\n    \"\"\"Get global trainer instance\"\"\"\n    global _trainer_instance\n    if _trainer_instance is None:\n        _trainer_instance = UnslothTrainer()\n    return _trainer_instance\n"
  },
  {
    "path": "studio/backend/core/training/training.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nTraining backend — subprocess orchestrator.\n\nEach training job runs in a fresh subprocess (mp.get_context(\"spawn\")),\nsolving the transformers version-switching problem. The old in-process\nUnslothTrainer singleton is only used inside the subprocess (worker.py).\n\nThis file orchestrates the subprocess lifecycle, pumps events from the\nworker's mp.Queue, and exposes the same API surface to routes/training.py.\n\nPattern follows core/data_recipe/jobs/manager.py.\n\"\"\"\n\nimport math\nimport multiprocessing as mp\nimport queue\nimport threading\nimport time\nimport structlog\nfrom loggers import get_logger\nfrom dataclasses import dataclass, field\nfrom pathlib import Path\nfrom typing import Optional, Tuple, Any\n\nimport matplotlib.pyplot as plt\n\nlogger = get_logger(__name__)\n\n_CTX = mp.get_context(\"spawn\")\n\n# Plot styling constants\nPLOT_WIDTH = 8\nPLOT_HEIGHT = 3.5\n\n\n@dataclass\nclass TrainingProgress:\n    \"\"\"Mirror of trainer.TrainingProgress — kept here so the parent process\n    never needs to import the heavy ML modules.\"\"\"\n\n    epoch: float = 0\n    step: int = 0\n    total_steps: int = 0\n    loss: float = 0.0\n    learning_rate: float = 0.0\n    is_training: bool = False\n    is_completed: bool = False\n    error: Optional[str] = None\n    status_message: str = \"Ready to train\"\n    elapsed_seconds: Optional[float] = None\n    eta_seconds: Optional[float] = None\n    grad_norm: Optional[float] = None\n    num_tokens: Optional[int] = None\n    eval_loss: Optional[float] = None\n\n\nclass TrainingBackend:\n    \"\"\"\n    Training orchestration backend — subprocess-based.\n    Launches a fresh subprocess per training job, communicates via mp.Queue.\n    \"\"\"\n\n    def __init__(self):\n        # Subprocess state\n        self._proc: Optional[mp.Process] = None\n        self._event_queue: Any = None\n        self._stop_queue: Any = None\n        self._pump_thread: Optional[threading.Thread] = None\n        self._lock = threading.Lock()\n\n        # Progress state (updated by pump thread from subprocess events)\n        self._progress = TrainingProgress()\n        self._should_stop = False\n        self._cancel_requested = False  # True only for stop(save=False)\n\n        # Training Metrics (consumed by routes for SSE and /metrics)\n        self.loss_history: list = []\n        self.lr_history: list = []\n        self.step_history: list = []\n        self.grad_norm_history: list = []\n        self.grad_norm_step_history: list = []\n        self.eval_loss_history: list = []\n        self.eval_step_history: list = []\n        self.eval_enabled: bool = False\n        self.current_theme: str = \"light\"\n\n        # Job metadata\n        self.current_job_id: Optional[str] = None\n        self._output_dir: Optional[str] = None\n\n        logger.info(\"TrainingBackend initialized (subprocess mode)\")\n\n    # ------------------------------------------------------------------\n    # Public API (called by routes/training.py)\n    # ------------------------------------------------------------------\n\n    def start_training(self, **kwargs) -> bool:\n        \"\"\"Spawn a subprocess to run the full training pipeline.\n\n        All kwargs are serialized into a config dict and sent to the worker.\n        Returns True if the subprocess was started successfully.\n        \"\"\"\n        with self._lock:\n            if self._proc is not None and self._proc.is_alive():\n                logger.warning(\"Training subprocess already running\")\n                return False\n\n        # Join prior pump thread to prevent it from consuming events\n        # from the new job's queue (it reads self._event_queue dynamically).\n        if self._pump_thread is not None and self._pump_thread.is_alive():\n            self._pump_thread.join(timeout = 5.0)\n            if self._pump_thread.is_alive():\n                logger.warning(\"Previous pump thread did not exit within 5s\")\n        self._pump_thread = None\n\n        # Reset state\n        self._should_stop = False\n        self._cancel_requested = False\n        self._progress = TrainingProgress(\n            is_training = True, status_message = \"Initializing training...\"\n        )\n        self.loss_history.clear()\n        self.lr_history.clear()\n        self.step_history.clear()\n        self.grad_norm_history.clear()\n        self.grad_norm_step_history.clear()\n        self.eval_loss_history.clear()\n        self.eval_step_history.clear()\n        self.eval_enabled = False\n        self._output_dir = None\n\n        # Build config dict for the subprocess\n        config = {\n            \"model_name\": kwargs[\"model_name\"],\n            \"training_type\": kwargs.get(\"training_type\", \"LoRA/QLoRA\"),\n            \"hf_token\": kwargs.get(\"hf_token\", \"\"),\n            \"load_in_4bit\": kwargs.get(\"load_in_4bit\", True),\n            \"max_seq_length\": kwargs.get(\"max_seq_length\", 2048),\n            \"hf_dataset\": kwargs.get(\"hf_dataset\", \"\"),\n            \"local_datasets\": kwargs.get(\"local_datasets\"),\n            \"local_eval_datasets\": kwargs.get(\"local_eval_datasets\"),\n            \"format_type\": kwargs.get(\"format_type\", \"\"),\n            \"subset\": kwargs.get(\"subset\"),\n            \"train_split\": kwargs.get(\"train_split\", \"train\"),\n            \"eval_split\": kwargs.get(\"eval_split\"),\n            \"eval_steps\": kwargs.get(\"eval_steps\", 0.00),\n            \"dataset_slice_start\": kwargs.get(\"dataset_slice_start\"),\n            \"dataset_slice_end\": kwargs.get(\"dataset_slice_end\"),\n            \"custom_format_mapping\": kwargs.get(\"custom_format_mapping\"),\n            \"is_dataset_image\": kwargs.get(\"is_dataset_image\", False),\n            \"is_dataset_audio\": kwargs.get(\"is_dataset_audio\", False),\n            \"is_embedding\": kwargs.get(\"is_embedding\", False),\n            \"num_epochs\": kwargs.get(\"num_epochs\", 3),\n            \"learning_rate\": kwargs.get(\"learning_rate\", \"2e-4\"),\n            \"batch_size\": kwargs.get(\"batch_size\", 2),\n            \"gradient_accumulation_steps\": kwargs.get(\"gradient_accumulation_steps\", 4),\n            \"warmup_steps\": kwargs.get(\"warmup_steps\"),\n            \"warmup_ratio\": kwargs.get(\"warmup_ratio\"),\n            \"max_steps\": kwargs.get(\"max_steps\", 0),\n            \"save_steps\": kwargs.get(\"save_steps\", 0),\n            \"weight_decay\": kwargs.get(\"weight_decay\", 0.01),\n            \"random_seed\": kwargs.get(\"random_seed\", 3407),\n            \"packing\": kwargs.get(\"packing\", False),\n            \"optim\": kwargs.get(\"optim\", \"adamw_8bit\"),\n            \"lr_scheduler_type\": kwargs.get(\"lr_scheduler_type\", \"linear\"),\n            \"use_lora\": kwargs.get(\"use_lora\", True),\n            \"lora_r\": kwargs.get(\"lora_r\", 16),\n            \"lora_alpha\": kwargs.get(\"lora_alpha\", 16),\n            \"lora_dropout\": kwargs.get(\"lora_dropout\", 0.0),\n            \"target_modules\": kwargs.get(\"target_modules\"),\n            \"gradient_checkpointing\": kwargs.get(\"gradient_checkpointing\", \"unsloth\"),\n            \"use_rslora\": kwargs.get(\"use_rslora\", False),\n            \"use_loftq\": kwargs.get(\"use_loftq\", False),\n            \"train_on_completions\": kwargs.get(\"train_on_completions\", False),\n            \"finetune_vision_layers\": kwargs.get(\"finetune_vision_layers\", True),\n            \"finetune_language_layers\": kwargs.get(\"finetune_language_layers\", True),\n            \"finetune_attention_modules\": kwargs.get(\n                \"finetune_attention_modules\", True\n            ),\n            \"finetune_mlp_modules\": kwargs.get(\"finetune_mlp_modules\", True),\n            \"enable_wandb\": kwargs.get(\"enable_wandb\", False),\n            \"wandb_token\": kwargs.get(\"wandb_token\"),\n            \"wandb_project\": kwargs.get(\"wandb_project\", \"unsloth-training\"),\n            \"enable_tensorboard\": kwargs.get(\"enable_tensorboard\", False),\n            \"tensorboard_dir\": kwargs.get(\"tensorboard_dir\", \"runs\"),\n            \"trust_remote_code\": kwargs.get(\"trust_remote_code\", False),\n        }\n\n        # Derive load_in_4bit from training_type\n        if config[\"training_type\"] != \"LoRA/QLoRA\":\n            config[\"load_in_4bit\"] = False\n\n        # Spawn subprocess\n        from .worker import run_training_process\n\n        self._event_queue = _CTX.Queue()\n        self._stop_queue = _CTX.Queue()\n\n        self._proc = _CTX.Process(\n            target = run_training_process,\n            kwargs = {\n                \"event_queue\": self._event_queue,\n                \"stop_queue\": self._stop_queue,\n                \"config\": config,\n            },\n            daemon = True,\n        )\n        self._proc.start()\n        logger.info(\"Training subprocess started (pid=%s)\", self._proc.pid)\n\n        # Start event pump thread\n        self._pump_thread = threading.Thread(target = self._pump_loop, daemon = True)\n        self._pump_thread.start()\n\n        return True\n\n    def stop_training(self, save: bool = True) -> bool:\n        \"\"\"Send stop signal to the training subprocess.\"\"\"\n        self._should_stop = True\n        if not save:\n            self._cancel_requested = True\n        with self._lock:\n            if self._stop_queue is not None:\n                try:\n                    self._stop_queue.put({\"type\": \"stop\", \"save\": save})\n                except (OSError, ValueError):\n                    pass\n            # Update progress immediately for responsive UI\n            self._progress.status_message = (\n                \"Stopping training and saving checkpoint...\"\n                if save\n                else \"Cancelling training...\"\n            )\n        return True\n\n    def force_terminate(self) -> None:\n        \"\"\"Force-kill the training subprocess so state can be reset immediately.\"\"\"\n        with self._lock:\n            if self._proc is not None and self._proc.is_alive():\n                logger.info(\n                    \"Force-terminating training subprocess (pid=%s)\", self._proc.pid\n                )\n                self._proc.terminate()\n            proc = self._proc\n\n        if proc is not None:\n            proc.join(timeout = 5.0)\n            if proc.is_alive():\n                proc.kill()\n                proc.join(timeout = 2.0)\n\n    def is_training_active(self) -> bool:\n        \"\"\"Check if training is currently active.\"\"\"\n        with self._lock:\n            # Subprocess alive = active\n            if self._proc is not None and self._proc.is_alive():\n                return True\n\n            # Stop was requested and process exited → inactive\n            if self._should_stop:\n                return False\n\n            # Check progress state\n            p = self._progress\n            if p.is_training:\n                return True\n            if p.is_completed or p.error:\n                return False\n\n            # Check status message for activity indicators\n            status_lower = (p.status_message or \"\").lower()\n            if any(\n                k in status_lower\n                for k in [\n                    \"cancelled\",\n                    \"canceled\",\n                    \"stopped\",\n                    \"completed\",\n                    \"ready to train\",\n                ]\n            ):\n                return False\n            if any(\n                k in status_lower\n                for k in [\n                    \"loading\",\n                    \"preparing\",\n                    \"training\",\n                    \"configuring\",\n                    \"tokenizing\",\n                    \"starting\",\n                    \"importing\",\n                ]\n            ):\n                return True\n\n            return False\n\n    def get_training_status(self, theme: str = \"light\") -> Tuple:\n        \"\"\"Get current training status and loss plot.\"\"\"\n        with self._lock:\n            progress = self._progress\n\n        if not (progress.is_training or progress.is_completed or progress.error):\n            return (None, progress)\n\n        plot = self._create_loss_plot(progress, theme)\n        return (plot, progress)\n\n    def refresh_plot_for_theme(self, theme: str) -> Optional[plt.Figure]:\n        \"\"\"Refresh plot with new theme.\"\"\"\n        if theme and isinstance(theme, str) and theme in [\"light\", \"dark\"]:\n            self.current_theme = theme\n        if self.loss_history:\n            with self._lock:\n                progress = self._progress\n            return self._create_loss_plot(progress, self.current_theme)\n        return None\n\n    # ------------------------------------------------------------------\n    # Compatibility shims — routes/training.py accesses these\n    # ------------------------------------------------------------------\n\n    class _TrainerShim:\n        \"\"\"Minimal shim so routes that access backend.trainer.* still work.\"\"\"\n\n        def __init__(self, backend: \"TrainingBackend\"):\n            self._backend = backend\n            self.should_stop = False\n\n        @property\n        def training_progress(self):\n            return self._backend._progress\n\n        @training_progress.setter\n        def training_progress(self, value):\n            self._backend._progress = value\n\n        def get_training_progress(self):\n            return self._backend._progress\n\n        def _update_progress(self, **kwargs):\n            with self._backend._lock:\n                for key, value in kwargs.items():\n                    if hasattr(self._backend._progress, key):\n                        setattr(self._backend._progress, key, value)\n\n    @property\n    def trainer(self):\n        \"\"\"Compatibility shim for routes that access backend.trainer.*\"\"\"\n        return self._TrainerShim(self)\n\n    # ------------------------------------------------------------------\n    # Event pump (background thread)\n    # ------------------------------------------------------------------\n\n    def _pump_loop(self) -> None:\n        \"\"\"Background thread: consume events from subprocess → update state.\"\"\"\n        while True:\n            if self._proc is None or self._event_queue is None:\n                return\n\n            # Try to read an event\n            event = self._read_queue(self._event_queue, timeout_sec = 0.25)\n            if event is not None:\n                self._handle_event(event)\n                continue\n\n            # No event — check if process is still alive\n            if self._proc.is_alive():\n                continue\n\n            # Process exited — drain remaining events\n            for e in self._drain_queue(self._event_queue):\n                self._handle_event(e)\n\n            # Mark as done if no explicit complete/error was received\n            with self._lock:\n                if self._progress.is_training:\n                    if self._should_stop:\n                        self._progress.is_training = False\n                        self._progress.status_message = \"Training stopped.\"\n                    else:\n                        self._progress.is_training = False\n                        self._progress.error = (\n                            self._progress.error\n                            or \"Training process exited unexpectedly\"\n                        )\n            return\n\n    def _handle_event(self, event: dict) -> None:\n        \"\"\"Apply a subprocess event to local state.\"\"\"\n        etype = event.get(\"type\")\n\n        with self._lock:\n            if etype == \"progress\":\n                self._progress.step = event.get(\"step\", self._progress.step)\n                self._progress.epoch = event.get(\"epoch\", self._progress.epoch)\n                self._progress.loss = event.get(\"loss\", self._progress.loss)\n                self._progress.learning_rate = event.get(\n                    \"learning_rate\", self._progress.learning_rate\n                )\n                self._progress.total_steps = event.get(\n                    \"total_steps\", self._progress.total_steps\n                )\n                self._progress.elapsed_seconds = event.get(\"elapsed_seconds\")\n                self._progress.eta_seconds = event.get(\"eta_seconds\")\n                self._progress.grad_norm = event.get(\"grad_norm\")\n                self._progress.num_tokens = event.get(\"num_tokens\")\n                self._progress.eval_loss = event.get(\"eval_loss\")\n                self._progress.is_training = True\n                status = event.get(\"status_message\", \"\")\n                if status:\n                    self._progress.status_message = status\n\n                # Update metric histories\n                step = event.get(\"step\", 0)\n                loss = event.get(\"loss\", 0.0)\n                lr = event.get(\"learning_rate\", 0.0)\n                if step >= 0 and loss > 0:\n                    self.loss_history.append(loss)\n                    self.lr_history.append(lr)\n                    self.step_history.append(step)\n\n                grad_norm = event.get(\"grad_norm\")\n                if grad_norm is not None:\n                    try:\n                        gn = float(grad_norm)\n                    except (TypeError, ValueError):\n                        gn = None\n                    if gn is not None and math.isfinite(gn):\n                        self.grad_norm_history.append(gn)\n                        self.grad_norm_step_history.append(step)\n\n                eval_loss = event.get(\"eval_loss\")\n                if eval_loss is not None:\n                    self.eval_loss_history.append(eval_loss)\n                    self.eval_step_history.append(step)\n                    self.eval_enabled = True\n\n            elif etype == \"eval_configured\":\n                self.eval_enabled = True\n\n            elif etype == \"status\":\n                self._progress.status_message = event.get(\"message\", \"\")\n                self._progress.is_training = True\n\n            elif etype == \"complete\":\n                self._progress.is_training = False\n                self._progress.is_completed = True\n                self._output_dir = event.get(\"output_dir\")\n                msg = event.get(\"status_message\", \"Training completed\")\n                self._progress.status_message = msg\n\n            elif etype == \"error\":\n                self._progress.is_training = False\n                self._progress.error = event.get(\"error\", \"Unknown error\")\n                logger.error(\"Training error: %s\", event.get(\"error\"))\n                stack = event.get(\"stack\", \"\")\n                if stack:\n                    logger.error(\"Stack trace:\\n%s\", stack)\n\n    @staticmethod\n    def _read_queue(q: Any, timeout_sec: float) -> Optional[dict]:\n        try:\n            return q.get(timeout = timeout_sec)\n        except queue.Empty:\n            return None\n        except (EOFError, OSError, ValueError):\n            return None\n\n    @staticmethod\n    def _drain_queue(q: Any) -> list:\n        events = []\n        while True:\n            try:\n                events.append(q.get_nowait())\n            except queue.Empty:\n                return events\n            except (EOFError, OSError, ValueError):\n                return events\n\n    # ------------------------------------------------------------------\n    # Plot generation (unchanged from original)\n    # ------------------------------------------------------------------\n\n    def _create_loss_plot(\n        self, progress: TrainingProgress, theme: str = \"light\"\n    ) -> plt.Figure:\n        \"\"\"Create training loss plot with theme-aware styling.\"\"\"\n        plt.close(\"all\")\n\n        LIGHT_STYLE = {\n            \"facecolor\": \"#ffffff\",\n            \"grid_color\": \"#d1d5db\",\n            \"line\": \"#16b88a\",\n            \"text\": \"#1f2937\",\n            \"empty_text\": \"#6b7280\",\n        }\n        DARK_STYLE = {\n            \"facecolor\": \"#292929\",\n            \"grid_color\": \"#404040\",\n            \"line\": \"#4ade80\",\n            \"text\": \"#e5e7eb\",\n            \"empty_text\": \"#9ca3af\",\n        }\n\n        style = LIGHT_STYLE if theme == \"light\" else DARK_STYLE\n\n        fig, ax = plt.subplots(figsize = (PLOT_WIDTH, PLOT_HEIGHT))\n        fig.patch.set_facecolor(style[\"facecolor\"])\n        ax.set_facecolor(style[\"facecolor\"])\n\n        if self.loss_history:\n            steps = self.step_history\n            losses = self.loss_history\n            scatter_color = \"#60a5fa\"\n            ax.scatter(\n                steps,\n                losses,\n                s = 16,\n                alpha = 0.6,\n                color = scatter_color,\n                linewidths = 0,\n                label = \"Training Loss (raw)\",\n            )\n\n            MA_WINDOW = 20\n            window = min(MA_WINDOW, len(losses))\n\n            if window >= 2:\n                cumsum = [0.0]\n                for v in losses:\n                    cumsum.append(cumsum[-1] + float(v))\n\n                ma = []\n                for i in range(len(losses)):\n                    start = max(0, i - window + 1)\n                    denom = i - start + 1\n                    ma.append((cumsum[i + 1] - cumsum[start]) / denom)\n\n                ax.plot(\n                    steps,\n                    ma,\n                    color = style[\"line\"],\n                    linewidth = 2.5,\n                    alpha = 0.95,\n                    label = f\"Moving Avg ({ma[-1]:.4f})\",\n                )\n\n                leg = ax.legend(frameon = False, fontsize = 9)\n                for t in leg.get_texts():\n                    t.set_color(style[\"text\"])\n\n            ax.set_xlabel(\"Steps\", fontsize = 10, color = style[\"text\"])\n            ax.set_ylabel(\"Loss\", fontsize = 10, color = style[\"text\"])\n\n            if progress.error:\n                title = f\"Error: {progress.error}\"\n            elif progress.is_completed:\n                title = f\"Training completed! Final loss: {progress.loss:.4f}\"\n            elif progress.status_message:\n                title = progress.status_message\n            elif progress.step > 0:\n                title = f\"Epoch: {progress.epoch} | Step: {progress.step}/{progress.total_steps} | Loss: {progress.loss:.4f}\"\n            else:\n                title = \"Training Loss\"\n\n            ax.set_title(\n                title, fontsize = 11, fontweight = \"bold\", pad = 10, color = style[\"text\"]\n            )\n            ax.grid(True, alpha = 0.4, linestyle = \"--\", color = style[\"grid_color\"])\n            ax.tick_params(colors = style[\"text\"], which = \"both\")\n            ax.spines[\"top\"].set_visible(False)\n            ax.spines[\"right\"].set_visible(False)\n            ax.spines[\"bottom\"].set_color(style[\"text\"])\n            ax.spines[\"left\"].set_color(style[\"text\"])\n        else:\n            display_msg = (\n                progress.status_message\n                if progress.status_message\n                else \"Waiting for training data...\"\n            )\n            ax.text(\n                0.5,\n                0.5,\n                display_msg,\n                ha = \"center\",\n                va = \"center\",\n                fontsize = 16,\n                color = style[\"empty_text\"],\n                transform = ax.transAxes,\n            )\n            ax.set_xticks([])\n            ax.set_yticks([])\n            for spine in ax.spines.values():\n                spine.set_visible(False)\n\n        fig.tight_layout()\n        return fig\n\n    def _transfer_to_inference_backend(self) -> bool:\n        \"\"\"Transfer model to inference backend.\n\n        With subprocess-based training, the model lives in the subprocess\n        and is freed when it exits. Inference must load from the saved\n        checkpoint on disk. This is a no-op placeholder.\n        \"\"\"\n        logger.info(\n            \"_transfer_to_inference_backend: subprocess training — \"\n            \"model must be loaded from disk (output_dir=%s)\",\n            self._output_dir,\n        )\n        return False\n\n\n# ========== GLOBAL INSTANCE ==========\n_training_backend = None\n\n\ndef get_training_backend() -> TrainingBackend:\n    \"\"\"Get global training backend instance\"\"\"\n    global _training_backend\n    if _training_backend is None:\n        _training_backend = TrainingBackend()\n    return _training_backend\n"
  },
  {
    "path": "studio/backend/core/training/worker.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nTraining subprocess entry point.\n\nEach training job runs in a fresh subprocess (mp.get_context(\"spawn\")).\nThis gives us a clean Python interpreter with no stale module state —\nsolving the transformers version-switching problem completely.\n\nPattern follows core/data_recipe/jobs/worker.py.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport structlog\nfrom loggers import get_logger\nimport os\nimport sys\nimport time\nimport traceback\nfrom pathlib import Path\nfrom typing import Any\n\nlogger = get_logger(__name__)\n\n\ndef _activate_transformers_version(model_name: str) -> None:\n    \"\"\"Activate the correct transformers version BEFORE any ML imports.\n\n    If the model needs transformers 5.x, prepend the pre-installed .venv_t5/\n    directory to sys.path. Otherwise do nothing (default 4.57.x in .venv/).\n    \"\"\"\n    # Ensure backend is on path for utils imports\n    backend_path = str(Path(__file__).resolve().parent.parent.parent)\n    if backend_path not in sys.path:\n        sys.path.insert(0, backend_path)\n\n    from utils.transformers_version import (\n        needs_transformers_5,\n        _resolve_base_model,\n        _ensure_venv_t5_exists,\n        _VENV_T5_DIR,\n    )\n\n    resolved = _resolve_base_model(model_name)\n    if needs_transformers_5(resolved):\n        if not _ensure_venv_t5_exists():\n            raise RuntimeError(\n                f\"Cannot activate transformers 5.x: .venv_t5 missing at {_VENV_T5_DIR}\"\n            )\n        if _VENV_T5_DIR not in sys.path:\n            sys.path.insert(0, _VENV_T5_DIR)\n        logger.info(\"Activated transformers 5.x from %s\", _VENV_T5_DIR)\n        # Propagate to child subprocesses (e.g. GGUF converter)\n        _pp = os.environ.get(\"PYTHONPATH\", \"\")\n        os.environ[\"PYTHONPATH\"] = _VENV_T5_DIR + (os.pathsep + _pp if _pp else \"\")\n    else:\n        logger.info(\"Using default transformers (4.57.x) for %s\", model_name)\n\n\ndef run_training_process(\n    *,\n    event_queue: Any,\n    stop_queue: Any,\n    config: dict,\n) -> None:\n    \"\"\"Subprocess entrypoint. Fresh Python — no stale module state.\n\n    Args:\n        event_queue: mp.Queue for sending progress/status/error events to parent.\n        stop_queue: mp.Queue for receiving stop commands from parent.\n        config: Training configuration dict with all parameters.\n    \"\"\"\n    os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n    os.environ[\"PYTHONWARNINGS\"] = (\n        \"ignore\"  # Suppress warnings at C-level before imports\n    )\n\n    import warnings\n    from loggers.config import LogConfig\n\n    if os.getenv(\"ENVIRONMENT_TYPE\", \"production\") == \"production\":\n        warnings.filterwarnings(\"ignore\")\n\n    LogConfig.setup_logging(\n        service_name = \"unsloth-studio-training-worker\",\n        env = os.getenv(\"ENVIRONMENT_TYPE\", \"production\"),\n    )\n\n    model_name = config[\"model_name\"]\n\n    # ── 1. Activate correct transformers version BEFORE any ML imports ──\n    try:\n        _activate_transformers_version(model_name)\n    except Exception as exc:\n        event_queue.put(\n            {\n                \"type\": \"error\",\n                \"error\": f\"Failed to activate transformers version: {exc}\",\n                \"stack\": traceback.format_exc(limit = 20),\n                \"ts\": time.time(),\n            }\n        )\n        return\n\n    # ── 1a. Auto-enable trust_remote_code for unsloth/* transformers 5.x models ──\n    # Some newer architectures (e.g. NemotronH) have config parsing bugs in\n    # transformers that require trust_remote_code=True as a workaround.\n    # Only auto-enable for unsloth/* prefixed models (trusted source).\n    from utils.transformers_version import needs_transformers_5\n\n    if (\n        needs_transformers_5(model_name)\n        and model_name.lower().startswith(\"unsloth/\")\n        and not config.get(\"trust_remote_code\", False)\n    ):\n        config[\"trust_remote_code\"] = True\n        logger.info(\n            \"Auto-enabled trust_remote_code for unsloth/* transformers 5.x model: %s\",\n            model_name,\n        )\n\n    # ── 1b. Auto-install mamba-ssm for SSM/hybrid models (NemotronH, Falcon-H1) ──\n    _SSM_MODEL_SUBSTRINGS = (\"nemotron_h\", \"nemotron-3-nano\", \"falcon_h1\", \"falcon-h1\")\n    if any(sub in model_name.lower() for sub in _SSM_MODEL_SUBSTRINGS):\n        try:\n            import mamba_ssm  # noqa: F401\n\n            logger.info(\"mamba-ssm already installed\")\n        except ImportError:\n            logger.info(\n                \"SSM model detected — installing mamba-ssm and causal-conv1d (this may take several minutes)...\"\n            )\n            _send_status(\n                event_queue, \"Installing mamba-ssm (first time only, ~7 min)...\"\n            )\n            import subprocess as _sp\n\n            # --no-build-isolation: compile against current torch (no version conflicts)\n            # --no-deps: don't pull in torch/transformers/triton (already installed)\n            for _pkg in [\"causal_conv1d\", \"mamba_ssm\"]:\n                _r = _sp.run(\n                    [\n                        sys.executable,\n                        \"-m\",\n                        \"pip\",\n                        \"install\",\n                        \"--no-build-isolation\",\n                        \"--no-deps\",\n                        \"--no-cache-dir\",\n                        _pkg,\n                    ],\n                    stdout = _sp.PIPE,\n                    stderr = _sp.STDOUT,\n                    text = True,\n                )\n                if _r.returncode != 0:\n                    logger.error(\"Failed to install %s:\\n%s\", _pkg, _r.stdout)\n                else:\n                    logger.info(\"Installed %s successfully\", _pkg)\n            logger.info(\"mamba-ssm installation complete\")\n\n    # ── 1c. Set fork start method so dataset.map() can multiprocess ──\n    # The parent launched us via spawn (clean process), but the compiled\n    # SFTTrainer checks get_start_method() and disables num_proc if not \"fork\".\n    # Linux only: fork is the default start method and is safe here (no CUDA\n    # context exists yet). macOS defaults to spawn since Python 3.8 because\n    # fork is unsafe with macOS frameworks (Metal/MPS, CoreFoundation) --\n    # do NOT override on macOS. Windows has no fork at all.\n    if sys.platform == \"linux\":\n        import multiprocessing as _mp\n\n        try:\n            _mp.set_start_method(\"fork\", force = True)\n        except RuntimeError:\n            pass  # Already set\n\n    # ── 1c. On Windows, check Triton availability (must be before import torch) ──\n    if sys.platform == \"win32\":\n        try:\n            import triton  # noqa: F401\n\n            logger.info(\"Triton available — torch.compile enabled\")\n        except ImportError:\n            os.environ[\"TORCHDYNAMO_DISABLE\"] = \"1\"\n            logger.warning(\n                \"Triton not found on Windows — torch.compile disabled. \"\n                'Install for better performance: pip install \"triton-windows<3.7\"'\n            )\n\n    # ── 2. Now import ML libraries (fresh in this clean process) ──\n    try:\n        _send_status(event_queue, \"Importing Unsloth...\")\n\n        backend_path = str(Path(__file__).resolve().parent.parent.parent)\n        if backend_path not in sys.path:\n            sys.path.insert(0, backend_path)\n\n        from core.training.trainer import UnslothTrainer, TrainingProgress\n        from utils.paths import (\n            ensure_dir,\n            resolve_output_dir,\n            resolve_tensorboard_dir,\n            datasets_root,\n        )\n\n        import transformers\n\n        logger.info(\"Subprocess loaded transformers %s\", transformers.__version__)\n    except Exception as exc:\n        event_queue.put(\n            {\n                \"type\": \"error\",\n                \"error\": f\"Failed to import ML libraries: {exc}\",\n                \"stack\": traceback.format_exc(limit = 20),\n                \"ts\": time.time(),\n            }\n        )\n        return\n\n    # ── 2b. EMBEDDING MODEL FAST-PATH ──\n    # Embedding models use a completely different pipeline (FastSentenceTransformer\n    # + SentenceTransformerTrainer + MultipleNegativesRankingLoss) so we branch\n    # early and handle the entire flow in a self-contained function.\n    if config.get(\"is_embedding\", False):\n        try:\n            _run_embedding_training(event_queue, stop_queue, config)\n        except Exception as exc:\n            event_queue.put(\n                {\n                    \"type\": \"error\",\n                    \"error\": str(exc),\n                    \"stack\": traceback.format_exc(limit = 20),\n                    \"ts\": time.time(),\n                }\n            )\n        return\n\n    # ── 3. Create a fresh trainer instance ──\n    trainer = UnslothTrainer()\n\n    # Wire up progress callback → event_queue\n    def _on_progress(progress: TrainingProgress):\n        has_train_loss = progress.step >= 0 and progress.loss > 0\n        has_eval_loss = progress.eval_loss is not None\n        if has_train_loss or has_eval_loss:\n            event_queue.put(\n                {\n                    \"type\": \"progress\",\n                    \"step\": progress.step,\n                    \"epoch\": progress.epoch,\n                    \"loss\": progress.loss,\n                    \"learning_rate\": progress.learning_rate,\n                    \"total_steps\": progress.total_steps,\n                    \"elapsed_seconds\": progress.elapsed_seconds,\n                    \"eta_seconds\": progress.eta_seconds,\n                    \"grad_norm\": progress.grad_norm,\n                    \"num_tokens\": progress.num_tokens,\n                    \"eval_loss\": progress.eval_loss,\n                    \"status_message\": progress.status_message,\n                    \"ts\": time.time(),\n                }\n            )\n        if progress.status_message:\n            _send_status(event_queue, progress.status_message)\n\n    trainer.add_progress_callback(_on_progress)\n\n    # Wire up stop_queue polling to trainer.should_stop\n    import threading\n    import queue as _queue\n\n    def _poll_stop():\n        while True:\n            try:\n                msg = stop_queue.get(timeout = 1.0)\n                if msg and msg.get(\"type\") == \"stop\":\n                    save = msg.get(\"save\", True)\n                    trainer.should_stop = True\n                    trainer.save_on_stop = save\n                    logger.info(\"Stop signal received (save=%s)\", save)\n                    return\n            except _queue.Empty:\n                continue\n            except (EOFError, OSError):\n                return\n\n    stop_thread = threading.Thread(target = _poll_stop, daemon = True)\n    stop_thread.start()\n\n    # ── 4. Execute the training pipeline ──\n    # Order: detect → dataset → model → prepare → train\n    # Dataset processing (including LLM-assisted detection) runs BEFORE model\n    # loading so both never occupy VRAM at the same time.\n    try:\n        hf_token = config.get(\"hf_token\", \"\")\n        hf_token = hf_token if hf_token and hf_token.strip() else None\n\n        # ── 4a. Lightweight detection + tokenizer (no VRAM) ──\n        _send_status(event_queue, \"Detecting model type...\")\n        trainer.pre_detect_and_load_tokenizer(\n            model_name = model_name,\n            max_seq_length = config[\"max_seq_length\"],\n            hf_token = hf_token,\n            is_dataset_image = config.get(\"is_dataset_image\", False),\n            is_dataset_audio = config.get(\"is_dataset_audio\", False),\n            trust_remote_code = config.get(\"trust_remote_code\", False),\n        )\n        if trainer.should_stop:\n            event_queue.put({\"type\": \"complete\", \"output_dir\": None, \"ts\": time.time()})\n            return\n\n        # ── 4b. Load and format dataset (LLM helper may use VRAM briefly) ──\n        _send_status(event_queue, \"Loading and formatting dataset...\")\n        hf_dataset = config.get(\"hf_dataset\", \"\")\n        dataset_result = trainer.load_and_format_dataset(\n            dataset_source = hf_dataset if hf_dataset and hf_dataset.strip() else None,\n            format_type = config.get(\"format_type\", \"\"),\n            local_datasets = config.get(\"local_datasets\") or None,\n            local_eval_datasets = config.get(\"local_eval_datasets\") or None,\n            custom_format_mapping = config.get(\"custom_format_mapping\"),\n            subset = config.get(\"subset\"),\n            train_split = config.get(\"train_split\", \"train\"),\n            eval_split = config.get(\"eval_split\"),\n            eval_steps = config.get(\"eval_steps\", 0.00),\n            dataset_slice_start = config.get(\"dataset_slice_start\"),\n            dataset_slice_end = config.get(\"dataset_slice_end\"),\n        )\n\n        if isinstance(dataset_result, tuple):\n            dataset, eval_dataset = dataset_result\n        else:\n            dataset = dataset_result\n            eval_dataset = None\n\n        # [DEBUG] Print first sample before model is loaded\n        # dataset is a dict {\"dataset\": <Dataset>, \"detected_format\": ..., ...}\n        # or a raw Dataset for audio paths\n        # try:\n        #     ds = dataset[\"dataset\"] if isinstance(dataset, dict) else dataset\n        #     print(\n        #         f\"\\n[DEBUG] Dataset loaded BEFORE model. type={type(ds).__name__}, len={len(ds)}\",\n        #         flush = True,\n        #     )\n        #     print(f\"[DEBUG] Columns: {ds.column_names}\", flush = True)\n        #     sample = ds[0]\n        #     preview = {k: str(v)[:300] for k, v in sample.items()}\n        #     print(f\"[DEBUG] First sample: {preview}\\n\", flush = True)\n        # except Exception as e:\n        #     print(\n        #         f\"[DEBUG] Could not preview first sample: {type(e).__name__}: {e}\",\n        #         flush = True,\n        #     )\n\n        # Disable eval if eval_steps <= 0\n        eval_steps = config.get(\"eval_steps\", 0.00)\n        if eval_steps is not None and float(eval_steps) <= 0:\n            eval_dataset = None\n\n        # Tell the parent process that eval is configured so the frontend\n        # shows \"Waiting for first evaluation step...\" instead of \"not configured\"\n        if eval_dataset is not None:\n            event_queue.put(\n                {\n                    \"type\": \"eval_configured\",\n                    \"ts\": time.time(),\n                }\n            )\n\n        if dataset is None or trainer.should_stop:\n            if trainer.should_stop:\n                event_queue.put(\n                    {\"type\": \"complete\", \"output_dir\": None, \"ts\": time.time()}\n                )\n            else:\n                event_queue.put(\n                    {\n                        \"type\": \"error\",\n                        \"error\": trainer.training_progress.error\n                        or \"Failed to load dataset\",\n                        \"stack\": \"\",\n                        \"ts\": time.time(),\n                    }\n                )\n            return\n\n        # ── Start tqdm monitor early so it captures download + tokenization bars ──\n        import threading as _th\n\n        _tqdm_stop = _th.Event()\n\n        def _monitor_tqdm():\n            from tqdm.auto import tqdm as _tqdm_cls\n\n            while not _tqdm_stop.is_set():\n                for bar in list(getattr(_tqdm_cls, \"_instances\", set())):\n                    try:\n                        n, total = bar.n or 0, bar.total or 0\n                        desc = getattr(bar, \"desc\", \"\") or \"\"\n                        if total > 0 and n > 0 and desc:\n                            pct = min(int(n * 100 / total), 100)\n                            _send_status(\n                                event_queue, f\"{desc.strip()} {pct}% ({n:,}/{total:,})\"\n                            )\n                    except (AttributeError, ReferenceError):\n                        pass\n                _tqdm_stop.wait(3)\n\n        _tqdm_thread = _th.Thread(target = _monitor_tqdm, daemon = True)\n        _tqdm_thread.start()\n\n        training_type = config.get(\"training_type\", \"LoRA/QLoRA\")\n        use_lora = training_type == \"LoRA/QLoRA\"\n\n        # ── 4c. Load training model (uses VRAM — dataset already formatted) ──\n        _send_status(event_queue, \"Loading model...\")\n        success = trainer.load_model(\n            model_name = model_name,\n            max_seq_length = config[\"max_seq_length\"],\n            load_in_4bit = config[\"load_in_4bit\"],\n            full_finetuning = not use_lora,\n            hf_token = hf_token,\n            is_dataset_image = config.get(\"is_dataset_image\", False),\n            is_dataset_audio = config.get(\"is_dataset_audio\", False),\n            trust_remote_code = config.get(\"trust_remote_code\", False),\n        )\n        if not success or trainer.should_stop:\n            if trainer.should_stop:\n                event_queue.put(\n                    {\"type\": \"complete\", \"output_dir\": None, \"ts\": time.time()}\n                )\n            else:\n                error_msg = trainer.training_progress.error or \"Failed to load model\"\n                event_queue.put(\n                    {\n                        \"type\": \"error\",\n                        \"error\": error_msg,\n                        \"stack\": \"\",\n                        \"ts\": time.time(),\n                    }\n                )\n            return\n\n        # ── 4d. Prepare model (LoRA or full finetuning) ──\n        if use_lora:\n            _send_status(event_queue, \"Configuring LoRA adapters...\")\n            success = trainer.prepare_model_for_training(\n                use_lora = True,\n                finetune_vision_layers = config.get(\"finetune_vision_layers\", True),\n                finetune_language_layers = config.get(\"finetune_language_layers\", True),\n                finetune_attention_modules = config.get(\n                    \"finetune_attention_modules\", True\n                ),\n                finetune_mlp_modules = config.get(\"finetune_mlp_modules\", True),\n                target_modules = config.get(\"target_modules\"),\n                lora_r = config.get(\"lora_r\", 16),\n                lora_alpha = config.get(\"lora_alpha\", 16),\n                lora_dropout = config.get(\"lora_dropout\", 0.0),\n                use_gradient_checkpointing = config.get(\n                    \"gradient_checkpointing\", \"unsloth\"\n                ),\n                use_rslora = config.get(\"use_rslora\", False),\n                use_loftq = config.get(\"use_loftq\", False),\n            )\n        else:\n            _send_status(event_queue, \"Preparing model for full finetuning...\")\n            success = trainer.prepare_model_for_training(use_lora = False)\n\n        if not success or trainer.should_stop:\n            if trainer.should_stop:\n                event_queue.put(\n                    {\"type\": \"complete\", \"output_dir\": None, \"ts\": time.time()}\n                )\n            else:\n                event_queue.put(\n                    {\n                        \"type\": \"error\",\n                        \"error\": trainer.training_progress.error\n                        or \"Failed to prepare model\",\n                        \"stack\": \"\",\n                        \"ts\": time.time(),\n                    }\n                )\n            return\n\n        # Convert learning rate\n        try:\n            lr_value = float(config.get(\"learning_rate\", \"2e-4\"))\n        except ValueError:\n            event_queue.put(\n                {\n                    \"type\": \"error\",\n                    \"error\": f\"Invalid learning rate: {config.get('learning_rate')}\",\n                    \"stack\": \"\",\n                    \"ts\": time.time(),\n                }\n            )\n            return\n\n        # Generate output dir\n        output_dir = config.get(\"output_dir\")\n        if not output_dir:\n            output_dir = f\"{model_name.replace('/', '_')}_{int(time.time())}\"\n        output_dir = str(resolve_output_dir(output_dir))\n        ensure_dir(Path(output_dir))\n\n        tensorboard_dir = config.get(\"tensorboard_dir\")\n        if config.get(\"enable_tensorboard\", False):\n            tensorboard_dir = str(resolve_tensorboard_dir(tensorboard_dir))\n            ensure_dir(Path(tensorboard_dir))\n\n        # Start training (directly — no inner thread, we ARE the subprocess)\n        dataset_display = (\n            config.get(\"hf_dataset\", \"\") or config.get(\"uploaded_file\", \"\") or \"\"\n        )\n        _send_status(\n            event_queue,\n            f'Training \"{model_name}\"'\n            + (f\"\\nDataset = {dataset_display}\" if dataset_display else \"\"),\n        )\n        max_steps = config.get(\"max_steps\", 0)\n        save_steps = config.get(\"save_steps\", 0)\n\n        trainer._train_worker(\n            dataset,\n            output_dir = output_dir,\n            num_epochs = config.get(\"num_epochs\", 3),\n            learning_rate = lr_value,\n            batch_size = config.get(\"batch_size\", 2),\n            gradient_accumulation_steps = config.get(\"gradient_accumulation_steps\", 4),\n            warmup_steps = config.get(\"warmup_steps\"),\n            warmup_ratio = config.get(\"warmup_ratio\"),\n            max_steps = max_steps if max_steps and max_steps > 0 else 0,\n            save_steps = save_steps if save_steps and save_steps > 0 else 0,\n            weight_decay = config.get(\"weight_decay\", 0.01),\n            random_seed = config.get(\"random_seed\", 3407),\n            packing = config.get(\"packing\", False),\n            train_on_completions = config.get(\"train_on_completions\", False),\n            enable_wandb = config.get(\"enable_wandb\", False),\n            wandb_project = config.get(\"wandb_project\", \"unsloth-training\"),\n            wandb_token = config.get(\"wandb_token\"),\n            enable_tensorboard = config.get(\"enable_tensorboard\", False),\n            tensorboard_dir = tensorboard_dir,\n            eval_dataset = eval_dataset,\n            eval_steps = eval_steps,\n            max_seq_length = config.get(\"max_seq_length\", 2048),\n            optim = config.get(\"optim\", \"adamw_8bit\"),\n            lr_scheduler_type = config.get(\"lr_scheduler_type\", \"linear\"),\n        )\n\n        _tqdm_stop.set()\n\n        # Check final state\n        progress = trainer.get_training_progress()\n        if progress.error:\n            event_queue.put(\n                {\n                    \"type\": \"error\",\n                    \"error\": progress.error,\n                    \"stack\": \"\",\n                    \"ts\": time.time(),\n                }\n            )\n        else:\n            event_queue.put(\n                {\n                    \"type\": \"complete\",\n                    \"output_dir\": output_dir,\n                    \"status_message\": progress.status_message or \"Training completed\",\n                    \"ts\": time.time(),\n                }\n            )\n\n    except Exception as exc:\n        event_queue.put(\n            {\n                \"type\": \"error\",\n                \"error\": str(exc),\n                \"stack\": traceback.format_exc(limit = 20),\n                \"ts\": time.time(),\n            }\n        )\n\n\ndef _send_status(event_queue: Any, message: str) -> None:\n    \"\"\"Send a status update to the parent process.\"\"\"\n    event_queue.put(\n        {\n            \"type\": \"status\",\n            \"message\": message,\n            \"ts\": time.time(),\n        }\n    )\n\n\ndef _run_embedding_training(event_queue: Any, stop_queue: Any, config: dict) -> None:\n    \"\"\"Self-contained embedding model training pipeline.\n\n    Uses FastSentenceTransformer + SentenceTransformerTrainer +\n    MultipleNegativesRankingLoss — completely separate from the\n    LLM/VLM/audio paths in UnslothTrainer.\n\n    Mirrors the pattern from the reference embedding notebooks:\n      All_MiniLM_L6_v2.py, BGE_M3.py, EmbeddingGemma_300M.py,\n      ModernBert.py, Qwen3_Embedding_0_6B.py\n    \"\"\"\n    import math\n    import queue as _queue\n    import threading\n\n    model_name = config[\"model_name\"]\n    training_start_time = time.time()\n\n    # ── 1. Import embedding-specific libraries ──\n    _send_status(event_queue, \"Importing embedding libraries...\")\n    try:\n        from unsloth import FastSentenceTransformer, is_bfloat16_supported\n        from sentence_transformers import (\n            SentenceTransformerTrainer,\n            SentenceTransformerTrainingArguments,\n        )\n        from sentence_transformers.losses import MultipleNegativesRankingLoss\n        from sentence_transformers.training_args import BatchSamplers\n        from datasets import load_dataset, Dataset\n        from transformers import TrainerCallback\n        from utils.paths import datasets_root, resolve_output_dir\n    except ImportError as e:\n        event_queue.put(\n            {\n                \"type\": \"error\",\n                \"error\": f\"Failed to import embedding libraries: {e}. \"\n                \"Ensure 'sentence_transformers' and 'unsloth' are installed.\",\n                \"stack\": traceback.format_exc(limit = 20),\n                \"ts\": time.time(),\n            }\n        )\n        return\n\n    # ── Stop signal handling ──\n    _should_stop = False\n    _save_on_stop = True\n\n    def _poll_stop():\n        nonlocal _should_stop, _save_on_stop\n        while True:\n            try:\n                msg = stop_queue.get(timeout = 1.0)\n                if msg and msg.get(\"type\") == \"stop\":\n                    _save_on_stop = msg.get(\"save\", True)\n                    _should_stop = True\n                    logger.info(\n                        \"Embedding training: stop signal received (save=%s)\",\n                        _save_on_stop,\n                    )\n                    return\n            except _queue.Empty:\n                continue\n            except (EOFError, OSError):\n                return\n\n    stop_thread = threading.Thread(target = _poll_stop, daemon = True)\n    stop_thread.start()\n\n    # ── 2. Load model ──\n    _send_status(event_queue, \"Loading embedding model...\")\n    try:\n        hf_token = config.get(\"hf_token\", \"\")\n        hf_token = hf_token if hf_token and hf_token.strip() else None\n        max_seq_length = config.get(\"max_seq_length\", 512)\n        training_type = config.get(\"training_type\", \"LoRA/QLoRA\")\n        use_lora = training_type == \"LoRA/QLoRA\"\n\n        model = FastSentenceTransformer.from_pretrained(\n            model_name = model_name,\n            max_seq_length = max_seq_length,\n            full_finetuning = not use_lora,\n            token = hf_token,\n        )\n    except Exception as e:\n        event_queue.put(\n            {\n                \"type\": \"error\",\n                \"error\": f\"Failed to load embedding model '{model_name}': {e}\",\n                \"stack\": traceback.format_exc(limit = 20),\n                \"ts\": time.time(),\n            }\n        )\n        return\n\n    if _should_stop:\n        event_queue.put({\"type\": \"complete\", \"output_dir\": None, \"ts\": time.time()})\n        return\n\n    # ── 3. Apply LoRA ──\n    if use_lora:\n        _send_status(event_queue, \"Configuring LoRA adapters (FEATURE_EXTRACTION)...\")\n        try:\n            gradient_checkpointing = config.get(\"gradient_checkpointing\", False)\n            # Normalize: \"none\" or empty → False\n            if gradient_checkpointing in (\"none\", \"\", None):\n                gradient_checkpointing = False\n\n            model = FastSentenceTransformer.get_peft_model(\n                model,\n                r = config.get(\"lora_r\", 32),\n                target_modules = config.get(\"target_modules\")\n                or [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\"],\n                lora_alpha = config.get(\"lora_alpha\", 64),\n                lora_dropout = config.get(\"lora_dropout\", 0.0),\n                bias = \"none\",\n                use_gradient_checkpointing = gradient_checkpointing,\n                random_state = config.get(\"random_seed\", 3407),\n                use_rslora = config.get(\"use_rslora\", False),\n                loftq_config = {\"loftq_bits\": 4, \"loftq_iter\": 1}\n                if config.get(\"use_loftq\")\n                else None,\n                task_type = \"FEATURE_EXTRACTION\",\n            )\n        except Exception as e:\n            event_queue.put(\n                {\n                    \"type\": \"error\",\n                    \"error\": f\"Failed to configure LoRA for embedding model: {e}\",\n                    \"stack\": traceback.format_exc(limit = 20),\n                    \"ts\": time.time(),\n                }\n            )\n            return\n\n    if _should_stop:\n        event_queue.put({\"type\": \"complete\", \"output_dir\": None, \"ts\": time.time()})\n        return\n\n    # ── 4. Load dataset ──\n    _send_status(event_queue, \"Loading dataset...\")\n    try:\n        hf_dataset = config.get(\"hf_dataset\", \"\")\n        local_datasets = config.get(\"local_datasets\") or []\n        subset = config.get(\"subset\") or None\n        train_split = config.get(\"train_split\", \"train\") or \"train\"\n\n        if hf_dataset and hf_dataset.strip():\n            hf_token = config.get(\"hf_token\", \"\")\n            hf_token = hf_token if hf_token and hf_token.strip() else None\n            dataset = load_dataset(\n                hf_dataset.strip(),\n                subset,\n                split = train_split,\n                token = hf_token,\n            )\n        elif local_datasets:\n            # Load from local file(s) — mirrors the non-embedding pipeline's\n            # directory handling so recipe outputs (parquet-files/) work.\n            all_files: list[str] = []\n            for dataset_file in local_datasets:\n                file_path = (\n                    dataset_file\n                    if os.path.isabs(dataset_file)\n                    else os.path.join(\n                        str(datasets_root()),\n                        dataset_file,\n                    )\n                )\n                if os.path.isdir(file_path):\n                    file_path_obj = Path(file_path)\n                    parquet_dir = (\n                        file_path_obj / \"parquet-files\"\n                        if (file_path_obj / \"parquet-files\").exists()\n                        else file_path_obj\n                    )\n                    parquet_files = sorted(parquet_dir.glob(\"*.parquet\"))\n                    if parquet_files:\n                        all_files.extend(str(p) for p in parquet_files)\n                        continue\n                    candidates: list[Path] = []\n                    for ext in (\".json\", \".jsonl\", \".csv\", \".parquet\"):\n                        candidates.extend(sorted(file_path_obj.glob(f\"*{ext}\")))\n                    if candidates:\n                        all_files.extend(str(c) for c in candidates)\n                        continue\n                    raise ValueError(\n                        f\"No supported data files in directory: {file_path_obj}\"\n                    )\n                else:\n                    all_files.append(file_path)\n\n            if all_files:\n                first_ext = Path(all_files[0]).suffix.lower()\n                if first_ext in (\".json\", \".jsonl\"):\n                    loader = \"json\"\n                elif first_ext == \".csv\":\n                    loader = \"csv\"\n                elif first_ext == \".parquet\":\n                    loader = \"parquet\"\n                else:\n                    raise ValueError(\n                        f\"Unsupported local dataset format: {all_files[0]}\"\n                    )\n                dataset = load_dataset(loader, data_files = all_files, split = \"train\")\n        else:\n            event_queue.put(\n                {\n                    \"type\": \"error\",\n                    \"error\": \"No dataset specified for embedding training.\",\n                    \"stack\": \"\",\n                    \"ts\": time.time(),\n                }\n            )\n            return\n\n        # Apply dataset slicing if specified\n        slice_start = config.get(\"dataset_slice_start\")\n        slice_end = config.get(\"dataset_slice_end\")\n        if slice_start is not None or slice_end is not None:\n            start = slice_start if slice_start is not None else 0\n            end = slice_end if slice_end is not None else len(dataset)\n            dataset = dataset.select(range(start, min(end + 1, len(dataset))))\n\n        logger.info(f\"Embedding dataset loaded: {len(dataset)} samples\")\n    except Exception as e:\n        event_queue.put(\n            {\n                \"type\": \"error\",\n                \"error\": f\"Failed to load dataset: {e}\",\n                \"stack\": traceback.format_exc(limit = 20),\n                \"ts\": time.time(),\n            }\n        )\n        return\n\n    if _should_stop:\n        event_queue.put({\"type\": \"complete\", \"output_dir\": None, \"ts\": time.time()})\n        return\n\n    # ── 5. Create loss function ──\n    loss = MultipleNegativesRankingLoss(model)\n\n    # ── 6. Build training arguments ──\n    _send_status(event_queue, \"Configuring training...\")\n    try:\n        lr_value = float(config.get(\"learning_rate\", \"2e-4\"))\n    except ValueError:\n        event_queue.put(\n            {\n                \"type\": \"error\",\n                \"error\": f\"Invalid learning rate: {config.get('learning_rate')}\",\n                \"stack\": \"\",\n                \"ts\": time.time(),\n            }\n        )\n        return\n\n    output_dir = config.get(\"output_dir\")\n    if not output_dir:\n        output_dir = str(\n            resolve_output_dir(f\"{model_name.replace('/', '_')}_{int(time.time())}\")\n        )\n\n    num_epochs = config.get(\"num_epochs\", 2)\n    batch_size = config.get(\"batch_size\", 256)\n    gradient_accumulation_steps = config.get(\"gradient_accumulation_steps\", 1)\n    max_steps_val = config.get(\"max_steps\", 0)\n    save_steps_val = config.get(\"save_steps\", 0)\n    warmup_ratio = config.get(\"warmup_ratio\", 0.03)\n    warmup_steps_val = config.get(\"warmup_steps\")\n    log_frequency = config.get(\"log_frequency\", 50)\n\n    # Build args dict\n    training_args_kwargs = {\n        \"output_dir\": output_dir,\n        \"per_device_train_batch_size\": batch_size,\n        \"gradient_accumulation_steps\": gradient_accumulation_steps,\n        \"learning_rate\": lr_value,\n        \"fp16\": not is_bfloat16_supported(),\n        \"bf16\": is_bfloat16_supported(),\n        \"logging_steps\": 1,\n        \"report_to\": [\"wandb\"] if config.get(\"enable_wandb\") else \"none\",\n        \"lr_scheduler_type\": config.get(\"lr_scheduler_type\", \"linear\"),\n        \"batch_sampler\": BatchSamplers.NO_DUPLICATES,\n        \"optim\": config.get(\"optim\", \"adamw_8bit\"),\n        \"weight_decay\": config.get(\"weight_decay\", 0.01),\n        \"seed\": config.get(\"random_seed\", 3407),\n    }\n\n    # max_steps vs epochs\n    if max_steps_val and max_steps_val > 0:\n        training_args_kwargs[\"max_steps\"] = max_steps_val\n    else:\n        training_args_kwargs[\"num_train_epochs\"] = num_epochs if num_epochs > 0 else 2\n\n    # warmup: prefer warmup_ratio (standard for embedding scripts), fallback to steps\n    if warmup_ratio is not None and warmup_ratio > 0:\n        training_args_kwargs[\"warmup_ratio\"] = warmup_ratio\n    elif warmup_steps_val is not None and warmup_steps_val > 0:\n        training_args_kwargs[\"warmup_steps\"] = warmup_steps_val\n\n    # save_steps\n    if save_steps_val and save_steps_val > 0:\n        training_args_kwargs[\"save_steps\"] = save_steps_val\n        training_args_kwargs[\"save_strategy\"] = \"steps\"\n\n    args = SentenceTransformerTrainingArguments(**training_args_kwargs)\n\n    # ── 7. Calculate total steps for progress tracking ──\n    if max_steps_val and max_steps_val > 0:\n        total_steps = max_steps_val\n    else:\n        effective_epochs = num_epochs if num_epochs > 0 else 2\n        len_dataloader = math.ceil(len(dataset) / batch_size)\n        steps_per_epoch = max(len_dataloader // gradient_accumulation_steps, 1)\n        total_steps = steps_per_epoch * effective_epochs\n\n    # ── 8. Create progress callback ──\n    class _EmbeddingProgressCallback(TrainerCallback):\n        \"\"\"Sends training progress events to the parent process via event_queue.\"\"\"\n\n        def on_log(self, args, state, control, logs = None, **kwargs):\n            if not logs:\n                return\n            loss_value = logs.get(\"loss\", logs.get(\"train_loss\", 0.0))\n            current_step = state.global_step\n\n            elapsed = time.time() - training_start_time\n            eta = None\n            if current_step > 0 and total_steps > 0:\n                remaining = total_steps - current_step\n                if remaining > 0:\n                    eta = (elapsed / current_step) * remaining\n\n            event_queue.put(\n                {\n                    \"type\": \"progress\",\n                    \"step\": current_step,\n                    \"epoch\": round(state.epoch, 2) if state.epoch else 0,\n                    \"loss\": loss_value,\n                    \"learning_rate\": logs.get(\"learning_rate\", 0.0),\n                    \"total_steps\": total_steps,\n                    \"elapsed_seconds\": elapsed,\n                    \"eta_seconds\": eta,\n                    \"grad_norm\": logs.get(\"grad_norm\"),\n                    \"num_tokens\": getattr(state, \"num_input_tokens_seen\", None),\n                    \"eval_loss\": logs.get(\"eval_loss\"),\n                    \"status_message\": \"\",\n                    \"ts\": time.time(),\n                }\n            )\n\n        def on_step_end(self, args, state, control, **kwargs):\n            if _should_stop:\n                logger.info(\"Embedding training: stop at step %d\", state.global_step)\n                control.should_training_stop = True\n                return control\n\n    # ── 9. Create trainer and train ──\n    _send_status(event_queue, \"Starting embedding training...\")\n    try:\n        trainer = SentenceTransformerTrainer(\n            model = model,\n            train_dataset = dataset,\n            loss = loss,\n            args = args,\n            callbacks = [_EmbeddingProgressCallback()],\n        )\n\n        trainer.train()\n    except Exception as e:\n        event_queue.put(\n            {\n                \"type\": \"error\",\n                \"error\": f\"Embedding training failed: {e}\",\n                \"stack\": traceback.format_exc(limit = 20),\n                \"ts\": time.time(),\n            }\n        )\n        return\n\n    # ── 10. Save model ──\n    if _should_stop and not _save_on_stop:\n        event_queue.put(\n            {\n                \"type\": \"complete\",\n                \"output_dir\": None,\n                \"status_message\": \"Training cancelled\",\n                \"ts\": time.time(),\n            }\n        )\n        return\n\n    _send_status(event_queue, \"Saving model...\")\n    try:\n        model.save_pretrained(output_dir)\n        model.tokenizer.save_pretrained(output_dir)\n        logger.info(\"Embedding model saved to %s\", output_dir)\n    except Exception as e:\n        logger.error(\"Failed to save embedding model: %s\", e)\n        event_queue.put(\n            {\n                \"type\": \"error\",\n                \"error\": f\"Training completed but failed to save: {e}\",\n                \"stack\": traceback.format_exc(limit = 20),\n                \"ts\": time.time(),\n            }\n        )\n        return\n\n    # ── 11. Done ──\n    event_queue.put(\n        {\n            \"type\": \"complete\",\n            \"output_dir\": output_dir,\n            \"status_message\": \"Embedding training completed\",\n            \"ts\": time.time(),\n        }\n    )\n"
  },
  {
    "path": "studio/backend/loggers/.gitkeep",
    "content": ""
  },
  {
    "path": "studio/backend/loggers/__init__.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nfrom .handlers import get_logger\n\n__all__ = [\"get_logger\"]\n"
  },
  {
    "path": "studio/backend/loggers/config.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"Logging configuration for structured logging with structlog.\n\nThis module provides centralized logging configuration with environment-specific\nformats and processors. Supports both development and production environments\nwith consistent structured logging.\n\nKey Features:\n- Environment-specific formatting (JSON for production, console for development)\n- Timestamp standardization (ISO format)\n- Context variable integration\n- Log level filtering\n- Logger caching for performance\n\"\"\"\n\nimport logging\nimport os\nimport sys\nfrom typing import Optional\n\nimport structlog\n\n\nclass LogConfig:\n    \"\"\"Structured logging configuration for the application.\n\n    Provides static method to configure structlog with environment-specific\n    formatting and processors for consistent structured logging.\n    \"\"\"\n\n    @staticmethod\n    def setup_logging(\n        service_name: str = \"unsloth-studio-backend\", env: Optional[str] = None\n    ) -> structlog.BoundLogger:\n        \"\"\"Configure structured logging for the application.\n        Args:\n            service_name: Name of the service for logging identification\n            env: Environment (development/production), affects logging format\n        \"\"\"\n        # Determine log level from environment\n        log_level_name = os.getenv(\"LOG_LEVEL\", \"INFO\").upper()\n        # Fallback to INFO if an invalid level is provided\n        log_level = getattr(logging, log_level_name, logging.INFO)\n\n        structlog.configure(\n            processors = [\n                # Reorder processors to control field order\n                structlog.processors.TimeStamper(fmt = \"iso\"),  # timestamp first\n                structlog.processors.add_log_level,  # level second\n                structlog.contextvars.merge_contextvars,\n                # Custom processor to flatten the extra field\n                lambda logger, method_name, event_dict: {\n                    \"timestamp\": event_dict.get(\"timestamp\"),\n                    \"level\": event_dict.get(\"level\"),\n                    \"event\": event_dict.get(\"event\"),\n                    **(event_dict.get(\"extra\", {})),  # Flatten extra into main dict\n                    **{\n                        k: v\n                        for k, v in event_dict.items()\n                        if k not in [\"timestamp\", \"level\", \"event\", \"extra\"]\n                    },\n                },\n                (\n                    structlog.processors.JSONRenderer(sort_keys = False)  # Preserve order\n                    if env == \"production\"\n                    else structlog.dev.ConsoleRenderer()\n                ),\n            ],\n            wrapper_class = structlog.make_filtering_bound_logger(log_level),\n            logger_factory = structlog.PrintLoggerFactory(file = sys.stdout),\n            cache_logger_on_first_use = True,\n        )\n\n        return structlog.get_logger(service_name)\n"
  },
  {
    "path": "studio/backend/loggers/handlers.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"Logging handlers and middleware for structured logging.\n\nThis module provides FastAPI middleware and structlog processors for:\n- Request/response logging with timing\n- Sensitive data filtering in logs\n- Structured logging configuration\n- Error handling with detailed context\n\nKey Components:\n- LoggingMiddleware: FastAPI middleware for request/response logging\n- filter_sensitive_data: Structlog processor for data sanitization\n- get_logger: Factory function for structured loggers\n\"\"\"\n\nimport time\nfrom typing import Callable\n\nimport structlog\nfrom fastapi import Request, Response\nfrom starlette.middleware.base import BaseHTTPMiddleware\n\nlogger = structlog.get_logger(__name__)\n\n\nclass LoggingMiddleware(BaseHTTPMiddleware):\n    async def dispatch(self, request: Request, call_next: Callable) -> Response:\n        start_time = time.time()\n\n        try:\n            response = await call_next(request)\n\n            # Log response\n            process_time = (time.time() - start_time) * 1000\n\n            EXCLUDED_PATHS = {\n                \"/api/train/status\",\n                \"/api/train/metrics\",\n                \"/api/train/hardware\",\n                \"/api/system\",\n            }\n            is_excluded = (\n                request.url.path in EXCLUDED_PATHS\n                or request.url.path.startswith(\"/assets/\")\n                or request.url.path.endswith(\n                    (\".png\", \".jpg\", \".jpeg\", \".ico\", \".woff\", \".woff2\", \".ttf\")\n                )\n            )\n\n            if not is_excluded:\n                logger.info(\n                    \"request_completed\",\n                    method = request.method,\n                    path = request.url.path,\n                    status_code = response.status_code,\n                    process_time_ms = round(process_time, 2),\n                )\n\n            return response\n\n        except Exception as e:\n            logger.error(\n                \"request_failed\",\n                path = request.url.path,\n                method = request.method,\n                error = str(e),\n                exc_info = True,\n            )\n            raise\n\n\ndef filter_sensitive_data(logger, method_name, event_dict):\n    \"\"\"Structlog processor to filter out base64 data from logs.\"\"\"\n\n    def filter_value(value):\n        if (\n            isinstance(value, str)\n            and len(value) > 100\n            and (\",\" in value or \"/\" in value)\n        ):\n            # Likely base64 data, truncate it\n            return value[:20] + \"...\"\n        elif isinstance(value, dict):\n            return {k: filter_value(v) for k, v in value.items()}\n        elif isinstance(value, list):\n            return [filter_value(item) for item in value]\n        return value\n\n    return {k: filter_value(v) for k, v in event_dict.items()}\n\n\ndef get_logger(name: str) -> structlog.BoundLogger:\n    \"\"\"Get a logger instance for a specific module.\n    Args:\n        name: Usually __name__ of the module\n    Returns:\n        A bound structured logger\n    \"\"\"\n    return structlog.get_logger(name)\n"
  },
  {
    "path": "studio/backend/main.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nMain FastAPI application for Unsloth UI Backend\n\"\"\"\n\nimport os\n\n# Suppress annoying C-level dependency warnings globally\nos.environ[\"PYTHONWARNINGS\"] = \"ignore\"\n\nimport shutil\nimport sys\nimport warnings\nfrom contextlib import asynccontextmanager\n\n# Suppress annoying dependency warnings in production\nif os.getenv(\"ENVIRONMENT_TYPE\", \"production\") == \"production\":\n    warnings.filterwarnings(\"ignore\")\n    # Alternatively, you can be more specific:\n    # warnings.filterwarnings(\"ignore\", category=DeprecationWarning)\n    # warnings.filterwarnings(\"ignore\", module=\"triton.*\")\n\nfrom fastapi import FastAPI\nfrom fastapi.middleware.cors import CORSMiddleware\nfrom fastapi.staticfiles import StaticFiles\nfrom fastapi.responses import FileResponse, HTMLResponse, Response\nfrom pathlib import Path\nfrom datetime import datetime\n\n# Import routers\nfrom routes import (\n    auth_router,\n    data_recipe_router,\n    datasets_router,\n    export_router,\n    inference_router,\n    models_router,\n    training_router,\n)\nfrom auth import storage\nfrom utils.hardware import detect_hardware, get_device, DeviceType\nimport utils.hardware.hardware as _hw_module\n\nfrom utils.cache_cleanup import clear_unsloth_compiled_cache\n\n\n@asynccontextmanager\nasync def lifespan(app: FastAPI):\n    \"\"\"Startup: detect hardware, seed default admin if needed. Shutdown: clean up compiled cache.\"\"\"\n    # Clean up any stale compiled cache from previous runs\n    clear_unsloth_compiled_cache()\n\n    # Remove stale .venv_overlay from previous versions — no longer used.\n    # Version switching now uses .venv_t5/ (pre-installed by setup.sh).\n    overlay_dir = Path(__file__).resolve().parent.parent.parent / \".venv_overlay\"\n    if overlay_dir.is_dir():\n        shutil.rmtree(overlay_dir, ignore_errors = True)\n\n    # Detect hardware first — sets DEVICE global used everywhere\n    detect_hardware()\n\n    # Pre-cache the helper GGUF model for LLM-assisted dataset detection.\n    # Runs in a background thread so it doesn't block server startup.\n    import threading\n\n    def _precache():\n        try:\n            from utils.datasets.llm_assist import precache_helper_gguf\n\n            precache_helper_gguf()\n        except Exception:\n            pass  # non-critical\n\n    threading.Thread(target = _precache, daemon = True).start()\n\n    if storage.ensure_default_admin():\n        bootstrap_pw = storage.get_bootstrap_password()\n        app.state.bootstrap_password = bootstrap_pw\n        print(\"\\n\" + \"=\" * 60)\n        print(\"DEFAULT ADMIN ACCOUNT CREATED\")\n        print(\n            \"Sign in with the seeded credentials and change the password immediately:\\n\"\n        )\n        print(f\"    username: {storage.DEFAULT_ADMIN_USERNAME}\")\n        print(f\"    password: {bootstrap_pw}\\n\")\n        print(\"=\" * 60 + \"\\n\")\n    else:\n        app.state.bootstrap_password = storage.get_bootstrap_password()\n    yield\n    # Cleanup\n    _hw_module.DEVICE = None\n    clear_unsloth_compiled_cache()\n\n\n# Create FastAPI app\napp = FastAPI(\n    title = \"Unsloth UI Backend\",\n    version = \"1.0.0\",\n    description = \"Backend API for Unsloth UI - Training and Model Management\",\n    lifespan = lifespan,\n)\n\n# Initialize structured logging\nfrom loggers.config import LogConfig\nfrom loggers.handlers import LoggingMiddleware\n\nlogger = LogConfig.setup_logging(\n    service_name = \"unsloth-studio-backend\",\n    env = os.getenv(\"ENVIRONMENT_TYPE\", \"production\"),\n)\n\napp.add_middleware(LoggingMiddleware)\n\n# CORS middleware\napp.add_middleware(\n    CORSMiddleware,\n    allow_origins = [\"*\"],  # In production, specify allowed origins\n    allow_credentials = True,\n    allow_methods = [\"*\"],\n    allow_headers = [\"*\"],\n)\n\n# ============ Register API Routes ============\n\n# Register routers\napp.include_router(auth_router, prefix = \"/api/auth\", tags = [\"auth\"])\napp.include_router(training_router, prefix = \"/api/train\", tags = [\"training\"])\napp.include_router(models_router, prefix = \"/api/models\", tags = [\"models\"])\napp.include_router(inference_router, prefix = \"/api/inference\", tags = [\"inference\"])\n\n# OpenAI-compatible endpoints: mount the same inference router at /v1\n# so external tools (Open WebUI, SillyTavern, etc.) can use the\n# standard /v1/chat/completions path.\napp.include_router(inference_router, prefix = \"/v1\", tags = [\"openai-compat\"])\napp.include_router(datasets_router, prefix = \"/api/datasets\", tags = [\"datasets\"])\napp.include_router(data_recipe_router, prefix = \"/api/data-recipe\", tags = [\"data-recipe\"])\napp.include_router(export_router, prefix = \"/api/export\", tags = [\"export\"])\n\n\n# ============ Health and System Endpoints ============\n\n\n@app.get(\"/api/health\")\nasync def health_check():\n    \"\"\"Health check endpoint\"\"\"\n    platform_map = {\"darwin\": \"mac\", \"win32\": \"windows\", \"linux\": \"linux\"}\n    device_type = platform_map.get(sys.platform, sys.platform)\n\n    return {\n        \"status\": \"healthy\",\n        \"timestamp\": datetime.now().isoformat(),\n        \"service\": \"Unsloth UI Backend\",\n        \"device_type\": device_type,\n        \"chat_only\": _hw_module.CHAT_ONLY,\n    }\n\n\n@app.get(\"/api/system\")\nasync def get_system_info():\n    \"\"\"Get system information\"\"\"\n    import platform\n    import subprocess\n    import psutil\n    from utils.hardware import get_device, get_gpu_memory_info, DeviceType\n\n    # GPU Info — query nvidia-smi for physical GPUs, filtered by\n    # CUDA_VISIBLE_DEVICES when set (the frontend uses this for GGUF\n    # fit estimation and llama-server respects CVD too).\n    import os\n\n    gpu_info: dict = {\"available\": False, \"devices\": []}\n\n    device = get_device()\n    if device == DeviceType.CUDA:\n        # Parse CUDA_VISIBLE_DEVICES allowlist\n        allowed_indices = None\n        cvd = os.environ.get(\"CUDA_VISIBLE_DEVICES\")\n        if cvd is not None and cvd.strip():\n            try:\n                allowed_indices = set(int(x.strip()) for x in cvd.split(\",\"))\n            except ValueError:\n                pass  # Non-numeric (e.g. GPU-uuid), show all\n\n        try:\n            result = subprocess.run(\n                [\n                    \"nvidia-smi\",\n                    \"--query-gpu=index,name,memory.total\",\n                    \"--format=csv,noheader,nounits\",\n                ],\n                capture_output = True,\n                text = True,\n                timeout = 10,\n            )\n            if result.returncode == 0:\n                for line in result.stdout.strip().splitlines():\n                    parts = [p.strip() for p in line.split(\",\")]\n                    if len(parts) == 3:\n                        idx = int(parts[0])\n                        if allowed_indices is not None and idx not in allowed_indices:\n                            continue\n                        gpu_info[\"devices\"].append(\n                            {\n                                \"index\": idx,\n                                \"name\": parts[1],\n                                \"memory_total_gb\": round(int(parts[2]) / 1024, 2),\n                            }\n                        )\n                gpu_info[\"available\"] = len(gpu_info[\"devices\"]) > 0\n        except Exception:\n            pass\n\n    # Fallback to torch-based single-GPU detection\n    if not gpu_info[\"available\"]:\n        mem_info = get_gpu_memory_info()\n        if mem_info.get(\"available\"):\n            gpu_info[\"available\"] = True\n            gpu_info[\"devices\"].append(\n                {\n                    \"index\": mem_info.get(\"device\", 0),\n                    \"name\": mem_info.get(\"device_name\", \"Unknown\"),\n                    \"memory_total_gb\": round(mem_info.get(\"total_gb\", 0), 2),\n                }\n            )\n\n    # CPU & Memory\n    memory = psutil.virtual_memory()\n\n    return {\n        \"platform\": platform.platform(),\n        \"python_version\": platform.python_version(),\n        \"device_backend\": get_device().value,\n        \"cpu_count\": psutil.cpu_count(),\n        \"memory\": {\n            \"total_gb\": round(memory.total / 1e9, 2),\n            \"available_gb\": round(memory.available / 1e9, 2),\n            \"percent_used\": memory.percent,\n        },\n        \"gpu\": gpu_info,\n    }\n\n\n@app.get(\"/api/system/hardware\")\nasync def get_hardware_info():\n    \"\"\"Return GPU name, total VRAM, and key ML package versions.\"\"\"\n    from utils.hardware import get_gpu_summary, get_package_versions\n\n    return {\n        \"gpu\": get_gpu_summary(),\n        \"versions\": get_package_versions(),\n    }\n\n\n# ============ Serve Frontend (Optional) ============\n\n\ndef _strip_crossorigin(html_bytes: bytes) -> bytes:\n    \"\"\"Remove ``crossorigin`` attributes from script/link tags.\n\n    Vite adds ``crossorigin`` by default which forces CORS mode on font\n    subresource loads.  When Studio is served over plain HTTP, Firefox\n    HTTPS-Only Mode does not exempt CORS font requests -- causing all\n    @font-face downloads to fail silently.  Stripping the attribute\n    makes them regular same-origin fetches that work on any protocol.\n    \"\"\"\n    import re as _re\n\n    html = html_bytes.decode(\"utf-8\")\n    html = _re.sub(r'\\s+crossorigin(?:=\"[^\"]*\")?', \"\", html)\n    return html.encode(\"utf-8\")\n\n\ndef _inject_bootstrap(html_bytes: bytes, app: FastAPI) -> bytes:\n    \"\"\"Inject bootstrap credentials into HTML when password change is required.\n\n    The script tag is only injected while the default admin account still\n    has ``must_change_password=True``.  Once the user changes the password\n    the HTML is served clean — no credentials leak.\n    \"\"\"\n    import json as _json\n\n    if not storage.requires_password_change(storage.DEFAULT_ADMIN_USERNAME):\n        return html_bytes\n\n    bootstrap_pw = getattr(app.state, \"bootstrap_password\", None)\n    if not bootstrap_pw:\n        return html_bytes\n\n    payload = _json.dumps(\n        {\n            \"username\": storage.DEFAULT_ADMIN_USERNAME,\n            \"password\": bootstrap_pw,\n        }\n    )\n    tag = f\"<script>window.__UNSLOTH_BOOTSTRAP__={payload}</script>\"\n    html = html_bytes.decode(\"utf-8\")\n    html = html.replace(\"</head>\", f\"{tag}</head>\", 1)\n    return html.encode(\"utf-8\")\n\n\ndef setup_frontend(app: FastAPI, build_path: Path):\n    \"\"\"Mount frontend static files (optional)\"\"\"\n    if not build_path.exists():\n        return False\n\n    # Mount assets\n    assets_dir = build_path / \"assets\"\n    if assets_dir.exists():\n        app.mount(\"/assets\", StaticFiles(directory = assets_dir), name = \"assets\")\n\n    @app.get(\"/\")\n    async def serve_root():\n        content = (build_path / \"index.html\").read_bytes()\n        content = _strip_crossorigin(content)\n        content = _inject_bootstrap(content, app)\n        return Response(\n            content = content,\n            media_type = \"text/html\",\n            headers = {\"Cache-Control\": \"no-cache, no-store, must-revalidate\"},\n        )\n\n    @app.get(\"/{full_path:path}\")\n    async def serve_frontend(full_path: str):\n        if full_path.startswith(\"api\"):\n            return {\"error\": \"API endpoint not found\"}\n\n        file_path = (build_path / full_path).resolve()\n\n        # Block path traversal — ensure resolved path stays inside build_path\n        if not file_path.is_relative_to(build_path.resolve()):\n            return Response(status_code = 403)\n\n        if file_path.is_file():\n            return FileResponse(file_path)\n\n        # Serve index.html as bytes — avoids Content-Length mismatch\n        content = (build_path / \"index.html\").read_bytes()\n        content = _strip_crossorigin(content)\n        content = _inject_bootstrap(content, app)\n        return Response(\n            content = content,\n            media_type = \"text/html\",\n            headers = {\"Cache-Control\": \"no-cache, no-store, must-revalidate\"},\n        )\n\n    return True\n"
  },
  {
    "path": "studio/backend/models/.gitkeep",
    "content": ""
  },
  {
    "path": "studio/backend/models/__init__.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nPydantic models for API request/response schemas\n\"\"\"\n\nfrom .training import (\n    TrainingStartRequest,\n    TrainingJobResponse,\n    TrainingStatus,\n    TrainingProgress,\n)\nfrom .models import (\n    CheckpointInfo,\n    ModelCheckpoints,\n    CheckpointListResponse,\n    ModelDetails,\n    LocalModelInfo,\n    LocalModelListResponse,\n    LoRAInfo,\n    LoRAScanResponse,\n    ModelListResponse,\n)\nfrom .auth import (\n    AuthLoginRequest,\n    RefreshTokenRequest,\n    AuthStatusResponse,\n    ChangePasswordRequest,\n)\nfrom .export import (\n    LoadCheckpointRequest,\n    ExportStatusResponse,\n    ExportOperationResponse,\n    ExportMergedModelRequest,\n    ExportBaseModelRequest,\n    ExportGGUFRequest,\n    ExportLoRAAdapterRequest,\n)\nfrom .users import Token\nfrom .datasets import (\n    CheckFormatRequest,\n    CheckFormatResponse,\n)\nfrom .inference import (\n    LoadRequest,\n    UnloadRequest,\n    GenerateRequest,\n    LoadResponse,\n    UnloadResponse,\n    InferenceStatusResponse,\n)\nfrom .responses import (\n    TrainingStopResponse,\n    TrainingMetricsResponse,\n    LoRABaseModelResponse,\n    VisionCheckResponse,\n    EmbeddingCheckResponse,\n)\nfrom .data_recipe import (\n    RecipePayload,\n    PreviewResponse,\n    ValidateError,\n    ValidateResponse,\n    JobCreateResponse,\n)\n\n__all__ = [\n    # Training schemas\n    \"TrainingStartRequest\",\n    \"TrainingJobResponse\",\n    \"TrainingStatus\",\n    \"TrainingProgress\",\n    # Model management schemas\n    \"ModelDetails\",\n    \"LocalModelInfo\",\n    \"LocalModelListResponse\",\n    \"LoRAInfo\",\n    \"LoRAScanResponse\",\n    \"ModelListResponse\",\n    # Auth schemas\n    \"AuthLoginRequest\",\n    \"RefreshTokenRequest\",\n    \"AuthStatusResponse\",\n    \"ChangePasswordRequest\",\n    # Export schemas\n    \"CheckpointInfo\",\n    \"ModelCheckpoints\",\n    \"CheckpointListResponse\",\n    \"LoadCheckpointRequest\",\n    \"ExportStatusResponse\",\n    \"ExportOperationResponse\",\n    \"ExportMergedModelRequest\",\n    \"ExportBaseModelRequest\",\n    \"ExportGGUFRequest\",\n    \"ExportLoRAAdapterRequest\",\n    \"Token\",\n    # Dataset schemas\n    \"CheckFormatRequest\",\n    \"CheckFormatResponse\",\n    # Inference schemas\n    \"LoadRequest\",\n    \"UnloadRequest\",\n    \"GenerateRequest\",\n    \"LoadResponse\",\n    \"UnloadResponse\",\n    \"InferenceStatusResponse\",\n    # Response schemas\n    \"TrainingStopResponse\",\n    \"TrainingMetricsResponse\",\n    \"LoRABaseModelResponse\",\n    \"VisionCheckResponse\",\n    \"EmbeddingCheckResponse\",\n    # Data recipe\n    \"RecipePayload\",\n    \"PreviewResponse\",\n    \"ValidateError\",\n    \"ValidateResponse\",\n    \"JobCreateResponse\",\n]\n"
  },
  {
    "path": "studio/backend/models/auth.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nPydantic schemas for Authentication API\n\"\"\"\n\nfrom pydantic import BaseModel, Field\n\n\nclass AuthLoginRequest(BaseModel):\n    \"\"\"Login payload: username/password to obtain a JWT.\"\"\"\n\n    username: str = Field(..., description = \"Username\")\n    password: str = Field(..., description = \"Password\")\n\n\nclass RefreshTokenRequest(BaseModel):\n    \"\"\"Refresh token payload to obtain new access + refresh tokens.\"\"\"\n\n    refresh_token: str = Field(\n        ..., description = \"Refresh token from a previous login or refresh\"\n    )\n\n\nclass AuthStatusResponse(BaseModel):\n    \"\"\"Indicate whether the seeded admin auth flow is ready.\"\"\"\n\n    initialized: bool = Field(\n        ..., description = \"True if the auth database contains a login user\"\n    )\n    default_username: str = Field(..., description = \"Default seeded admin username\")\n    requires_password_change: bool = Field(\n        ...,\n        description = \"True if the seeded admin must still change the default password\",\n    )\n\n\nclass ChangePasswordRequest(BaseModel):\n    \"\"\"Change the current user's password, typically on first login.\"\"\"\n\n    current_password: str = Field(\n        ..., min_length = 8, description = \"Existing password for the authenticated user\"\n    )\n    new_password: str = Field(\n        ..., min_length = 8, description = \"Replacement password (minimum 8 characters)\"\n    )\n"
  },
  {
    "path": "studio/backend/models/data_recipe.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nPydantic schemas for Data Recipe (DataDesigner) API.\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom typing import Any\n\nfrom pydantic import BaseModel, Field\n\n\nclass RecipePayload(BaseModel):\n    recipe: dict[str, Any] = Field(default_factory = dict)\n    run: dict[str, Any] | None = None\n    ui: dict[str, Any] | None = None\n\n\nclass PreviewResponse(BaseModel):\n    dataset: list[dict[str, Any]] = Field(default_factory = list)\n    processor_artifacts: dict[str, Any] | None = None\n    analysis: dict[str, Any] | None = None\n\n\nclass ValidateError(BaseModel):\n    message: str\n    path: str | None = None\n    code: str | None = None\n\n\nclass ValidateResponse(BaseModel):\n    valid: bool\n    errors: list[ValidateError] = Field(default_factory = list)\n    raw_detail: str | None = None\n\n\nclass JobCreateResponse(BaseModel):\n    job_id: str\n\n\nclass PublishDatasetRequest(BaseModel):\n    repo_id: str = Field(min_length = 3, description = \"Hugging Face dataset repo ID\")\n    description: str = Field(\n        min_length = 1,\n        max_length = 4000,\n        description = \"Short dataset description for the dataset card\",\n    )\n    hf_token: str | None = Field(\n        default = None,\n        description = \"Optional Hugging Face token for private or write-protected repos\",\n    )\n    private: bool = Field(\n        default = False,\n        description = \"Create or update the dataset repo as private\",\n    )\n    artifact_path: str | None = Field(\n        default = None,\n        description = \"Execution artifact path captured by the UI for completed runs\",\n    )\n\n\nclass PublishDatasetResponse(BaseModel):\n    success: bool = True\n    url: str\n    message: str\n\n\nclass SeedInspectRequest(BaseModel):\n    dataset_name: str = Field(min_length = 1)\n    hf_token: str | None = None\n    subset: str | None = None\n    split: str | None = \"train\"\n    preview_size: int = Field(default = 10, ge = 1, le = 50)\n\n\nclass SeedInspectUploadRequest(BaseModel):\n    filename: str = Field(min_length = 1)\n    content_base64: str = Field(min_length = 1)\n    preview_size: int = Field(default = 10, ge = 1, le = 50)\n    seed_source_type: str | None = None\n    unstructured_chunk_size: int | None = Field(default = None, ge = 1, le = 20000)\n    unstructured_chunk_overlap: int | None = Field(default = None, ge = 0, le = 20000)\n\n\nclass SeedInspectResponse(BaseModel):\n    dataset_name: str\n    resolved_path: str\n    columns: list[str] = Field(default_factory = list)\n    preview_rows: list[dict[str, Any]] = Field(default_factory = list)\n    split: str | None = None\n    subset: str | None = None\n\n\nclass McpToolsListRequest(BaseModel):\n    mcp_providers: list[dict[str, Any]] = Field(default_factory = list)\n    timeout_sec: float | None = Field(default = None, gt = 0)\n\n\nclass McpToolsProviderResult(BaseModel):\n    name: str\n    tools: list[str] = Field(default_factory = list)\n    error: str | None = None\n\n\nclass McpToolsListResponse(BaseModel):\n    providers: list[McpToolsProviderResult] = Field(default_factory = list)\n    duplicate_tools: dict[str, list[str]] = Field(default_factory = dict)\n"
  },
  {
    "path": "studio/backend/models/datasets.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nDataset-related Pydantic models for API requests and responses.\n\"\"\"\n\nfrom typing import Any, Dict, List, Optional\n\nfrom pydantic import BaseModel, Field, model_validator\n\n\nclass CheckFormatRequest(BaseModel):\n    \"\"\"Request for dataset format check\"\"\"\n\n    dataset_name: str  # HuggingFace dataset name or local path\n    is_vlm: bool = False\n    hf_token: Optional[str] = None\n    subset: Optional[str] = None\n    train_split: Optional[str] = \"train\"\n\n    @model_validator(mode = \"before\")\n    @classmethod\n    def _compat_split(cls, values: Any) -> Any:\n        \"\"\"Accept legacy 'split' field as alias for 'train_split'.\"\"\"\n        if isinstance(values, dict) and \"split\" in values:\n            values.setdefault(\"train_split\", values.pop(\"split\"))\n        return values\n\n\nclass CheckFormatResponse(BaseModel):\n    \"\"\"Response for dataset format check\"\"\"\n\n    requires_manual_mapping: bool\n    detected_format: str\n    columns: List[str]\n    is_image: bool = False\n    is_audio: bool = False\n    multimodal_columns: Optional[List[str]] = None\n    suggested_mapping: Optional[Dict[str, str]] = None\n    detected_image_column: Optional[str] = None\n    detected_audio_column: Optional[str] = None\n    detected_text_column: Optional[str] = None\n    detected_speaker_column: Optional[str] = None\n    preview_samples: Optional[List[Dict]] = None\n    total_rows: Optional[int] = None\n    warning: Optional[str] = None\n\n\nclass AiAssistMappingRequest(BaseModel):\n    \"\"\"Request for LLM-assisted column classification (user-triggered).\"\"\"\n\n    columns: List[str]\n    samples: List[Dict[str, Any]]  # Preview rows already loaded in the dialog\n    dataset_name: Optional[str] = None  # For LLM context\n    hf_token: Optional[str] = None  # For fetching dataset card\n    model_name: Optional[str] = None\n    model_type: Optional[str] = None\n\n\nclass AiAssistMappingResponse(BaseModel):\n    \"\"\"Response from LLM-assisted column classification and conversion advice.\"\"\"\n\n    success: bool\n    suggested_mapping: Optional[Dict[str, str]] = None\n    warning: Optional[str] = None\n    # Conversion advisor fields\n    system_prompt: Optional[str] = None\n    label_mapping: Optional[Dict[str, Dict[str, str]]] = None\n    dataset_type: Optional[str] = None\n    is_conversational: Optional[bool] = None\n    user_notification: Optional[str] = None\n\n\nclass UploadDatasetResponse(BaseModel):\n    \"\"\"Response with stored dataset path for training.\"\"\"\n\n    filename: str = Field(..., description = \"Original filename\")\n    stored_path: str = Field(..., description = \"Absolute path stored on backend\")\n\n\nclass LocalDatasetItem(BaseModel):\n    class Metadata(BaseModel):\n        actual_num_records: Optional[int] = None\n        target_num_records: Optional[int] = None\n        total_num_batches: Optional[int] = None\n        num_completed_batches: Optional[int] = None\n        columns: Optional[List[str]] = None\n\n    id: str\n    label: str\n    path: str\n    rows: Optional[int] = None\n    updated_at: Optional[float] = None\n    metadata: Optional[Metadata] = None\n\n\nclass LocalDatasetsResponse(BaseModel):\n    datasets: List[LocalDatasetItem] = Field(default_factory = list)\n"
  },
  {
    "path": "studio/backend/models/export.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nPydantic schemas for Export API.\n\"\"\"\n\nfrom pydantic import BaseModel, Field\nfrom typing import List, Optional, Literal, Dict, Any\n\n\nclass LoadCheckpointRequest(BaseModel):\n    \"\"\"Request for loading a checkpoint into the export backend.\"\"\"\n\n    checkpoint_path: str = Field(..., description = \"Path to the checkpoint directory\")\n    max_seq_length: int = Field(\n        2048,\n        ge = 128,\n        le = 32768,\n        description = \"Maximum sequence length for loading the model\",\n    )\n    load_in_4bit: bool = Field(\n        True,\n        description = \"Whether to load the model in 4-bit quantization\",\n    )\n    trust_remote_code: bool = Field(\n        False,\n        description = \"Allow loading models with custom code. Only enable for checkpoints/base models you trust.\",\n    )\n\n\nclass ExportStatusResponse(BaseModel):\n    \"\"\"Current export backend status.\"\"\"\n\n    current_checkpoint: Optional[str] = Field(\n        None,\n        description = \"Path to the currently loaded checkpoint, if any\",\n    )\n    is_vision: bool = Field(\n        False,\n        description = \"True if the loaded checkpoint is a vision model\",\n    )\n    is_peft: bool = Field(\n        False,\n        description = \"True if the loaded checkpoint is a PEFT (LoRA) model\",\n    )\n\n\nclass ExportOperationResponse(BaseModel):\n    \"\"\"Generic response for export operations.\"\"\"\n\n    success: bool = Field(..., description = \"True if the operation succeeded\")\n    message: str = Field(..., description = \"Human-readable status or error message\")\n    details: Optional[Dict[str, Any]] = Field(\n        default = None,\n        description = \"Optional extra details about the operation\",\n    )\n\n\nclass ExportCommonOptions(BaseModel):\n    \"\"\"Common options for export operations that save locally and/or push to Hub.\"\"\"\n\n    save_directory: str = Field(\n        ...,\n        description = \"Local directory where the exported artifacts will be written\",\n    )\n    push_to_hub: bool = Field(\n        False,\n        description = \"If True, also push the exported model to the Hugging Face Hub\",\n    )\n    repo_id: Optional[str] = Field(\n        None,\n        description = \"Hugging Face Hub repository ID (username/model-name)\",\n    )\n    hf_token: Optional[str] = Field(\n        None,\n        description = \"Hugging Face access token used for Hub operations\",\n    )\n    private: bool = Field(\n        False,\n        description = \"If True, create a private repository on the Hub (where applicable)\",\n    )\n    base_model_id: Optional[str] = Field(\n        None,\n        description = \"HuggingFace model ID of the base model (for model card metadata)\",\n    )\n\n\nclass ExportMergedModelRequest(ExportCommonOptions):\n    \"\"\"Request for exporting a merged PEFT model.\"\"\"\n\n    format_type: Literal[\"16-bit (FP16)\", \"4-bit (FP4)\"] = Field(\n        \"16-bit (FP16)\",\n        description = \"Export precision / format for the merged model\",\n    )\n\n\nclass ExportBaseModelRequest(ExportCommonOptions):\n    \"\"\"Request for exporting a non-PEFT (base) model.\"\"\"\n\n    # Uses fields from ExportCommonOptions only\n\n\nclass ExportGGUFRequest(BaseModel):\n    \"\"\"Request for exporting the current model to GGUF format.\"\"\"\n\n    save_directory: str = Field(\n        ...,\n        description = \"Directory where GGUF files will be saved\",\n    )\n    quantization_method: str = Field(\n        \"Q4_K_M\",\n        description = 'GGUF quantization method (e.g. \"Q4_K_M\")',\n    )\n    push_to_hub: bool = Field(\n        False,\n        description = \"If True, also push GGUF artifacts to the Hugging Face Hub\",\n    )\n    repo_id: Optional[str] = Field(\n        None,\n        description = \"Hugging Face Hub repository ID for GGUF upload\",\n    )\n    hf_token: Optional[str] = Field(\n        None,\n        description = \"Hugging Face token for GGUF upload\",\n    )\n\n\nclass ExportLoRAAdapterRequest(ExportCommonOptions):\n    \"\"\"Request for exporting only the LoRA adapter (not merged).\"\"\"\n\n    # Uses fields from ExportCommonOptions only\n"
  },
  {
    "path": "studio/backend/models/inference.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nPydantic schemas for Inference API\n\"\"\"\n\nfrom __future__ import annotations\n\nimport time\nimport uuid\nfrom typing import Annotated, Any, Dict, Literal, Optional, List, Union\n\nfrom pydantic import BaseModel, Discriminator, Field, Tag\n\n\nclass LoadRequest(BaseModel):\n    \"\"\"Request to load a model for inference\"\"\"\n\n    model_path: str = Field(..., description = \"Model identifier or local path\")\n    hf_token: Optional[str] = Field(\n        None, description = \"HuggingFace token for gated models\"\n    )\n    max_seq_length: int = Field(\n        4096, ge = 128, le = 32768, description = \"Maximum sequence length\"\n    )\n    load_in_4bit: bool = Field(True, description = \"Load model in 4-bit quantization\")\n    is_lora: bool = Field(False, description = \"Whether this is a LoRA adapter\")\n    gguf_variant: Optional[str] = Field(\n        None, description = \"GGUF quantization variant (e.g. 'Q4_K_M')\"\n    )\n    trust_remote_code: bool = Field(\n        False,\n        description = \"Allow loading models with custom code (e.g. NVIDIA Nemotron). Only enable for repos you trust.\",\n    )\n    chat_template_override: Optional[str] = Field(\n        None,\n        description = \"Custom Jinja2 chat template to use instead of the model's default\",\n    )\n    cache_type_kv: Optional[str] = Field(\n        None,\n        description = \"KV cache data type for both K and V (e.g. 'f16', 'bf16', 'q8_0', 'q4_1', 'q5_1')\",\n    )\n\n\nclass UnloadRequest(BaseModel):\n    \"\"\"Request to unload a model\"\"\"\n\n    model_path: str = Field(..., description = \"Model identifier to unload\")\n\n\nclass ValidateModelRequest(BaseModel):\n    \"\"\"\n    Lightweight validation request to check whether a model identifier\n    *can be resolved* into a ModelConfig.\n\n    This does NOT actually load weights into GPU memory.\n    \"\"\"\n\n    model_path: str = Field(..., description = \"Model identifier or local path\")\n    hf_token: Optional[str] = Field(\n        None, description = \"HuggingFace token for gated models\"\n    )\n    gguf_variant: Optional[str] = Field(\n        None, description = \"GGUF quantization variant (e.g. 'Q4_K_M')\"\n    )\n\n\nclass ValidateModelResponse(BaseModel):\n    \"\"\"\n    Result of model validation.\n\n    valid == True means ModelConfig.from_identifier() succeeded and basic\n    introspection (GGUF / LoRA / vision flags) is available.\n    \"\"\"\n\n    valid: bool = Field(..., description = \"Whether the model identifier looks valid\")\n    message: str = Field(..., description = \"Human-readable validation message\")\n    identifier: Optional[str] = Field(None, description = \"Resolved model identifier\")\n    display_name: Optional[str] = Field(\n        None, description = \"Display name derived from identifier\"\n    )\n    is_gguf: bool = Field(False, description = \"Whether this is a GGUF model (llama.cpp)\")\n    is_lora: bool = Field(False, description = \"Whether this is a LoRA adapter\")\n    is_vision: bool = Field(False, description = \"Whether this is a vision-capable model\")\n\n\nclass GenerateRequest(BaseModel):\n    \"\"\"Request for text generation (legacy /generate/stream endpoint)\"\"\"\n\n    messages: List[dict] = Field(..., description = \"Chat messages in OpenAI format\")\n    system_prompt: str = Field(\"\", description = \"System prompt\")\n    temperature: float = Field(0.6, ge = 0.0, le = 2.0, description = \"Sampling temperature\")\n    top_p: float = Field(0.95, ge = 0.0, le = 1.0, description = \"Top-p sampling\")\n    top_k: int = Field(20, ge = -1, le = 100, description = \"Top-k sampling\")\n    max_new_tokens: int = Field(\n        2048, ge = 1, le = 4096, description = \"Maximum tokens to generate\"\n    )\n    repetition_penalty: float = Field(\n        1.0, ge = 1.0, le = 2.0, description = \"Repetition penalty\"\n    )\n    presence_penalty: float = Field(0.0, ge = 0.0, le = 2.0, description = \"Presence penalty\")\n    image_base64: Optional[str] = Field(\n        None, description = \"Base64 encoded image for vision models\"\n    )\n\n\nclass LoadResponse(BaseModel):\n    \"\"\"Response after loading a model\"\"\"\n\n    status: str = Field(..., description = \"Load status\")\n    model: str = Field(..., description = \"Model identifier\")\n    display_name: str = Field(..., description = \"Display name of the model\")\n    is_vision: bool = Field(False, description = \"Whether model is a vision model\")\n    is_lora: bool = Field(False, description = \"Whether model is a LoRA adapter\")\n    is_gguf: bool = Field(\n        False, description = \"Whether model is a GGUF model (llama.cpp)\"\n    )\n    is_audio: bool = Field(False, description = \"Whether model is a TTS audio model\")\n    audio_type: Optional[str] = Field(\n        None, description = \"Audio codec type: snac, csm, bicodec, dac\"\n    )\n    has_audio_input: bool = Field(\n        False, description = \"Whether model accepts audio input (ASR)\"\n    )\n    inference: dict = Field(\n        ..., description = \"Inference parameters (temperature, top_p, top_k, min_p)\"\n    )\n    context_length: Optional[int] = Field(\n        None, description = \"Model's native context length (from GGUF metadata)\"\n    )\n    supports_reasoning: bool = Field(\n        False,\n        description = \"Whether model supports thinking/reasoning mode (enable_thinking)\",\n    )\n    supports_tools: bool = Field(\n        False,\n        description = \"Whether model supports tool calling (web search, etc.)\",\n    )\n    cache_type_kv: Optional[str] = Field(\n        None,\n        description = \"KV cache data type for K and V (e.g. 'f16', 'bf16', 'q8_0')\",\n    )\n    chat_template: Optional[str] = Field(\n        None,\n        description = \"Jinja2 chat template string (from GGUF metadata or tokenizer)\",\n    )\n\n\nclass UnloadResponse(BaseModel):\n    \"\"\"Response after unloading a model\"\"\"\n\n    status: str = Field(..., description = \"Unload status\")\n    model: str = Field(..., description = \"Model identifier that was unloaded\")\n\n\nclass InferenceStatusResponse(BaseModel):\n    \"\"\"Current inference backend status\"\"\"\n\n    active_model: Optional[str] = Field(\n        None, description = \"Currently active model identifier\"\n    )\n    is_vision: bool = Field(\n        False, description = \"Whether the active model is a vision model\"\n    )\n    is_gguf: bool = Field(\n        False, description = \"Whether the active model is a GGUF model (llama.cpp)\"\n    )\n    gguf_variant: Optional[str] = Field(\n        None, description = \"GGUF quantization variant (e.g. Q4_K_M)\"\n    )\n    is_audio: bool = Field(\n        False, description = \"Whether the active model is a TTS audio model\"\n    )\n    audio_type: Optional[str] = Field(\n        None, description = \"Audio codec type: snac, csm, bicodec, dac\"\n    )\n    has_audio_input: bool = Field(\n        False, description = \"Whether model accepts audio input (ASR)\"\n    )\n    loading: List[str] = Field(\n        default_factory = list, description = \"Models currently being loaded\"\n    )\n    loaded: List[str] = Field(\n        default_factory = list, description = \"Models currently loaded\"\n    )\n    inference: Optional[Dict[str, Any]] = Field(\n        None, description = \"Recommended inference parameters for the active model\"\n    )\n    supports_reasoning: bool = Field(\n        False, description = \"Whether the active model supports reasoning/thinking mode\"\n    )\n    supports_tools: bool = Field(\n        False, description = \"Whether the active model supports tool calling\"\n    )\n    context_length: Optional[int] = Field(\n        None, description = \"Context length of the active model\"\n    )\n\n\n# =====================================================================\n# OpenAI-Compatible Chat Completions Models\n# =====================================================================\n\n\n# ── Multimodal content parts (OpenAI vision format) ──────────────\n\n\nclass TextContentPart(BaseModel):\n    \"\"\"Text content part in a multimodal message.\"\"\"\n\n    type: Literal[\"text\"]\n    text: str\n\n\nclass ImageUrl(BaseModel):\n    \"\"\"Image URL object — supports data URIs and remote URLs.\"\"\"\n\n    url: str = Field(..., description = \"data:image/png;base64,... or https://...\")\n    detail: Optional[Literal[\"auto\", \"low\", \"high\"]] = \"auto\"\n\n\nclass ImageContentPart(BaseModel):\n    \"\"\"Image content part in a multimodal message.\"\"\"\n\n    type: Literal[\"image_url\"]\n    image_url: ImageUrl\n\n\ndef _content_part_discriminator(v):\n    if isinstance(v, dict):\n        return v.get(\"type\")\n    return getattr(v, \"type\", None)\n\n\nContentPart = Annotated[\n    Union[\n        Annotated[TextContentPart, Tag(\"text\")],\n        Annotated[ImageContentPart, Tag(\"image_url\")],\n    ],\n    Discriminator(_content_part_discriminator),\n]\n\"\"\"Union type for multimodal content parts, discriminated by the 'type' field.\"\"\"\n\n\n# ── Messages ─────────────────────────────────────────────────────\n\n\nclass ChatMessage(BaseModel):\n    \"\"\"\n    A single message in the conversation.\n\n    ``content`` may be a plain string (text-only) or a list of\n    content parts for multimodal messages (OpenAI vision format).\n    \"\"\"\n\n    role: Literal[\"system\", \"user\", \"assistant\"] = Field(\n        ..., description = \"Message role\"\n    )\n    content: Union[str, list[ContentPart]] = Field(\n        ..., description = \"Message content (string or multimodal parts)\"\n    )\n\n\nclass ChatCompletionRequest(BaseModel):\n    \"\"\"\n    OpenAI-compatible chat completion request.\n\n    Extensions (non-OpenAI fields) are marked with 'x-unsloth'.\n    \"\"\"\n\n    model: str = Field(\n        \"default\",\n        description = \"Model identifier (informational; the active model is used)\",\n    )\n    messages: list[ChatMessage] = Field(..., description = \"Conversation messages\")\n    stream: bool = Field(True, description = \"Whether to stream the response via SSE\")\n    temperature: float = Field(0.6, ge = 0.0, le = 2.0)\n    top_p: float = Field(0.95, ge = 0.0, le = 1.0)\n    max_tokens: Optional[int] = Field(\n        None, ge = 1, description = \"Maximum tokens to generate (None = until EOS)\"\n    )\n    presence_penalty: float = Field(0.0, ge = 0.0, le = 2.0, description = \"Presence penalty\")\n\n    # ── Unsloth extensions (ignored by standard OpenAI clients) ──\n    top_k: int = Field(20, ge = -1, le = 100, description = \"[x-unsloth] Top-k sampling\")\n    min_p: float = Field(\n        0.01, ge = 0.0, le = 1.0, description = \"[x-unsloth] Min-p sampling threshold\"\n    )\n    repetition_penalty: float = Field(\n        1.1, ge = 1.0, le = 2.0, description = \"[x-unsloth] Repetition penalty\"\n    )\n    image_base64: Optional[str] = Field(\n        None, description = \"[x-unsloth] Base64-encoded image for vision models\"\n    )\n    audio_base64: Optional[str] = Field(\n        None, description = \"[x-unsloth] Base64-encoded WAV for audio-input models (ASR)\"\n    )\n    use_adapter: Optional[Union[bool, str]] = Field(\n        None,\n        description = (\n            \"[x-unsloth] Adapter control for compare mode. \"\n            \"null = no change (default), \"\n            \"false = disable adapters (base model), \"\n            \"true = enable the current adapter, \"\n            \"string = enable a specific adapter by name.\"\n        ),\n    )\n    enable_thinking: Optional[bool] = Field(\n        None,\n        description = \"[x-unsloth] Enable/disable thinking/reasoning mode for supported models\",\n    )\n    enable_tools: Optional[bool] = Field(\n        None,\n        description = \"[x-unsloth] Enable tool calling for supported models\",\n    )\n    enabled_tools: Optional[list[str]] = Field(\n        None,\n        description = \"[x-unsloth] List of enabled tool names (e.g. ['web_search', 'python', 'terminal']). If None, all tools are enabled.\",\n    )\n    auto_heal_tool_calls: Optional[bool] = Field(\n        True,\n        description = \"[x-unsloth] Auto-detect and fix malformed tool calls from model output.\",\n    )\n    max_tool_calls_per_message: Optional[int] = Field(\n        10,\n        ge = 0,\n        description = \"[x-unsloth] Maximum number of tool call iterations per message (0 = disabled, 9999 = unlimited).\",\n    )\n    tool_call_timeout: Optional[int] = Field(\n        300,\n        ge = 1,\n        description = \"[x-unsloth] Timeout in seconds for each tool call execution (9999 = no limit).\",\n    )\n    session_id: Optional[str] = Field(\n        None,\n        description = \"[x-unsloth] Session/thread ID for scoping tool execution sandbox.\",\n    )\n\n\n# ── Streaming response chunks ────────────────────────────────────\n\n\nclass ChoiceDelta(BaseModel):\n    \"\"\"Delta content for a streaming chunk.\"\"\"\n\n    role: Optional[str] = None\n    content: Optional[str] = None\n\n\nclass ChunkChoice(BaseModel):\n    \"\"\"A single choice in a streaming chunk.\"\"\"\n\n    index: int = 0\n    delta: ChoiceDelta\n    finish_reason: Optional[Literal[\"stop\", \"length\"]] = None\n\n\nclass ChatCompletionChunk(BaseModel):\n    \"\"\"A single SSE chunk in OpenAI streaming format.\"\"\"\n\n    id: str = Field(default_factory = lambda: f\"chatcmpl-{uuid.uuid4().hex[:12]}\")\n    object: Literal[\"chat.completion.chunk\"] = \"chat.completion.chunk\"\n    created: int = Field(default_factory = lambda: int(time.time()))\n    model: str = \"default\"\n    choices: list[ChunkChoice]\n\n\n# ── Non-streaming response ───────────────────────────────────────\n\n\nclass CompletionMessage(BaseModel):\n    \"\"\"The assistant's complete response message.\"\"\"\n\n    role: Literal[\"assistant\"] = \"assistant\"\n    content: str\n\n\nclass CompletionChoice(BaseModel):\n    \"\"\"A single choice in a non-streaming response.\"\"\"\n\n    index: int = 0\n    message: CompletionMessage\n    finish_reason: Literal[\"stop\", \"length\"] = \"stop\"\n\n\nclass CompletionUsage(BaseModel):\n    \"\"\"Token usage statistics (approximate).\"\"\"\n\n    prompt_tokens: int = 0\n    completion_tokens: int = 0\n    total_tokens: int = 0\n\n\nclass ChatCompletion(BaseModel):\n    \"\"\"Non-streaming chat completion response.\"\"\"\n\n    id: str = Field(default_factory = lambda: f\"chatcmpl-{uuid.uuid4().hex[:12]}\")\n    object: Literal[\"chat.completion\"] = \"chat.completion\"\n    created: int = Field(default_factory = lambda: int(time.time()))\n    model: str = \"default\"\n    choices: list[CompletionChoice]\n    usage: CompletionUsage = Field(default_factory = CompletionUsage)\n"
  },
  {
    "path": "studio/backend/models/models.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nPydantic schemas for Model Management API\n\"\"\"\n\nfrom pydantic import BaseModel, Field\nfrom typing import Optional, List, Dict, Any, Literal\n\nModelType = Literal[\"text\", \"vision\", \"audio\", \"embeddings\"]\n\n\nclass CheckpointInfo(BaseModel):\n    \"\"\"Information about a discovered checkpoint directory.\"\"\"\n\n    display_name: str = Field(\n        ..., description = \"User-friendly checkpoint name (folder name)\"\n    )\n    path: str = Field(..., description = \"Full path to the checkpoint directory\")\n    loss: Optional[float] = Field(None, description = \"Training loss at this checkpoint\")\n\n\nclass ModelCheckpoints(BaseModel):\n    \"\"\"A training run and its associated checkpoints.\"\"\"\n\n    name: str = Field(..., description = \"Training run folder name\")\n    checkpoints: List[CheckpointInfo] = Field(\n        default_factory = list,\n        description = \"List of checkpoints for this training run (final + intermediate)\",\n    )\n    base_model: Optional[str] = Field(\n        None,\n        description = \"Base model name from adapter_config.json or config.json\",\n    )\n    peft_type: Optional[str] = Field(\n        None,\n        description = \"PEFT type (e.g. LORA) if adapter training, None for full fine-tune\",\n    )\n    lora_rank: Optional[int] = Field(\n        None,\n        description = \"LoRA rank (r) if applicable\",\n    )\n\n\nclass CheckpointListResponse(BaseModel):\n    \"\"\"Response for listing available checkpoints in an outputs directory.\"\"\"\n\n    outputs_dir: str = Field(..., description = \"Directory that was scanned\")\n    models: List[ModelCheckpoints] = Field(\n        default_factory = list,\n        description = \"List of training runs with their checkpoints\",\n    )\n\n\nclass ModelDetails(BaseModel):\n    \"\"\"Detailed model configuration and metadata - can be used for both list and detail views\"\"\"\n\n    id: str = Field(..., description = \"Model identifier\")\n    model_name: Optional[str] = Field(\n        None, description = \"Model identifier (alias for id, for backward compatibility)\"\n    )\n    name: Optional[str] = Field(None, description = \"Display name for the model\")\n    config: Optional[Dict[str, Any]] = Field(\n        None, description = \"Model configuration dictionary\"\n    )\n    is_vision: bool = Field(False, description = \"Whether model is a vision model\")\n    is_embedding: bool = Field(\n        False, description = \"Whether model is an embedding/sentence-transformer model\"\n    )\n    is_lora: bool = Field(False, description = \"Whether model is a LoRA adapter\")\n    is_gguf: bool = Field(\n        False, description = \"Whether model is a GGUF model (llama.cpp format)\"\n    )\n    is_audio: bool = Field(False, description = \"Whether model is a TTS audio model\")\n    audio_type: Optional[str] = Field(\n        None, description = \"Audio codec type: snac, csm, bicodec, dac\"\n    )\n    has_audio_input: bool = Field(\n        False, description = \"Whether model accepts audio input (ASR)\"\n    )\n    model_type: Optional[ModelType] = Field(\n        None, description = \"Collapsed model modality: text, vision, audio, or embeddings\"\n    )\n    base_model: Optional[str] = Field(\n        None, description = \"Base model if this is a LoRA adapter\"\n    )\n    max_position_embeddings: Optional[int] = Field(\n        None, description = \"Maximum context length supported by the model\"\n    )\n    model_size_bytes: Optional[int] = Field(\n        None, description = \"Total size of model weight files in bytes\"\n    )\n\n\nclass LoRAInfo(BaseModel):\n    \"\"\"LoRA adapter or exported model information\"\"\"\n\n    display_name: str = Field(..., description = \"Display name for the LoRA\")\n    adapter_path: str = Field(\n        ..., description = \"Path to the LoRA adapter or exported model\"\n    )\n    base_model: Optional[str] = Field(None, description = \"Base model identifier\")\n    source: Optional[str] = Field(None, description = \"'training' or 'exported'\")\n    export_type: Optional[str] = Field(\n        None, description = \"'lora', 'merged', or 'gguf' (for exports)\"\n    )\n\n\nclass LoRAScanResponse(BaseModel):\n    \"\"\"Response schema for scanning trained LoRA adapters\"\"\"\n\n    loras: List[LoRAInfo] = Field(\n        default_factory = list, description = \"List of found LoRA adapters\"\n    )\n    outputs_dir: str = Field(..., description = \"Directory that was scanned\")\n\n\nclass ModelListResponse(BaseModel):\n    \"\"\"Response schema for listing models\"\"\"\n\n    models: List[ModelDetails] = Field(\n        default_factory = list, description = \"List of models\"\n    )\n    default_models: List[str] = Field(\n        default_factory = list, description = \"List of default model IDs\"\n    )\n\n\nclass GgufVariantDetail(BaseModel):\n    \"\"\"A single GGUF quantization variant in a HuggingFace repo.\"\"\"\n\n    filename: str = Field(\n        ..., description = \"GGUF filename (e.g., 'gemma-3-4b-it-Q4_K_M.gguf')\"\n    )\n    quant: str = Field(..., description = \"Quantization label (e.g., 'Q4_K_M')\")\n    size_bytes: int = Field(0, description = \"File size in bytes\")\n    downloaded: bool = Field(\n        False, description = \"Whether this variant is already in the local HF cache\"\n    )\n\n\nclass GgufVariantsResponse(BaseModel):\n    \"\"\"Response for listing GGUF quantization variants in a HuggingFace repo.\"\"\"\n\n    repo_id: str = Field(..., description = \"HuggingFace repo ID\")\n    variants: List[GgufVariantDetail] = Field(\n        default_factory = list, description = \"Available GGUF variants\"\n    )\n    has_vision: bool = Field(\n        False, description = \"Whether the model has vision support (mmproj files)\"\n    )\n    default_variant: Optional[str] = Field(\n        None, description = \"Recommended default quantization variant\"\n    )\n\n\nclass LocalModelInfo(BaseModel):\n    \"\"\"Discovered local model candidate.\"\"\"\n\n    id: str = Field(..., description = \"Identifier to use for loading/training\")\n    display_name: str = Field(..., description = \"Display label\")\n    path: str = Field(..., description = \"Local path where model data was discovered\")\n    source: Literal[\"models_dir\", \"hf_cache\"] = Field(\n        ...,\n        description = \"Discovery source\",\n    )\n    model_id: Optional[str] = Field(\n        None,\n        description = \"HF repo id for cached models, e.g. org/model\",\n    )\n    updated_at: Optional[float] = Field(\n        None,\n        description = \"Unix timestamp of latest observed update\",\n    )\n\n\nclass LocalModelListResponse(BaseModel):\n    \"\"\"Response schema for listing local/cached models.\"\"\"\n\n    models_dir: str = Field(\n        ..., description = \"Directory scanned for custom local models\"\n    )\n    hf_cache_dir: Optional[str] = Field(\n        None,\n        description = \"HF cache root that was scanned\",\n    )\n    models: List[LocalModelInfo] = Field(\n        default_factory = list,\n        description = \"Discovered local/cached models\",\n    )\n"
  },
  {
    "path": "studio/backend/models/responses.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nPydantic response schemas for endpoints that previously returned raw dicts.\nThese are small response models for training and model management routes.\n\"\"\"\n\nfrom pydantic import BaseModel, Field\nfrom typing import Optional, List\n\n\n# --- Training route response models ---\n\n\nclass TrainingStopResponse(BaseModel):\n    \"\"\"Response for stopping a training job\"\"\"\n\n    status: str = Field(..., description = \"Current status: 'stopped' or 'idle'\")\n    message: str = Field(..., description = \"Human-readable status message\")\n\n\nclass TrainingMetricsResponse(BaseModel):\n    \"\"\"Response for training metrics history\"\"\"\n\n    loss_history: List[float] = Field(\n        default_factory = list, description = \"Loss values per step\"\n    )\n    lr_history: List[float] = Field(\n        default_factory = list, description = \"Learning rate per step\"\n    )\n    step_history: List[int] = Field(default_factory = list, description = \"Step numbers\")\n    grad_norm_history: List[float] = Field(\n        default_factory = list, description = \"Gradient norm values\"\n    )\n    grad_norm_step_history: List[int] = Field(\n        default_factory = list, description = \"Step numbers for gradient norm values\"\n    )\n    current_loss: Optional[float] = Field(None, description = \"Most recent loss value\")\n    current_lr: Optional[float] = Field(None, description = \"Most recent learning rate\")\n    current_step: Optional[int] = Field(None, description = \"Most recent step number\")\n\n\n# --- Model management route response models ---\n\n\nclass LoRABaseModelResponse(BaseModel):\n    \"\"\"Response for getting a LoRA's base model\"\"\"\n\n    lora_path: str = Field(..., description = \"Path to the LoRA adapter\")\n    base_model: str = Field(..., description = \"Base model identifier\")\n\n\nclass VisionCheckResponse(BaseModel):\n    \"\"\"Response for checking if a model is a vision model\"\"\"\n\n    model_name: str = Field(..., description = \"Model identifier\")\n    is_vision: bool = Field(..., description = \"Whether the model is a vision model\")\n\n\nclass EmbeddingCheckResponse(BaseModel):\n    \"\"\"Response for checking if a model is an embedding model\"\"\"\n\n    model_name: str = Field(..., description = \"Model identifier\")\n    is_embedding: bool = Field(\n        ..., description = \"Whether the model is an embedding/sentence-transformer model\"\n    )\n"
  },
  {
    "path": "studio/backend/models/training.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nPydantic schemas for Training API\n\"\"\"\n\nfrom pydantic import BaseModel, Field, model_validator\nfrom typing import Any, Optional, List, Dict, Literal\n\n\nclass TrainingStartRequest(BaseModel):\n    \"\"\"Request schema for starting training\"\"\"\n\n    # Model parameters\n    model_name: str = Field(\n        ..., description = \"Model identifier (e.g., 'unsloth/llama-3-8b-bnb-4bit')\"\n    )\n    training_type: str = Field(\n        ..., description = \"Training type: 'LoRA/QLoRA' or 'Full Finetuning'\"\n    )\n    hf_token: Optional[str] = Field(None, description = \"HuggingFace token\")\n    load_in_4bit: bool = Field(True, description = \"Load model in 4-bit quantization\")\n    max_seq_length: int = Field(2048, description = \"Maximum sequence length\")\n    trust_remote_code: bool = Field(\n        False,\n        description = \"Allow loading models with custom code (e.g. NVIDIA Nemotron). Only enable for repos you trust.\",\n    )\n\n    # Dataset parameters\n    hf_dataset: Optional[str] = Field(\n        None, description = \"HuggingFace dataset identifier\"\n    )\n    local_datasets: List[str] = Field(\n        default_factory = list, description = \"List of local dataset paths\"\n    )\n    local_eval_datasets: List[str] = Field(\n        default_factory = list, description = \"List of local eval dataset paths\"\n    )\n    format_type: str = Field(..., description = \"Dataset format type\")\n    subset: Optional[str] = None\n    train_split: Optional[str] = Field(\"train\", description = \"Training split name\")\n    eval_split: Optional[str] = Field(\n        None, description = \"Eval split name. None = auto-detect\"\n    )\n    eval_steps: float = Field(\n        0.00, description = \"Fraction of total steps between evals (0-1)\"\n    )\n    dataset_slice_start: Optional[int] = Field(\n        None, description = \"Inclusive start row index for dataset slicing\"\n    )\n    dataset_slice_end: Optional[int] = Field(\n        None, description = \"Inclusive end row index for dataset slicing\"\n    )\n\n    @model_validator(mode = \"before\")\n    @classmethod\n    def _compat_split(cls, values: Any) -> Any:\n        \"\"\"Accept legacy 'split' field as alias for 'train_split'.\"\"\"\n        if isinstance(values, dict) and \"split\" in values:\n            values.setdefault(\"train_split\", values.pop(\"split\"))\n        return values\n\n    custom_format_mapping: Optional[Dict[str, Any]] = Field(\n        None,\n        description = (\n            \"User-provided column-to-role mapping, e.g. {'image': 'image', 'caption': 'text'} \"\n            \"for VLM or {'instruction': 'user', 'output': 'assistant'} for LLM. \"\n            \"Enhanced format includes __system_prompt, __user_template, \"\n            \"__assistant_template, __label_mapping metadata keys.\"\n        ),\n    )\n    # Training parameters\n    num_epochs: int = Field(1, description = \"Number of training epochs\")\n    learning_rate: str = Field(\"2e-4\", description = \"Learning rate\")\n    batch_size: int = Field(1, description = \"Batch size\")\n    gradient_accumulation_steps: int = Field(\n        1, description = \"Gradient accumulation steps\"\n    )\n    warmup_steps: Optional[int] = Field(None, description = \"Warmup steps\")\n    warmup_ratio: Optional[float] = Field(None, description = \"Warmup ratio\")\n    max_steps: Optional[int] = Field(None, description = \"Maximum training steps\")\n    save_steps: int = Field(100, description = \"Steps between checkpoints\")\n    weight_decay: float = Field(0.01, description = \"Weight decay\")\n    random_seed: int = Field(42, description = \"Random seed\")\n    packing: bool = Field(False, description = \"Enable sequence packing\")\n    optim: str = Field(\"adamw_8bit\", description = \"Optimizer\")\n    lr_scheduler_type: str = Field(\"linear\", description = \"Learning rate scheduler type\")\n\n    # LoRA parameters\n    use_lora: bool = Field(True, description = \"Use LoRA (derived from training_type)\")\n    lora_r: int = Field(16, description = \"LoRA rank\")\n    lora_alpha: int = Field(16, description = \"LoRA alpha\")\n    lora_dropout: float = Field(0.0, description = \"LoRA dropout\")\n    target_modules: List[str] = Field(\n        default_factory = list, description = \"Target modules for LoRA\"\n    )\n    gradient_checkpointing: str = Field(\n        \"\", description = \"Gradient checkpointing setting\"\n    )\n    use_rslora: bool = Field(False, description = \"Use RSLoRA\")\n    use_loftq: bool = Field(False, description = \"Use LoftQ\")\n    train_on_completions: bool = Field(False, description = \"Train on completions only\")\n\n    # Vision-specific LoRA parameters\n    finetune_vision_layers: bool = Field(False, description = \"Finetune vision layers\")\n    finetune_language_layers: bool = Field(\n        False, description = \"Finetune language layers\"\n    )\n    finetune_attention_modules: bool = Field(\n        False, description = \"Finetune attention modules\"\n    )\n    finetune_mlp_modules: bool = Field(False, description = \"Finetune MLP modules\")\n    is_dataset_image: bool = Field(\n        False, description = \"Whether the dataset contains image data\"\n    )\n    is_dataset_audio: bool = Field(\n        False, description = \"Whether the dataset contains audio data\"\n    )\n    is_embedding: bool = Field(\n        False, description = \"Whether model is an embedding/sentence-transformer model\"\n    )\n\n    # Logging parameters\n    enable_wandb: bool = Field(False, description = \"Enable Weights & Biases logging\")\n    wandb_token: Optional[str] = Field(None, description = \"W&B token\")\n    wandb_project: Optional[str] = Field(None, description = \"W&B project name\")\n    enable_tensorboard: bool = Field(False, description = \"Enable TensorBoard logging\")\n    tensorboard_dir: Optional[str] = Field(None, description = \"TensorBoard directory\")\n\n\nclass TrainingJobResponse(BaseModel):\n    \"\"\"Immediate response when training is initiated\"\"\"\n\n    job_id: str = Field(..., description = \"Unique training job identifier\")\n    status: Literal[\"queued\", \"error\"] = Field(..., description = \"Initial job status\")\n    message: str = Field(..., description = \"Human-readable status message\")\n    error: Optional[str] = Field(None, description = \"Error details if status is 'error'\")\n\n\nclass TrainingStatus(BaseModel):\n    \"\"\"Current training job status - works for streaming or polling\"\"\"\n\n    job_id: str = Field(..., description = \"Training job identifier\")\n    phase: Literal[\n        \"idle\",\n        \"loading_model\",\n        \"loading_dataset\",\n        \"configuring\",\n        \"training\",\n        \"completed\",\n        \"error\",\n        \"stopped\",\n    ] = Field(..., description = \"Current phase of training pipeline\")\n    is_training_running: bool = Field(\n        ..., description = \"True if training loop is actively running\"\n    )\n    eval_enabled: bool = Field(\n        False,\n        description = \"True if evaluation dataset is configured for this training run\",\n    )\n    message: str = Field(..., description = \"Human-readable status message\")\n    error: Optional[str] = Field(None, description = \"Error details if phase is 'error'\")\n    details: Optional[dict] = Field(\n        None, description = \"Phase-specific info, e.g. {'model_size': '8B'}\"\n    )\n    metric_history: Optional[dict] = Field(\n        None,\n        description = \"Full metric history arrays for chart recovery after SSE reconnection. \"\n        \"Keys: 'steps', 'loss', 'lr', 'grad_norm', 'grad_norm_steps' — each a list of numeric values.\",\n    )\n\n\nclass TrainingProgress(BaseModel):\n    \"\"\"Training progress metrics - for streaming or polling\"\"\"\n\n    job_id: str = Field(..., description = \"Training job identifier\")\n    step: int = Field(..., description = \"Current training step\")\n    total_steps: int = Field(..., description = \"Total training steps\")\n    loss: float = Field(..., description = \"Current loss value\")\n    learning_rate: float = Field(..., description = \"Current learning rate\")\n    progress_percent: float = Field(\n        ..., description = \"Progress percentage (0.0 to 100.0)\"\n    )\n    epoch: Optional[float] = Field(None, description = \"Current epoch\")\n    elapsed_seconds: Optional[float] = Field(\n        None, description = \"Time elapsed since training started\"\n    )\n    eta_seconds: Optional[float] = Field(None, description = \"Estimated time remaining\")\n    grad_norm: Optional[float] = Field(\n        None, description = \"L2 norm of gradients, computed before gradient clipping\"\n    )\n    num_tokens: Optional[int] = Field(\n        None, description = \"Total number of tokens processed so far\"\n    )\n    eval_loss: Optional[float] = Field(\n        None, description = \"Eval loss from the most recent evaluation step\"\n    )\n"
  },
  {
    "path": "studio/backend/models/users.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"Pydantic models for authentication tokens.\n\nThis module defines the Token response model used by auth routes.\n\"\"\"\n\nfrom pydantic import BaseModel, Field\n\n\nclass Token(BaseModel):\n    \"\"\"Authentication response model for session credentials.\"\"\"\n\n    access_token: str = Field(\n        ..., description = \"Session access credential used for authenticated API requests\"\n    )\n    refresh_token: str = Field(\n        ...,\n        description = \"Session refresh credential used to renew an expired access credential\",\n    )\n    token_type: str = Field(\n        ..., description = \"Credential type for the Authorization header, always 'bearer'\"\n    )\n    must_change_password: bool = Field(\n        ..., description = \"True when the user must change the seeded default password\"\n    )\n"
  },
  {
    "path": "studio/backend/plugins/__init__.py",
    "content": ""
  },
  {
    "path": "studio/backend/plugins/data-designer-unstructured-seed/__init__.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n"
  },
  {
    "path": "studio/backend/plugins/data-designer-unstructured-seed/pyproject.toml",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n[build-system]\nrequires = [\"setuptools>=68\", \"wheel\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"data-designer-unstructured-seed\"\nversion = \"0.1.0\"\ndescription = \"Local Data Designer unstructured seed reader plugin\"\nrequires-python = \">=3.11\"\ndependencies = [\n  \"data-designer-engine>=0.5.1,<0.6\",\n  \"pandas>=2,<3\",\n]\n\n[project.entry-points.\"data_designer.plugins\"]\nunstructured = \"data_designer_unstructured_seed.plugin:unstructured_seed_plugin\"\n\n[tool.setuptools]\npackage-dir = {\"\" = \"src\"}\n\n[tool.setuptools.packages.find]\nwhere = [\"src\"]\n"
  },
  {
    "path": "studio/backend/plugins/data-designer-unstructured-seed/src/data_designer_unstructured_seed/__init__.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nfrom .chunking import (\n    DEFAULT_CHUNK_OVERLAP,\n    DEFAULT_CHUNK_SIZE,\n    build_unstructured_preview_rows,\n    materialize_unstructured_seed_dataset,\n    resolve_chunking,\n)\nfrom .config import UnstructuredSeedSource\nfrom .impl import UnstructuredSeedReader\nfrom .plugin import unstructured_seed_plugin\n\n__all__ = [\n    \"DEFAULT_CHUNK_OVERLAP\",\n    \"DEFAULT_CHUNK_SIZE\",\n    \"build_unstructured_preview_rows\",\n    \"materialize_unstructured_seed_dataset\",\n    \"resolve_chunking\",\n    \"UnstructuredSeedSource\",\n    \"UnstructuredSeedReader\",\n    \"unstructured_seed_plugin\",\n]\n"
  },
  {
    "path": "studio/backend/plugins/data-designer-unstructured-seed/src/data_designer_unstructured_seed/chunking.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nfrom __future__ import annotations\n\nimport hashlib\nimport re\nfrom pathlib import Path\nfrom typing import Any\n\nfrom utils.paths import ensure_dir, unstructured_seed_cache_root\n\nDEFAULT_CHUNK_SIZE = 1200\nDEFAULT_CHUNK_OVERLAP = 200\nMAX_CHUNK_SIZE = 20000\n_MIN_BREAK_RATIO = 0.6\n_CACHE_DIR = unstructured_seed_cache_root()\n\n\ndef resolve_chunking(\n    chunk_size: Any,\n    chunk_overlap: Any,\n) -> tuple[int, int]:\n    size = _to_int(chunk_size, DEFAULT_CHUNK_SIZE)\n    size = max(1, min(size, MAX_CHUNK_SIZE))\n    overlap = _to_int(chunk_overlap, DEFAULT_CHUNK_OVERLAP)\n    overlap = max(0, min(overlap, max(0, size - 1)))\n    return size, overlap\n\n\ndef build_unstructured_preview_rows(\n    *,\n    source_path: Path,\n    preview_size: int,\n    chunk_size: Any,\n    chunk_overlap: Any,\n) -> list[dict[str, str]]:\n    parquet_path, rows = materialize_unstructured_seed_dataset(\n        source_path = source_path,\n        chunk_size = chunk_size,\n        chunk_overlap = chunk_overlap,\n    )\n    count = max(0, int(preview_size))\n    if rows:\n        return rows[:count]\n\n    try:\n        import pandas as pd\n    except ImportError as exc:  # pragma: no cover\n        raise RuntimeError(\n            f\"pandas is required for unstructured seed processing: {exc}\"\n        ) from exc\n\n    dataframe = pd.read_parquet(parquet_path).head(count)\n    return [\n        {\"chunk_text\": str(value.get(\"chunk_text\", \"\")).strip()}\n        for value in dataframe.to_dict(orient = \"records\")\n        if str(value.get(\"chunk_text\", \"\")).strip()\n    ]\n\n\ndef materialize_unstructured_seed_dataset(\n    *,\n    source_path: Path,\n    chunk_size: Any,\n    chunk_overlap: Any,\n) -> tuple[Path, list[dict[str, str]]]:\n    resolved = source_path.expanduser().resolve()\n    if not resolved.is_file():\n        raise FileNotFoundError(f\"unstructured seed file not found: {resolved}\")\n\n    size, overlap = resolve_chunking(chunk_size, chunk_overlap)\n    key = _compute_cache_key(\n        source_path = resolved,\n        chunk_size = size,\n        chunk_overlap = overlap,\n    )\n    parquet_path = _CACHE_DIR / f\"{key}.parquet\"\n    if parquet_path.exists():\n        return parquet_path, []\n\n    text = load_unstructured_text_file(resolved)\n    chunks = split_text_into_chunks(\n        text = text,\n        chunk_size = size,\n        chunk_overlap = overlap,\n    )\n    if not chunks:\n        raise ValueError(\"No text found in unstructured seed source.\")\n\n    rows = [{\"chunk_text\": chunk} for chunk in chunks]\n    ensure_dir(_CACHE_DIR)\n    try:\n        import pandas as pd\n    except ImportError as exc:  # pragma: no cover\n        raise RuntimeError(\n            f\"pandas is required for unstructured seed processing: {exc}\"\n        ) from exc\n\n    tmp_path = _CACHE_DIR / f\"{key}.tmp.parquet\"\n    pd.DataFrame(rows).to_parquet(tmp_path, index = False)\n    tmp_path.replace(parquet_path)\n    return parquet_path, rows\n\n\ndef load_unstructured_text_file(path: Path) -> str:\n    ext = path.suffix.lower()\n    if ext not in {\".txt\", \".md\"}:\n        raise ValueError(f\"Unsupported unstructured seed file type: {ext}\")\n\n    raw = path.read_text(encoding = \"utf-8\", errors = \"ignore\")\n    return normalize_unstructured_text(raw)\n\n\ndef normalize_unstructured_text(text: str) -> str:\n    normalized = text.replace(\"\\r\\n\", \"\\n\").replace(\"\\r\", \"\\n\")\n    return re.sub(r\"\\n{3,}\", \"\\n\\n\", normalized).strip()\n\n\ndef split_text_into_chunks(\n    *,\n    text: str,\n    chunk_size: int,\n    chunk_overlap: int,\n) -> list[str]:\n    if not text:\n        return []\n    if chunk_size <= 0:\n        return [text]\n\n    chunks: list[str] = []\n    start = 0\n    min_break_index = int(chunk_size * _MIN_BREAK_RATIO)\n    text_len = len(text)\n    while start < text_len:\n        end = min(text_len, start + chunk_size)\n        if end < text_len:\n            window = text[start:end]\n            cut = _find_break_index(window, min_break_index)\n            if cut is not None and cut > 0:\n                end = start + cut\n\n        if end <= start:\n            end = min(text_len, start + chunk_size)\n\n        chunk = text[start:end].strip()\n        if chunk:\n            chunks.append(chunk)\n        if end >= text_len:\n            break\n\n        next_start = end - chunk_overlap\n        if next_start <= start:\n            next_start = end\n        start = max(0, next_start)\n\n    return chunks\n\n\ndef _find_break_index(window: str, min_index: int) -> int | None:\n    breakpoints = [\"\\n\\n\", \"\\n\", \" \"]\n    for token in breakpoints:\n        idx = window.rfind(token)\n        if idx >= min_index:\n            return idx + len(token)\n    return None\n\n\ndef _to_int(value: Any, fallback: int) -> int:\n    if isinstance(value, bool):\n        return fallback\n    try:\n        parsed = int(str(value).strip())\n    except (TypeError, ValueError):\n        return fallback\n    return parsed\n\n\ndef _compute_cache_key(\n    *,\n    source_path: Path,\n    chunk_size: int,\n    chunk_overlap: int,\n) -> str:\n    stat = source_path.stat()\n    payload = \"|\".join(\n        [\n            str(source_path),\n            str(stat.st_size),\n            str(stat.st_mtime_ns),\n            str(chunk_size),\n            str(chunk_overlap),\n        ]\n    ).encode(\"utf-8\")\n    return hashlib.sha256(payload).hexdigest()\n"
  },
  {
    "path": "studio/backend/plugins/data-designer-unstructured-seed/src/data_designer_unstructured_seed/config.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nfrom __future__ import annotations\n\nfrom pathlib import Path\nfrom typing import Literal\n\nfrom pydantic import Field, field_validator\n\nfrom data_designer.config.seed_source import SeedSource\n\nfrom .chunking import DEFAULT_CHUNK_OVERLAP, DEFAULT_CHUNK_SIZE, resolve_chunking\n\n\nclass UnstructuredSeedSource(SeedSource):\n    seed_type: Literal[\"unstructured\"] = \"unstructured\"\n    path: str = Field(..., min_length = 1)\n    chunk_size: int = DEFAULT_CHUNK_SIZE\n    chunk_overlap: int = DEFAULT_CHUNK_OVERLAP\n\n    @field_validator(\"path\", mode = \"after\")\n    @classmethod\n    def _validate_path(cls, value: str) -> str:\n        path = Path(value).expanduser()\n        if not path.is_file():\n            raise ValueError(f\"Unstructured seed path is not a file: {path}\")\n        return value\n\n    @field_validator(\"chunk_size\", mode = \"after\")\n    @classmethod\n    def _validate_chunk_size(cls, value: int) -> int:\n        size, _ = resolve_chunking(value, 0)\n        return size\n\n    @field_validator(\"chunk_overlap\", mode = \"after\")\n    @classmethod\n    def _validate_chunk_overlap(cls, value: int, info) -> int:\n        size = info.data.get(\"chunk_size\", cls.model_fields[\"chunk_size\"].default)\n        _, overlap = resolve_chunking(size, value)\n        return overlap\n"
  },
  {
    "path": "studio/backend/plugins/data-designer-unstructured-seed/src/data_designer_unstructured_seed/impl.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nfrom __future__ import annotations\n\nfrom pathlib import Path\n\nimport data_designer.lazy_heavy_imports as lazy\nfrom data_designer.engine.resources.seed_reader import SeedReader\n\nfrom .chunking import materialize_unstructured_seed_dataset\nfrom .config import UnstructuredSeedSource\n\n\nclass UnstructuredSeedReader(SeedReader[UnstructuredSeedSource]):\n    def create_duckdb_connection(self):\n        return lazy.duckdb.connect()\n\n    def get_dataset_uri(self) -> str:\n        path, _ = materialize_unstructured_seed_dataset(\n            source_path = Path(self.source.path),\n            chunk_size = self.source.chunk_size,\n            chunk_overlap = self.source.chunk_overlap,\n        )\n        return str(path)\n"
  },
  {
    "path": "studio/backend/plugins/data-designer-unstructured-seed/src/data_designer_unstructured_seed/plugin.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nfrom data_designer.plugins.plugin import Plugin, PluginType\n\nunstructured_seed_plugin = Plugin(\n    impl_qualified_name = \"data_designer_unstructured_seed.impl.UnstructuredSeedReader\",\n    config_qualified_name = \"data_designer_unstructured_seed.config.UnstructuredSeedSource\",\n    plugin_type = PluginType.SEED_READER,\n)\n"
  },
  {
    "path": "studio/backend/requirements/__init__.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n"
  },
  {
    "path": "studio/backend/requirements/base.txt",
    "content": "# Core unsloth packages\nunsloth-zoo\nunsloth\n"
  },
  {
    "path": "studio/backend/requirements/extras-no-deps.txt",
    "content": "# Audio extras (installed with --no-deps --no-cache-dir)\ndescript-audio-codec\ndescript-audiotools\njulius\ntorchcodec\nsnac\n\n# TRL and related packages\ntrl==0.23.1\ngit+https://github.com/meta-pytorch/OpenEnv.git\n# executorch>=1.0.1               # 41.5 MB - no imports in unsloth/zoo/studio\ntorch-c-dlpack-ext\nsentence_transformers==5.2.0\ntransformers==4.57.6\n"
  },
  {
    "path": "studio/backend/requirements/extras.txt",
    "content": "# OpenEnv dependencies\ntomli\ntomli-w\n\n# ExecuTorch dependencies\nruamel.yaml\n# coremltools                    # 10.2 MB - Apple CoreML, no imports in unsloth/zoo/studio\nexpecttest\nflatbuffers\nhydra-core\nhypothesis\nkgb\nparameterized\npytest<9.0\npytest-json-report\npytest-rerunfailures==15.1\npytest-xdist\n# Also needed by sentence_transformers (installed with --no-deps in extras-no-deps.txt)\nscikit-learn==1.7.1\n\n# Additional extras\npybind11\nlangid\njiwer\nomegaconf\neinx\npyloudnorm\nopenai-whisper\nuroman                           # 4.0 MB - used for Outetts.\nMeCab                            # 19.9 MB - used for Outetts.\ninflect                          # number-to-words, required by OuteTTS\nloguru\nflatten_dict\nffmpy\nrandomname\nargbind\ntiktoken\nftfy\nimportlib-resources\nlibrosa\nmarkdown2\nmatplotlib\npystoi\nsoundfile\ntensorboard\ntorch-stoi\nevaluate\ntimm\ntransformers-cfg\nopen_spiel\naddict\neasydict\neinops\ntabulate\nfastmcp>=3.0.2\nopenai>=2.7.2\nwebsockets>=15.0.1\n"
  },
  {
    "path": "studio/backend/requirements/overrides.txt",
    "content": "# Torch AO overrides (installed with --force-reinstall --no-cache-dir)\ntorchao==0.14.0\npytorch_tokenizers\n\n# Kernel packages\nkernels\n"
  },
  {
    "path": "studio/backend/requirements/single-env/constraints.txt",
    "content": "# Single-env pins for unsloth + studio + data-designer\n# Keep compatible with unsloth transformers bounds.\ntransformers==4.57.6\ntrl==0.23.1\nhuggingface-hub==0.36.2\n\n# Studio stack\ndatasets==4.3.0\npyarrow==23.0.1\n\n# FastMCP/OpenEnv compat\nfastmcp>=3.0.2\nmcp>=1.24,<2\nwebsockets>=15.0.1\n\npandas==2.3.3\n"
  },
  {
    "path": "studio/backend/requirements/single-env/data-designer-deps.txt",
    "content": "# Data Designer runtime deps installed explicitly (single-env mode).\n# DuckDB 1.5 removed Relation.record_batch(); keep <1.5 until upstream ships the fix.\nanyascii<1,>=0.3.3\nduckdb<1.5,>=1.1.3\nfaker<21,>=20.1.0\nhttpx<1,>=0.27.2\nhttpx-retries<1,>=0.4.2\njson-repair<1,>=0.48.0\njsonpath-rust-bindings<2,>=1.0\njsonschema<5,>=4.0.0\nlitellm<1.80.12,>=1.73.6\nlxml<7,>=6.0.2\nmarko<3,>=2.1.2\nnetworkx<4,>=3.0\npython-json-logger<4,>=3\nruff<1,>=0.14.10\nscipy<2,>=1.11.0\nsqlfluff<4,>=3.2.0\ntiktoken<1,>=0.8.0\n"
  },
  {
    "path": "studio/backend/requirements/single-env/data-designer.txt",
    "content": "# Install Data Designer in same env as Unsloth.\ndata-designer==0.5.2\ndata-designer-config==0.5.2\ndata-designer-engine==0.5.2\nprompt-toolkit>=3,<4\n"
  },
  {
    "path": "studio/backend/requirements/single-env/patch_metadata.py",
    "content": "#!/usr/bin/env python3\n# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"Relax strict metadata pins so pip check matches known working single-env stack.\n\nWhy:\n- data-designer pins huggingface-hub>=1.0.1 and pyarrow<20.\n- unsloth/transformers pins huggingface-hub<1.\n- studio datasets pins pyarrow>=21.\n\nRuntime works in this app with hub 0.36.x + pyarrow 23.x, but metadata conflicts.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport importlib.metadata as im\nimport re\nfrom pathlib import Path\n\nTARGETS = (\n    \"data-designer\",\n    \"data-designer-engine\",\n    \"data-designer-config\",\n)\n\nPATCHES: tuple[tuple[re.Pattern[str], str], ...] = (\n    (\n        re.compile(r\"^Requires-Dist: huggingface-hub<2,>=1\\.0\\.1$\", re.MULTILINE),\n        \"Requires-Dist: huggingface-hub<2,>=0.34.0\",\n    ),\n    (\n        re.compile(r\"^Requires-Dist: pyarrow<20,>=19\\.0\\.1$\", re.MULTILINE),\n        \"Requires-Dist: pyarrow>=21.0.0\",\n    ),\n)\n\n\ndef metadata_path(dist_name: str) -> Path | None:\n    try:\n        dist = im.distribution(dist_name)\n    except im.PackageNotFoundError:\n        return None\n    for f in dist.files or []:\n        sf = str(f)\n        if sf.endswith(\".dist-info/METADATA\"):\n            return Path(dist.locate_file(f))\n    return None\n\n\ndef patch_file(path: Path) -> bool:\n    original = path.read_text(encoding = \"utf-8\")\n    updated = original\n    for pattern, repl in PATCHES:\n        updated = pattern.sub(repl, updated)\n    if updated == original:\n        return False\n    path.write_text(updated, encoding = \"utf-8\")\n    return True\n\n\ndef main() -> int:\n    changed = 0\n    checked = 0\n    for name in TARGETS:\n        p = metadata_path(name)\n        if p is None:\n            continue\n        checked += 1\n        if patch_file(p):\n            changed += 1\n    print(f\"single-env metadata patch: checked={checked}, changed={changed}\")\n    return 0\n\n\nif __name__ == \"__main__\":\n    raise SystemExit(main())\n"
  },
  {
    "path": "studio/backend/requirements/studio.txt",
    "content": "# Studio UI backend dependencies\ntyper\nfastapi\nuvicorn\npydantic\nmatplotlib\npandas\nnest_asyncio\ndatasets==4.3.0\npyjwt\neasydict\naddict\n# gradio>=4.0.0                  # 148 MB - Studio uses React + FastAPI, not Gradio\nhuggingface-hub==0.36.2\nstructlog>=24.1.0\ndiceware\nddgs\n"
  },
  {
    "path": "studio/backend/requirements/triton-kernels.txt",
    "content": "# Triton kernels (installed with --no-deps, from source)\ntriton_kernels @ git+https://github.com/triton-lang/triton.git@release/3.6.x#subdirectory=python/triton_kernels\n"
  },
  {
    "path": "studio/backend/routes/.gitkeep",
    "content": ""
  },
  {
    "path": "studio/backend/routes/__init__.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nAPI Routes\n\"\"\"\n\nfrom routes.training import router as training_router\nfrom routes.models import router as models_router\nfrom routes.inference import router as inference_router\nfrom routes.datasets import router as datasets_router\nfrom routes.auth import router as auth_router\nfrom routes.data_recipe import router as data_recipe_router\nfrom routes.export import router as export_router\n\n__all__ = [\n    \"training_router\",\n    \"models_router\",\n    \"inference_router\",\n    \"datasets_router\",\n    \"auth_router\",\n    \"data_recipe_router\",\n    \"export_router\",\n]\n"
  },
  {
    "path": "studio/backend/routes/auth.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nAuthentication API routes\n\"\"\"\n\nfrom fastapi import APIRouter, Depends, HTTPException, status\n\nfrom models.auth import (\n    AuthLoginRequest,\n    RefreshTokenRequest,\n    AuthStatusResponse,\n    ChangePasswordRequest,\n)\nfrom models.users import Token\nfrom auth import storage, hashing\nfrom auth.authentication import (\n    create_access_token,\n    create_refresh_token,\n    get_current_subject,\n    get_current_subject_allow_password_change,\n    refresh_access_token,\n)\n\nrouter = APIRouter()\n\n\n@router.get(\"/status\", response_model = AuthStatusResponse)\nasync def auth_status() -> AuthStatusResponse:\n    \"\"\"\n    Check whether auth has already been initialized.\n\n    - initialized = False -> frontend should wait for the seeded admin bootstrap.\n    - initialized = True  -> frontend should show login or force the first password change.\n    \"\"\"\n    return AuthStatusResponse(\n        initialized = storage.is_initialized(),\n        default_username = storage.DEFAULT_ADMIN_USERNAME,\n        requires_password_change = storage.requires_password_change(\n            storage.DEFAULT_ADMIN_USERNAME\n        )\n        if storage.is_initialized()\n        else True,\n    )\n\n\n@router.post(\"/login\", response_model = Token)\nasync def login(payload: AuthLoginRequest) -> Token:\n    \"\"\"\n    Login with username/password and receive access + refresh tokens.\n    \"\"\"\n    record = storage.get_user_and_secret(payload.username)\n    if record is None:\n        raise HTTPException(\n            status_code = status.HTTP_401_UNAUTHORIZED,\n            detail = \"Incorrect password. Run 'unsloth studio reset-password' in your terminal to reset it.\",\n        )\n\n    salt, pwd_hash, _jwt_secret, must_change_password = record\n    if not hashing.verify_password(payload.password, salt, pwd_hash):\n        raise HTTPException(\n            status_code = status.HTTP_401_UNAUTHORIZED,\n            detail = \"Incorrect password. Run 'unsloth studio reset-password' in your terminal to reset it.\",\n        )\n\n    access_token = create_access_token(subject = payload.username)\n    refresh_token = create_refresh_token(subject = payload.username)\n    return Token(\n        access_token = access_token,\n        refresh_token = refresh_token,\n        token_type = \"bearer\",\n        must_change_password = must_change_password,\n    )\n\n\n@router.post(\"/refresh\", response_model = Token)\nasync def refresh(payload: RefreshTokenRequest) -> Token:\n    \"\"\"\n    Exchange a valid refresh token for a new access token.\n\n    The refresh token itself is reusable until it expires (7 days).\n    \"\"\"\n    new_access_token, username = refresh_access_token(payload.refresh_token)\n    if new_access_token is None or username is None:\n        raise HTTPException(\n            status_code = status.HTTP_401_UNAUTHORIZED,\n            detail = \"Invalid or expired refresh token\",\n        )\n\n    return Token(\n        access_token = new_access_token,\n        refresh_token = payload.refresh_token,\n        token_type = \"bearer\",\n        must_change_password = storage.requires_password_change(username),\n    )\n\n\n@router.post(\"/change-password\", response_model = Token)\nasync def change_password(\n    payload: ChangePasswordRequest,\n    current_subject: str = Depends(get_current_subject_allow_password_change),\n) -> Token:\n    \"\"\"Allow the authenticated user to replace the default password.\"\"\"\n    record = storage.get_user_and_secret(current_subject)\n    if record is None:\n        raise HTTPException(\n            status_code = status.HTTP_401_UNAUTHORIZED,\n            detail = \"User session is invalid\",\n        )\n\n    salt, pwd_hash, _jwt_secret, _must_change_password = record\n    if not hashing.verify_password(payload.current_password, salt, pwd_hash):\n        raise HTTPException(\n            status_code = status.HTTP_401_UNAUTHORIZED,\n            detail = \"Current password is incorrect\",\n        )\n    if payload.current_password == payload.new_password:\n        raise HTTPException(\n            status_code = status.HTTP_400_BAD_REQUEST,\n            detail = \"New password must be different from the current password\",\n        )\n\n    storage.update_password(current_subject, payload.new_password)\n    storage.revoke_user_refresh_tokens(current_subject)\n    access_token = create_access_token(subject = current_subject)\n    refresh_token = create_refresh_token(subject = current_subject)\n    return Token(\n        access_token = access_token,\n        refresh_token = refresh_token,\n        token_type = \"bearer\",\n        must_change_password = False,\n    )\n"
  },
  {
    "path": "studio/backend/routes/data_recipe/__init__.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"Data Recipe route package.\"\"\"\n\nfrom __future__ import annotations\n\nimport sys\nfrom pathlib import Path\n\nfrom fastapi import APIRouter, Depends\n\nfrom auth.authentication import get_current_subject\n\nbackend_path = Path(__file__).parent.parent.parent\nif str(backend_path) not in sys.path:\n    sys.path.insert(0, str(backend_path))\n\nfrom .jobs import router as jobs_router\nfrom .mcp import router as mcp_router\nfrom .seed import router as seed_router\nfrom .validate import router as validate_router\n\nrouter = APIRouter(dependencies = [Depends(get_current_subject)])\nrouter.include_router(seed_router)\nrouter.include_router(validate_router)\nrouter.include_router(jobs_router)\nrouter.include_router(mcp_router)\n\n__all__ = [\"router\"]\n"
  },
  {
    "path": "studio/backend/routes/data_recipe/jobs.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"Job lifecycle endpoints for data recipe.\"\"\"\n\nfrom __future__ import annotations\n\nfrom typing import Any\n\nfrom fastapi import APIRouter, HTTPException, Query, Request\nfrom fastapi.responses import JSONResponse, StreamingResponse\nfrom pydantic import ValidationError\n\nfrom core.data_recipe.huggingface import (\n    RecipeDatasetPublishError,\n    publish_recipe_dataset,\n)\nfrom core.data_recipe.jobs import get_job_manager\nfrom models.data_recipe import (\n    JobCreateResponse,\n    PublishDatasetRequest,\n    PublishDatasetResponse,\n    RecipePayload,\n)\n\nrouter = APIRouter()\n\n\ndef _normalize_run_name(value: Any) -> str | None:\n    if value is None:\n        return None\n    if not isinstance(value, str):\n        raise HTTPException(\n            status_code = 400, detail = \"invalid run_name: must be a string\"\n        )\n    trimmed = value.strip()\n    if not trimmed:\n        return None\n    return trimmed[:120]\n\n\n@router.post(\"/jobs\", response_class = JSONResponse, response_model = JobCreateResponse)\ndef create_job(payload: RecipePayload):\n    recipe = payload.recipe\n    if not recipe.get(\"columns\"):\n        raise HTTPException(status_code = 400, detail = \"Recipe must include columns.\")\n\n    run: dict[str, Any] = payload.run or {}\n    run.pop(\"artifact_path\", None)\n    run.pop(\"dataset_name\", None)\n    execution_type = str(run.get(\"execution_type\") or \"full\").strip().lower()\n    if execution_type not in {\"preview\", \"full\"}:\n        raise HTTPException(\n            status_code = 400,\n            detail = \"invalid execution_type: must be 'preview' or 'full'\",\n        )\n    run[\"execution_type\"] = execution_type\n    run[\"run_name\"] = _normalize_run_name(run.get(\"run_name\"))\n    run_config_raw = run.get(\"run_config\")\n    if run_config_raw is not None:\n        try:\n            from data_designer.config.run_config import RunConfig\n\n            RunConfig.model_validate(run_config_raw)\n        except (ImportError, ValidationError, TypeError, ValueError) as exc:\n            raise HTTPException(\n                status_code = 400, detail = f\"invalid run_config: {exc}\"\n            ) from exc\n\n    mgr = get_job_manager()\n    try:\n        job_id = mgr.start(recipe = recipe, run = run)\n    except RuntimeError as exc:\n        raise HTTPException(status_code = 409, detail = str(exc)) from exc\n    except ValueError as exc:\n        raise HTTPException(status_code = 400, detail = str(exc)) from exc\n\n    return {\"job_id\": job_id}\n\n\n@router.get(\"/jobs/{job_id}/status\")\ndef job_status(job_id: str):\n    mgr = get_job_manager()\n    state = mgr.get_status(job_id)\n    if state is None:\n        raise HTTPException(status_code = 404, detail = \"job not found\")\n    return state\n\n\n@router.get(\"/jobs/current\")\ndef current_job():\n    mgr = get_job_manager()\n    state = mgr.get_current_status()\n    if state is None:\n        raise HTTPException(status_code = 404, detail = \"no job\")\n    return state\n\n\n@router.post(\"/jobs/{job_id}/cancel\")\ndef cancel_job(job_id: str):\n    mgr = get_job_manager()\n    ok = mgr.cancel(job_id)\n    if not ok:\n        raise HTTPException(status_code = 404, detail = \"job not found\")\n    return mgr.get_status(job_id)\n\n\n@router.get(\"/jobs/{job_id}/analysis\")\ndef job_analysis(job_id: str):\n    mgr = get_job_manager()\n    analysis = mgr.get_analysis(job_id)\n    if analysis is None:\n        raise HTTPException(status_code = 404, detail = \"analysis not ready\")\n    return analysis\n\n\n@router.get(\"/jobs/{job_id}/dataset\")\ndef job_dataset(\n    job_id: str,\n    limit: int = Query(default = 20, ge = 1, le = 500),\n    offset: int = Query(default = 0, ge = 0),\n):\n    mgr = get_job_manager()\n    result = mgr.get_dataset(job_id, limit = limit, offset = offset)\n    if result is None:\n        raise HTTPException(status_code = 404, detail = \"dataset not ready\")\n    if \"error\" in result:\n        raise HTTPException(status_code = 422, detail = result[\"error\"])\n    return {\n        \"dataset\": result[\"dataset\"],\n        \"total\": result[\"total\"],\n        \"limit\": limit,\n        \"offset\": offset,\n    }\n\n\n@router.post(\n    \"/jobs/{job_id}/publish\",\n    response_class = JSONResponse,\n    response_model = PublishDatasetResponse,\n)\ndef publish_job_dataset(job_id: str, payload: PublishDatasetRequest):\n    repo_id = payload.repo_id.strip()\n    description = payload.description.strip()\n    hf_token = payload.hf_token.strip() if isinstance(payload.hf_token, str) else None\n    artifact_path = (\n        payload.artifact_path.strip()\n        if isinstance(payload.artifact_path, str)\n        else None\n    )\n\n    if not repo_id:\n        raise HTTPException(status_code = 400, detail = \"repo_id is required\")\n    if not description:\n        raise HTTPException(status_code = 400, detail = \"description is required\")\n\n    mgr = get_job_manager()\n    status = mgr.get_status(job_id)\n    if status is not None:\n        if (\n            status.get(\"status\") != \"completed\"\n            or status.get(\"execution_type\") != \"full\"\n        ):\n            raise HTTPException(\n                status_code = 409,\n                detail = \"Only completed full runs can be published.\",\n            )\n        status_artifact = status.get(\"artifact_path\")\n        if isinstance(status_artifact, str) and status_artifact.strip():\n            artifact_path = status_artifact.strip()\n\n    if not artifact_path:\n        raise HTTPException(\n            status_code = 400,\n            detail = \"This execution does not have publishable dataset artifacts.\",\n        )\n\n    try:\n        url = publish_recipe_dataset(\n            artifact_path = artifact_path,\n            repo_id = repo_id,\n            description = description,\n            hf_token = hf_token or None,\n            private = payload.private,\n        )\n    except RecipeDatasetPublishError as exc:\n        raise HTTPException(status_code = 400, detail = str(exc)) from exc\n    except Exception as exc:\n        raise HTTPException(status_code = 500, detail = str(exc)) from exc\n\n    return {\n        \"success\": True,\n        \"url\": url,\n        \"message\": f\"Published dataset to {repo_id}.\",\n    }\n\n\n@router.get(\"/jobs/{job_id}/events\")\nasync def job_events(request: Request, job_id: str):\n    mgr = get_job_manager()\n    last_id = request.headers.get(\"last-event-id\")\n    after_seq: int | None = None\n    if last_id:\n        try:\n            after_seq = int(str(last_id).strip())\n        except (TypeError, ValueError):\n            after_seq = None\n\n    after_q = request.query_params.get(\"after\")\n    if after_q:\n        try:\n            after_seq = int(str(after_q).strip())\n        except (TypeError, ValueError):\n            pass\n\n    sub = mgr.subscribe(job_id, after_seq = after_seq)\n    if sub is None:\n        raise HTTPException(status_code = 404, detail = \"job not found\")\n\n    async def gen():\n        try:\n            for event in sub.replay:\n                yield sub.format_sse(event)\n\n            while True:\n                if await request.is_disconnected():\n                    break\n                event = await sub.next_event(timeout_sec = 1.0)\n                if event is None:\n                    continue\n                yield sub.format_sse(event)\n        finally:\n            mgr.unsubscribe(sub)\n\n    return StreamingResponse(gen(), media_type = \"text/event-stream\")\n"
  },
  {
    "path": "studio/backend/routes/data_recipe/mcp.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"MCP helper endpoints for data recipe.\"\"\"\n\nfrom __future__ import annotations\n\nfrom collections import defaultdict\n\nfrom fastapi import APIRouter\n\nfrom core.data_recipe.service import build_mcp_providers\nfrom models.data_recipe import (\n    McpToolsListRequest,\n    McpToolsListResponse,\n    McpToolsProviderResult,\n)\n\nrouter = APIRouter()\n\n\n@router.post(\"/mcp/tools\", response_model = McpToolsListResponse)\ndef list_mcp_tools(payload: McpToolsListRequest) -> McpToolsListResponse:\n    try:\n        from data_designer.engine.mcp import io as mcp_io\n    except ImportError as exc:\n        return McpToolsListResponse(\n            providers = [\n                McpToolsProviderResult(\n                    name = \"\",\n                    error = f\"MCP dependencies unavailable: {exc}\",\n                )\n            ]\n        )\n\n    providers: list[McpToolsProviderResult] = []\n    tool_to_providers: dict[str, list[str]] = defaultdict(list)\n\n    for provider_payload in payload.mcp_providers:\n        provider_name = str(provider_payload.get(\"name\", \"\")).strip()\n        built = build_mcp_providers({\"mcp_providers\": [provider_payload]})\n        if len(built) != 1:\n            providers.append(\n                McpToolsProviderResult(\n                    name = provider_name,\n                    error = \"Unsupported MCP provider config.\",\n                )\n            )\n            continue\n\n        provider = built[0]\n        try:\n            tools = mcp_io.list_tools(provider, timeout_sec = payload.timeout_sec)\n            tool_names = sorted(\n                {tool.name for tool in tools if getattr(tool, \"name\", \"\")}\n            )\n            for tool_name in tool_names:\n                tool_to_providers[tool_name].append(provider.name)\n            providers.append(\n                McpToolsProviderResult(\n                    name = provider.name,\n                    tools = tool_names,\n                )\n            )\n        except Exception as exc:\n            providers.append(\n                McpToolsProviderResult(\n                    name = provider.name or provider_name,\n                    error = str(exc).strip() or \"Failed to load tools.\",\n                )\n            )\n\n    duplicate_tools = {\n        tool_name: provider_names\n        for tool_name, provider_names in sorted(tool_to_providers.items())\n        if len(provider_names) > 1\n    }\n\n    return McpToolsListResponse(\n        providers = providers,\n        duplicate_tools = duplicate_tools,\n    )\n"
  },
  {
    "path": "studio/backend/routes/data_recipe/seed.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"Seed inspect endpoints for data recipe.\"\"\"\n\nfrom __future__ import annotations\n\nimport base64\nimport binascii\nfrom itertools import islice\nfrom pathlib import Path\nfrom typing import Any\nfrom uuid import uuid4\n\nfrom fastapi import APIRouter, HTTPException\nfrom data_designer_unstructured_seed.chunking import (\n    build_unstructured_preview_rows,\n    resolve_chunking,\n)\nfrom core.data_recipe.jsonable import to_preview_jsonable\nfrom utils.paths import ensure_dir, seed_uploads_root\n\nfrom models.data_recipe import (\n    SeedInspectRequest,\n    SeedInspectResponse,\n    SeedInspectUploadRequest,\n)\n\nrouter = APIRouter()\n\nDATA_EXTS = (\".parquet\", \".jsonl\", \".json\", \".csv\")\nDEFAULT_SPLIT = \"train\"\nLOCAL_UPLOAD_EXTS = {\".csv\", \".json\", \".jsonl\"}\nUNSTRUCTURED_UPLOAD_EXTS = {\".txt\", \".md\"}\nSEED_UPLOAD_DIR = seed_uploads_root()\n\n\ndef _serialize_preview_value(value: Any) -> Any:\n    return to_preview_jsonable(value)\n\n\ndef _serialize_preview_rows(rows: list[dict[str, Any]]) -> list[dict[str, Any]]:\n    return [\n        {str(key): _serialize_preview_value(value) for key, value in row.items()}\n        for row in rows\n    ]\n\n\ndef _normalize_optional_text(value: str | None) -> str | None:\n    if value is None:\n        return None\n    trimmed = value.strip()\n    return trimmed if trimmed else None\n\n\ndef _list_hf_data_files(*, dataset_name: str, token: str | None) -> list[str]:\n    try:\n        from huggingface_hub import HfApi\n        from huggingface_hub.utils import HfHubHTTPError\n    except ImportError:\n        return []\n    try:\n        api = HfApi()\n        repo_files = api.list_repo_files(dataset_name, repo_type = \"dataset\", token = token)\n        return [file for file in repo_files if file.lower().endswith(DATA_EXTS)]\n    except (HfHubHTTPError, OSError, ValueError):\n        return []\n\n\ndef _select_best_file(data_files: list[str], split: str = DEFAULT_SPLIT) -> str | None:\n    if not data_files:\n        return None\n    split_lower = split.lower()\n\n    def score(path: str) -> tuple[int, int]:\n        name = path.lower()\n        if f\"/{split_lower}/\" in name:\n            return (0, len(path))\n        if (\n            f\"_{split_lower}.\" in name\n            or f\"-{split_lower}.\" in name\n            or f\"/{split_lower}.\" in name\n            or f\"/{split_lower}_\" in name\n            or f\"/{split_lower}-\" in name\n        ):\n            return (1, len(path))\n        return (2, len(path))\n\n    return sorted(data_files, key = score)[0]\n\n\ndef _resolve_seed_hf_path(\n    dataset_name: str, data_files: list[str], split: str = DEFAULT_SPLIT\n) -> str | None:\n    selected = _select_best_file(data_files, split)\n    if not selected:\n        return None\n\n    ext = Path(selected).suffix.lower()\n    if ext not in DATA_EXTS:\n        return f\"datasets/{dataset_name}/{selected}\"\n\n    parent = Path(selected).parent.as_posix()\n    if not parent or parent == \".\":\n        return f\"datasets/{dataset_name}/**/*{ext}\"\n    return f\"datasets/{dataset_name}/{parent}/**/*{ext}\"\n\n\ndef _build_stream_load_kwargs(\n    *,\n    dataset_name: str,\n    split: str,\n    subset: str | None,\n    token: str | None,\n    data_file: str | None = None,\n) -> dict[str, Any]:\n    kwargs: dict[str, Any] = {\n        \"path\": dataset_name,\n        \"split\": split,\n        \"streaming\": True,\n        \"trust_remote_code\": False,\n    }\n    if data_file:\n        kwargs[\"data_files\"] = [data_file]\n    if subset:\n        kwargs[\"name\"] = subset\n    if token:\n        kwargs[\"token\"] = token\n    return kwargs\n\n\ndef _load_preview_rows(\n    *,\n    load_dataset_fn,\n    load_kwargs: dict[str, Any],\n    preview_size: int,\n) -> list[dict[str, Any]]:\n    streamed_ds = load_dataset_fn(**load_kwargs)\n    return [row for row in islice(streamed_ds, preview_size)]\n\n\ndef _extract_columns(rows: list[dict[str, Any]]) -> list[str]:\n    columns_seen: dict[str, None] = {}\n    for row in rows:\n        for key in row.keys():\n            columns_seen[str(key)] = None\n    return list(columns_seen.keys())\n\n\ndef _sanitize_filename(filename: str) -> str:\n    name = Path(filename).name.strip().replace(\"\\x00\", \"\")\n    if not name:\n        return \"seed_upload\"\n    return name\n\n\ndef _decode_base64_payload(content_base64: str) -> bytes:\n    raw = content_base64.strip()\n    if \",\" in raw and raw.lower().startswith(\"data:\"):\n        raw = raw.split(\",\", 1)[1]\n    try:\n        return base64.b64decode(raw, validate = True)\n    except binascii.Error as exc:\n        raise HTTPException(status_code = 400, detail = \"invalid base64 payload\") from exc\n\n\ndef _read_preview_rows_from_local_file(\n    path: Path, preview_size: int\n) -> list[dict[str, Any]]:\n    try:\n        import pandas as pd\n    except ImportError as exc:\n        raise HTTPException(\n            status_code = 500, detail = f\"seed inspect dependencies unavailable: {exc}\"\n        ) from exc\n\n    ext = path.suffix.lower()\n    try:\n        if ext == \".csv\":\n            df = pd.read_csv(path, nrows = preview_size)\n        elif ext == \".jsonl\":\n            df = pd.read_json(path, lines = True).head(preview_size)\n        elif ext == \".json\":\n            try:\n                df = pd.read_json(path).head(preview_size)\n            except ValueError:\n                df = pd.read_json(path, lines = True).head(preview_size)\n        else:\n            raise HTTPException(status_code = 422, detail = f\"unsupported file type: {ext}\")\n    except HTTPException:\n        raise\n    except (ValueError, OSError) as exc:\n        raise HTTPException(\n            status_code = 422, detail = f\"seed inspect failed: {exc}\"\n        ) from exc\n\n    rows = df.to_dict(orient = \"records\")\n    return _serialize_preview_rows(rows)\n\n\ndef _read_preview_rows_from_unstructured_file(\n    *,\n    path: Path,\n    preview_size: int,\n    chunk_size: int | None,\n    chunk_overlap: int | None,\n) -> list[dict[str, Any]]:\n    size, overlap = resolve_chunking(chunk_size, chunk_overlap)\n    try:\n        rows = build_unstructured_preview_rows(\n            source_path = path,\n            preview_size = preview_size,\n            chunk_size = size,\n            chunk_overlap = overlap,\n        )\n    except (FileNotFoundError, RuntimeError, ValueError, OSError) as exc:\n        raise HTTPException(\n            status_code = 422, detail = f\"seed inspect failed: {exc}\"\n        ) from exc\n    return _serialize_preview_rows(rows)\n\n\n@router.post(\"/seed/inspect\", response_model = SeedInspectResponse)\ndef inspect_seed_dataset(payload: SeedInspectRequest) -> SeedInspectResponse:\n    dataset_name = payload.dataset_name.strip()\n    if not dataset_name or dataset_name.count(\"/\") < 1:\n        raise HTTPException(\n            status_code = 400,\n            detail = \"dataset_name must be a Hugging Face repo id like org/repo\",\n        )\n\n    try:\n        from datasets import load_dataset\n    except ImportError as exc:\n        raise HTTPException(\n            status_code = 500, detail = f\"seed inspect dependencies unavailable: {exc}\"\n        ) from exc\n\n    split = _normalize_optional_text(payload.split) or DEFAULT_SPLIT\n    subset = _normalize_optional_text(payload.subset)\n    token = _normalize_optional_text(payload.hf_token)\n    preview_size = int(payload.preview_size)\n\n    preview_rows: list[dict[str, Any]] = []\n    data_files = _list_hf_data_files(dataset_name = dataset_name, token = token)\n\n    selected_file = _select_best_file(data_files, split)\n    if selected_file:\n        try:\n            single_file_kwargs = _build_stream_load_kwargs(\n                dataset_name = dataset_name,\n                split = split,\n                subset = subset,\n                token = token,\n                data_file = selected_file,\n            )\n            preview_rows = _load_preview_rows(\n                load_dataset_fn = load_dataset,\n                load_kwargs = single_file_kwargs,\n                preview_size = preview_size,\n            )\n        except (ValueError, OSError, RuntimeError):\n            preview_rows = []\n\n    if not preview_rows:\n        try:\n            split_kwargs = _build_stream_load_kwargs(\n                dataset_name = dataset_name,\n                split = split,\n                subset = subset,\n                token = token,\n            )\n            preview_rows = _load_preview_rows(\n                load_dataset_fn = load_dataset,\n                load_kwargs = split_kwargs,\n                preview_size = preview_size,\n            )\n        except (ValueError, OSError, RuntimeError) as exc:\n            raise HTTPException(\n                status_code = 422, detail = f\"seed inspect failed: {exc}\"\n            ) from exc\n\n    if not preview_rows:\n        raise HTTPException(\n            status_code = 422, detail = \"dataset appears empty or unreadable\"\n        )\n    preview_rows = _serialize_preview_rows(preview_rows)\n    columns = _extract_columns(preview_rows)\n\n    if not data_files:\n        resolved_path = f\"datasets/{dataset_name}/**/*.parquet\"\n    else:\n        resolved_path = _resolve_seed_hf_path(dataset_name, data_files, split)\n        if not resolved_path:\n            raise HTTPException(\n                status_code = 422, detail = \"unable to resolve seed dataset path\"\n            )\n\n    return SeedInspectResponse(\n        dataset_name = dataset_name,\n        resolved_path = resolved_path,\n        columns = columns,\n        preview_rows = preview_rows,\n        split = split,\n        subset = subset,\n    )\n\n\n@router.post(\"/seed/inspect-upload\", response_model = SeedInspectResponse)\ndef inspect_seed_upload(payload: SeedInspectUploadRequest) -> SeedInspectResponse:\n    seed_source_type = _normalize_optional_text(payload.seed_source_type) or \"local\"\n    filename = _sanitize_filename(payload.filename)\n    ext = Path(filename).suffix.lower()\n    if seed_source_type == \"unstructured\":\n        if ext not in UNSTRUCTURED_UPLOAD_EXTS:\n            allowed = \", \".join(sorted(UNSTRUCTURED_UPLOAD_EXTS))\n            raise HTTPException(\n                status_code = 400,\n                detail = f\"unsupported file type: {ext}. allowed: {allowed}\",\n            )\n    else:\n        if ext not in LOCAL_UPLOAD_EXTS:\n            allowed = \", \".join(sorted(LOCAL_UPLOAD_EXTS))\n            raise HTTPException(\n                status_code = 400,\n                detail = f\"unsupported file type: {ext}. allowed: {allowed}\",\n            )\n\n    file_bytes = _decode_base64_payload(payload.content_base64)\n    if not file_bytes:\n        raise HTTPException(status_code = 400, detail = \"empty upload payload\")\n    max_size_bytes = 50 * 1024 * 1024\n    if len(file_bytes) > max_size_bytes:\n        raise HTTPException(status_code = 413, detail = \"file too large (max 50MB)\")\n\n    ensure_dir(SEED_UPLOAD_DIR)\n    stored_name = f\"{uuid4().hex}_{filename}\"\n    stored_path = SEED_UPLOAD_DIR / stored_name\n    stored_path.write_bytes(file_bytes)\n\n    if seed_source_type == \"unstructured\":\n        preview_rows = _read_preview_rows_from_unstructured_file(\n            path = stored_path,\n            preview_size = int(payload.preview_size),\n            chunk_size = payload.unstructured_chunk_size,\n            chunk_overlap = payload.unstructured_chunk_overlap,\n        )\n    else:\n        preview_rows = _read_preview_rows_from_local_file(\n            stored_path,\n            int(payload.preview_size),\n        )\n    if not preview_rows:\n        raise HTTPException(\n            status_code = 422, detail = \"dataset appears empty or unreadable\"\n        )\n    columns = _extract_columns(preview_rows)\n\n    return SeedInspectResponse(\n        dataset_name = filename,\n        resolved_path = str(stored_path),\n        columns = columns,\n        preview_rows = preview_rows,\n        split = None,\n        subset = None,\n    )\n"
  },
  {
    "path": "studio/backend/routes/data_recipe/validate.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"Validation endpoints for data recipe.\"\"\"\n\nfrom __future__ import annotations\n\nfrom typing import Any\n\nfrom fastapi import APIRouter, HTTPException\n\nfrom core.data_recipe.service import (\n    build_config_builder,\n    create_data_designer,\n    validate_recipe,\n)\nfrom models.data_recipe import RecipePayload, ValidateError, ValidateResponse\n\nrouter = APIRouter()\n\n\ndef _collect_validation_errors(recipe: dict[str, Any]) -> list[ValidateError]:\n    try:\n        from data_designer.engine.compiler import (\n            _add_internal_row_id_column_if_needed,\n            _get_allowed_references,\n            _resolve_and_add_seed_columns,\n        )\n        from data_designer.engine.validation import (\n            ViolationLevel,\n            validate_data_designer_config,\n        )\n    except ImportError:\n        return []\n\n    try:\n        builder = build_config_builder(recipe)\n        designer = create_data_designer(recipe)\n        resource_provider = designer._create_resource_provider(  # type: ignore[attr-defined]\n            \"validate-configuration\",\n            builder,\n        )\n        config = builder.build()\n        _resolve_and_add_seed_columns(config, resource_provider.seed_reader)\n        _add_internal_row_id_column_if_needed(config)\n        violations = validate_data_designer_config(\n            columns = config.columns,\n            processor_configs = config.processors or [],\n            allowed_references = _get_allowed_references(config),\n        )\n    except (TypeError, ValueError, AttributeError):\n        return []\n\n    errors: list[ValidateError] = []\n    for violation in violations:\n        if violation.level != ViolationLevel.ERROR:\n            continue\n        code = getattr(violation.type, \"value\", None)\n        path = violation.column if violation.column else None\n        message = str(violation.message).strip() or \"Validation failed.\"\n        errors.append(\n            ValidateError(\n                message = message,\n                path = path,\n                code = code,\n            )\n        )\n    return errors\n\n\n@router.post(\"/validate\", response_model = ValidateResponse)\ndef validate(payload: RecipePayload) -> ValidateResponse:\n    recipe = payload.recipe\n    if not recipe.get(\"columns\"):\n        return ValidateResponse(\n            valid = False,\n            errors = [ValidateError(message = \"Recipe must include columns.\")],\n        )\n\n    try:\n        validate_recipe(recipe)\n    except RuntimeError as exc:\n        raise HTTPException(status_code = 503, detail = str(exc)) from exc\n    except Exception as exc:\n        detail = str(exc).strip() or \"Validation failed.\"\n        parsed_errors = _collect_validation_errors(recipe)\n        return ValidateResponse(\n            valid = False,\n            errors = parsed_errors or [ValidateError(message = detail)],\n            raw_detail = detail,\n        )\n\n    return ValidateResponse(valid = True)\n"
  },
  {
    "path": "studio/backend/routes/datasets.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nDatasets API routes\n\"\"\"\n\nimport base64\nimport io\nimport json\nimport sys\nfrom pathlib import Path\nfrom uuid import uuid4\nfrom fastapi import APIRouter, Depends, HTTPException, UploadFile\nimport structlog\nfrom loggers import get_logger\n\n# Add backend directory to path\nbackend_path = Path(__file__).parent.parent.parent\nif str(backend_path) not in sys.path:\n    sys.path.insert(0, str(backend_path))\n\n# Import dataset utilities\nfrom utils.datasets import check_dataset_format\nfrom auth.authentication import get_current_subject\n\nrouter = APIRouter()\nlogger = get_logger(__name__)\n\n\nfrom models.datasets import (\n    AiAssistMappingRequest,\n    AiAssistMappingResponse,\n    CheckFormatRequest,\n    CheckFormatResponse,\n    LocalDatasetItem,\n    LocalDatasetsResponse,\n    UploadDatasetResponse,\n)\nfrom utils.paths import (\n    dataset_uploads_root,\n    ensure_dir,\n    recipe_datasets_root,\n    resolve_dataset_path,\n)\n\n\ndef _serialize_preview_value(value):\n    \"\"\"make it json safe for client preview ⊂(◉‿◉)つ\"\"\"\n    if value is None or isinstance(value, (str, int, float, bool)):\n        return value\n\n    try:\n        from PIL.Image import Image as PILImage\n\n        if isinstance(value, PILImage):\n            buffer = io.BytesIO()\n            value.convert(\"RGB\").save(buffer, format = \"JPEG\", quality = 85)\n            return {\n                \"type\": \"image\",\n                \"mime\": \"image/jpeg\",\n                \"width\": value.width,\n                \"height\": value.height,\n                \"data\": base64.b64encode(buffer.getvalue()).decode(\"ascii\"),\n            }\n    except Exception:\n        pass\n\n    if isinstance(value, dict):\n        return {str(key): _serialize_preview_value(item) for key, item in value.items()}\n\n    if isinstance(value, (list, tuple)):\n        return [_serialize_preview_value(item) for item in value]\n\n    return str(value)\n\n\ndef _serialize_preview_rows(rows):\n    return [\n        {str(key): _serialize_preview_value(value) for key, value in dict(row).items()}\n        for row in rows\n    ]\n\n\n# --- Endpoints ---\n\n# Recognized data-file extensions for the single-file fallback approach.\n# Tabular formats are preferred over archives for Tier 1 preview because\n# archives (e.g. images.zip) may be loaded as ImageFolder datasets with\n# synthetic columns (image/label) that don't match the real dataset schema.\n_TABULAR_EXTS = (\".parquet\", \".json\", \".jsonl\", \".csv\", \".tsv\", \".arrow\")\n_ARCHIVE_EXTS = (\".tar\", \".tar.gz\", \".tgz\", \".gz\", \".zst\", \".zip\", \".txt\")\nDATA_EXTS = _TABULAR_EXTS + _ARCHIVE_EXTS\nLOCAL_FILE_EXTS = (\".json\", \".jsonl\", \".csv\", \".parquet\")\nLOCAL_UPLOAD_EXTS = {\".csv\", \".json\", \".jsonl\", \".parquet\"}\nLOCAL_DATASETS_ROOT = recipe_datasets_root()\nDATASET_UPLOAD_DIR = dataset_uploads_root()\n\n\ndef _safe_read_metadata(path: Path) -> dict | None:\n    try:\n        payload = json.loads(path.read_text(encoding = \"utf-8\"))\n    except (OSError, ValueError, TypeError):\n        return None\n    if not isinstance(payload, dict):\n        return None\n    return payload\n\n\ndef _safe_read_rows_from_metadata(payload: dict | None) -> int | None:\n    if not payload:\n        return None\n    for key in (\"actual_num_records\", \"target_num_records\"):\n        value = payload.get(key)\n        if isinstance(value, int):\n            return value\n    return None\n\n\ndef _safe_read_metadata_summary(payload: dict | None) -> dict | None:\n    if not payload:\n        return None\n\n    actual_num_records = (\n        payload.get(\"actual_num_records\")\n        if isinstance(payload.get(\"actual_num_records\"), int)\n        else None\n    )\n    target_num_records = (\n        payload.get(\"target_num_records\")\n        if isinstance(payload.get(\"target_num_records\"), int)\n        else actual_num_records\n    )\n\n    columns: list[str] | None = None\n    schema = payload.get(\"schema\")\n    if isinstance(schema, dict):\n        columns = [str(key) for key in schema.keys()]\n    if not columns:\n        stats = payload.get(\"column_statistics\")\n        if isinstance(stats, list):\n            derived = [\n                str(item.get(\"column_name\"))\n                for item in stats\n                if isinstance(item, dict) and item.get(\"column_name\")\n            ]\n            columns = derived or None\n\n    parquet_files_count = None\n    file_paths = payload.get(\"file_paths\")\n    if isinstance(file_paths, dict):\n        parquet_files = file_paths.get(\"parquet-files\")\n        if isinstance(parquet_files, list):\n            parquet_files_count = len(parquet_files)\n\n    total_num_batches = (\n        payload.get(\"total_num_batches\")\n        if isinstance(payload.get(\"total_num_batches\"), int)\n        else parquet_files_count\n    )\n    num_completed_batches = (\n        payload.get(\"num_completed_batches\")\n        if isinstance(payload.get(\"num_completed_batches\"), int)\n        else total_num_batches\n    )\n\n    return {\n        \"actual_num_records\": actual_num_records,\n        \"target_num_records\": target_num_records,\n        \"total_num_batches\": total_num_batches,\n        \"num_completed_batches\": num_completed_batches,\n        \"columns\": columns,\n    }\n\n\ndef _build_local_dataset_items() -> list[LocalDatasetItem]:\n    if not LOCAL_DATASETS_ROOT.exists():\n        return []\n\n    items: list[LocalDatasetItem] = []\n    for entry in LOCAL_DATASETS_ROOT.iterdir():\n        if not entry.is_dir() or not entry.name.startswith(\"recipe_\"):\n            continue\n        parquet_dir = entry / \"parquet-files\"\n        if not parquet_dir.exists() or not any(parquet_dir.glob(\"*.parquet\")):\n            continue\n\n        rows = None\n        metadata_summary = None\n        metadata_path = entry / \"metadata.json\"\n        if metadata_path.exists():\n            metadata_payload = _safe_read_metadata(metadata_path)\n            rows = _safe_read_rows_from_metadata(metadata_payload)\n            metadata_summary = _safe_read_metadata_summary(metadata_payload)\n\n        try:\n            updated_at = entry.stat().st_mtime\n        except OSError:\n            updated_at = None\n\n        items.append(\n            LocalDatasetItem(\n                id = entry.name,\n                label = entry.name,\n                path = str(parquet_dir.resolve()),\n                rows = rows,\n                updated_at = updated_at,\n                metadata = metadata_summary,\n            )\n        )\n\n    items.sort(key = lambda item: item.updated_at or 0, reverse = True)\n    return items\n\n\ndef _load_local_preview_slice(\n    *, dataset_path: Path, train_split: str, preview_size: int\n):\n    from datasets import load_dataset\n\n    if dataset_path.is_dir():\n        parquet_dir = (\n            dataset_path / \"parquet-files\"\n            if (dataset_path / \"parquet-files\").exists()\n            else dataset_path\n        )\n        parquet_files = sorted(parquet_dir.glob(\"*.parquet\"))\n        if parquet_files:\n            dataset = load_dataset(\n                \"parquet\",\n                data_files = [str(path) for path in parquet_files],\n                split = train_split,\n            )\n            total_rows = len(dataset)\n            preview_slice = dataset.select(range(min(preview_size, total_rows)))\n            return preview_slice, total_rows\n        else:\n            candidate_files: list[Path] = []\n            for ext in LOCAL_FILE_EXTS:\n                candidate_files.extend(sorted(dataset_path.glob(f\"*{ext}\")))\n            if not candidate_files:\n                raise HTTPException(\n                    status_code = 400,\n                    detail = \"Unsupported local dataset directory (expected parquet/json/jsonl/csv files)\",\n                )\n            dataset_path = candidate_files[0]\n\n    if dataset_path.suffix in [\".json\", \".jsonl\"]:\n        dataset = load_dataset(\"json\", data_files = str(dataset_path), split = train_split)\n    elif dataset_path.suffix == \".csv\":\n        dataset = load_dataset(\"csv\", data_files = str(dataset_path), split = train_split)\n    elif dataset_path.suffix == \".parquet\":\n        dataset = load_dataset(\n            \"parquet\", data_files = str(dataset_path), split = train_split\n        )\n    else:\n        raise HTTPException(\n            status_code = 400, detail = f\"Unsupported file format: {dataset_path.suffix}\"\n        )\n\n    total_rows = len(dataset)\n    preview_slice = dataset.select(range(min(preview_size, total_rows)))\n    return preview_slice, total_rows\n\n\ndef _sanitize_filename(filename: str) -> str:\n    name = Path(filename).name.strip().replace(\"\\x00\", \"\")\n    if not name:\n        return \"dataset_upload\"\n    return name\n\n\n@router.post(\"/upload\", response_model = UploadDatasetResponse)\nasync def upload_dataset(\n    file: UploadFile,\n    current_subject: str = Depends(get_current_subject),\n) -> UploadDatasetResponse:\n    filename = _sanitize_filename(file.filename or \"dataset_upload\")\n    ext = Path(filename).suffix.lower()\n    if ext not in LOCAL_UPLOAD_EXTS:\n        allowed = \", \".join(sorted(LOCAL_UPLOAD_EXTS))\n        raise HTTPException(\n            status_code = 400,\n            detail = f\"Unsupported file type: {ext}. Allowed: {allowed}\",\n        )\n\n    ensure_dir(DATASET_UPLOAD_DIR)\n    stem = Path(filename).stem\n    stored_name = f\"{uuid4().hex}_{stem}{ext}\"\n    stored_path = DATASET_UPLOAD_DIR / stored_name\n\n    # Stream file to disk in chunks to avoid holding entire file in memory\n    with open(stored_path, \"wb\") as f:\n        while chunk := await file.read(1024 * 1024):\n            f.write(chunk)\n\n    if stored_path.stat().st_size == 0:\n        stored_path.unlink(missing_ok = True)\n        raise HTTPException(status_code = 400, detail = \"Empty upload payload\")\n\n    return UploadDatasetResponse(filename = filename, stored_path = str(stored_path))\n\n\n@router.get(\"/local\", response_model = LocalDatasetsResponse)\ndef list_local_datasets(\n    current_subject: str = Depends(get_current_subject),\n) -> LocalDatasetsResponse:\n    return LocalDatasetsResponse(datasets = _build_local_dataset_items())\n\n\n@router.post(\"/check-format\", response_model = CheckFormatResponse)\ndef check_format(\n    request: CheckFormatRequest,\n    current_subject: str = Depends(get_current_subject),\n):\n    \"\"\"\n    Check if a dataset requires manual column mapping.\n\n    Strategy for HuggingFace datasets:\n      1. list_repo_files → pick the first data file → load_dataset(data_files=[…])\n         Avoids resolving thousands of files; typically ~2-4 s.\n      2. Full streaming load_dataset as a last-resort fallback.\n\n    Local files are loaded directly.\n\n    Using a plain `def` (not async) so FastAPI runs this in a thread-pool,\n    preventing any blocking IO from freezing the event loop.\n    \"\"\"\n    try:\n        from itertools import islice\n        from datasets import Dataset, load_dataset\n        from utils.datasets import format_dataset\n\n        PREVIEW_SIZE = 10\n\n        logger.info(f\"Checking format for dataset: {request.dataset_name}\")\n\n        dataset_path = resolve_dataset_path(request.dataset_name)\n        total_rows = None\n\n        if dataset_path.exists():\n            # ── Local file ──────────────────────────────────────────\n            train_split = request.train_split or \"train\"\n            preview_slice, total_rows = _load_local_preview_slice(\n                dataset_path = dataset_path,\n                train_split = train_split,\n                preview_size = PREVIEW_SIZE,\n            )\n        else:\n            # ── HuggingFace dataset ─────────────────────────────────\n            # Tier 1: list_repo_files → load only the first data file\n            preview_slice = None\n\n            try:\n                from huggingface_hub import HfApi\n\n                api = HfApi()\n                repo_files = api.list_repo_files(\n                    request.dataset_name,\n                    repo_type = \"dataset\",\n                    token = request.hf_token or None,\n                )\n                data_files = [\n                    f for f in repo_files if any(f.endswith(ext) for ext in DATA_EXTS)\n                ]\n\n                # Prefer tabular formats over archives (e.g. images.zip → ImageFolder\n                # with synthetic image/label columns that don't match the real schema).\n                tabular_files = [\n                    f\n                    for f in data_files\n                    if any(f.endswith(ext) for ext in _TABULAR_EXTS)\n                ]\n                candidates = tabular_files or data_files\n\n                # When a subset is specified, narrow to files whose name matches\n                # (e.g. subset=\"testmini\" → prefer \"testmini.parquet\").\n                if request.subset and candidates:\n                    subset_matches = [\n                        f for f in candidates if request.subset in Path(f).stem\n                    ]\n                    if subset_matches:\n                        candidates = subset_matches\n\n                if candidates:\n                    first_file = candidates[0]\n                    logger.info(f\"Tier 1: loading single file {first_file}\")\n                    load_kwargs = {\n                        \"path\": request.dataset_name,\n                        \"data_files\": [first_file],\n                        \"split\": \"train\",\n                        \"streaming\": True,\n                    }\n                    if request.hf_token:\n                        load_kwargs[\"token\"] = request.hf_token\n\n                    streamed_ds = load_dataset(**load_kwargs)\n                    rows = list(islice(streamed_ds, PREVIEW_SIZE))\n                    if rows:\n                        preview_slice = Dataset.from_list(rows)\n            except Exception as e:\n                logger.warning(f\"Tier 1 (single-file) failed: {e}\")\n\n            if preview_slice is None:\n                # Tier 2: full streaming (resolves all files — slow for large repos)\n                logger.info(\"Tier 2: falling back to full streaming load_dataset\")\n                load_kwargs = {\n                    \"path\": request.dataset_name,\n                    \"split\": request.train_split,\n                    \"streaming\": True,\n                }\n                if request.subset:\n                    load_kwargs[\"name\"] = request.subset\n                if request.hf_token:\n                    load_kwargs[\"token\"] = request.hf_token\n\n                streamed_ds = load_dataset(**load_kwargs)\n\n                rows = list(islice(streamed_ds, PREVIEW_SIZE))\n                if not rows:\n                    raise HTTPException(\n                        status_code = 400,\n                        detail = \"Dataset appears to be empty or could not be streamed\",\n                    )\n\n                preview_slice = Dataset.from_list(rows)\n            total_rows = None\n\n        # Run lightweight format check on the preview slice\n        result = check_dataset_format(preview_slice, is_vlm = request.is_vlm)\n\n        logger.info(\n            f\"Format check result: requires_mapping={result['requires_manual_mapping']}, format={result['detected_format']}, is_image={result.get('is_image', False)}\"\n        )\n\n        # Generate preview samples\n        preview_samples = None\n        if not result[\"requires_manual_mapping\"]:\n            if result.get(\"suggested_mapping\"):\n                # Heuristic-detected: show raw data so columns match the API response.\n                # Processing (column stripping) happens at training time, not preview.\n                preview_samples = _serialize_preview_rows(preview_slice)\n            else:\n                try:\n                    format_result = format_dataset(\n                        preview_slice,\n                        format_type = \"auto\",\n                        num_proc = 1,  # Only 10 preview rows — no need for multiprocessing\n                    )\n                    processed = format_result[\"dataset\"]\n                    preview_samples = _serialize_preview_rows(processed)\n                except Exception as e:\n                    logger.warning(\n                        f\"Processed preview generation failed (non-fatal): {e}\"\n                    )\n                    preview_samples = _serialize_preview_rows(preview_slice)\n        else:\n            preview_samples = _serialize_preview_rows(preview_slice)\n\n        # Collect warnings: from check_dataset_format + URL-based image detection\n        warning = result.get(\"warning\")\n        image_col = result.get(\"detected_image_column\")\n        if image_col and image_col in (result.get(\"columns\") or []):\n            try:\n                sample_val = preview_slice[0][image_col]\n                if isinstance(sample_val, str) and sample_val.startswith(\n                    (\"http://\", \"https://\")\n                ):\n                    url_warning = (\n                        \"This dataset contains image URLs instead of embedded images. \"\n                        \"Images will be downloaded during training, which may be slow for large datasets.\"\n                    )\n                    logger.info(f\"URL-based image column detected: {image_col}\")\n                    warning = f\"{warning} {url_warning}\" if warning else url_warning\n            except Exception:\n                pass\n\n        return CheckFormatResponse(\n            requires_manual_mapping = result[\"requires_manual_mapping\"],\n            detected_format = result[\"detected_format\"],\n            columns = result[\"columns\"],\n            is_image = result.get(\"is_image\", False),\n            is_audio = result.get(\"is_audio\", False),\n            multimodal_columns = result.get(\"multimodal_columns\"),\n            suggested_mapping = result.get(\"suggested_mapping\"),\n            detected_image_column = result.get(\"detected_image_column\"),\n            detected_audio_column = result.get(\"detected_audio_column\"),\n            detected_text_column = result.get(\"detected_text_column\"),\n            detected_speaker_column = result.get(\"detected_speaker_column\"),\n            preview_samples = preview_samples,\n            total_rows = total_rows,\n            warning = warning,\n        )\n\n    except HTTPException:\n        raise\n    except Exception as e:\n        logger.error(f\"Error checking dataset format: {e}\", exc_info = True)\n        raise HTTPException(\n            status_code = 500, detail = f\"Failed to check dataset format: {str(e)}\"\n        )\n\n\n@router.post(\"/ai-assist-mapping\", response_model = AiAssistMappingResponse)\ndef ai_assist_mapping(\n    request: AiAssistMappingRequest,\n    current_subject: str = Depends(get_current_subject),\n):\n    \"\"\"\n    Run LLM-assisted dataset conversion advisor (user-triggered).\n\n    Multi-pass analysis using a 7B helper model:\n      Pass 1: Classify dataset type from HF card + samples\n      Pass 2: Generate conversion strategy (system prompt, templates)\n      Pass 3: Validate conversion quality\n\n    Falls back to simple column classification if the advisor fails.\n    \"\"\"\n    try:\n        from utils.datasets.llm_assist import llm_conversion_advisor\n\n        # Truncate sample values for the LLM prompt\n        truncated = [\n            {col: str(s.get(col, \"\"))[:200] for col in request.columns}\n            for s in request.samples[:5]\n        ]\n\n        result = llm_conversion_advisor(\n            column_names = request.columns,\n            samples = truncated,\n            dataset_name = request.dataset_name,\n            hf_token = request.hf_token,\n            model_name = request.model_name,\n            model_type = request.model_type,\n        )\n\n        if result and result.get(\"success\"):\n            return AiAssistMappingResponse(\n                success = True,\n                suggested_mapping = result.get(\"suggested_mapping\"),\n                system_prompt = result.get(\"system_prompt\"),\n                user_template = result.get(\"user_template\"),\n                assistant_template = result.get(\"assistant_template\"),\n                label_mapping = result.get(\"label_mapping\"),\n                dataset_type = result.get(\"dataset_type\"),\n                is_conversational = result.get(\"is_conversational\"),\n                user_notification = result.get(\"user_notification\"),\n            )\n\n        return AiAssistMappingResponse(\n            success = False,\n            warning = \"AI could not determine column roles. Please assign them manually.\",\n        )\n\n    except Exception as e:\n        logger.error(f\"AI assist mapping failed: {e}\", exc_info = True)\n        raise HTTPException(status_code = 500, detail = f\"AI assist failed: {str(e)}\")\n"
  },
  {
    "path": "studio/backend/routes/export.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nExport API routes: checkpoint discovery and model export operations.\n\"\"\"\n\nimport sys\nfrom pathlib import Path\nfrom fastapi import APIRouter, Depends, HTTPException, Query\nimport structlog\nfrom loggers import get_logger\n\n# Add backend directory to path\nbackend_path = Path(__file__).parent.parent.parent\nif str(backend_path) not in sys.path:\n    sys.path.insert(0, str(backend_path))\n\n# Auth\nfrom auth.authentication import get_current_subject\n\n# Import backend functions\ntry:\n    from core.export import get_export_backend\nexcept ImportError:\n    parent_backend = backend_path.parent / \"backend\"\n    if str(parent_backend) not in sys.path:\n        sys.path.insert(0, str(parent_backend))\n    from core.export import get_export_backend\n\n# Import Pydantic models\nfrom models import (\n    LoadCheckpointRequest,\n    ExportStatusResponse,\n    ExportOperationResponse,\n    ExportMergedModelRequest,\n    ExportBaseModelRequest,\n    ExportGGUFRequest,\n    ExportLoRAAdapterRequest,\n)\n\nrouter = APIRouter()\nlogger = get_logger(__name__)\n\n\n@router.post(\"/load-checkpoint\", response_model = ExportOperationResponse)\nasync def load_checkpoint(\n    request: LoadCheckpointRequest,\n    current_subject: str = Depends(get_current_subject),\n):\n    \"\"\"\n    Load a checkpoint into the export backend.\n\n    Wraps ExportBackend.load_checkpoint.\n    \"\"\"\n    try:\n        # Version switching is handled automatically by the subprocess-based\n        # export backend — no need for ensure_transformers_version() here.\n\n        # Free GPU memory: shut down any running inference/training subprocesses\n        # before loading the export checkpoint (they'd compete for VRAM).\n        try:\n            from core.inference import get_inference_backend\n\n            inf = get_inference_backend()\n            if inf.active_model_name:\n                logger.info(\n                    \"Unloading inference model '%s' to free GPU memory for export\",\n                    inf.active_model_name,\n                )\n                inf._shutdown_subprocess()\n                inf.active_model_name = None\n                inf.models.clear()\n        except Exception as e:\n            logger.warning(\"Could not unload inference model: %s\", e)\n\n        try:\n            from core.training import get_training_backend\n\n            trn = get_training_backend()\n            if trn.is_training_active():\n                logger.info(\"Stopping active training to free GPU memory for export\")\n                trn.stop_training()\n                # Wait for training subprocess to actually exit before proceeding,\n                # otherwise it may still hold GPU memory when export tries to load.\n                for _ in range(60):  # up to 30s\n                    if not trn.is_training_active():\n                        break\n                    import time\n\n                    time.sleep(0.5)\n                else:\n                    logger.warning(\n                        \"Training subprocess did not exit within 30s, proceeding anyway\"\n                    )\n        except Exception as e:\n            logger.warning(\"Could not stop training: %s\", e)\n\n        backend = get_export_backend()\n        success, message = backend.load_checkpoint(\n            checkpoint_path = request.checkpoint_path,\n            max_seq_length = request.max_seq_length,\n            load_in_4bit = request.load_in_4bit,\n            trust_remote_code = request.trust_remote_code,\n        )\n\n        if not success:\n            raise HTTPException(status_code = 400, detail = message)\n\n        return ExportOperationResponse(success = True, message = message)\n    except HTTPException:\n        raise\n    except Exception as e:\n        logger.error(f\"Error loading checkpoint: {e}\", exc_info = True)\n        raise HTTPException(\n            status_code = 500,\n            detail = f\"Failed to load checkpoint: {str(e)}\",\n        )\n\n\n@router.post(\"/cleanup\", response_model = ExportOperationResponse)\nasync def cleanup_export_memory(\n    current_subject: str = Depends(get_current_subject),\n):\n    \"\"\"\n    Cleanup export-related models from memory (GPU/CPU).\n\n    Wraps ExportBackend.cleanup_memory.\n    \"\"\"\n    try:\n        backend = get_export_backend()\n        success = backend.cleanup_memory()\n\n        if not success:\n            raise HTTPException(\n                status_code = 500,\n                detail = \"Memory cleanup failed. See server logs for details.\",\n            )\n\n        return ExportOperationResponse(\n            success = True,\n            message = \"Memory cleanup completed successfully\",\n        )\n    except HTTPException:\n        raise\n    except Exception as e:\n        logger.error(f\"Error during export memory cleanup: {e}\", exc_info = True)\n        raise HTTPException(\n            status_code = 500,\n            detail = f\"Failed to cleanup export memory: {str(e)}\",\n        )\n\n\n@router.get(\"/status\", response_model = ExportStatusResponse)\nasync def get_export_status(\n    current_subject: str = Depends(get_current_subject),\n):\n    \"\"\"\n    Get current export backend status (loaded checkpoint, model type, PEFT flag).\n    \"\"\"\n    try:\n        backend = get_export_backend()\n        return ExportStatusResponse(\n            current_checkpoint = backend.current_checkpoint,\n            is_vision = bool(getattr(backend, \"is_vision\", False)),\n            is_peft = bool(getattr(backend, \"is_peft\", False)),\n        )\n    except Exception as e:\n        logger.error(f\"Error getting export status: {e}\", exc_info = True)\n        raise HTTPException(\n            status_code = 500,\n            detail = f\"Failed to get export status: {str(e)}\",\n        )\n\n\n@router.post(\"/export/merged\", response_model = ExportOperationResponse)\nasync def export_merged_model(\n    request: ExportMergedModelRequest,\n    current_subject: str = Depends(get_current_subject),\n):\n    \"\"\"\n    Export a merged PEFT model (e.g., 16-bit or 4-bit) and optionally push to Hub.\n\n    Wraps ExportBackend.export_merged_model.\n    \"\"\"\n    try:\n        backend = get_export_backend()\n        success, message = backend.export_merged_model(\n            save_directory = request.save_directory,\n            format_type = request.format_type,\n            push_to_hub = request.push_to_hub,\n            repo_id = request.repo_id,\n            hf_token = request.hf_token,\n            private = request.private,\n        )\n\n        if not success:\n            raise HTTPException(status_code = 400, detail = message)\n\n        return ExportOperationResponse(success = True, message = message)\n    except HTTPException:\n        raise\n    except Exception as e:\n        logger.error(f\"Error exporting merged model: {e}\", exc_info = True)\n        raise HTTPException(\n            status_code = 500,\n            detail = f\"Failed to export merged model: {str(e)}\",\n        )\n\n\n@router.post(\"/export/base\", response_model = ExportOperationResponse)\nasync def export_base_model(\n    request: ExportBaseModelRequest,\n    current_subject: str = Depends(get_current_subject),\n):\n    \"\"\"\n    Export a non-PEFT base model and optionally push to Hub.\n\n    Wraps ExportBackend.export_base_model.\n    \"\"\"\n    try:\n        backend = get_export_backend()\n        success, message = backend.export_base_model(\n            save_directory = request.save_directory,\n            push_to_hub = request.push_to_hub,\n            repo_id = request.repo_id,\n            hf_token = request.hf_token,\n            private = request.private,\n            base_model_id = request.base_model_id,\n        )\n\n        if not success:\n            raise HTTPException(status_code = 400, detail = message)\n\n        return ExportOperationResponse(success = True, message = message)\n    except HTTPException:\n        raise\n    except Exception as e:\n        logger.error(f\"Error exporting base model: {e}\", exc_info = True)\n        raise HTTPException(\n            status_code = 500,\n            detail = f\"Failed to export base model: {str(e)}\",\n        )\n\n\n@router.post(\"/export/gguf\", response_model = ExportOperationResponse)\nasync def export_gguf(\n    request: ExportGGUFRequest,\n    current_subject: str = Depends(get_current_subject),\n):\n    \"\"\"\n    Export the current model to GGUF format and optionally push to Hub.\n\n    Wraps ExportBackend.export_gguf.\n    \"\"\"\n    try:\n        backend = get_export_backend()\n        success, message = backend.export_gguf(\n            save_directory = request.save_directory,\n            quantization_method = request.quantization_method,\n            push_to_hub = request.push_to_hub,\n            repo_id = request.repo_id,\n            hf_token = request.hf_token,\n        )\n\n        if not success:\n            raise HTTPException(status_code = 400, detail = message)\n\n        return ExportOperationResponse(success = True, message = message)\n    except HTTPException:\n        raise\n    except Exception as e:\n        logger.error(f\"Error exporting GGUF model: {e}\", exc_info = True)\n        raise HTTPException(\n            status_code = 500,\n            detail = f\"Failed to export GGUF model: {str(e)}\",\n        )\n\n\n@router.post(\"/export/lora\", response_model = ExportOperationResponse)\nasync def export_lora_adapter(\n    request: ExportLoRAAdapterRequest,\n    current_subject: str = Depends(get_current_subject),\n):\n    \"\"\"\n    Export only the LoRA adapter (if the loaded model is PEFT).\n\n    Wraps ExportBackend.export_lora_adapter.\n    \"\"\"\n    try:\n        backend = get_export_backend()\n        success, message = backend.export_lora_adapter(\n            save_directory = request.save_directory,\n            push_to_hub = request.push_to_hub,\n            repo_id = request.repo_id,\n            hf_token = request.hf_token,\n            private = request.private,\n        )\n\n        if not success:\n            raise HTTPException(status_code = 400, detail = message)\n\n        return ExportOperationResponse(success = True, message = message)\n    except HTTPException:\n        raise\n    except Exception as e:\n        logger.error(f\"Error exporting LoRA adapter: {e}\", exc_info = True)\n        raise HTTPException(\n            status_code = 500,\n            detail = f\"Failed to export LoRA adapter: {str(e)}\",\n        )\n"
  },
  {
    "path": "studio/backend/routes/inference.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nInference API routes for model loading and text generation.\n\"\"\"\n\nimport sys\nimport time\nimport uuid\nfrom pathlib import Path\nfrom fastapi import APIRouter, Depends, HTTPException, Request\nfrom fastapi.responses import StreamingResponse, JSONResponse\nfrom typing import Optional\nimport json\nimport structlog\nfrom loggers import get_logger\nimport asyncio\nimport threading\n\n\n# Add backend directory to path\nbackend_path = Path(__file__).parent.parent.parent\nif str(backend_path) not in sys.path:\n    sys.path.insert(0, str(backend_path))\n\n# Import backend functions\ntry:\n    from core.inference import get_inference_backend\n    from core.inference.llama_cpp import LlamaCppBackend\n    from utils.models import ModelConfig\n    from utils.inference import load_inference_config\n    from utils.models.model_config import load_model_defaults\nexcept ImportError:\n    parent_backend = backend_path.parent / \"backend\"\n    if str(parent_backend) not in sys.path:\n        sys.path.insert(0, str(parent_backend))\n    from core.inference import get_inference_backend\n    from core.inference.llama_cpp import LlamaCppBackend\n    from utils.models import ModelConfig\n    from utils.inference import load_inference_config\n    from utils.models.model_config import load_model_defaults\n\nfrom models.inference import (\n    LoadRequest,\n    UnloadRequest,\n    GenerateRequest,\n    LoadResponse,\n    UnloadResponse,\n    InferenceStatusResponse,\n    ChatCompletionRequest,\n    ChatCompletionChunk,\n    ChatCompletion,\n    ChunkChoice,\n    ChoiceDelta,\n    CompletionChoice,\n    CompletionMessage,\n    ValidateModelRequest,\n    ValidateModelResponse,\n)\nfrom auth.authentication import get_current_subject\n\nimport io\nimport wave\nimport base64\nimport numpy as np\n\nrouter = APIRouter()\nlogger = get_logger(__name__)\n\n\n# GGUF inference backend (llama-server)\n_llama_cpp_backend = LlamaCppBackend()\n\n\ndef get_llama_cpp_backend() -> LlamaCppBackend:\n    return _llama_cpp_backend\n\n\n@router.post(\"/load\", response_model = LoadResponse)\nasync def load_model(\n    request: LoadRequest,\n    current_subject: str = Depends(get_current_subject),\n):\n    \"\"\"\n    Load a model for inference.\n\n    The model_path should be a clean identifier from GET /models/list.\n    Returns inference configuration parameters (temperature, top_p, top_k, min_p)\n    from the model's YAML config, falling back to default.yaml for missing values.\n\n    GGUF models are loaded via llama-server (llama.cpp) instead of Unsloth.\n    \"\"\"\n    try:\n        # Version switching is handled automatically by the subprocess-based\n        # inference backend — no need for ensure_transformers_version() here.\n\n        # ── Already-loaded check: skip reload if the exact model is active ──\n        backend = get_inference_backend()\n        llama_backend = get_llama_cpp_backend()\n\n        if request.gguf_variant:\n            if (\n                llama_backend.is_loaded\n                and llama_backend.hf_variant\n                and llama_backend.hf_variant.lower() == request.gguf_variant.lower()\n                and llama_backend.model_identifier\n                and llama_backend.model_identifier.lower() == request.model_path.lower()\n            ):\n                logger.info(\n                    f\"Model already loaded (GGUF): {request.model_path} variant={request.gguf_variant}, skipping reload\"\n                )\n                inference_config = load_inference_config(llama_backend.model_identifier)\n                from utils.models import is_audio_input_type\n\n                _gguf_audio = (\n                    llama_backend._audio_type\n                    if hasattr(llama_backend, \"_audio_type\")\n                    else None\n                )\n                _gguf_is_audio = getattr(llama_backend, \"_is_audio\", False)\n                return LoadResponse(\n                    status = \"already_loaded\",\n                    model = llama_backend.model_identifier,\n                    display_name = llama_backend.model_identifier,\n                    is_vision = llama_backend._is_vision,\n                    is_lora = False,\n                    is_gguf = True,\n                    is_audio = _gguf_is_audio,\n                    audio_type = _gguf_audio,\n                    has_audio_input = is_audio_input_type(_gguf_audio)\n                    if _gguf_audio\n                    else False,\n                    inference = inference_config,\n                    context_length = llama_backend.context_length,\n                    supports_reasoning = llama_backend.supports_reasoning,\n                    chat_template = llama_backend.chat_template,\n                )\n        else:\n            if (\n                backend.active_model_name\n                and backend.active_model_name.lower() == request.model_path.lower()\n            ):\n                logger.info(\n                    f\"Model already loaded (Unsloth): {request.model_path}, skipping reload\"\n                )\n                inference_config = load_inference_config(backend.active_model_name)\n                _model_info = backend.models.get(backend.active_model_name, {})\n                _chat_template = None\n                try:\n                    _tpl_info = _model_info.get(\"chat_template_info\", {})\n                    _chat_template = _tpl_info.get(\"template\")\n                except Exception as e:\n                    logger.warning(\n                        f\"Could not retrieve chat template for {backend.active_model_name}: {e}\"\n                    )\n                return LoadResponse(\n                    status = \"already_loaded\",\n                    model = backend.active_model_name,\n                    display_name = backend.active_model_name,\n                    is_vision = _model_info.get(\"is_vision\", False),\n                    is_lora = _model_info.get(\"is_lora\", False),\n                    is_gguf = False,\n                    is_audio = _model_info.get(\"is_audio\", False),\n                    audio_type = _model_info.get(\"audio_type\"),\n                    has_audio_input = _model_info.get(\"has_audio_input\", False),\n                    inference = inference_config,\n                    chat_template = _chat_template,\n                )\n\n        # Create config using clean factory method\n        # is_lora is auto-detected from adapter_config.json on disk/HF\n        config = ModelConfig.from_identifier(\n            model_id = request.model_path,\n            hf_token = request.hf_token,\n            gguf_variant = request.gguf_variant,\n        )\n\n        if not config:\n            raise HTTPException(\n                status_code = 400,\n                detail = f\"Invalid model identifier: {request.model_path}\",\n            )\n\n        # ── GGUF path: load via llama-server ──────────────────────\n        if config.is_gguf:\n            llama_backend = get_llama_cpp_backend()\n            unsloth_backend = get_inference_backend()\n\n            # Unload any active Unsloth model first to free VRAM\n            if unsloth_backend.active_model_name:\n                logger.info(\n                    f\"Unloading Unsloth model '{unsloth_backend.active_model_name}' before loading GGUF\"\n                )\n                unsloth_backend.unload_model(unsloth_backend.active_model_name)\n\n            # Route to HF mode or local mode based on config\n            # Run in a thread so the event loop stays free for progress\n            # polling and other requests during the (potentially long)\n            # GGUF download + llama-server startup.\n            if config.gguf_hf_repo:\n                # HF mode: download via huggingface_hub then start llama-server\n                success = await asyncio.to_thread(\n                    llama_backend.load_model,\n                    hf_repo = config.gguf_hf_repo,\n                    hf_variant = config.gguf_variant,\n                    hf_token = request.hf_token,\n                    model_identifier = config.identifier,\n                    is_vision = config.is_vision,\n                    n_ctx = request.max_seq_length,\n                    chat_template_override = request.chat_template_override,\n                    cache_type_kv = request.cache_type_kv,\n                )\n            else:\n                # Local mode: llama-server loads via -m <path>\n                success = await asyncio.to_thread(\n                    llama_backend.load_model,\n                    gguf_path = config.gguf_file,\n                    mmproj_path = config.gguf_mmproj_file,\n                    model_identifier = config.identifier,\n                    is_vision = config.is_vision,\n                    n_ctx = request.max_seq_length,\n                    chat_template_override = request.chat_template_override,\n                    cache_type_kv = request.cache_type_kv,\n                )\n\n            if not success:\n                raise HTTPException(\n                    status_code = 500,\n                    detail = f\"Failed to load GGUF model: {config.display_name}\",\n                )\n\n            logger.info(f\"Loaded GGUF model via llama-server: {config.identifier}\")\n\n            # Detect TTS audio by probing the loaded model's vocabulary\n            from utils.models import is_audio_input_type\n\n            _gguf_audio = llama_backend.detect_audio_type()\n            _gguf_is_audio = _gguf_audio in (\"snac\", \"bicodec\", \"dac\")\n            llama_backend._is_audio = _gguf_is_audio\n            llama_backend._audio_type = _gguf_audio\n            if _gguf_is_audio:\n                logger.info(f\"GGUF model detected as audio: audio_type={_gguf_audio}\")\n                await asyncio.to_thread(llama_backend.init_audio_codec, _gguf_audio)\n\n            inference_config = load_inference_config(config.identifier)\n\n            return LoadResponse(\n                status = \"loaded\",\n                model = config.identifier,\n                display_name = config.display_name,\n                is_vision = config.is_vision,\n                is_lora = False,\n                is_gguf = True,\n                is_audio = _gguf_is_audio,\n                audio_type = _gguf_audio,\n                has_audio_input = is_audio_input_type(_gguf_audio),\n                inference = inference_config,\n                context_length = llama_backend.context_length,\n                supports_reasoning = llama_backend.supports_reasoning,\n                supports_tools = llama_backend.supports_tools,\n                cache_type_kv = llama_backend.cache_type_kv,\n                chat_template = llama_backend.chat_template,\n            )\n\n        # ── Standard path: load via Unsloth/transformers ──────────\n        backend = get_inference_backend()\n\n        # Unload any active GGUF model first\n        llama_backend = get_llama_cpp_backend()\n        if llama_backend.is_loaded:\n            logger.info(\"Unloading GGUF model before loading Unsloth model\")\n            llama_backend.unload_model()\n\n        # Shut down any export subprocess to free VRAM\n        try:\n            from core.export import get_export_backend\n\n            exp_backend = get_export_backend()\n            if exp_backend.current_checkpoint:\n                logger.info(\n                    \"Shutting down export subprocess to free GPU memory for inference\"\n                )\n                exp_backend._shutdown_subprocess()\n                exp_backend.current_checkpoint = None\n                exp_backend.is_vision = False\n                exp_backend.is_peft = False\n        except Exception as e:\n            logger.warning(\"Could not shut down export subprocess: %s\", e)\n\n        # Auto-detect quantization for LoRA adapters from adapter_config.json\n        # The training pipeline patches this file with \"unsloth_training_method\"\n        # which is 'qlora' or 'lora'. Only LoRA (16-bit) needs load_in_4bit=False.\n        load_in_4bit = request.load_in_4bit\n        if config.is_lora and config.path:\n            import json\n            from pathlib import Path\n\n            adapter_cfg_path = Path(config.path) / \"adapter_config.json\"\n            if adapter_cfg_path.exists():\n                try:\n                    with open(adapter_cfg_path) as f:\n                        adapter_cfg = json.load(f)\n                    training_method = adapter_cfg.get(\"unsloth_training_method\")\n                    if training_method == \"lora\" and load_in_4bit:\n                        logger.info(\n                            f\"adapter_config.json says unsloth_training_method='lora' — \"\n                            f\"setting load_in_4bit=False to match 16-bit training\"\n                        )\n                        load_in_4bit = False\n                    elif training_method == \"qlora\" and not load_in_4bit:\n                        logger.info(\n                            f\"adapter_config.json says unsloth_training_method='qlora' — \"\n                            f\"setting load_in_4bit=True to match QLoRA training\"\n                        )\n                        load_in_4bit = True\n                    elif training_method:\n                        logger.info(\n                            f\"Training method: {training_method}, load_in_4bit={load_in_4bit}\"\n                        )\n                    else:\n                        # No unsloth_training_method — fallback to base model name\n                        if (\n                            config.base_model\n                            and \"-bnb-4bit\" not in config.base_model.lower()\n                            and load_in_4bit\n                        ):\n                            logger.info(\n                                f\"No unsloth_training_method in adapter_config.json. \"\n                                f\"Base model '{config.base_model}' has no -bnb-4bit suffix — \"\n                                f\"setting load_in_4bit=False\"\n                            )\n                            load_in_4bit = False\n                except Exception as e:\n                    logger.warning(f\"Could not read adapter_config.json: {e}\")\n\n        # Load the model in a thread so the event loop stays free\n        # for download progress polling and other requests.\n        success = await asyncio.to_thread(\n            backend.load_model,\n            config = config,\n            max_seq_length = request.max_seq_length,\n            load_in_4bit = load_in_4bit,\n            hf_token = request.hf_token,\n            trust_remote_code = request.trust_remote_code,\n        )\n\n        if not success:\n            # Check if YAML says this model needs trust_remote_code\n            if not request.trust_remote_code:\n                model_defaults = load_model_defaults(config.identifier)\n                yaml_trust = model_defaults.get(\"inference\", {}).get(\n                    \"trust_remote_code\", False\n                )\n                if yaml_trust:\n                    raise HTTPException(\n                        status_code = 400,\n                        detail = (\n                            f\"Model '{config.display_name}' requires trust_remote_code to be enabled. \"\n                            f\"Please enable 'Trust remote code' in Chat Settings and try again.\"\n                        ),\n                    )\n            raise HTTPException(\n                status_code = 500, detail = f\"Failed to load model: {config.display_name}\"\n            )\n\n        logger.info(f\"Loaded model: {config.identifier}\")\n\n        # Load inference configuration parameters\n        inference_config = load_inference_config(config.identifier)\n\n        # Get chat template from tokenizer\n        _chat_template = None\n        try:\n            _model_info = backend.models.get(config.identifier, {})\n            _tpl_info = _model_info.get(\"chat_template_info\", {})\n            _chat_template = _tpl_info.get(\"template\")\n        except Exception:\n            pass\n\n        return LoadResponse(\n            status = \"loaded\",\n            model = config.identifier,\n            display_name = config.display_name,\n            is_vision = config.is_vision,\n            is_lora = config.is_lora,\n            is_gguf = False,\n            is_audio = config.is_audio,\n            audio_type = config.audio_type,\n            has_audio_input = config.has_audio_input,\n            inference = inference_config,\n            chat_template = _chat_template,\n        )\n\n    except HTTPException:\n        raise\n    except Exception as e:\n        logger.error(f\"Error loading model: {e}\", exc_info = True)\n        msg = str(e)\n        # Surface a friendlier message for models that Unsloth cannot load\n        not_supported_hints = [\n            \"No config file found\",\n            \"not yet supported\",\n            \"is not supported\",\n            \"does not support\",\n        ]\n        if any(h.lower() in msg.lower() for h in not_supported_hints):\n            msg = f\"This model is not supported yet. Try a different model. (Original error: {msg})\"\n        raise HTTPException(status_code = 500, detail = f\"Failed to load model: {msg}\")\n\n\n@router.post(\"/validate\", response_model = ValidateModelResponse)\nasync def validate_model(\n    request: ValidateModelRequest,\n    current_subject: str = Depends(get_current_subject),\n):\n    \"\"\"\n    Lightweight validation endpoint for model identifiers.\n\n    This checks that ModelConfig.from_identifier() can resolve the given\n    model_path, but it does NOT actually load model weights into GPU memory.\n    \"\"\"\n    try:\n        config = ModelConfig.from_identifier(\n            model_id = request.model_path,\n            hf_token = request.hf_token,\n            gguf_variant = request.gguf_variant,\n        )\n\n        if not config:\n            raise HTTPException(\n                status_code = 400,\n                detail = f\"Invalid model identifier: {request.model_path}\",\n            )\n\n        return ValidateModelResponse(\n            valid = True,\n            message = \"Model identifier is valid.\",\n            identifier = config.identifier,\n            display_name = getattr(config, \"display_name\", config.identifier),\n            is_gguf = getattr(config, \"is_gguf\", False),\n            is_lora = getattr(config, \"is_lora\", False),\n            is_vision = getattr(config, \"is_vision\", False),\n        )\n\n    except HTTPException:\n        raise\n    except Exception as e:\n        logger.error(\n            f\"Error validating model identifier '{request.model_path}': {e}\",\n            exc_info = True,\n        )\n        raise HTTPException(\n            status_code = 400,\n            detail = f\"Invalid model: {str(e)}\",\n        )\n\n\n@router.post(\"/unload\", response_model = UnloadResponse)\nasync def unload_model(\n    request: UnloadRequest,\n    current_subject: str = Depends(get_current_subject),\n):\n    \"\"\"\n    Unload a model from memory.\n    Routes to the correct backend (llama-server for GGUF, Unsloth otherwise).\n    \"\"\"\n    try:\n        # Check if the GGUF backend has this model loaded or is loading it\n        llama_backend = get_llama_cpp_backend()\n        if llama_backend.is_active and (\n            llama_backend.model_identifier == request.model_path\n            or not llama_backend.is_loaded\n        ):\n            llama_backend.unload_model()\n            logger.info(f\"Unloaded GGUF model: {request.model_path}\")\n            return UnloadResponse(status = \"unloaded\", model = request.model_path)\n\n        # Otherwise, unload from Unsloth backend\n        backend = get_inference_backend()\n        backend.unload_model(request.model_path)\n        logger.info(f\"Unloaded model: {request.model_path}\")\n        return UnloadResponse(status = \"unloaded\", model = request.model_path)\n\n    except Exception as e:\n        logger.error(f\"Error unloading model: {e}\", exc_info = True)\n        raise HTTPException(status_code = 500, detail = f\"Failed to unload model: {str(e)}\")\n\n\n@router.post(\"/generate/stream\")\nasync def generate_stream(\n    request: GenerateRequest,\n    current_subject: str = Depends(get_current_subject),\n):\n    \"\"\"\n    Generate a chat response with Server-Sent Events (SSE) streaming.\n\n    For vision models, provide image_base64 with the base64-encoded image.\n    \"\"\"\n    backend = get_inference_backend()\n\n    if not backend.active_model_name:\n        raise HTTPException(\n            status_code = 400, detail = \"No model loaded. Call POST /inference/load first.\"\n        )\n\n    # Decode image if provided (for vision models)\n    image = None\n    if request.image_base64:\n        try:\n            import base64\n            from PIL import Image\n            from io import BytesIO\n\n            # Check if current model supports vision\n            model_info = backend.models.get(backend.active_model_name, {})\n            if not model_info.get(\"is_vision\"):\n                raise HTTPException(\n                    status_code = 400,\n                    detail = \"Image provided but current model is text-only. Load a vision model.\",\n                )\n\n            image_data = base64.b64decode(request.image_base64)\n            image = Image.open(BytesIO(image_data))\n            image = backend.resize_image(image)\n\n        except HTTPException:\n            raise\n        except Exception as e:\n            raise HTTPException(\n                status_code = 400, detail = f\"Failed to decode image: {str(e)}\"\n            )\n\n    async def stream():\n        try:\n            for chunk in backend.generate_chat_response(\n                messages = request.messages,\n                system_prompt = request.system_prompt,\n                image = image,\n                temperature = request.temperature,\n                top_p = request.top_p,\n                top_k = request.top_k,\n                max_new_tokens = request.max_new_tokens,\n                repetition_penalty = request.repetition_penalty,\n            ):\n                yield f\"data: {json.dumps({'content': chunk})}\\n\\n\"\n            yield \"data: [DONE]\\n\\n\"\n\n        except Exception as e:\n            backend.reset_generation_state()\n            logger.error(f\"Error during generation: {e}\", exc_info = True)\n            yield f\"data: {json.dumps({'error': 'An internal error occurred'})}\\n\\n\"\n\n    return StreamingResponse(\n        stream(),\n        media_type = \"text/event-stream\",\n        headers = {\n            \"Cache-Control\": \"no-cache\",\n            \"Connection\": \"keep-alive\",\n        },\n    )\n\n\n@router.get(\"/status\", response_model = InferenceStatusResponse)\nasync def get_status(\n    current_subject: str = Depends(get_current_subject),\n):\n    \"\"\"\n    Get current inference backend status.\n    Reports whichever backend (Unsloth or llama-server) is currently active.\n    \"\"\"\n    try:\n        llama_backend = get_llama_cpp_backend()\n\n        # If a GGUF model is loaded via llama-server, report that\n        if llama_backend.is_loaded:\n            _model_id = llama_backend.model_identifier\n            _inference_cfg = load_inference_config(_model_id) if _model_id else None\n            return InferenceStatusResponse(\n                active_model = _model_id,\n                is_vision = llama_backend.is_vision,\n                is_gguf = True,\n                gguf_variant = llama_backend.hf_variant,\n                is_audio = getattr(llama_backend, \"_is_audio\", False),\n                audio_type = getattr(llama_backend, \"_audio_type\", None),\n                loading = [],\n                loaded = [_model_id],\n                inference = _inference_cfg,\n                supports_reasoning = llama_backend.supports_reasoning,\n                supports_tools = llama_backend.supports_tools,\n                context_length = llama_backend.context_length,\n            )\n\n        # Otherwise, report Unsloth backend status\n        backend = get_inference_backend()\n\n        is_vision = False\n        is_audio = False\n        audio_type = None\n        has_audio_input = False\n        if backend.active_model_name:\n            model_info = backend.models.get(backend.active_model_name, {})\n            is_vision = model_info.get(\"is_vision\", False)\n            is_audio = model_info.get(\"is_audio\", False)\n            audio_type = model_info.get(\"audio_type\")\n            has_audio_input = model_info.get(\"has_audio_input\", False)\n\n        # gpt-oss safetensors models support reasoning via harmony channels\n        supports_reasoning = False\n        if backend.active_model_name and hasattr(backend, \"_is_gpt_oss_model\"):\n            supports_reasoning = backend._is_gpt_oss_model()\n\n        return InferenceStatusResponse(\n            active_model = backend.active_model_name,\n            is_vision = is_vision,\n            is_gguf = False,\n            is_audio = is_audio,\n            audio_type = audio_type,\n            has_audio_input = has_audio_input,\n            loading = list(getattr(backend, \"loading_models\", set())),\n            loaded = list(backend.models.keys()),\n            supports_reasoning = supports_reasoning,\n        )\n\n    except Exception as e:\n        logger.error(f\"Error getting status: {e}\", exc_info = True)\n        raise HTTPException(status_code = 500, detail = f\"Failed to get status: {str(e)}\")\n\n\n# =====================================================================\n# Audio (TTS) Generation  (/audio/generate)\n# =====================================================================\n\n\n@router.post(\"/audio/generate\")\nasync def generate_audio(\n    payload: ChatCompletionRequest,\n    request: Request,\n    current_subject: str = Depends(get_current_subject),\n):\n    \"\"\"\n    Generate audio (TTS) from the latest user message.\n    Returns a JSON response with base64-encoded WAV audio.\n    Works with both GGUF (llama-server) and Unsloth/transformers backends.\n    \"\"\"\n    import base64\n\n    # Extract text from the last user message\n    _, chat_messages, _ = _extract_content_parts(payload.messages)\n    if not chat_messages:\n        raise HTTPException(status_code = 400, detail = \"No messages provided.\")\n    last_user_msg = next(\n        (m for m in reversed(chat_messages) if m[\"role\"] == \"user\"), None\n    )\n    if not last_user_msg:\n        raise HTTPException(status_code = 400, detail = \"No user message found.\")\n    text = last_user_msg[\"content\"]\n\n    # Pick backend — both return (wav_bytes, sample_rate)\n    llama_backend = get_llama_cpp_backend()\n    if llama_backend.is_loaded and getattr(llama_backend, \"_is_audio\", False):\n        model_name = llama_backend.model_identifier\n        gen = lambda: llama_backend.generate_audio_response(\n            text = text,\n            audio_type = llama_backend._audio_type,\n            temperature = payload.temperature,\n            top_p = payload.top_p,\n            top_k = payload.top_k,\n            min_p = payload.min_p,\n            max_new_tokens = payload.max_tokens or 2048,\n            repetition_penalty = payload.repetition_penalty,\n        )\n    else:\n        backend = get_inference_backend()\n        if not backend.active_model_name:\n            raise HTTPException(status_code = 400, detail = \"No model loaded.\")\n        model_info = backend.models.get(backend.active_model_name, {})\n        if not model_info.get(\"is_audio\"):\n            raise HTTPException(\n                status_code = 400, detail = \"Active model is not an audio model.\"\n            )\n        model_name = backend.active_model_name\n        gen = lambda: backend.generate_audio_response(\n            text = text,\n            temperature = payload.temperature,\n            top_p = payload.top_p,\n            top_k = payload.top_k,\n            min_p = payload.min_p,\n            max_new_tokens = payload.max_tokens or 2048,\n            repetition_penalty = payload.repetition_penalty,\n            use_adapter = payload.use_adapter,\n        )\n\n    try:\n        wav_bytes, sample_rate = await asyncio.get_event_loop().run_in_executor(\n            None, gen\n        )\n    except Exception as e:\n        logger.error(f\"Audio generation error: {e}\", exc_info = True)\n        raise HTTPException(status_code = 500, detail = str(e))\n\n    audio_b64 = base64.b64encode(wav_bytes).decode(\"ascii\")\n    return JSONResponse(\n        content = {\n            \"id\": f\"chatcmpl-{uuid.uuid4().hex[:12]}\",\n            \"object\": \"chat.completion.audio\",\n            \"model\": model_name,\n            \"audio\": {\"data\": audio_b64, \"format\": \"wav\", \"sample_rate\": sample_rate},\n            \"choices\": [\n                {\n                    \"index\": 0,\n                    \"message\": {\n                        \"role\": \"assistant\",\n                        \"content\": f'[Generated audio from: \"{text[:100]}\"]',\n                    },\n                    \"finish_reason\": \"stop\",\n                }\n            ],\n        }\n    )\n\n\n# =====================================================================\n# OpenAI-Compatible Chat Completions  (/chat/completions)\n# =====================================================================\n\n\ndef _decode_audio_base64(b64: str) -> np.ndarray:\n    \"\"\"Decode base64 audio (any format) → float32 numpy array at 16kHz.\"\"\"\n    import torch\n    import torchaudio\n    import tempfile\n    import os\n    from utils.paths import ensure_dir, tmp_root\n\n    raw = base64.b64decode(b64)\n    # torchaudio.load needs a file path or file-like object with format hint\n    # Write to a temp file so torchaudio can auto-detect the format\n    with tempfile.NamedTemporaryFile(\n        suffix = \".audio\",\n        delete = False,\n        dir = str(ensure_dir(tmp_root())),\n    ) as tmp:\n        tmp.write(raw)\n        tmp_path = tmp.name\n    try:\n        waveform, sr = torchaudio.load(tmp_path)\n    finally:\n        os.unlink(tmp_path)\n\n    # Convert to mono if stereo\n    if waveform.shape[0] > 1:\n        waveform = waveform.mean(dim = 0, keepdim = True)\n\n    # Resample to 16kHz if needed\n    if sr != 16000:\n        resampler = torchaudio.transforms.Resample(orig_freq = sr, new_freq = 16000)\n        waveform = resampler(waveform)\n\n    return waveform.squeeze(0).numpy()\n\n\ndef _extract_content_parts(\n    messages: list,\n) -> tuple[str, list[dict], \"Optional[str]\"]:\n    \"\"\"\n    Parse OpenAI-format messages into components the inference backend expects.\n\n    Handles both plain-string ``content`` and multimodal content-part arrays\n    (``[{type: \"text\", ...}, {type: \"image_url\", ...}]``).\n\n    Returns:\n        system_prompt:  The system message text (empty string if none provided).\n        chat_messages:  Non-system messages with content flattened to strings.\n        image_base64:   Base64 data of the *first* image found, or ``None``.\n    \"\"\"\n    system_prompt = \"\"\n    chat_messages: list[dict] = []\n    first_image_b64: Optional[str] = None\n\n    for msg in messages:\n        # ── System messages → extract as system_prompt ────────\n        if msg.role == \"system\":\n            if isinstance(msg.content, str):\n                system_prompt = msg.content\n            elif isinstance(msg.content, list):\n                # Unlikely but handle: join text parts\n                system_prompt = \"\\n\".join(\n                    p.text for p in msg.content if p.type == \"text\"\n                )\n            continue\n\n        # ── User / assistant messages ─────────────────────────\n        if isinstance(msg.content, str):\n            # Plain string content — pass through\n            chat_messages.append({\"role\": msg.role, \"content\": msg.content})\n        elif isinstance(msg.content, list):\n            # Multimodal content parts\n            text_parts: list[str] = []\n            for part in msg.content:\n                if part.type == \"text\":\n                    text_parts.append(part.text)\n                elif part.type == \"image_url\" and first_image_b64 is None:\n                    url = part.image_url.url\n                    if url.startswith(\"data:\"):\n                        # data:image/png;base64,<DATA> → extract <DATA>\n                        first_image_b64 = url.split(\",\", 1)[1] if \",\" in url else None\n                    else:\n                        logger.warning(\n                            f\"Remote image URLs not yet supported: {url[:80]}...\"\n                        )\n            combined_text = \"\\n\".join(text_parts) if text_parts else \"\"\n            chat_messages.append({\"role\": msg.role, \"content\": combined_text})\n\n    return system_prompt, chat_messages, first_image_b64\n\n\n@router.post(\"/chat/completions\")\nasync def openai_chat_completions(\n    payload: ChatCompletionRequest,\n    request: Request,\n    current_subject: str = Depends(get_current_subject),\n):\n    \"\"\"\n    OpenAI-compatible chat completions endpoint.\n\n    Supports multimodal messages: ``content`` may be a plain string or a\n    list of content parts (``text`` / ``image_url``).\n\n    Streaming (default):  returns SSE chunks matching OpenAI's format.\n    Non-streaming:        returns a single ChatCompletion JSON object.\n\n    Automatically routes to the correct backend:\n    - GGUF models → llama-server via LlamaCppBackend\n    - Other models → Unsloth/transformers via InferenceBackend\n    \"\"\"\n    llama_backend = get_llama_cpp_backend()\n    using_gguf = llama_backend.is_loaded\n\n    # ── Determine which backend is active ─────────────────────\n    if using_gguf:\n        model_name = llama_backend.model_identifier or payload.model\n        if getattr(llama_backend, \"_is_audio\", False):\n            return await generate_audio(payload, request)\n    else:\n        backend = get_inference_backend()\n        if not backend.active_model_name:\n            raise HTTPException(\n                status_code = 400,\n                detail = \"No model loaded. Call POST /inference/load first.\",\n            )\n        model_name = backend.active_model_name or payload.model\n\n        # ── Audio TTS path: auto-route to audio generation ────\n        # (Whisper is ASR not TTS — handled below in audio input path)\n        model_info = backend.models.get(backend.active_model_name, {})\n        if model_info.get(\"is_audio\") and model_info.get(\"audio_type\") != \"whisper\":\n            return await generate_audio(payload, request)\n\n        # ── Whisper without audio: return clear error ──\n        if model_info.get(\"audio_type\") == \"whisper\" and not payload.audio_base64:\n            raise HTTPException(\n                status_code = 400,\n                detail = \"Whisper models require audio input. Please upload an audio file.\",\n            )\n\n        # ── Audio INPUT path: decode WAV and route to audio input generation ──\n        if payload.audio_base64 and model_info.get(\"has_audio_input\"):\n            audio_array = _decode_audio_base64(payload.audio_base64)\n            system_prompt, chat_messages, _ = _extract_content_parts(payload.messages)\n            cancel_event = threading.Event()\n            completion_id = f\"chatcmpl-{uuid.uuid4().hex[:12]}\"\n            created = int(time.time())\n\n            def audio_input_generate():\n                if model_info.get(\"audio_type\") == \"whisper\":\n                    return backend.generate_whisper_response(\n                        audio_array = audio_array,\n                        cancel_event = cancel_event,\n                    )\n                return backend.generate_audio_input_response(\n                    messages = chat_messages,\n                    system_prompt = system_prompt,\n                    audio_array = audio_array,\n                    temperature = payload.temperature,\n                    top_p = payload.top_p,\n                    top_k = payload.top_k,\n                    min_p = payload.min_p,\n                    max_new_tokens = payload.max_tokens or 2048,\n                    repetition_penalty = payload.repetition_penalty,\n                    cancel_event = cancel_event,\n                )\n\n            if payload.stream:\n\n                async def audio_input_stream():\n                    try:\n                        first_chunk = ChatCompletionChunk(\n                            id = completion_id,\n                            created = created,\n                            model = model_name,\n                            choices = [\n                                ChunkChoice(\n                                    delta = ChoiceDelta(role = \"assistant\"),\n                                    finish_reason = None,\n                                )\n                            ],\n                        )\n                        yield f\"data: {first_chunk.model_dump_json(exclude_none = True)}\\n\\n\"\n\n                        for chunk_text in audio_input_generate():\n                            if await request.is_disconnected():\n                                cancel_event.set()\n                                return\n                            if chunk_text:\n                                chunk = ChatCompletionChunk(\n                                    id = completion_id,\n                                    created = created,\n                                    model = model_name,\n                                    choices = [\n                                        ChunkChoice(\n                                            delta = ChoiceDelta(content = chunk_text),\n                                            finish_reason = None,\n                                        )\n                                    ],\n                                )\n                                yield f\"data: {chunk.model_dump_json(exclude_none = True)}\\n\\n\"\n\n                        final_chunk = ChatCompletionChunk(\n                            id = completion_id,\n                            created = created,\n                            model = model_name,\n                            choices = [\n                                ChunkChoice(delta = ChoiceDelta(), finish_reason = \"stop\")\n                            ],\n                        )\n                        yield f\"data: {final_chunk.model_dump_json(exclude_none = True)}\\n\\n\"\n                        yield \"data: [DONE]\\n\\n\"\n                    except asyncio.CancelledError:\n                        cancel_event.set()\n                        raise\n                    except Exception as e:\n                        logger.error(\n                            f\"Error during audio input streaming: {e}\", exc_info = True\n                        )\n                        yield f\"data: {json.dumps({'error': {'message': 'An internal error occurred', 'type': 'server_error'}})}\\n\\n\"\n\n                return StreamingResponse(\n                    audio_input_stream(),\n                    media_type = \"text/event-stream\",\n                    headers = {\n                        \"Cache-Control\": \"no-cache\",\n                        \"Connection\": \"keep-alive\",\n                        \"X-Accel-Buffering\": \"no\",\n                    },\n                )\n            else:\n                full_text = \"\".join(audio_input_generate())\n                response = ChatCompletion(\n                    id = completion_id,\n                    created = created,\n                    model = model_name,\n                    choices = [\n                        CompletionChoice(\n                            message = CompletionMessage(content = full_text),\n                            finish_reason = \"stop\",\n                        )\n                    ],\n                )\n                return JSONResponse(content = response.model_dump())\n\n    # ── Parse messages (handles multimodal content parts) ─────\n    system_prompt, chat_messages, extracted_image_b64 = _extract_content_parts(\n        payload.messages\n    )\n\n    if not chat_messages:\n        raise HTTPException(\n            status_code = 400,\n            detail = \"At least one non-system message is required.\",\n        )\n\n    # ── GGUF path: proxy to llama-server /v1/chat/completions ──\n    if using_gguf:\n        # Reject images if this GGUF model doesn't support vision\n        image_b64 = extracted_image_b64 or payload.image_base64\n        if image_b64 and not llama_backend.is_vision:\n            raise HTTPException(\n                status_code = 400,\n                detail = \"Image provided but current GGUF model does not support vision.\",\n            )\n\n        # Convert image to PNG for llama-server (stb_image has limited format support)\n        if image_b64:\n            try:\n                import base64 as _b64\n                from io import BytesIO as _BytesIO\n                from PIL import Image as _Image\n\n                raw = _b64.b64decode(image_b64)\n                img = _Image.open(_BytesIO(raw))\n                if img.mode == \"RGBA\":\n                    img = img.convert(\"RGB\")\n                buf = _BytesIO()\n                img.save(buf, format = \"PNG\")\n                image_b64 = _b64.b64encode(buf.getvalue()).decode(\"ascii\")\n            except Exception as e:\n                raise HTTPException(\n                    status_code = 400, detail = f\"Failed to process image: {e}\"\n                )\n\n        # Build message list with system prompt prepended\n        gguf_messages = []\n        if system_prompt:\n            gguf_messages.append({\"role\": \"system\", \"content\": system_prompt})\n        gguf_messages.extend(chat_messages)\n\n        cancel_event = threading.Event()\n\n        completion_id = f\"chatcmpl-{uuid.uuid4().hex[:12]}\"\n        created = int(time.time())\n\n        # ── Tool-calling path (agentic loop) ──────────────────\n        use_tools = (\n            payload.enable_tools and llama_backend.supports_tools and not image_b64\n        )\n\n        if use_tools:\n            from core.inference.tools import ALL_TOOLS\n\n            if payload.enabled_tools is not None:\n                tools_to_use = [\n                    t\n                    for t in ALL_TOOLS\n                    if t[\"function\"][\"name\"] in payload.enabled_tools\n                ]\n            else:\n                tools_to_use = ALL_TOOLS\n\n            def gguf_generate_with_tools():\n                return llama_backend.generate_chat_completion_with_tools(\n                    messages = gguf_messages,\n                    tools = tools_to_use,\n                    temperature = payload.temperature,\n                    top_p = payload.top_p,\n                    top_k = payload.top_k,\n                    min_p = payload.min_p,\n                    max_tokens = payload.max_tokens,\n                    repetition_penalty = payload.repetition_penalty,\n                    presence_penalty = payload.presence_penalty,\n                    cancel_event = cancel_event,\n                    enable_thinking = payload.enable_thinking,\n                    auto_heal_tool_calls = payload.auto_heal_tool_calls\n                    if payload.auto_heal_tool_calls is not None\n                    else True,\n                    max_tool_iterations = payload.max_tool_calls_per_message\n                    if payload.max_tool_calls_per_message is not None\n                    else 10,\n                    tool_call_timeout = payload.tool_call_timeout\n                    if payload.tool_call_timeout is not None\n                    else 300,\n                    session_id = payload.session_id,\n                )\n\n            _tool_sentinel = object()\n\n            async def gguf_tool_stream():\n                try:\n                    first_chunk = ChatCompletionChunk(\n                        id = completion_id,\n                        created = created,\n                        model = model_name,\n                        choices = [\n                            ChunkChoice(\n                                delta = ChoiceDelta(role = \"assistant\"),\n                                finish_reason = None,\n                            )\n                        ],\n                    )\n                    yield f\"data: {first_chunk.model_dump_json(exclude_none = True)}\\n\\n\"\n\n                    # Iterate the synchronous generator in a thread so\n                    # the event loop stays free for disconnect detection.\n                    gen = gguf_generate_with_tools()\n                    prev_text = \"\"\n                    while True:\n                        if await request.is_disconnected():\n                            cancel_event.set()\n                            return\n\n                        event = await asyncio.to_thread(next, gen, _tool_sentinel)\n                        if event is _tool_sentinel:\n                            break\n\n                        if event[\"type\"] == \"status\":\n                            # Emit tool status as a custom SSE event\n                            status_data = json.dumps(\n                                {\n                                    \"type\": \"tool_status\",\n                                    \"content\": event[\"text\"],\n                                }\n                            )\n                            yield f\"data: {status_data}\\n\\n\"\n                            continue\n\n                        if event[\"type\"] in (\"tool_start\", \"tool_end\"):\n                            yield f\"data: {json.dumps(event)}\\n\\n\"\n                            continue\n\n                        # \"content\" type -- cumulative text\n                        cumulative = event.get(\"text\", \"\")\n                        new_text = cumulative[len(prev_text) :]\n                        prev_text = cumulative\n                        if not new_text:\n                            continue\n                        chunk = ChatCompletionChunk(\n                            id = completion_id,\n                            created = created,\n                            model = model_name,\n                            choices = [\n                                ChunkChoice(\n                                    delta = ChoiceDelta(content = new_text),\n                                    finish_reason = None,\n                                )\n                            ],\n                        )\n                        yield f\"data: {chunk.model_dump_json(exclude_none = True)}\\n\\n\"\n\n                    final_chunk = ChatCompletionChunk(\n                        id = completion_id,\n                        created = created,\n                        model = model_name,\n                        choices = [\n                            ChunkChoice(\n                                delta = ChoiceDelta(),\n                                finish_reason = \"stop\",\n                            )\n                        ],\n                    )\n                    yield f\"data: {final_chunk.model_dump_json(exclude_none = True)}\\n\\n\"\n                    yield \"data: [DONE]\\n\\n\"\n\n                except asyncio.CancelledError:\n                    cancel_event.set()\n                    raise\n                except Exception as e:\n                    import traceback\n\n                    tb = traceback.format_exc()\n                    logger.error(f\"Error during GGUF tool streaming: {e}\\n{tb}\")\n                    error_chunk = {\n                        \"error\": {\n                            \"message\": \"An internal error occurred\",\n                            \"type\": \"server_error\",\n                        },\n                    }\n                    yield f\"data: {json.dumps(error_chunk)}\\n\\n\"\n\n            return StreamingResponse(\n                gguf_tool_stream(),\n                media_type = \"text/event-stream\",\n                headers = {\n                    \"Cache-Control\": \"no-cache\",\n                    \"Connection\": \"keep-alive\",\n                    \"X-Accel-Buffering\": \"no\",\n                },\n            )\n\n        # ── Standard GGUF path (no tools) ─────────────────────\n\n        def gguf_generate():\n            return llama_backend.generate_chat_completion(\n                messages = gguf_messages,\n                image_b64 = image_b64,\n                temperature = payload.temperature,\n                top_p = payload.top_p,\n                top_k = payload.top_k,\n                min_p = payload.min_p,\n                max_tokens = payload.max_tokens,\n                repetition_penalty = payload.repetition_penalty,\n                presence_penalty = payload.presence_penalty,\n                cancel_event = cancel_event,\n                enable_thinking = payload.enable_thinking,\n            )\n\n        _gguf_sentinel = object()\n\n        if payload.stream:\n\n            async def gguf_stream_chunks():\n                try:\n                    # First chunk: role\n                    first_chunk = ChatCompletionChunk(\n                        id = completion_id,\n                        created = created,\n                        model = model_name,\n                        choices = [\n                            ChunkChoice(\n                                delta = ChoiceDelta(role = \"assistant\"),\n                                finish_reason = None,\n                            )\n                        ],\n                    )\n                    yield f\"data: {first_chunk.model_dump_json(exclude_none = True)}\\n\\n\"\n\n                    # Iterate the synchronous generator in a thread so\n                    # the event loop stays free for disconnect detection.\n                    gen = gguf_generate()\n                    prev_text = \"\"\n                    while True:\n                        if await request.is_disconnected():\n                            cancel_event.set()\n                            return\n                        cumulative = await asyncio.to_thread(next, gen, _gguf_sentinel)\n                        if cumulative is _gguf_sentinel:\n                            break\n                        new_text = cumulative[len(prev_text) :]\n                        prev_text = cumulative\n                        if not new_text:\n                            continue\n                        chunk = ChatCompletionChunk(\n                            id = completion_id,\n                            created = created,\n                            model = model_name,\n                            choices = [\n                                ChunkChoice(\n                                    delta = ChoiceDelta(content = new_text),\n                                    finish_reason = None,\n                                )\n                            ],\n                        )\n                        yield f\"data: {chunk.model_dump_json(exclude_none = True)}\\n\\n\"\n\n                    # Final chunk\n                    final_chunk = ChatCompletionChunk(\n                        id = completion_id,\n                        created = created,\n                        model = model_name,\n                        choices = [\n                            ChunkChoice(\n                                delta = ChoiceDelta(),\n                                finish_reason = \"stop\",\n                            )\n                        ],\n                    )\n                    yield f\"data: {final_chunk.model_dump_json(exclude_none = True)}\\n\\n\"\n                    yield \"data: [DONE]\\n\\n\"\n\n                except asyncio.CancelledError:\n                    cancel_event.set()\n                    raise\n                except Exception as e:\n                    logger.error(f\"Error during GGUF streaming: {e}\", exc_info = True)\n                    error_chunk = {\n                        \"error\": {\n                            \"message\": \"An internal error occurred\",\n                            \"type\": \"server_error\",\n                        },\n                    }\n                    yield f\"data: {json.dumps(error_chunk)}\\n\\n\"\n\n            return StreamingResponse(\n                gguf_stream_chunks(),\n                media_type = \"text/event-stream\",\n                headers = {\n                    \"Cache-Control\": \"no-cache\",\n                    \"Connection\": \"keep-alive\",\n                    \"X-Accel-Buffering\": \"no\",\n                },\n            )\n        else:\n            try:\n                full_text = \"\"\n                for token in gguf_generate():\n                    full_text = token\n\n                response = ChatCompletion(\n                    id = completion_id,\n                    created = created,\n                    model = model_name,\n                    choices = [\n                        CompletionChoice(\n                            message = CompletionMessage(content = full_text),\n                            finish_reason = \"stop\",\n                        )\n                    ],\n                )\n                return JSONResponse(content = response.model_dump())\n\n            except Exception as e:\n                logger.error(f\"Error during GGUF completion: {e}\", exc_info = True)\n                raise HTTPException(status_code = 500, detail = str(e))\n\n    # ── Standard Unsloth path ─────────────────────────────────\n\n    # Decode image (from content parts OR legacy field)\n    image_b64 = extracted_image_b64 or payload.image_base64\n    image = None\n\n    if image_b64:\n        try:\n            import base64\n            from PIL import Image\n            from io import BytesIO\n\n            model_info = backend.models.get(backend.active_model_name, {})\n            if not model_info.get(\"is_vision\"):\n                raise HTTPException(\n                    status_code = 400,\n                    detail = \"Image provided but current model is text-only. Load a vision model.\",\n                )\n\n            image_data = base64.b64decode(image_b64)\n            image = Image.open(BytesIO(image_data))\n            image = backend.resize_image(image)\n\n        except HTTPException:\n            raise\n        except Exception as e:\n            raise HTTPException(status_code = 400, detail = f\"Failed to decode image: {e}\")\n\n    # Shared generation kwargs\n    gen_kwargs = dict(\n        messages = chat_messages,\n        system_prompt = system_prompt,\n        image = image,\n        temperature = payload.temperature,\n        top_p = payload.top_p,\n        top_k = payload.top_k,\n        min_p = payload.min_p,\n        max_new_tokens = payload.max_tokens or 2048,\n        repetition_penalty = payload.repetition_penalty,\n    )\n\n    # Choose generation path (adapter-controlled or standard)\n    cancel_event = threading.Event()\n\n    if payload.use_adapter is not None:\n\n        def generate():\n            return backend.generate_with_adapter_control(\n                use_adapter = payload.use_adapter,\n                cancel_event = cancel_event,\n                **gen_kwargs,\n            )\n    else:\n\n        def generate():\n            return backend.generate_chat_response(\n                cancel_event = cancel_event, **gen_kwargs\n            )\n\n    completion_id = f\"chatcmpl-{uuid.uuid4().hex[:12]}\"\n    created = int(time.time())\n\n    # ── Streaming response ────────────────────────────────────────\n    if payload.stream:\n\n        async def stream_chunks():\n            try:\n                first_chunk = ChatCompletionChunk(\n                    id = completion_id,\n                    created = created,\n                    model = model_name,\n                    choices = [\n                        ChunkChoice(\n                            delta = ChoiceDelta(role = \"assistant\"),\n                            finish_reason = None,\n                        )\n                    ],\n                )\n                yield f\"data: {first_chunk.model_dump_json(exclude_none = True)}\\n\\n\"\n\n                prev_text = \"\"\n                # Run sync generator in thread pool to avoid blocking\n                # the event loop. Critical for compare mode: two SSE\n                # requests arrive concurrently but the orchestrator\n                # serializes them via _gen_lock. Without run_in_executor\n                # the second request's blocking lock acquisition would\n                # freeze the entire event loop, stalling both streams.\n                _DONE = object()  # sentinel for generator exhaustion\n                loop = asyncio.get_event_loop()\n                gen = generate()\n                while True:\n                    # next(gen, _DONE) returns _DONE instead of raising\n                    # StopIteration — StopIteration cannot propagate\n                    # through asyncio futures (Python limitation).\n                    cumulative = await loop.run_in_executor(None, next, gen, _DONE)\n                    if cumulative is _DONE:\n                        break\n                    if await request.is_disconnected():\n                        cancel_event.set()\n                        backend.reset_generation_state()\n                        return\n                    new_text = cumulative[len(prev_text) :]\n                    prev_text = cumulative\n                    if not new_text:\n                        continue\n                    chunk = ChatCompletionChunk(\n                        id = completion_id,\n                        created = created,\n                        model = model_name,\n                        choices = [\n                            ChunkChoice(\n                                delta = ChoiceDelta(content = new_text),\n                                finish_reason = None,\n                            )\n                        ],\n                    )\n                    yield f\"data: {chunk.model_dump_json(exclude_none = True)}\\n\\n\"\n\n                final_chunk = ChatCompletionChunk(\n                    id = completion_id,\n                    created = created,\n                    model = model_name,\n                    choices = [\n                        ChunkChoice(\n                            delta = ChoiceDelta(),\n                            finish_reason = \"stop\",\n                        )\n                    ],\n                )\n                yield f\"data: {final_chunk.model_dump_json(exclude_none = True)}\\n\\n\"\n                yield \"data: [DONE]\\n\\n\"\n\n            except asyncio.CancelledError:\n                cancel_event.set()\n                backend.reset_generation_state()\n                raise\n            except Exception as e:\n                backend.reset_generation_state()\n                logger.error(f\"Error during OpenAI streaming: {e}\", exc_info = True)\n                error_chunk = {\n                    \"error\": {\n                        \"message\": \"An internal error occurred\",\n                        \"type\": \"server_error\",\n                    },\n                }\n                yield f\"data: {json.dumps(error_chunk)}\\n\\n\"\n\n        return StreamingResponse(\n            stream_chunks(),\n            media_type = \"text/event-stream\",\n            headers = {\n                \"Cache-Control\": \"no-cache\",\n                \"Connection\": \"keep-alive\",\n                \"X-Accel-Buffering\": \"no\",\n            },\n        )\n\n    # ── Non-streaming response ────────────────────────────────────\n    else:\n        try:\n            full_text = \"\"\n            for token in generate():\n                full_text = token\n\n            response = ChatCompletion(\n                id = completion_id,\n                created = created,\n                model = model_name,\n                choices = [\n                    CompletionChoice(\n                        message = CompletionMessage(content = full_text),\n                        finish_reason = \"stop\",\n                    )\n                ],\n            )\n            return JSONResponse(content = response.model_dump())\n\n        except Exception as e:\n            backend.reset_generation_state()\n            logger.error(f\"Error during OpenAI completion: {e}\", exc_info = True)\n            raise HTTPException(status_code = 500, detail = str(e))\n\n\n# =====================================================================\n# OpenAI-Compatible Models Listing  (/models → /v1/models)\n# =====================================================================\n\n\n@router.get(\"/models\")\nasync def openai_list_models(\n    current_subject: str = Depends(get_current_subject),\n):\n    \"\"\"\n    OpenAI-compatible model listing endpoint.\n\n    Returns the currently loaded model in the format expected by\n    OpenAI-compatible clients (``GET /v1/models``).\n    \"\"\"\n    models = []\n\n    # Check GGUF backend\n    llama_backend = get_llama_cpp_backend()\n    if llama_backend.is_loaded:\n        models.append(\n            {\n                \"id\": llama_backend.model_identifier,\n                \"object\": \"model\",\n                \"owned_by\": \"local\",\n            }\n        )\n\n    # Check Unsloth backend\n    backend = get_inference_backend()\n    if backend.active_model_name:\n        models.append(\n            {\n                \"id\": backend.active_model_name,\n                \"object\": \"model\",\n                \"owned_by\": \"local\",\n            }\n        )\n\n    return {\"object\": \"list\", \"data\": models}\n"
  },
  {
    "path": "studio/backend/routes/models.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nModel Management API routes\n\"\"\"\n\nimport os\nimport sys\nfrom pathlib import Path\nfrom fastapi import APIRouter, Body, Depends, HTTPException, Query\nfrom typing import List, Optional\nimport structlog\nfrom loggers import get_logger\n\nimport re as _re\n\n_VALID_REPO_ID = _re.compile(r\"^[A-Za-z0-9._-]+/[A-Za-z0-9._-]+$\")\n\n\ndef _is_valid_repo_id(repo_id: str) -> bool:\n    return bool(_VALID_REPO_ID.fullmatch(repo_id))\n\n\n# Add backend directory to path\nbackend_path = Path(__file__).parent.parent.parent\nif str(backend_path) not in sys.path:\n    sys.path.insert(0, str(backend_path))\n\nfrom auth.authentication import get_current_subject\n\n# Import backend functions\ntry:\n    from utils.models import (\n        scan_trained_loras,\n        scan_exported_models,\n        load_model_defaults,\n        get_base_model_from_lora,\n        is_vision_model,\n        is_embedding_model,\n        scan_checkpoints,\n        list_gguf_variants,\n        ModelConfig,\n    )\n    from utils.models.model_config import (\n        _pick_best_gguf,\n        _extract_quant_label,\n        is_audio_input_type,\n    )\n    from core.inference import get_inference_backend\n    from utils.paths import (\n        outputs_root,\n        exports_root,\n        resolve_output_dir,\n        resolve_export_dir,\n    )\nexcept ImportError:\n    # Fallback: try to import from parent directory\n    parent_backend = backend_path.parent / \"backend\"\n    if str(parent_backend) not in sys.path:\n        sys.path.insert(0, str(parent_backend))\n    from utils.models import (\n        scan_trained_loras,\n        scan_exported_models,\n        load_model_defaults,\n        get_base_model_from_lora,\n        is_vision_model,\n        is_embedding_model,\n        scan_checkpoints,\n        list_gguf_variants,\n        ModelConfig,\n    )\n    from utils.models.model_config import (\n        _pick_best_gguf,\n        _extract_quant_label,\n        is_audio_input_type,\n    )\n    from core.inference import get_inference_backend\n    from utils.paths import (\n        outputs_root,\n        exports_root,\n        resolve_output_dir,\n        resolve_export_dir,\n    )\n\nfrom models import (\n    CheckpointInfo,\n    CheckpointListResponse,\n    LocalModelInfo,\n    LocalModelListResponse,\n    ModelCheckpoints,\n    ModelDetails,\n    LoRAScanResponse,\n    LoRAInfo,\n    ModelListResponse,\n)\nfrom models.models import GgufVariantDetail, GgufVariantsResponse, ModelType\nfrom models.responses import (\n    LoRABaseModelResponse,\n    VisionCheckResponse,\n    EmbeddingCheckResponse,\n)\n\nrouter = APIRouter()\nlogger = get_logger(__name__)\n\n\ndef derive_model_type(\n    is_vision: bool, audio_type: Optional[str], is_embedding: bool = False\n) -> ModelType:\n    \"\"\"Collapse individual capability flags into a single model modality string.\"\"\"\n    if is_embedding:\n        return \"embeddings\"\n    if audio_type is not None:\n        return \"audio\"\n    if is_vision:\n        return \"vision\"\n    return \"text\"\n\n\ndef _resolve_hf_cache_dir() -> Path:\n    \"\"\"Resolve local HF cache root used by hub downloads.\"\"\"\n    try:\n        from huggingface_hub.constants import HF_HUB_CACHE\n\n        return Path(HF_HUB_CACHE)\n    except Exception:\n        return Path.home() / \".cache\" / \"huggingface\" / \"hub\"\n\n\ndef _scan_models_dir(models_dir: Path) -> List[LocalModelInfo]:\n    if not models_dir.exists() or not models_dir.is_dir():\n        return []\n\n    found: List[LocalModelInfo] = []\n    for child in models_dir.iterdir():\n        if not child.is_dir():\n            continue\n        has_model_files = (\n            (child / \"config.json\").exists()\n            or (child / \"adapter_config.json\").exists()\n            or any(child.glob(\"*.safetensors\"))\n            or any(child.glob(\"*.bin\"))\n            or any(child.glob(\"*.gguf\"))\n        )\n        if not has_model_files:\n            continue\n        try:\n            updated_at = child.stat().st_mtime\n        except OSError:\n            updated_at = None\n        found.append(\n            LocalModelInfo(\n                id = str(child),\n                display_name = child.name,\n                path = str(child),\n                source = \"models_dir\",\n                updated_at = updated_at,\n            ),\n        )\n    # Also scan for standalone .gguf files directly in the models directory\n    for gguf_file in models_dir.glob(\"*.gguf\"):\n        if gguf_file.is_file():\n            try:\n                updated_at = gguf_file.stat().st_mtime\n            except OSError:\n                updated_at = None\n            found.append(\n                LocalModelInfo(\n                    id = str(gguf_file),\n                    display_name = gguf_file.stem,\n                    path = str(gguf_file),\n                    source = \"models_dir\",\n                    updated_at = updated_at,\n                ),\n            )\n\n    return found\n\n\ndef _scan_hf_cache(cache_dir: Path) -> List[LocalModelInfo]:\n    if not cache_dir.exists() or not cache_dir.is_dir():\n        return []\n\n    found: List[LocalModelInfo] = []\n    for repo_dir in cache_dir.glob(\"models--*\"):\n        if not repo_dir.is_dir():\n            continue\n\n        repo_name = repo_dir.name[len(\"models--\") :]\n        if not repo_name:\n            continue\n        model_id = repo_name.replace(\"--\", \"/\")\n\n        try:\n            updated_at = repo_dir.stat().st_mtime\n        except OSError:\n            updated_at = None\n\n        found.append(\n            LocalModelInfo(\n                id = model_id,\n                model_id = model_id,\n                display_name = model_id.split(\"/\")[-1],\n                path = str(repo_dir),\n                source = \"hf_cache\",\n                updated_at = updated_at,\n            ),\n        )\n    return found\n\n\n@router.get(\"/local\", response_model = LocalModelListResponse)\nasync def list_local_models(\n    models_dir: str = Query(\n        default = \"./models\", description = \"Directory to scan for local model folders\"\n    ),\n    current_subject: str = Depends(get_current_subject),\n):\n    \"\"\"\n    List local model candidates from custom models dir and HF cache.\n    \"\"\"\n    # Validate models_dir against an allowlist of trusted directories.\n    # Only the trusted Path objects are used for filesystem access -- the\n    # user-supplied string is only used for matching, never for path construction.\n    hf_cache_dir = _resolve_hf_cache_dir()\n    allowed_roots = [Path(\"./models\").resolve(), hf_cache_dir]\n    try:\n        from utils.paths import studio_root, outputs_root\n\n        allowed_roots.extend([studio_root(), outputs_root()])\n    except Exception:\n        pass\n\n    requested = os.path.realpath(os.path.expanduser(models_dir))\n    models_root = None\n    for root in allowed_roots:\n        root_str = os.path.realpath(str(root))\n        if requested == root_str or requested.startswith(root_str + os.sep):\n            models_root = root  # Use the trusted root, not the user-supplied path\n            break\n    if models_root is None:\n        raise HTTPException(\n            status_code = 403,\n            detail = \"Directory not allowed\",\n        )\n\n    try:\n        local_models = _scan_models_dir(models_root) + _scan_hf_cache(hf_cache_dir)\n\n        deduped: dict[str, LocalModelInfo] = {}\n        for model in local_models:\n            if model.id not in deduped:\n                deduped[model.id] = model\n\n        models = sorted(\n            deduped.values(),\n            key = lambda item: (item.updated_at or 0),\n            reverse = True,\n        )\n\n        return LocalModelListResponse(\n            models_dir = str(models_root),\n            hf_cache_dir = str(hf_cache_dir),\n            models = models,\n        )\n    except Exception as e:\n        logger.error(f\"Error listing local models: {e}\", exc_info = True)\n        raise HTTPException(\n            status_code = 500,\n            detail = f\"Failed to list local models: {str(e)}\",\n        )\n\n\n@router.get(\"/list\")\nasync def list_models(\n    current_subject: str = Depends(get_current_subject),\n):\n    \"\"\"\n    List available models (default models and loaded models).\n\n    This endpoint returns the default models and any currently loaded models.\n    \"\"\"\n    try:\n        inference_backend = get_inference_backend()\n\n        # Get default models\n        default_models = inference_backend.default_models\n\n        # Get loaded models\n        loaded_models = []\n        for model_name, model_data in inference_backend.models.items():\n            _is_vision = model_data.get(\"is_vision\", False)\n            _audio_type = model_data.get(\"audio_type\")\n            model_info = ModelDetails(\n                id = model_name,\n                name = model_name.split(\"/\")[-1] if \"/\" in model_name else model_name,\n                is_vision = _is_vision,\n                is_lora = model_data.get(\"is_lora\", False),\n                is_audio = model_data.get(\"is_audio\", False),\n                audio_type = _audio_type,\n                has_audio_input = model_data.get(\"has_audio_input\", False),\n                model_type = derive_model_type(_is_vision, _audio_type),\n            )\n            loaded_models.append(model_info)\n\n        # Include active GGUF model (loaded via llama-server)\n        from routes.inference import get_llama_cpp_backend\n\n        llama_backend = get_llama_cpp_backend()\n        if llama_backend.is_loaded and llama_backend.model_identifier:\n            loaded_models.append(\n                ModelDetails(\n                    id = llama_backend.model_identifier,\n                    name = llama_backend.model_identifier.split(\"/\")[-1],\n                    is_gguf = True,\n                    is_vision = llama_backend.is_vision,\n                    is_audio = getattr(llama_backend, \"_is_audio\", False),\n                    audio_type = getattr(llama_backend, \"_audio_type\", None),\n                )\n            )\n\n        # Combine default and loaded models\n        all_models = []\n        seen_ids = set()\n\n        # Add default models\n        for model_id in default_models:\n            if model_id not in seen_ids:\n                model_info = ModelDetails(\n                    id = model_id,\n                    name = model_id.split(\"/\")[-1] if \"/\" in model_id else model_id,\n                    is_gguf = model_id.upper().endswith(\"-GGUF\"),\n                )\n                all_models.append(model_info)\n                seen_ids.add(model_id)\n\n        # Add loaded models\n        for model_info in loaded_models:\n            if model_info.id not in seen_ids:\n                all_models.append(model_info)\n                seen_ids.add(model_info.id)\n\n        return ModelListResponse(models = all_models, default_models = default_models)\n\n    except Exception as e:\n        logger.error(f\"Error listing models: {e}\", exc_info = True)\n        raise HTTPException(status_code = 500, detail = f\"Failed to list models: {str(e)}\")\n\n\ndef _get_max_position_embeddings(config) -> Optional[int]:\n    \"\"\"Extract max_position_embeddings from a model config, checking text_config fallback.\"\"\"\n    if hasattr(config, \"max_position_embeddings\"):\n        return config.max_position_embeddings\n    if hasattr(config, \"text_config\") and hasattr(\n        config.text_config, \"max_position_embeddings\"\n    ):\n        return config.text_config.max_position_embeddings\n    return None\n\n\ndef _get_model_size_bytes(\n    model_name: str, hf_token: Optional[str] = None\n) -> Optional[int]:\n    \"\"\"Get total size of model weight files from HF Hub.\"\"\"\n    try:\n        from huggingface_hub import HfApi\n\n        api = HfApi(token = hf_token)\n        info = api.repo_info(model_name, repo_type = \"model\", token = hf_token)\n        if not info.siblings:\n            return None\n\n        weight_exts = (\".safetensors\", \".bin\", \".pt\", \".pth\", \".gguf\")\n        total = 0\n        for sibling in info.siblings:\n            if sibling.rfilename and any(\n                sibling.rfilename.endswith(ext) for ext in weight_exts\n            ):\n                if sibling.size is not None:\n                    total += sibling.size\n\n        return total if total > 0 else None\n    except Exception as e:\n        logger.warning(f\"Could not get model size for {model_name}: {e}\")\n        return None\n\n\n@router.get(\"/config/{model_name:path}\")\nasync def get_model_config(\n    model_name: str,\n    hf_token: Optional[str] = Query(None),\n    current_subject: str = Depends(get_current_subject),\n):\n    \"\"\"\n    Get configuration for a specific model.\n\n    This endpoint wraps the backend load_model_defaults function.\n    \"\"\"\n    try:\n        from utils.models.model_config import is_local_path\n\n        if not is_local_path(model_name):\n            model_name = model_name.lower()\n\n        logger.info(f\"Getting model config for: {model_name}\")\n        from utils.models.model_config import detect_audio_type\n\n        # Load model defaults from backend\n        config_dict = load_model_defaults(model_name)\n\n        # Detect model capabilities (pass HF token for gated models)\n        is_vision = is_vision_model(model_name)\n        is_embedding = is_embedding_model(model_name, hf_token = hf_token)\n        audio_type = detect_audio_type(model_name, hf_token = hf_token)\n\n        # Check if it's a LoRA adapter\n        is_lora = False\n        base_model = None\n        max_position_embeddings = None\n        try:\n            model_config = ModelConfig.from_identifier(model_name)\n            is_lora = model_config.is_lora\n            base_model = model_config.base_model if is_lora else None\n            max_position_embeddings = _get_max_position_embeddings(model_config)\n        except Exception:\n            pass\n\n        # Fallback: try AutoConfig directly if not found yet\n        if max_position_embeddings is None:\n            try:\n                from transformers import AutoConfig as _AutoConfig\n\n                _trust = model_name.lower().startswith(\"unsloth/\")\n                _ac = _AutoConfig.from_pretrained(\n                    model_name, trust_remote_code = _trust, token = hf_token\n                )\n                max_position_embeddings = _get_max_position_embeddings(_ac)\n            except Exception:\n                pass\n\n        logger.info(\n            f\"Model config result for {model_name}: is_vision={is_vision}, is_embedding={is_embedding}, audio_type={audio_type}, is_lora={is_lora}, max_position_embeddings={max_position_embeddings}\"\n        )\n        return ModelDetails(\n            id = model_name,\n            model_name = model_name,\n            config = config_dict,\n            is_vision = is_vision,\n            is_embedding = is_embedding,\n            is_lora = is_lora,\n            is_audio = audio_type is not None,\n            audio_type = audio_type,\n            has_audio_input = is_audio_input_type(audio_type),\n            model_type = derive_model_type(is_vision, audio_type, is_embedding),\n            base_model = base_model,\n            max_position_embeddings = max_position_embeddings,\n            model_size_bytes = _get_model_size_bytes(model_name, hf_token),\n        )\n\n    except Exception as e:\n        logger.error(f\"Error getting model config: {e}\", exc_info = True)\n        raise HTTPException(\n            status_code = 500, detail = f\"Failed to get model config: {str(e)}\"\n        )\n\n\n@router.get(\"/loras\")\nasync def scan_loras(\n    outputs_dir: str = Query(\n        default = str(outputs_root()), description = \"Directory to scan for LoRA adapters\"\n    ),\n    exports_dir: str = Query(\n        default = str(exports_root()), description = \"Directory to scan for exported models\"\n    ),\n    current_subject: str = Depends(get_current_subject),\n):\n    \"\"\"\n    Scan for trained LoRA adapters and exported models.\n\n    Returns both training outputs (from outputs_dir) and exported models\n    (from exports_dir) in a single list, distinguished by source field.\n    \"\"\"\n    try:\n        resolved_outputs_dir = str(resolve_output_dir(outputs_dir))\n        resolved_exports_dir = str(resolve_export_dir(exports_dir))\n        lora_list = []\n\n        # Scan training outputs\n        trained_loras = scan_trained_loras(outputs_dir = resolved_outputs_dir)\n        for display_name, adapter_path in trained_loras:\n            base_model = get_base_model_from_lora(adapter_path)\n            lora_list.append(\n                LoRAInfo(\n                    display_name = display_name,\n                    adapter_path = adapter_path,\n                    base_model = base_model,\n                    source = \"training\",\n                )\n            )\n\n        # Scan exported models (merged, LoRA, base — skips GGUF)\n        exported = scan_exported_models(exports_dir = resolved_exports_dir)\n        for display_name, model_path, export_type, base_model in exported:\n            lora_list.append(\n                LoRAInfo(\n                    display_name = display_name,\n                    adapter_path = model_path,\n                    base_model = base_model,\n                    source = \"exported\",\n                    export_type = export_type,\n                )\n            )\n\n        return LoRAScanResponse(loras = lora_list, outputs_dir = resolved_outputs_dir)\n\n    except Exception as e:\n        logger.error(f\"Error scanning LoRAs: {e}\", exc_info = True)\n        raise HTTPException(\n            status_code = 500, detail = f\"Failed to scan LoRA adapters: {str(e)}\"\n        )\n\n\n@router.get(\"/loras/{lora_path:path}/base-model\", response_model = LoRABaseModelResponse)\nasync def get_lora_base_model(\n    lora_path: str,\n    current_subject: str = Depends(get_current_subject),\n):\n    \"\"\"\n    Get the base model for a LoRA adapter.\n\n    This endpoint wraps the backend get_base_model_from_lora function.\n    \"\"\"\n    try:\n        base_model = get_base_model_from_lora(lora_path)\n\n        if base_model is None:\n            raise HTTPException(\n                status_code = 404,\n                detail = f\"Could not determine base model for LoRA: {lora_path}\",\n            )\n\n        return LoRABaseModelResponse(\n            lora_path = lora_path,\n            base_model = base_model,\n        )\n\n    except HTTPException:\n        raise\n    except Exception as e:\n        logger.error(f\"Error getting LoRA base model: {e}\", exc_info = True)\n        raise HTTPException(\n            status_code = 500, detail = f\"Failed to get base model: {str(e)}\"\n        )\n\n\n@router.get(\"/check-vision/{model_name:path}\", response_model = VisionCheckResponse)\nasync def check_vision_model(\n    model_name: str,\n    current_subject: str = Depends(get_current_subject),\n):\n    \"\"\"\n    Check if a model is a vision model.\n\n    This endpoint wraps the backend is_vision_model function.\n    \"\"\"\n    try:\n        logger.info(f\"Checking if vision model: {model_name}\")\n        is_vision = is_vision_model(model_name)\n\n        logger.info(f\"Vision check result for {model_name}: is_vision={is_vision}\")\n        return VisionCheckResponse(\n            model_name = model_name,\n            is_vision = is_vision,\n        )\n\n    except Exception as e:\n        logger.error(f\"Error checking vision model: {e}\", exc_info = True)\n        raise HTTPException(\n            status_code = 500, detail = f\"Failed to check vision model: {str(e)}\"\n        )\n\n\n@router.get(\"/check-embedding/{model_name:path}\", response_model = EmbeddingCheckResponse)\nasync def check_embedding_model(\n    model_name: str,\n    hf_token: Optional[str] = Query(None),\n    current_subject: str = Depends(get_current_subject),\n):\n    \"\"\"\n    Check if a model is an embedding model.\n\n    This endpoint wraps the backend is_embedding_model function.\n    \"\"\"\n    try:\n        logger.info(f\"Checking if embedding model: {model_name}\")\n        is_embedding = is_embedding_model(model_name, hf_token = hf_token)\n\n        logger.info(\n            f\"Embedding check result for {model_name}: is_embedding={is_embedding}\"\n        )\n        return EmbeddingCheckResponse(\n            model_name = model_name,\n            is_embedding = is_embedding,\n        )\n\n    except Exception as e:\n        logger.error(f\"Error checking embedding model: {e}\", exc_info = True)\n        raise HTTPException(\n            status_code = 500, detail = f\"Failed to check embedding model: {str(e)}\"\n        )\n\n\n@router.get(\"/gguf-variants\", response_model = GgufVariantsResponse)\nasync def get_gguf_variants(\n    repo_id: str = Query(\n        ..., description = \"HuggingFace repo ID (e.g. 'unsloth/gemma-3-4b-it-GGUF')\"\n    ),\n    hf_token: Optional[str] = Query(\n        None, description = \"HuggingFace token for private repos\"\n    ),\n    current_subject: str = Depends(get_current_subject),\n):\n    \"\"\"\n    List available GGUF quantization variants for a HuggingFace repo.\n\n    Returns all available quantization variants (Q4_K_M, Q8_0, BF16, etc.)\n    with file sizes, whether the model supports vision, and the recommended\n    default variant.\n    \"\"\"\n    try:\n        variants, has_vision = list_gguf_variants(repo_id, hf_token = hf_token)\n\n        # Determine default variant\n        filenames = [v.filename for v in variants]\n        best = _pick_best_gguf(filenames)\n        default_variant = _extract_quant_label(best) if best else None\n\n        # Check which variants are fully downloaded in the HF cache.\n        # For split GGUFs, ALL shards must be present -- sum cached bytes\n        # per variant and compare against the expected total.\n        # HF cache dir uses the exact case from the repo_id at download time,\n        # which may differ from the canonical HF repo_id, so do a\n        # case-insensitive match.\n        cached_bytes_by_quant: dict[str, int] = {}\n        try:\n            import re as _re\n            from huggingface_hub import constants as hf_constants\n\n            # Sanitize repo_id: must be \"owner/name\" with safe chars only\n            if not _is_valid_repo_id(repo_id):\n                raise ValueError(f\"Invalid repo_id format: {repo_id}\")\n\n            cache_dir = Path(hf_constants.HF_HUB_CACHE)\n            target = f\"models--{repo_id.replace('/', '--')}\".lower()\n            for entry in cache_dir.iterdir():\n                if entry.name.lower() == target:\n                    snapshots = entry / \"snapshots\"\n                    if snapshots.is_dir():\n                        for snap in snapshots.iterdir():\n                            for f in snap.rglob(\"*.gguf\"):\n                                q = _extract_quant_label(f.name)\n                                cached_bytes_by_quant[q] = (\n                                    cached_bytes_by_quant.get(q, 0) + f.stat().st_size\n                                )\n                    break\n        except Exception:\n            pass\n\n        def _is_fully_downloaded(variant) -> bool:\n            cached = cached_bytes_by_quant.get(variant.quant, 0)\n            if cached == 0 or variant.size_bytes == 0:\n                return False\n            # Allow small rounding tolerance (symlinks vs real sizes)\n            return cached >= variant.size_bytes * 0.99\n\n        return GgufVariantsResponse(\n            repo_id = repo_id,\n            variants = [\n                GgufVariantDetail(\n                    filename = v.filename,\n                    quant = v.quant,\n                    size_bytes = v.size_bytes,\n                    downloaded = _is_fully_downloaded(v),\n                )\n                for v in variants\n            ],\n            has_vision = has_vision,\n            default_variant = default_variant,\n        )\n\n    except Exception as e:\n        logger.error(f\"Error listing GGUF variants for '{repo_id}': {e}\", exc_info = True)\n        raise HTTPException(\n            status_code = 500,\n            detail = f\"Failed to list GGUF variants: {str(e)}\",\n        )\n\n\n@router.get(\"/gguf-download-progress\")\nasync def get_gguf_download_progress(\n    repo_id: str = Query(..., description = \"HuggingFace repo ID\"),\n    variant: str = Query(\"\", description = \"Quantization variant (e.g. UD-TQ1_0)\"),\n    expected_bytes: int = Query(0, description = \"Expected total download size in bytes\"),\n    current_subject: str = Depends(get_current_subject),\n):\n    \"\"\"Return download progress by checking cached GGUF files for a specific variant.\n\n    Tracks completed shard downloads in snapshots and in-progress downloads\n    in the blobs directory (incomplete files).\n    \"\"\"\n    try:\n        if not _is_valid_repo_id(repo_id):\n            return {\n                \"downloaded_bytes\": 0,\n                \"expected_bytes\": expected_bytes,\n                \"progress\": 0,\n            }\n\n        from huggingface_hub import constants as hf_constants\n\n        cache_dir = Path(hf_constants.HF_HUB_CACHE)\n        target = f\"models--{repo_id.replace('/', '--')}\".lower()\n        variant_lower = variant.lower().replace(\"-\", \"\").replace(\"_\", \"\")\n        downloaded_bytes = 0\n        in_progress_bytes = 0\n        for entry in cache_dir.iterdir():\n            if entry.name.lower() == target:\n                # Count completed .gguf files matching this variant in snapshots\n                for f in entry.rglob(\"*.gguf\"):\n                    fname = f.name.lower().replace(\"-\", \"\").replace(\"_\", \"\")\n                    if not variant_lower or variant_lower in fname:\n                        downloaded_bytes += f.stat().st_size\n                # Check blobs for in-progress downloads (.incomplete files)\n                blobs_dir = entry / \"blobs\"\n                if blobs_dir.is_dir():\n                    for f in blobs_dir.iterdir():\n                        if f.is_file() and f.name.endswith(\".incomplete\"):\n                            in_progress_bytes += f.stat().st_size\n                break\n\n        total_progress_bytes = downloaded_bytes + in_progress_bytes\n        progress = (\n            min(total_progress_bytes / expected_bytes, 0.99)\n            if expected_bytes > 0\n            else 0\n        )\n        # Only report 1.0 when all bytes are in completed files (not in-progress)\n        if expected_bytes > 0 and downloaded_bytes >= expected_bytes:\n            progress = 1.0\n        return {\n            \"downloaded_bytes\": total_progress_bytes,\n            \"expected_bytes\": expected_bytes,\n            \"progress\": round(progress, 3),\n        }\n    except Exception:\n        return {\"downloaded_bytes\": 0, \"expected_bytes\": expected_bytes, \"progress\": 0}\n\n\n@router.get(\"/download-progress\")\nasync def get_download_progress(\n    repo_id: str = Query(..., description = \"HuggingFace repo ID\"),\n    current_subject: str = Depends(get_current_subject),\n):\n    \"\"\"Return download progress for any HuggingFace model repo.\n\n    Checks the local HF cache for completed blobs and in-progress\n    (.incomplete) downloads. Uses the HF API to determine the expected\n    total size on the first call, then caches it for subsequent polls.\n    \"\"\"\n    _empty = {\"downloaded_bytes\": 0, \"expected_bytes\": 0, \"progress\": 0}\n    try:\n        if not _is_valid_repo_id(repo_id):\n            return _empty\n\n        from huggingface_hub import constants as hf_constants\n\n        cache_dir = Path(hf_constants.HF_HUB_CACHE)\n        target = f\"models--{repo_id.replace('/', '--')}\".lower()\n        completed_bytes = 0\n        in_progress_bytes = 0\n\n        for entry in cache_dir.iterdir():\n            if entry.name.lower() != target:\n                continue\n            blobs_dir = entry / \"blobs\"\n            if not blobs_dir.is_dir():\n                break\n            for f in blobs_dir.iterdir():\n                if not f.is_file():\n                    continue\n                if f.name.endswith(\".incomplete\"):\n                    in_progress_bytes += f.stat().st_size\n                else:\n                    completed_bytes += f.stat().st_size\n            break\n\n        downloaded_bytes = completed_bytes + in_progress_bytes\n        if downloaded_bytes == 0:\n            return _empty\n\n        # Get expected size from HF API (cached per repo_id)\n        expected_bytes = _get_repo_size_cached(repo_id)\n        if expected_bytes <= 0:\n            # Cannot determine total; report bytes only, no percentage\n            return {\n                \"downloaded_bytes\": downloaded_bytes,\n                \"expected_bytes\": 0,\n                \"progress\": 0,\n            }\n\n        # Use 95% threshold for completion (blob deduplication can make\n        # completed_bytes differ slightly from expected_bytes).\n        # Do NOT use \"no .incomplete files\" as a completion signal --\n        # HF downloads files sequentially, so between files there are\n        # no .incomplete files even though the download is far from done.\n        if completed_bytes >= expected_bytes * 0.95:\n            progress = 1.0\n        else:\n            progress = min(downloaded_bytes / expected_bytes, 0.99)\n        return {\n            \"downloaded_bytes\": downloaded_bytes,\n            \"expected_bytes\": expected_bytes,\n            \"progress\": round(progress, 3),\n        }\n    except Exception as e:\n        logger.warning(f\"Error checking download progress for {repo_id}: {e}\")\n        return _empty\n\n\n_repo_size_cache: dict[str, int] = {}\n\n\ndef _get_repo_size_cached(repo_id: str) -> int:\n    if repo_id in _repo_size_cache:\n        return _repo_size_cache[repo_id]\n    try:\n        from huggingface_hub import model_info as hf_model_info\n\n        info = hf_model_info(repo_id, token = None, files_metadata = True)\n        total = sum(s.size for s in info.siblings if s.size)\n        _repo_size_cache[repo_id] = total\n        return total\n    except Exception as e:\n        logger.warning(f\"Failed to get repo size for {repo_id}: {e}\")\n        return 0\n\n\n@router.get(\"/cached-gguf\")\nasync def list_cached_gguf(\n    current_subject: str = Depends(get_current_subject),\n):\n    \"\"\"List GGUF repos that have already been downloaded to the HF cache.\n\n    Uses scan_cache_dir() for proper repo IDs, then deduplicates by\n    lowercased key (HF cache dirs are lowercased but the canonical repo\n    ID preserves casing).\n    \"\"\"\n    try:\n        from huggingface_hub import scan_cache_dir\n\n        hf_cache = scan_cache_dir()\n        seen_lower: dict[str, dict] = {}\n        for repo_info in hf_cache.repos:\n            if repo_info.repo_type != \"model\":\n                continue\n            repo_id = repo_info.repo_id\n            if not repo_id.upper().endswith(\"-GGUF\"):\n                continue\n            # Check for actual .gguf files and sum sizes\n            total_size = 0\n            has_gguf = False\n            for revision in repo_info.revisions:\n                for f in revision.files:\n                    if f.file_name.endswith(\".gguf\"):\n                        has_gguf = True\n                        total_size += f.size_on_disk\n            if not has_gguf:\n                continue\n            # Deduplicate: keep the entry with the most data\n            key = repo_id.lower()\n            existing = seen_lower.get(key)\n            if existing is None or total_size > existing[\"size_bytes\"]:\n                seen_lower[key] = {\n                    \"repo_id\": repo_id,\n                    \"size_bytes\": total_size,\n                    \"cache_path\": str(repo_info.repo_path),\n                }\n        cached = sorted(seen_lower.values(), key = lambda c: c[\"repo_id\"])\n        return {\"cached\": cached}\n    except Exception as e:\n        logger.error(f\"Error listing cached GGUF repos: {e}\", exc_info = True)\n        return {\"cached\": []}\n\n\n@router.get(\"/cached-models\")\nasync def list_cached_models(\n    current_subject: str = Depends(get_current_subject),\n):\n    \"\"\"List non-GGUF model repos that have been downloaded to the HF cache.\n\n    Only includes repos that actually contain model weight files\n    (.safetensors, .bin), not repos with only config/metadata.\n    \"\"\"\n    _WEIGHT_EXTENSIONS = (\".safetensors\", \".bin\")\n\n    try:\n        from huggingface_hub import scan_cache_dir\n\n        hf_cache = scan_cache_dir()\n        seen_lower: dict[str, dict] = {}\n        for repo_info in hf_cache.repos:\n            if repo_info.repo_type != \"model\":\n                continue\n            repo_id = repo_info.repo_id\n            if repo_id.upper().endswith(\"-GGUF\"):\n                continue\n            total_size = sum(\n                f.size_on_disk for rev in repo_info.revisions for f in rev.files\n            )\n            if total_size == 0:\n                continue\n            # Skip repos that only have config/metadata files (no weights)\n            has_weights = any(\n                f.file_name.endswith(_WEIGHT_EXTENSIONS)\n                for rev in repo_info.revisions\n                for f in rev.files\n            )\n            if not has_weights:\n                continue\n            key = repo_id.lower()\n            existing = seen_lower.get(key)\n            if existing is None or total_size > existing[\"size_bytes\"]:\n                seen_lower[key] = {\n                    \"repo_id\": repo_id,\n                    \"size_bytes\": total_size,\n                }\n        cached = sorted(seen_lower.values(), key = lambda c: c[\"repo_id\"])\n        return {\"cached\": cached}\n    except Exception as e:\n        logger.error(f\"Error listing cached models: {e}\", exc_info = True)\n        return {\"cached\": []}\n\n\n@router.delete(\"/delete-cached\")\nasync def delete_cached_model(\n    repo_id: str = Body(...),\n    variant: Optional[str] = Body(None),\n    current_subject: str = Depends(get_current_subject),\n):\n    \"\"\"Delete a cached model repo (or a specific GGUF variant) from the HF cache.\n\n    When *variant* is provided, only the GGUF files matching that quant label\n    are removed (e.g. ``UD-Q4_K_XL``).  Otherwise the entire repo is deleted.\n    Refuses if the model is currently loaded for inference.\n    \"\"\"\n    if not _is_valid_repo_id(repo_id):\n        raise HTTPException(status_code = 400, detail = \"Invalid repo_id format\")\n\n    # Check if model is currently loaded\n    try:\n        from routes.inference import get_llama_cpp_backend\n\n        llama_backend = get_llama_cpp_backend()\n        if llama_backend.is_loaded and llama_backend.model_identifier:\n            loaded_id = llama_backend.model_identifier.lower()\n            if loaded_id == repo_id.lower() or loaded_id.startswith(repo_id.lower()):\n                raise HTTPException(\n                    status_code = 400,\n                    detail = \"Unload the model before deleting\",\n                )\n    except HTTPException:\n        raise\n    except Exception:\n        pass\n\n    try:\n        inference_backend = get_inference_backend()\n        if inference_backend.active_model_name:\n            active = inference_backend.active_model_name.lower()\n            if active == repo_id.lower() or active.startswith(repo_id.lower()):\n                raise HTTPException(\n                    status_code = 400,\n                    detail = \"Unload the model before deleting\",\n                )\n    except HTTPException:\n        raise\n    except Exception:\n        pass\n\n    try:\n        from huggingface_hub import scan_cache_dir\n\n        hf_cache = scan_cache_dir()\n        target_repo = None\n        for repo_info in hf_cache.repos:\n            if repo_info.repo_type != \"model\":\n                continue\n            if repo_info.repo_id.lower() == repo_id.lower():\n                target_repo = repo_info\n                break\n\n        if target_repo is None:\n            raise HTTPException(status_code = 404, detail = \"Model not found in cache\")\n\n        # ── Per-variant GGUF deletion ────────────────────────────\n        if variant:\n            deleted_bytes = 0\n            deleted_count = 0\n            for rev in target_repo.revisions:\n                for f in rev.files:\n                    if not f.file_name.endswith(\".gguf\"):\n                        continue\n                    quant = _extract_quant_label(f.file_name)\n                    if quant.lower() != variant.lower():\n                        continue\n                    # Delete the blob (actual data) and the snapshot symlink\n                    try:\n                        blob = Path(f.blob_path)\n                        snap = Path(f.file_path)\n                        size = blob.stat().st_size if blob.exists() else 0\n                        if snap.exists() or snap.is_symlink():\n                            snap.unlink()\n                        if blob.exists():\n                            blob.unlink()\n                        deleted_bytes += size\n                        deleted_count += 1\n                    except Exception as e:\n                        logger.warning(f\"Failed to delete {f.file_name}: {e}\")\n\n            if deleted_count == 0:\n                raise HTTPException(\n                    status_code = 404,\n                    detail = f\"Variant {variant} not found in cache for {repo_id}\",\n                )\n\n            freed_mb = deleted_bytes / (1024 * 1024)\n            logger.info(\n                f\"Deleted {deleted_count} file(s) for {repo_id} variant {variant}: \"\n                f\"{freed_mb:.1f} MB freed\"\n            )\n            return {\"status\": \"deleted\", \"repo_id\": repo_id, \"variant\": variant}\n\n        # ── Full repo deletion ───────────────────────────────────\n        revision_hashes = [rev.commit_hash for rev in target_repo.revisions]\n        if not revision_hashes:\n            raise HTTPException(status_code = 404, detail = \"No revisions found for model\")\n\n        delete_strategy = hf_cache.delete_revisions(*revision_hashes)\n        logger.info(\n            f\"Deleting cached model {repo_id}: \"\n            f\"{delete_strategy.expected_freed_size_str} will be freed\"\n        )\n        delete_strategy.execute()\n\n        return {\"status\": \"deleted\", \"repo_id\": repo_id}\n\n    except HTTPException:\n        raise\n    except Exception as e:\n        logger.error(f\"Error deleting cached model {repo_id}: {e}\", exc_info = True)\n        raise HTTPException(\n            status_code = 500,\n            detail = f\"Failed to delete cached model: {str(e)}\",\n        )\n\n\n@router.get(\"/checkpoints\", response_model = CheckpointListResponse)\nasync def list_checkpoints(\n    outputs_dir: str = Query(\n        default = str(outputs_root()),\n        description = \"Directory to scan for checkpoints\",\n    ),\n    current_subject: str = Depends(get_current_subject),\n):\n    \"\"\"\n    List available checkpoints in the outputs directory.\n\n    Scans the outputs folder for training runs and their checkpoints.\n    \"\"\"\n    try:\n        resolved_outputs_dir = str(resolve_output_dir(outputs_dir))\n        raw_models = scan_checkpoints(outputs_dir = resolved_outputs_dir)\n\n        models = [\n            ModelCheckpoints(\n                name = model_name,\n                checkpoints = [\n                    CheckpointInfo(display_name = display_name, path = path, loss = loss)\n                    for display_name, path, loss in checkpoints\n                ],\n                base_model = metadata.get(\"base_model\"),\n                peft_type = metadata.get(\"peft_type\"),\n                lora_rank = metadata.get(\"lora_rank\"),\n            )\n            for model_name, checkpoints, metadata in raw_models\n        ]\n\n        return CheckpointListResponse(\n            outputs_dir = resolved_outputs_dir,\n            models = models,\n        )\n    except Exception as e:\n        logger.error(f\"Error listing checkpoints: {e}\", exc_info = True)\n        raise HTTPException(\n            status_code = 500,\n            detail = f\"Failed to list checkpoints: {str(e)}\",\n        )\n"
  },
  {
    "path": "studio/backend/routes/training.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nTraining API routes\n\"\"\"\n\nimport sys\nfrom pathlib import Path\nfrom fastapi import APIRouter, Depends, HTTPException, Request\nfrom fastapi.responses import StreamingResponse\nfrom typing import Dict, Optional, Any\nimport structlog\nfrom loggers import get_logger\nimport asyncio\nfrom datetime import datetime\n\n# Add backend directory to path\n# The backend code should be in the same directory structure\nbackend_path = Path(__file__).parent.parent.parent\nif str(backend_path) not in sys.path:\n    sys.path.insert(0, str(backend_path))\n\n# Import backend functions\ntry:\n    from core.training import get_training_backend\n    from utils.models.model_config import load_model_defaults\n    from utils.paths import resolve_dataset_path\nexcept ImportError:\n    # Fallback: try to import from parent directory\n    parent_backend = backend_path.parent / \"backend\"\n    if str(parent_backend) not in sys.path:\n        sys.path.insert(0, str(parent_backend))\n    from core.training import get_training_backend\n    from utils.models.model_config import load_model_defaults\n    from utils.paths import resolve_dataset_path\n\n# Auth\nfrom auth.authentication import get_current_subject\n\nfrom models import (\n    TrainingStartRequest,\n    TrainingJobResponse,\n    TrainingStatus,\n    TrainingProgress,\n)\nfrom models.responses import TrainingStopResponse, TrainingMetricsResponse\nfrom pydantic import BaseModel as PydanticBaseModel\n\n\nclass TrainingStopRequest(PydanticBaseModel):\n    save: bool = True\n\n\nrouter = APIRouter()\nlogger = get_logger(__name__)\n\n\ndef _validate_local_dataset_paths(\n    paths: list[str], label: str = \"Local dataset\"\n) -> list[str]:\n    \"\"\"Resolve and validate a list of local dataset paths. Returns validated absolute paths.\"\"\"\n    validated = []\n    missing = []\n    for dataset_path in paths:\n        dataset_file = resolve_dataset_path(dataset_path)\n        if not dataset_file.exists():\n            missing.append(f\"{dataset_path} (resolved: {dataset_file})\")\n            continue\n        logger.info(f\"Found {label.lower()} file: {dataset_file}\")\n        validated.append(str(dataset_file))\n\n    if missing:\n        missing_detail = \"; \".join(missing[:3])\n        raise HTTPException(\n            status_code = 400,\n            detail = f\"{label} not found: {missing_detail}\",\n        )\n    return validated\n\n\n@router.get(\"/hardware\")\nasync def get_hardware_utilization(\n    current_subject: str = Depends(get_current_subject),\n):\n    \"\"\"\n    Get a live snapshot of GPU hardware utilization.\n\n    Designed to be polled by the frontend during training.\n    Returns GPU utilization %, temperature, VRAM usage, and power draw\n    via nvidia-smi for maximum accuracy.\n    \"\"\"\n    from utils.hardware import get_gpu_utilization\n\n    return get_gpu_utilization()\n\n\n@router.post(\"/start\")\nasync def start_training(\n    request: TrainingStartRequest,\n    current_subject: str = Depends(get_current_subject),\n):\n    \"\"\"\n    Start a training job.\n\n    This endpoint initiates training in the background and returns immediately.\n    Use the /status endpoint to check training progress.\n    \"\"\"\n    try:\n        logger.info(f\"Starting training job with model: {request.model_name}\")\n\n        # NOTE: No in-process ensure_transformers_version() call here.\n        # The subprocess (worker.py) activates the correct version in a\n        # fresh Python interpreter before importing any ML libraries.\n\n        backend = get_training_backend()\n\n        # Generate job ID and attach to backend for later status/progress calls\n        job_id = f\"job_{datetime.now().strftime('%Y%m%d_%H%M%S')}\"\n        backend.current_job_id = job_id\n\n        # Check if training is already active\n        if backend.is_training_active():\n            existing_job_id: Optional[str] = getattr(backend, \"current_job_id\", \"\")\n            return TrainingJobResponse(\n                job_id = existing_job_id or job_id,\n                status = \"error\",\n                message = (\n                    \"Training is already in progress. \"\n                    \"Stop current training before starting a new one.\"\n                ),\n                error = \"Training already active\",\n            )\n\n        # Validate dataset paths if provided\n        if request.local_datasets:\n            request.local_datasets = _validate_local_dataset_paths(\n                request.local_datasets, \"Local dataset\"\n            )\n        if request.local_eval_datasets and request.eval_steps > 0:\n            request.local_eval_datasets = _validate_local_dataset_paths(\n                request.local_eval_datasets, \"Local eval dataset\"\n            )\n\n        # Convert request to kwargs for backend\n        training_kwargs = {\n            \"model_name\": request.model_name,\n            \"training_type\": request.training_type,\n            \"hf_token\": request.hf_token or \"\",\n            \"load_in_4bit\": request.load_in_4bit,\n            \"max_seq_length\": request.max_seq_length,\n            \"hf_dataset\": request.hf_dataset or \"\",\n            \"local_datasets\": request.local_datasets,\n            \"local_eval_datasets\": request.local_eval_datasets,\n            \"format_type\": request.format_type,\n            \"subset\": request.subset,\n            \"train_split\": request.train_split,\n            \"eval_split\": request.eval_split,\n            \"eval_steps\": request.eval_steps,\n            \"dataset_slice_start\": request.dataset_slice_start,\n            \"dataset_slice_end\": request.dataset_slice_end,\n            \"custom_format_mapping\": request.custom_format_mapping,\n            \"num_epochs\": request.num_epochs,\n            \"learning_rate\": request.learning_rate,\n            \"batch_size\": request.batch_size,\n            \"gradient_accumulation_steps\": request.gradient_accumulation_steps,\n            \"warmup_steps\": request.warmup_steps,\n            \"warmup_ratio\": request.warmup_ratio,\n            \"max_steps\": request.max_steps,\n            \"save_steps\": request.save_steps,\n            \"weight_decay\": request.weight_decay,\n            \"random_seed\": request.random_seed,\n            \"packing\": request.packing,\n            \"optim\": request.optim,\n            \"lr_scheduler_type\": request.lr_scheduler_type,\n            \"use_lora\": request.use_lora,\n            \"lora_r\": request.lora_r,\n            \"lora_alpha\": request.lora_alpha,\n            \"lora_dropout\": request.lora_dropout,\n            \"target_modules\": request.target_modules\n            if request.target_modules\n            else None,\n            \"gradient_checkpointing\": request.gradient_checkpointing.strip()\n            if request.gradient_checkpointing and request.gradient_checkpointing.strip()\n            else \"unsloth\",\n            \"use_rslora\": request.use_rslora,\n            \"use_loftq\": request.use_loftq,\n            \"train_on_completions\": request.train_on_completions,\n            \"finetune_vision_layers\": request.finetune_vision_layers,\n            \"finetune_language_layers\": request.finetune_language_layers,\n            \"finetune_attention_modules\": request.finetune_attention_modules,\n            \"finetune_mlp_modules\": request.finetune_mlp_modules,\n            \"is_dataset_image\": request.is_dataset_image,\n            \"is_dataset_audio\": request.is_dataset_audio,\n            \"is_embedding\": request.is_embedding,\n            \"enable_wandb\": request.enable_wandb,\n            \"wandb_token\": request.wandb_token or \"\",\n            \"wandb_project\": request.wandb_project or \"\",\n            \"enable_tensorboard\": request.enable_tensorboard,\n            \"tensorboard_dir\": request.tensorboard_dir or \"\",\n            \"trust_remote_code\": request.trust_remote_code,\n        }\n\n        # Training page has no trust_remote_code toggle — the value comes from\n        # YAML model defaults applied when the user selects a model.  As a safety\n        # net, consult the YAML directly so models that need it always get it.\n        if not training_kwargs[\"trust_remote_code\"]:\n            model_defaults = load_model_defaults(request.model_name)\n            yaml_trust = model_defaults.get(\"training\", {}).get(\n                \"trust_remote_code\", False\n            )\n            if yaml_trust:\n                logger.info(\n                    f\"YAML config sets trust_remote_code=True for {request.model_name}\"\n                )\n                training_kwargs[\"trust_remote_code\"] = True\n\n        # Free GPU memory: shut down any running inference/export subprocesses\n        # before training starts (they'd compete for VRAM otherwise)\n        try:\n            from core.inference import get_inference_backend\n\n            inf_backend = get_inference_backend()\n            if inf_backend.active_model_name:\n                logger.info(\n                    \"Unloading inference model '%s' to free GPU memory for training\",\n                    inf_backend.active_model_name,\n                )\n                inf_backend._shutdown_subprocess()\n                inf_backend.active_model_name = None\n                inf_backend.models.clear()\n        except Exception as e:\n            logger.warning(\"Could not unload inference model: %s\", e)\n\n        try:\n            from core.export import get_export_backend\n\n            exp_backend = get_export_backend()\n            if exp_backend.current_checkpoint:\n                logger.info(\n                    \"Shutting down export subprocess to free GPU memory for training\"\n                )\n                exp_backend._shutdown_subprocess()\n                exp_backend.current_checkpoint = None\n                exp_backend.is_vision = False\n                exp_backend.is_peft = False\n        except Exception as e:\n            logger.warning(\"Could not shut down export subprocess: %s\", e)\n\n        # start_training now spawns a subprocess (non-blocking)\n        success = backend.start_training(**training_kwargs)\n\n        if not success:\n            progress_error = backend.trainer.training_progress.error\n            return TrainingJobResponse(\n                job_id = job_id,\n                status = \"error\",\n                message = progress_error or \"Failed to start training subprocess\",\n                error = progress_error or \"subprocess_start_failed\",\n            )\n\n        return TrainingJobResponse(\n            job_id = job_id,\n            status = \"queued\",\n            message = \"Training job queued and starting in subprocess\",\n            error = None,\n        )\n\n    except Exception as e:\n        logger.error(f\"Error starting training: {e}\", exc_info = True)\n        raise HTTPException(\n            status_code = 500,\n            detail = f\"Failed to start training: {str(e)}\",\n        )\n\n\n@router.post(\"/stop\", response_model = TrainingStopResponse)\nasync def stop_training(\n    body: TrainingStopRequest = TrainingStopRequest(),\n    current_subject: str = Depends(get_current_subject),\n):\n    \"\"\"\n    Stop the currently running training job.\n\n    Body:\n        save (bool): If True (default), save the model at the current checkpoint.\n    \"\"\"\n    try:\n        backend = get_training_backend()\n        is_active = backend.is_training_active()\n        logger.info(\"Stop requested: save=%s is_active=%s\", body.save, is_active)\n\n        if not is_active:\n            return TrainingStopResponse(\n                status = \"idle\", message = \"No training job is currently running\"\n            )\n\n        # Call backend stop method\n        backend.stop_training(save = body.save)\n\n        return TrainingStopResponse(\n            status = \"stopped\",\n            message = \"Stop requested. Training will stop at the next safe step.\",\n        )\n\n    except Exception as e:\n        logger.error(f\"Error stopping training: {e}\", exc_info = True)\n        raise HTTPException(\n            status_code = 500, detail = f\"Failed to stop training: {str(e)}\"\n        )\n\n\n@router.post(\"/reset\")\nasync def reset_training(\n    current_subject: str = Depends(get_current_subject),\n):\n    \"\"\"\n    Reset training state so the user can return to configuration.\n    \"\"\"\n    try:\n        backend = get_training_backend()\n        is_active = backend.is_training_active()\n\n        if is_active:\n            if backend._cancel_requested:\n                # Cancel (save=False) was requested — force-terminate so we can reset immediately\n                logger.info(\n                    \"Force-terminating subprocess for immediate reset (cancel path)\"\n                )\n                backend.force_terminate()\n            else:\n                logger.warning(\n                    \"Rejected reset while training active: is_active=%s\", is_active\n                )\n                raise HTTPException(\n                    status_code = 409,\n                    detail = \"Training is still running. Stop training and wait for it to finish before resetting.\",\n                )\n\n        logger.info(\"Reset training state: clearing runtime + metric history\")\n        backend._should_stop = False  # Clear stop flag so status returns to idle\n        backend.trainer._update_progress(\n            is_training = False,\n            is_completed = False,\n            error = None,\n            status_message = \"Ready to train\",\n            step = 0,\n            loss = 0.0,\n            epoch = 0,\n            total_steps = 0,\n        )\n        backend.loss_history = []\n        backend.lr_history = []\n        backend.step_history = []\n        backend.grad_norm_history = []\n        backend.grad_norm_step_history = []\n        return {\"status\": \"ok\"}\n    except HTTPException:\n        raise\n    except Exception as e:\n        logger.error(f\"Error resetting training: {e}\", exc_info = True)\n        raise HTTPException(\n            status_code = 500,\n            detail = f\"Failed to reset training: {str(e)}\",\n        )\n\n\n@router.get(\"/status\")\nasync def get_training_status(\n    current_subject: str = Depends(get_current_subject),\n):\n    \"\"\"\n    Get the current training status.\n    \"\"\"\n    try:\n        backend = get_training_backend()\n        job_id: str = getattr(backend, \"current_job_id\", \"\") or \"\"\n\n        # Check if training is active\n        is_active = backend.is_training_active()\n\n        # Get progress info from trainer\n        try:\n            progress = backend.trainer.get_training_progress()\n        except Exception:\n            progress = None\n\n        status_message = (\n            getattr(progress, \"status_message\", None) if progress else None\n        ) or \"Ready to train\"\n        error_message = getattr(progress, \"error\", None) if progress else None\n\n        # Check if training was stopped by user\n        trainer_stopped = getattr(backend, \"_should_stop\", False)\n\n        # Derive high-level phase\n        if error_message:\n            phase = \"error\"\n        elif is_active:\n            msg_lower = status_message.lower()\n            if \"loading\" in msg_lower or \"importing\" in msg_lower:\n                phase = \"loading_model\"\n            elif any(\n                k in msg_lower for k in [\"preparing\", \"initializing\", \"configuring\"]\n            ):\n                phase = \"configuring\"\n            else:\n                phase = \"training\"\n        elif trainer_stopped:\n            phase = \"stopped\"\n        elif progress and getattr(progress, \"is_completed\", False):\n            phase = \"completed\"\n        else:\n            phase = \"idle\"\n\n        details = None\n        if progress:\n            details = {\n                \"epoch\": getattr(progress, \"epoch\", 0),\n                \"step\": getattr(progress, \"step\", 0),\n                \"total_steps\": getattr(progress, \"total_steps\", 0),\n                \"loss\": getattr(progress, \"loss\", 0.0),\n                \"learning_rate\": getattr(progress, \"learning_rate\", 0.0),\n            }\n\n        # Build metric history for chart recovery after SSE reconnection\n        metric_history = None\n        if backend.step_history:\n            metric_history = {\n                \"steps\": list(backend.step_history),\n                \"loss\": list(backend.loss_history),\n                \"lr\": list(backend.lr_history),\n                \"grad_norm\": list(getattr(backend, \"grad_norm_history\", [])),\n                \"grad_norm_steps\": list(getattr(backend, \"grad_norm_step_history\", [])),\n                \"eval_loss\": list(backend.eval_loss_history),\n                \"eval_steps\": list(backend.eval_step_history),\n            }\n\n        return TrainingStatus(\n            job_id = job_id,\n            phase = phase,\n            is_training_running = is_active,\n            eval_enabled = backend.eval_enabled,\n            message = status_message,\n            error = error_message,\n            details = details,\n            metric_history = metric_history,\n        )\n\n    except Exception as e:\n        logger.error(f\"Error getting training status: {e}\", exc_info = True)\n        raise HTTPException(\n            status_code = 500, detail = f\"Failed to get training status: {str(e)}\"\n        )\n\n\n@router.get(\"/metrics\", response_model = TrainingMetricsResponse)\nasync def get_training_metrics(\n    current_subject: str = Depends(get_current_subject),\n):\n    \"\"\"\n    Get training metrics (loss, learning rate, steps).\n    \"\"\"\n    try:\n        backend = get_training_backend()\n\n        # Get metrics from backend\n        loss_history = backend.loss_history\n        lr_history = backend.lr_history\n        step_history = backend.step_history\n        grad_norm_history = getattr(backend, \"grad_norm_history\", [])\n        grad_norm_step_history = getattr(backend, \"grad_norm_step_history\", [])\n\n        # Get current values\n        current_loss = loss_history[-1] if loss_history else None\n        current_lr = lr_history[-1] if lr_history else None\n        current_step = step_history[-1] if step_history else None\n\n        return TrainingMetricsResponse(\n            loss_history = loss_history,\n            lr_history = lr_history,\n            step_history = step_history,\n            grad_norm_history = grad_norm_history,\n            grad_norm_step_history = grad_norm_step_history,\n            current_loss = current_loss,\n            current_lr = current_lr,\n            current_step = current_step,\n        )\n\n    except Exception as e:\n        logger.error(f\"Error getting training metrics: {e}\", exc_info = True)\n        raise HTTPException(\n            status_code = 500, detail = f\"Failed to get training metrics: {str(e)}\"\n        )\n\n\n@router.get(\"/progress\")\nasync def stream_training_progress(\n    request: Request,\n    current_subject: str = Depends(get_current_subject),\n):\n    \"\"\"\n    Stream training progress updates using Server-Sent Events (SSE).\n\n    This endpoint provides real-time updates on training progress.\n    Supports reconnection via the SSE spec:\n      - Sends `id:` with each event so the browser tracks position.\n      - Sends `retry:` to control reconnection interval.\n      - Sends named `event:` types (progress, heartbeat, complete, error).\n      - Reads `Last-Event-ID` header on reconnect to replay missed steps.\n    \"\"\"\n    # Read Last-Event-ID header for reconnection resume\n    last_event_id = request.headers.get(\"last-event-id\")\n    resume_from_step: Optional[int] = None\n    if last_event_id is not None:\n        try:\n            resume_from_step = int(last_event_id)\n            logger.info(f\"SSE reconnect: resuming from step {resume_from_step}\")\n        except ValueError:\n            logger.warning(f\"Invalid Last-Event-ID: {last_event_id}\")\n\n    async def event_generator():\n        backend = get_training_backend()\n        job_id: str = getattr(backend, \"current_job_id\", \"\") or \"\"\n\n        # ── Helpers ──────────────────────────────────────────────\n        def build_progress(\n            step: int,\n            loss: float,\n            learning_rate: float,\n            total_steps: int,\n            epoch: Optional[float] = None,\n            progress: Optional[Any] = None,\n            grad_norm_override: Optional[float] = None,\n            eval_loss_override: Optional[float] = None,\n        ) -> TrainingProgress:\n            total = max(total_steps, 0)\n            if step < 0 or total == 0:\n                progress_percent = 0.0\n            else:\n                progress_percent = (\n                    float(step) / float(total) * 100.0 if total > 0 else 0.0\n                )\n\n            # Get actual values from progress object if available\n            elapsed_seconds = (\n                getattr(progress, \"elapsed_seconds\", None) if progress else None\n            )\n            eta_seconds = getattr(progress, \"eta_seconds\", None) if progress else None\n            grad_norm = grad_norm_override\n            if grad_norm is None and progress:\n                grad_norm = getattr(progress, \"grad_norm\", None)\n            num_tokens = getattr(progress, \"num_tokens\", None) if progress else None\n            eval_loss = eval_loss_override\n            if eval_loss is None and progress:\n                eval_loss = getattr(progress, \"eval_loss\", None)\n\n            return TrainingProgress(\n                job_id = job_id,\n                step = step,\n                total_steps = total,\n                loss = loss,\n                learning_rate = learning_rate,\n                progress_percent = progress_percent,\n                epoch = epoch,\n                elapsed_seconds = elapsed_seconds,\n                eta_seconds = eta_seconds,\n                grad_norm = grad_norm,\n                num_tokens = num_tokens,\n                eval_loss = eval_loss,\n            )\n\n        def format_sse(\n            data: str,\n            event: str = \"progress\",\n            event_id: Optional[int] = None,\n        ) -> str:\n            \"\"\"Format a single SSE message with id/event/data fields.\"\"\"\n            lines = []\n            if event_id is not None:\n                lines.append(f\"id: {event_id}\")\n            lines.append(f\"event: {event}\")\n            lines.append(f\"data: {data}\")\n            lines.append(\"\")  # trailing blank line\n            lines.append(\"\")  # double newline terminates the event\n            return \"\\n\".join(lines)\n\n        # ── Retry directive ──────────────────────────────────────\n        # Tell the browser to reconnect after 3 seconds if the connection drops\n        yield \"retry: 3000\\n\\n\"\n\n        # ── Replay missed steps on reconnect ─────────────────────\n        if resume_from_step is not None and backend.step_history:\n            replayed = 0\n            grad_norm_by_step = {\n                step_val: grad_val\n                for step_val, grad_val in zip(\n                    getattr(backend, \"grad_norm_step_history\", []),\n                    getattr(backend, \"grad_norm_history\", []),\n                )\n            }\n            for i, step_val in enumerate(backend.step_history):\n                if step_val > resume_from_step:\n                    loss_val = (\n                        backend.loss_history[i]\n                        if i < len(backend.loss_history)\n                        else 0.0\n                    )\n                    lr_val = (\n                        backend.lr_history[i] if i < len(backend.lr_history) else 0.0\n                    )\n                    tp_replay = getattr(\n                        getattr(backend, \"trainer\", None), \"training_progress\", None\n                    )\n                    total_replay = (\n                        getattr(tp_replay, \"total_steps\", step_val)\n                        if tp_replay\n                        else step_val\n                    )\n                    epoch_replay = (\n                        getattr(tp_replay, \"epoch\", None) if tp_replay else None\n                    )\n                    payload = build_progress(\n                        step_val,\n                        loss_val,\n                        lr_val,\n                        total_replay,\n                        epoch_replay,\n                        progress = tp_replay,\n                        grad_norm_override = grad_norm_by_step.get(step_val),\n                    )\n                    yield format_sse(\n                        payload.model_dump_json(), event = \"progress\", event_id = step_val\n                    )\n                    replayed += 1\n            if replayed:\n                logger.info(f\"SSE reconnect: replayed {replayed} missed steps\")\n\n        # ── Initial status (only on fresh connections) ───────────\n        if resume_from_step is None:\n            is_active = backend.is_training_active()\n            tp = getattr(getattr(backend, \"trainer\", None), \"training_progress\", None)\n            initial_total_steps = getattr(tp, \"total_steps\", 0) if tp else 0\n            initial_epoch = getattr(tp, \"epoch\", None) if tp else None\n\n            initial_progress = build_progress(\n                step = 0,\n                loss = 0.0,\n                learning_rate = 0.0,\n                total_steps = initial_total_steps,\n                epoch = initial_epoch,\n                progress = tp,\n            )\n            yield format_sse(\n                initial_progress.model_dump_json(), event = \"progress\", event_id = 0\n            )\n\n            # If not active, send final state and exit\n            if not is_active:\n                if backend.step_history:\n                    final_step = backend.step_history[-1]\n                    final_loss = (\n                        backend.loss_history[-1] if backend.loss_history else 0.0\n                    )\n                    final_lr = backend.lr_history[-1] if backend.lr_history else 0.0\n                    final_total_steps = (\n                        getattr(tp, \"total_steps\", final_step) if tp else final_step\n                    )\n                    final_epoch = getattr(tp, \"epoch\", None) if tp else None\n                    payload = build_progress(\n                        final_step,\n                        final_loss,\n                        final_lr,\n                        final_total_steps,\n                        final_epoch,\n                        progress = tp,\n                    )\n                    yield format_sse(\n                        payload.model_dump_json(), event = \"complete\", event_id = final_step\n                    )\n                else:\n                    yield format_sse(\n                        build_progress(-1, 0.0, 0.0, 0, progress = tp).model_dump_json(),\n                        event = \"complete\",\n                        event_id = 0,\n                    )\n                return\n\n        # ── Live polling loop ────────────────────────────────────\n        last_step = resume_from_step if resume_from_step is not None else -1\n        no_update_count = 0\n        max_no_updates = (\n            1800  # Timeout after 30 minutes (large models need time for compilation)\n        )\n\n        while backend.is_training_active():\n            try:\n                if backend.step_history:\n                    current_step = backend.step_history[-1]\n                    current_loss = (\n                        backend.loss_history[-1] if backend.loss_history else 0.0\n                    )\n                    current_lr = backend.lr_history[-1] if backend.lr_history else 0.0\n                    tp_inner = getattr(\n                        getattr(backend, \"trainer\", None), \"training_progress\", None\n                    )\n                    current_total_steps = (\n                        getattr(tp_inner, \"total_steps\", current_step)\n                        if tp_inner\n                        else current_step\n                    )\n                    current_epoch = (\n                        getattr(tp_inner, \"epoch\", None) if tp_inner else None\n                    )\n\n                    # Only send if step changed\n                    if current_step != last_step:\n                        progress_payload = build_progress(\n                            current_step,\n                            current_loss,\n                            current_lr,\n                            current_total_steps,\n                            current_epoch,\n                            progress = tp_inner,\n                        )\n                        yield format_sse(\n                            progress_payload.model_dump_json(),\n                            event = \"progress\",\n                            event_id = current_step,\n                        )\n                        last_step = current_step\n                        no_update_count = 0\n                    else:\n                        no_update_count += 1\n                        # Send heartbeat every 10 seconds\n                        if no_update_count % 10 == 0:\n                            heartbeat_payload = build_progress(\n                                current_step,\n                                current_loss,\n                                current_lr,\n                                current_total_steps,\n                                current_epoch,\n                                progress = tp_inner,\n                            )\n                            yield format_sse(\n                                heartbeat_payload.model_dump_json(),\n                                event = \"heartbeat\",\n                                event_id = current_step,\n                            )\n                else:\n                    # No steps yet, but training is active (model loading, etc.)\n                    no_update_count += 1\n                    if no_update_count % 5 == 0:\n                        # Pull total_steps and status from trainer so\n                        # the frontend can show \"Tokenizing…\" etc.\n                        tp_prep = getattr(\n                            getattr(backend, \"trainer\", None),\n                            \"training_progress\",\n                            None,\n                        )\n                        prep_total = (\n                            getattr(tp_prep, \"total_steps\", 0) if tp_prep else 0\n                        )\n                        preparing_payload = build_progress(\n                            0,\n                            0.0,\n                            0.0,\n                            prep_total,\n                            progress = tp_prep,\n                        )\n                        yield format_sse(\n                            preparing_payload.model_dump_json(),\n                            event = \"heartbeat\",\n                            event_id = 0,\n                        )\n\n                # Timeout check\n                if no_update_count > max_no_updates:\n                    logger.warning(\"Progress stream timeout - no updates received\")\n                    tp_timeout = getattr(\n                        getattr(backend, \"trainer\", None), \"training_progress\", None\n                    )\n                    timeout_payload = build_progress(\n                        last_step, 0.0, 0.0, 0, progress = tp_timeout\n                    )\n                    yield format_sse(\n                        timeout_payload.model_dump_json(),\n                        event = \"error\",\n                        event_id = last_step if last_step >= 0 else 0,\n                    )\n                    break\n\n                await asyncio.sleep(1)  # Poll every second\n\n            except Exception as e:\n                logger.error(f\"Error in progress stream: {e}\", exc_info = True)\n                tp_error = getattr(\n                    getattr(backend, \"trainer\", None), \"training_progress\", None\n                )\n                error_payload = build_progress(0, 0.0, 0.0, 0, progress = tp_error)\n                yield format_sse(\n                    error_payload.model_dump_json(),\n                    event = \"error\",\n                    event_id = last_step if last_step >= 0 else 0,\n                )\n                break\n\n        # ── Final \"complete\" event ───────────────────────────────\n        final_step = backend.step_history[-1] if backend.step_history else last_step\n        final_loss = backend.loss_history[-1] if backend.loss_history else 0.0\n        final_lr = backend.lr_history[-1] if backend.lr_history else 0.0\n        final_tp = getattr(getattr(backend, \"trainer\", None), \"training_progress\", None)\n        final_total_steps = (\n            getattr(final_tp, \"total_steps\", final_step) if final_tp else final_step\n        )\n        final_epoch = getattr(final_tp, \"epoch\", None) if final_tp else None\n        final_payload = build_progress(\n            final_step,\n            final_loss,\n            final_lr,\n            final_total_steps,\n            final_epoch,\n            progress = final_tp,\n        )\n        yield format_sse(\n            final_payload.model_dump_json(),\n            event = \"complete\",\n            event_id = final_step if final_step >= 0 else 0,\n        )\n\n    return StreamingResponse(\n        event_generator(),\n        media_type = \"text/event-stream\",\n        headers = {\n            \"Cache-Control\": \"no-cache\",\n            \"Connection\": \"keep-alive\",\n            \"X-Accel-Buffering\": \"no\",\n        },\n    )\n"
  },
  {
    "path": "studio/backend/run.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nRun script for Unsloth UI Backend.\nWorks independently and can be moved to any directory.\n\"\"\"\n\nimport os\nimport sys\n\n# Suppress annoying C-level dependency warnings globally (e.g. SwigPyPacked)\nos.environ[\"PYTHONWARNINGS\"] = \"ignore\"\n\nfrom pathlib import Path\n\n# Add the backend directory to Python path\nbackend_dir = Path(__file__).parent\nif str(backend_dir) not in sys.path:\n    sys.path.insert(0, str(backend_dir))\n\nfrom loggers import get_logger\n\nlogger = get_logger(__name__)\n\n\ndef _resolve_external_ip() -> str:\n    \"\"\"\n    Resolve the machine's external IP address.\n\n    Tries (in order):\n    1. GCE metadata server (instant, works on Google Cloud VMs)\n    2. ifconfig.me (works anywhere with internet)\n    3. LAN IP via UDP socket trick (fallback)\n    \"\"\"\n    import urllib.request\n    import socket\n\n    # 1. Try GCE metadata server (responds in <10ms on GCE, times out fast elsewhere)\n    try:\n        req = urllib.request.Request(\n            \"http://metadata.google.internal/computeMetadata/v1/instance/network-interfaces/0/access-configs/0/external-ip\",\n            headers = {\"Metadata-Flavor\": \"Google\"},\n        )\n        with urllib.request.urlopen(req, timeout = 1) as resp:\n            ip = resp.read().decode().strip()\n            if ip:\n                return ip\n    except Exception:\n        pass\n\n    # 2. Try public IP service\n    try:\n        with urllib.request.urlopen(\"https://ifconfig.me\", timeout = 3) as resp:\n            ip = resp.read().decode().strip()\n            if ip:\n                return ip\n    except Exception:\n        pass\n\n    # 3. Fallback: LAN IP via UDP socket trick\n    try:\n        s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)\n        s.connect((\"8.8.8.8\", 80))\n        ip = s.getsockname()[0]\n        s.close()\n        return ip\n    except Exception:\n        return \"0.0.0.0\"\n\n\ndef _is_port_free(host: str, port: int) -> bool:\n    \"\"\"Check if a port is available for binding.\"\"\"\n    import socket\n\n    try:\n        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:\n            s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)\n            s.bind((host, port))\n            return True\n    except OSError:\n        return False\n\n\ndef _find_free_port(host: str, start: int, max_attempts: int = 20) -> int:\n    \"\"\"Find a free port starting from `start`, trying up to max_attempts ports.\"\"\"\n    for offset in range(max_attempts):\n        candidate = start + offset\n        if _is_port_free(host, candidate):\n            return candidate\n    raise RuntimeError(\n        f\"Could not find a free port in range {start}-{start + max_attempts - 1}\"\n    )\n\n\ndef _graceful_shutdown(server = None):\n    \"\"\"Explicitly shut down all subprocess backends and the uvicorn server.\n\n    Called from signal handlers to ensure child processes are cleaned up\n    before the parent exits. This is critical on Windows where atexit\n    handlers are unreliable after Ctrl+C.\n    \"\"\"\n    logger.info(\"Graceful shutdown initiated — cleaning up subprocesses...\")\n\n    # 1. Shut down uvicorn server (releases the listening socket)\n    if server is not None:\n        server.should_exit = True\n\n    # 2. Clean up inference subprocess (if instantiated)\n    try:\n        from core.inference.orchestrator import _inference_backend\n\n        if _inference_backend is not None:\n            _inference_backend._shutdown_subprocess(timeout = 5.0)\n    except Exception as e:\n        logger.warning(\"Error shutting down inference subprocess: %s\", e)\n\n    # 3. Clean up export subprocess (if instantiated)\n    try:\n        from core.export.orchestrator import _export_backend\n\n        if _export_backend is not None:\n            _export_backend._shutdown_subprocess(timeout = 5.0)\n    except Exception as e:\n        logger.warning(\"Error shutting down export subprocess: %s\", e)\n\n    # 4. Clean up training subprocess (if active)\n    try:\n        from core.training.training import _training_backend\n\n        if _training_backend is not None:\n            _training_backend.force_terminate()\n    except Exception as e:\n        logger.warning(\"Error shutting down training subprocess: %s\", e)\n\n    # 5. Kill llama-server subprocess (if loaded)\n    try:\n        from routes.inference import _llama_cpp_backend\n\n        if _llama_cpp_backend is not None:\n            _llama_cpp_backend._kill_process()\n    except Exception as e:\n        logger.warning(\"Error shutting down llama-server: %s\", e)\n\n    logger.info(\"All subprocesses cleaned up\")\n\n\n# The uvicorn server instance — set by run_server(), used by callers\n# that need to tell the server to exit (e.g. signal handlers).\n_server = None\n\n# Shutdown event — used to wake the main loop on signal\n_shutdown_event = None\n\n\ndef run_server(\n    host: str = \"0.0.0.0\",\n    port: int = 8888,\n    frontend_path: Path = Path(__file__).resolve().parent.parent / \"frontend\" / \"dist\",\n    silent: bool = False,\n):\n    \"\"\"\n    Start the FastAPI server.\n\n    Args:\n        host: Host to bind to\n        port: Port to bind to (auto-increments if in use)\n        frontend_path: Path to frontend build directory (optional)\n        silent: Suppress startup messages\n\n    Note:\n        Signal handlers are NOT registered here so that embedders\n        (e.g. Colab notebooks) keep their own interrupt semantics.\n        Standalone callers should register handlers after calling this.\n    \"\"\"\n    global _server, _shutdown_event\n\n    import nest_asyncio\n\n    nest_asyncio.apply()\n\n    import asyncio\n    from threading import Thread, Event\n    import time\n    import uvicorn\n\n    from main import app, setup_frontend\n    from utils.paths import ensure_studio_directories\n\n    # Create all standard directories on startup\n    ensure_studio_directories()\n\n    # Auto-find free port if requested port is in use\n    if not _is_port_free(host, port):\n        original_port = port\n        port = _find_free_port(host, port)\n        if not silent:\n            print(f\"Port {original_port} is in use, using port {port} instead\")\n\n    # Setup frontend if path provided\n    if frontend_path:\n        if setup_frontend(app, frontend_path):\n            if not silent:\n                print(f\"✅ Frontend loaded from {frontend_path}\")\n        else:\n            if not silent:\n                print(f\"⚠️ Frontend not found at {frontend_path}\")\n\n    # Create the uvicorn server and expose it for signal handlers\n    config = uvicorn.Config(\n        app, host = host, port = port, log_level = \"info\", access_log = False\n    )\n    _server = uvicorn.Server(config)\n    _shutdown_event = Event()\n\n    # Run server in a daemon thread\n    def _run():\n        asyncio.run(_server.serve())\n\n    thread = Thread(target = _run, daemon = True)\n    thread.start()\n    time.sleep(3)\n\n    if not silent:\n        display_host = _resolve_external_ip() if host == \"0.0.0.0\" else host\n\n        print(\"\")\n        print(\"=\" * 50)\n        print(f\"🦥 Open your web browser, and enter http://localhost:{port}\")\n        print(\"=\" * 50)\n        print(\"\")\n        print(\"=\" * 50)\n        print(f\"🦥 Unsloth Studio is running on port {port}\")\n        print(f\"   Local Access:          http://localhost:{port}\")\n        print(f\"   Worldwide Web Address: http://{display_host}:{port}\")\n        print(f\"   API:                   http://{display_host}:{port}/api\")\n        print(f\"   Health:                http://{display_host}:{port}/api/health\")\n        print(\"=\" * 50)\n\n    return app\n\n\n# For direct execution (also invoked by CLI via os.execvp / subprocess)\nif __name__ == \"__main__\":\n    import argparse\n    import signal\n\n    parser = argparse.ArgumentParser(description = \"Run Unsloth UI Backend server\")\n    parser.add_argument(\"--host\", default = \"0.0.0.0\", help = \"Host to bind to\")\n    parser.add_argument(\"--port\", type = int, default = 8888, help = \"Port to bind to\")\n    parser.add_argument(\n        \"--frontend\",\n        type = str,\n        default = Path(__file__).resolve().parent.parent / \"frontend\" / \"dist\",\n        help = \"Path to frontend build\",\n    )\n    parser.add_argument(\"--silent\", action = \"store_true\", help = \"Suppress output\")\n\n    args = parser.parse_args()\n\n    kwargs = dict(host = args.host, port = args.port, silent = args.silent)\n    if args.frontend is not None:\n        kwargs[\"frontend_path\"] = Path(args.frontend)\n    run_server(**kwargs)\n\n    # ── Signal handler — ensures subprocess cleanup on Ctrl+C ────\n    def _signal_handler(signum, frame):\n        _graceful_shutdown(_server)\n        _shutdown_event.set()\n\n    signal.signal(signal.SIGINT, _signal_handler)\n    signal.signal(signal.SIGTERM, _signal_handler)\n\n    # On Windows, some terminals send SIGBREAK for Ctrl+C / Ctrl+Break\n    if hasattr(signal, \"SIGBREAK\"):\n        signal.signal(signal.SIGBREAK, _signal_handler)\n\n    # Keep running until shutdown signal.\n    # NOTE: Event.wait() without a timeout blocks at the C level on Linux,\n    # which prevents Python from delivering SIGINT (Ctrl+C).  Using a\n    # short timeout in a loop lets the interpreter process pending signals.\n    while not _shutdown_event.is_set():\n        _shutdown_event.wait(timeout = 1)\n"
  },
  {
    "path": "studio/backend/state/.gitkeep",
    "content": ""
  },
  {
    "path": "studio/backend/state/__init__.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n"
  },
  {
    "path": "studio/backend/tests/__init__.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n"
  },
  {
    "path": "studio/backend/tests/conftest.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nShared pytest configuration for the backend test suite.\nEnsures that the backend root is on sys.path so that\n`import utils.utils` (and similar flat imports) resolve correctly.\n\"\"\"\n\nimport sys\nfrom pathlib import Path\n\n# Add backend root to sys.path (mirrors how the app itself is launched)\n_backend_root = Path(__file__).resolve().parent.parent\nif str(_backend_root) not in sys.path:\n    sys.path.insert(0, str(_backend_root))\n"
  },
  {
    "path": "studio/backend/tests/test_data_recipe_seed.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nfrom pathlib import Path\n\n\ndef test_seed_inspect_load_kwargs_disables_remote_code_execution():\n    seed_route = (\n        Path(__file__).resolve().parent.parent / \"routes\" / \"data_recipe\" / \"seed.py\"\n    ).read_text()\n\n    assert '\"trust_remote_code\": False' in seed_route\n"
  },
  {
    "path": "studio/backend/tests/test_utils.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nTests for utils/hardware and utils/utils — device detection, GPU memory, error formatting.\n\nThese tests are designed to pass on ANY platform:\n  • NVIDIA GPU  (CUDA backend, requires torch)\n  • Apple Silicon (MLX backend, requires mlx)\n  • CPU-only     (no GPU at all)\n\nNo ML framework is imported at the top level.\nTests that need torch/mlx internals for mocking are skipped when unavailable.\n\nRun with:\n    cd studio/backend\n    python -m pytest tests/test_utils.py -v\n\"\"\"\n\nimport platform\nfrom unittest.mock import patch, MagicMock\n\nimport pytest\n\n# --- Conditional framework imports ---\ntry:\n    import torch\n\n    HAS_TORCH = True\nexcept ImportError:\n    HAS_TORCH = False\n\ntry:\n    import mlx.core as mx\n\n    HAS_MLX = True\nexcept ImportError:\n    HAS_MLX = False\n\nneeds_torch = pytest.mark.skipif(not HAS_TORCH, reason = \"PyTorch not installed\")\nneeds_mlx = pytest.mark.skipif(not HAS_MLX, reason = \"MLX not installed\")\n\nfrom utils.hardware import (\n    get_device,\n    detect_hardware,\n    is_apple_silicon,\n    clear_gpu_cache,\n    get_gpu_memory_info,\n    log_gpu_memory,\n    DeviceType,\n)\nimport utils.hardware.hardware as _hw_module\nfrom utils.utils import format_error_message\n\n\n# ========== Helpers ==========\n\n\ndef _actual_device() -> str:\n    \"\"\"Return the real device string for the current machine.\"\"\"\n    if HAS_TORCH and torch.cuda.is_available():\n        return \"cuda\"\n    if is_apple_silicon() and HAS_MLX:\n        return \"mlx\"\n    return \"cpu\"\n\n\ndef _reset_and_detect():\n    \"\"\"Reset the cached DEVICE global and re-run detection.\"\"\"\n    _hw_module.DEVICE = None\n    return detect_hardware()\n\n\n# ========== get_device() ==========\n\n\nclass TestGetDevice:\n    \"\"\"Tests for get_device() — should agree with the real hardware.\"\"\"\n\n    def setup_method(self):\n        self._saved_device = _hw_module.DEVICE\n\n    def teardown_method(self):\n        _hw_module.DEVICE = self._saved_device\n\n    def test_returns_valid_device_type(self):\n        result = get_device()\n        assert result in (DeviceType.CUDA, DeviceType.MLX, DeviceType.CPU)\n\n    def test_matches_actual_hardware(self):\n        assert get_device().value == _actual_device()\n\n    # --- Mocked paths ---\n\n    @needs_torch\n    def test_returns_cuda_when_cuda_available(self):\n        with (\n            patch(\"utils.hardware.hardware._has_torch\", return_value = True),\n            patch(\"torch.cuda.is_available\", return_value = True),\n        ):\n            assert _reset_and_detect() == DeviceType.CUDA\n\n    @needs_mlx\n    def test_returns_mlx_when_on_apple_silicon_with_mlx(self):\n        with (\n            patch(\"utils.hardware.hardware._has_torch\", return_value = False),\n            patch(\"utils.hardware.hardware.is_apple_silicon\", return_value = True),\n            patch(\"utils.hardware.hardware._has_mlx\", return_value = True),\n        ):\n            assert _reset_and_detect() == DeviceType.MLX\n\n    def test_returns_cpu_when_nothing_available(self):\n        with (\n            patch(\"utils.hardware.hardware._has_torch\", return_value = False),\n            patch(\"utils.hardware.hardware.is_apple_silicon\", return_value = False),\n            patch(\"utils.hardware.hardware._has_mlx\", return_value = False),\n        ):\n            assert _reset_and_detect() == DeviceType.CPU\n\n\n# ========== is_apple_silicon() ==========\n\n\nclass TestIsAppleSilicon:\n    def test_returns_bool(self):\n        assert isinstance(is_apple_silicon(), bool)\n\n    def test_true_on_darwin_arm64(self):\n        with patch(\"utils.hardware.hardware.platform\") as mock_plat:\n            mock_plat.system.return_value = \"Darwin\"\n            mock_plat.machine.return_value = \"arm64\"\n            assert is_apple_silicon() is True\n\n    def test_false_on_linux_x86(self):\n        with patch(\"utils.hardware.hardware.platform\") as mock_plat:\n            mock_plat.system.return_value = \"Linux\"\n            mock_plat.machine.return_value = \"x86_64\"\n            assert is_apple_silicon() is False\n\n    def test_false_on_darwin_x86(self):\n        \"\"\"Intel Mac should return False.\"\"\"\n        with patch(\"utils.hardware.hardware.platform\") as mock_plat:\n            mock_plat.system.return_value = \"Darwin\"\n            mock_plat.machine.return_value = \"x86_64\"\n            assert is_apple_silicon() is False\n\n\n# ========== clear_gpu_cache() ==========\n\n\nclass TestClearGpuCache:\n    \"\"\"clear_gpu_cache() must never raise, regardless of platform.\"\"\"\n\n    def test_does_not_raise(self):\n        clear_gpu_cache()\n\n    @needs_torch\n    def test_calls_cuda_cache_when_cuda(self):\n        with (\n            patch(\"utils.hardware.hardware.get_device\", return_value = DeviceType.CUDA),\n            patch(\"torch.cuda.empty_cache\") as mock_empty,\n            patch(\"torch.cuda.ipc_collect\") as mock_ipc,\n        ):\n            clear_gpu_cache()\n            mock_empty.assert_called_once()\n            mock_ipc.assert_called_once()\n\n    @needs_mlx\n    def test_mlx_does_not_raise(self):\n        \"\"\"MLX cache clear is a no-op — should just succeed.\"\"\"\n        with patch(\"utils.hardware.hardware.get_device\", return_value = DeviceType.MLX):\n            clear_gpu_cache()\n\n    def test_noop_on_cpu(self):\n        with patch(\"utils.hardware.hardware.get_device\", return_value = DeviceType.CPU):\n            clear_gpu_cache()\n\n\n# ========== get_gpu_memory_info() ==========\n\n\nclass TestGetGpuMemoryInfo:\n    def test_returns_dict(self):\n        result = get_gpu_memory_info()\n        assert isinstance(result, dict)\n\n    def test_has_available_key(self):\n        assert \"available\" in get_gpu_memory_info()\n\n    def test_has_backend_key(self):\n        assert \"backend\" in get_gpu_memory_info()\n\n    def test_backend_matches_device(self):\n        result = get_gpu_memory_info()\n        assert result[\"backend\"] == get_device().value\n\n    # --- When a GPU IS available ---\n\n    @pytest.mark.skipif(\n        _actual_device() == \"cpu\", reason = \"No GPU available on this machine\"\n    )\n    def test_gpu_available_fields(self):\n        result = get_gpu_memory_info()\n        assert result[\"available\"] is True\n        assert result[\"total_gb\"] > 0\n        assert result[\"allocated_gb\"] >= 0\n        assert result[\"free_gb\"] >= 0\n        assert 0 <= result[\"utilization_pct\"] <= 100\n        assert \"device_name\" in result\n\n    # --- CUDA-specific mocked test ---\n\n    @needs_torch\n    def test_cuda_path_returns_correct_fields(self):\n        mock_props = MagicMock()\n        mock_props.total_memory = 16 * (1024**3)\n        mock_props.name = \"NVIDIA Test GPU\"\n\n        with (\n            patch(\"utils.hardware.hardware.get_device\", return_value = DeviceType.CUDA),\n            patch(\"torch.cuda.current_device\", return_value = 0),\n            patch(\"torch.cuda.get_device_properties\", return_value = mock_props),\n            patch(\"torch.cuda.memory_allocated\", return_value = 4 * (1024**3)),\n            patch(\"torch.cuda.memory_reserved\", return_value = 6 * (1024**3)),\n        ):\n            result = get_gpu_memory_info()\n\n        assert result[\"available\"] is True\n        assert result[\"backend\"] == \"cuda\"\n        assert result[\"device_name\"] == \"NVIDIA Test GPU\"\n        assert abs(result[\"total_gb\"] - 16.0) < 0.01\n        assert abs(result[\"allocated_gb\"] - 4.0) < 0.01\n        assert abs(result[\"free_gb\"] - 12.0) < 0.01\n        assert abs(result[\"utilization_pct\"] - 25.0) < 0.1\n\n    # --- MLX-specific mocked test ---\n\n    @needs_mlx\n    def test_mlx_path_returns_correct_fields(self):\n        mock_psutil_mem = MagicMock()\n        mock_psutil_mem.total = 32 * (1024**3)  # 32 GB unified\n\n        mock_psutil = MagicMock()\n        mock_psutil.virtual_memory.return_value = mock_psutil_mem\n\n        with (\n            patch(\"utils.hardware.hardware.get_device\", return_value = DeviceType.MLX),\n            patch.dict(\"sys.modules\", {\"psutil\": mock_psutil}),\n        ):\n            result = get_gpu_memory_info()\n\n        assert result[\"available\"] is True\n        assert result[\"backend\"] == \"mlx\"\n        assert \"Apple Silicon\" in result[\"device_name\"]\n        assert abs(result[\"total_gb\"] - 32.0) < 0.01\n\n    # --- CPU-only path ---\n\n    def test_cpu_path_returns_unavailable(self):\n        with patch(\"utils.hardware.hardware.get_device\", return_value = DeviceType.CPU):\n            result = get_gpu_memory_info()\n        assert result[\"available\"] is False\n        assert result[\"backend\"] == \"cpu\"\n\n    # --- Error resilience ---\n\n    @needs_torch\n    def test_cuda_error_returns_unavailable(self):\n        with (\n            patch(\"utils.hardware.hardware.get_device\", return_value = DeviceType.CUDA),\n            patch(\n                \"torch.cuda.current_device\",\n                side_effect = RuntimeError(\"CUDA init failed\"),\n            ),\n        ):\n            result = get_gpu_memory_info()\n        assert result[\"available\"] is False\n        assert \"error\" in result\n\n\n# ========== log_gpu_memory() ==========\n\n\nclass TestLogGpuMemory:\n    def test_does_not_raise(self):\n        log_gpu_memory(\"test\")\n\n    def test_logs_gpu_info_when_available(self, caplog):\n        fake_info = {\n            \"available\": True,\n            \"backend\": \"cuda\",\n            \"device_name\": \"FakeGPU\",\n            \"allocated_gb\": 2.0,\n            \"total_gb\": 16.0,\n            \"utilization_pct\": 12.5,\n            \"free_gb\": 14.0,\n        }\n        import structlog\n        from loggers import get_logger\n\n        with (\n            patch(\n                \"utils.hardware.hardware.get_gpu_memory_info\", return_value = fake_info\n            ),\n            caplog.at_level(logging.INFO, logger = \"utils.hardware.hardware\"),\n        ):\n            log_gpu_memory(\"unit-test\")\n\n        assert \"unit-test\" in caplog.text\n        assert \"CUDA\" in caplog.text\n        assert \"FakeGPU\" in caplog.text\n\n    def test_logs_cpu_fallback_when_no_gpu(self, caplog):\n        fake_info = {\"available\": False, \"backend\": \"cpu\"}\n        import structlog\n        from loggers import get_logger\n\n        with (\n            patch(\n                \"utils.hardware.hardware.get_gpu_memory_info\", return_value = fake_info\n            ),\n            caplog.at_level(logging.INFO, logger = \"utils.hardware.hardware\"),\n        ):\n            log_gpu_memory(\"cpu-test\")\n\n        assert \"No GPU available\" in caplog.text\n\n\n# ========== format_error_message() ==========\n\n\nclass TestFormatErrorMessage:\n    def test_not_found(self):\n        err = Exception(\"Repository not found for unsloth/test\")\n        msg = format_error_message(err, \"unsloth/test\")\n        assert \"not found\" in msg.lower()\n        assert \"test\" in msg\n\n    def test_unauthorized(self):\n        err = Exception(\"401 Unauthorized\")\n        msg = format_error_message(err, \"some/model\")\n        assert \"authentication\" in msg.lower() or \"unauthorized\" in msg.lower()\n\n    def test_gated_model(self):\n        err = Exception(\"Access to model requires authentication\")\n        msg = format_error_message(err, \"meta/llama\")\n        assert \"authentication\" in msg.lower()\n\n    def test_invalid_token(self):\n        err = Exception(\"Invalid user token\")\n        msg = format_error_message(err, \"any/model\")\n        assert \"invalid\" in msg.lower()\n\n    # --- OOM on CUDA ---\n\n    @needs_torch\n    def test_cuda_oom(self):\n        err = Exception(\"CUDA out of memory\")\n        with patch(\"utils.hardware.get_device\", return_value = DeviceType.CUDA):\n            msg = format_error_message(err, \"big/model\")\n        assert \"GPU\" in msg\n        assert \"big/model\" not in msg\n        assert \"model\" in msg\n\n    # --- OOM on MLX ---\n\n    @needs_mlx\n    def test_mlx_oom(self):\n        err = Exception(\"MLX backend out of memory\")\n        with patch(\"utils.hardware.get_device\", return_value = DeviceType.MLX):\n            msg = format_error_message(err, \"unsloth/huge-model\")\n        assert \"Apple Silicon\" in msg\n\n    # --- OOM on CPU ---\n\n    def test_cpu_oom(self):\n        err = Exception(\"not enough memory to allocate\")\n        with patch(\"utils.hardware.get_device\", return_value = DeviceType.CPU):\n            msg = format_error_message(err, \"any/model\")\n        assert \"system\" in msg.lower()\n\n    # --- Generic fallback ---\n\n    def test_generic_error(self):\n        err = Exception(\"Something completely unexpected\")\n        msg = format_error_message(err, \"any/model\")\n        assert msg == \"Something completely unexpected\"\n"
  },
  {
    "path": "studio/backend/utils/.gitkeep",
    "content": ""
  },
  {
    "path": "studio/backend/utils/__init__.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n"
  },
  {
    "path": "studio/backend/utils/cache_cleanup.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nUtility for cleaning up the Unsloth compiled cache directory.\n\nThe unsloth_compiled_cache is created by unsloth_zoo/compiler.py during\nFastModel.from_pretrained() and contains model-type-specific compiled Python\nfiles. It should be cleared between model loads to avoid stale artefacts.\n\"\"\"\n\nimport shutil\nimport structlog\nfrom loggers import get_logger\nfrom pathlib import Path\n\nlogger = get_logger(__name__)\n\n# Possible locations where unsloth_compiled_cache may appear\n_BACKEND_DIR = Path(__file__).resolve().parent.parent  # studio/backend\n_PROJECT_ROOT = _BACKEND_DIR.parent.parent  # repo root\n\n_CACHE_DIRS = [\n    _BACKEND_DIR / \"unsloth_compiled_cache\",\n    _PROJECT_ROOT / \"unsloth_compiled_cache\",\n    _PROJECT_ROOT / \"studio\" / \"tmp\" / \"unsloth_compiled_cache\",\n]\n\n\ndef clear_unsloth_compiled_cache() -> None:\n    \"\"\"Remove every known unsloth_compiled_cache directory (idempotent).\"\"\"\n    for cache_dir in _CACHE_DIRS:\n        if cache_dir.exists():\n            logger.info(f\"Removing unsloth compiled cache: {cache_dir}\")\n            shutil.rmtree(cache_dir, ignore_errors = True)\n"
  },
  {
    "path": "studio/backend/utils/datasets/__init__.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nDataset utilities package.\n\nThis package provides utilities for dataset format detection, conversion,\nand processing for LLM and VLM fine-tuning workflows.\n\nModules:\n- format_detection: Detect dataset formats (Alpaca, ShareGPT, ChatML)\n- format_conversion: Convert between dataset formats\n- chat_templates: Apply chat templates to datasets\n- vlm_processing: Vision-Language Model processing utilities\n- data_collators: Custom data collators for training\n- model_mappings: Model-to-template mapping constants\n\"\"\"\n\n# Format detection\nfrom .format_detection import (\n    detect_dataset_format,\n    detect_custom_format_heuristic,\n    detect_multimodal_dataset,\n    detect_vlm_dataset_structure,\n)\n\n# Format conversion\nfrom .format_conversion import (\n    standardize_chat_format,\n    convert_chatml_to_alpaca,\n    convert_alpaca_to_chatml,\n    convert_to_vlm_format,\n    convert_llava_to_vlm_format,\n    convert_sharegpt_with_images_to_vlm_format,\n)\n\n# Chat templates\nfrom .chat_templates import (\n    apply_chat_template_to_dataset,\n    get_dataset_info_summary,\n    get_tokenizer_chat_template,\n    DEFAULT_ALPACA_TEMPLATE,\n)\n\n# VLM processing\nfrom .vlm_processing import (\n    generate_smart_vlm_instruction,\n)\n\n# Data collators\nfrom .data_collators import (\n    DataCollatorSpeechSeq2SeqWithPadding,\n    DeepSeekOCRDataCollator,\n    VLMDataCollator,\n)\n\n# Model mappings (constants)\nfrom .model_mappings import (\n    TEMPLATE_TO_MODEL_MAPPER,\n    MODEL_TO_TEMPLATE_MAPPER,\n    TEMPLATE_TO_RESPONSES_MAPPER,\n)\n\n# Legacy imports from the original dataset_utils.py for backward compatibility\n# These functions have not yet been refactored into separate modules\nfrom .dataset_utils import (\n    check_dataset_format,\n    format_and_template_dataset,\n    format_dataset,\n)\n\n# Public API\n__all__ = [\n    # Detection\n    \"detect_dataset_format\",\n    \"detect_custom_format_heuristic\",\n    \"detect_multimodal_dataset\",\n    \"detect_vlm_dataset_structure\",\n    # Conversion\n    \"standardize_chat_format\",\n    \"convert_chatml_to_alpaca\",\n    \"convert_alpaca_to_chatml\",\n    \"convert_to_vlm_format\",\n    \"convert_llava_to_vlm_format\",\n    \"convert_sharegpt_with_images_to_vlm_format\",\n    # Templates\n    \"apply_chat_template_to_dataset\",\n    \"get_dataset_info_summary\",\n    \"get_tokenizer_chat_template\",\n    \"DEFAULT_ALPACA_TEMPLATE\",\n    # VLM\n    \"generate_smart_vlm_instruction\",\n    # Collators\n    \"DataCollatorSpeechSeq2SeqWithPadding\",\n    \"DeepSeekOCRDataCollator\",\n    \"VLMDataCollator\",\n    # Mappings\n    \"TEMPLATE_TO_MODEL_MAPPER\",\n    \"MODEL_TO_TEMPLATE_MAPPER\",\n    \"TEMPLATE_TO_RESPONSES_MAPPER\",\n    # Main entry points\n    \"check_dataset_format\",\n    \"format_and_template_dataset\",\n    \"format_dataset\",\n]\n"
  },
  {
    "path": "studio/backend/utils/datasets/chat_templates.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nChat template application utilities for dataset processing.\n\nThis module contains functions for applying chat templates to datasets\nand generating dataset info summaries.\n\"\"\"\n\nfrom torch.utils.data import IterableDataset\n\nfrom .format_detection import detect_dataset_format, detect_multimodal_dataset, detect_custom_format_heuristic\nfrom .model_mappings import MODEL_TO_TEMPLATE_MAPPER\nfrom loggers import get_logger\nlogger = get_logger(__name__)\n\n\n\n\nDEFAULT_ALPACA_TEMPLATE = \"\"\"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{}\n\n### Input:\n{}\n\n### Response:\n{}\"\"\"\n\n\ndef get_tokenizer_chat_template(tokenizer, model_name):\n    \"\"\"\n    Gets appropriate chat template for tokenizer based on model.\n    Uses Unsloth's get_chat_template if model is in the mapper.\n\n    Args:\n        tokenizer: HuggingFace tokenizer\n        model_name: Model class name (e.g., \"Gemma3ForCausalLM\")\n\n    Returns:\n        tokenizer: Tokenizer with appropriate chat template applied\n    \"\"\"\n    try:\n        from unsloth.chat_templates import get_chat_template\n    except ImportError:\n        # Unsloth not available, return tokenizer as-is\n        return tokenizer\n\n    # Normalize model_name to lowercase for matching\n    model_name_lower = model_name.lower()\n\n    # Check if model matches any template in mapper\n    matched_template = None\n\n    # Direct match in MODEL_TO_TEMPLATE_MAPPER\n    if model_name_lower in MODEL_TO_TEMPLATE_MAPPER:\n        matched_template = MODEL_TO_TEMPLATE_MAPPER[model_name_lower]\n        logger.info(f\"📝 Applying Unsloth chat template: {matched_template}\")\n        try:\n            tokenizer = get_chat_template(\n                tokenizer,\n                chat_template = matched_template,\n            )\n        except Exception as e:\n            logger.info(f\"⚠️ Failed to apply Unsloth template '{matched_template}': {e}\")\n            logger.info(f\"   Falling back to tokenizer's default chat template\")\n    else:\n        # Check if tokenizer actually has a chat_template set\n        has_chat_template = (\n            hasattr(tokenizer, 'chat_template')\n            and tokenizer.chat_template is not None\n        )\n        if has_chat_template:\n            logger.info(f\"📝 Using tokenizer's own chat template (no Unsloth template match)\")\n        else:\n            # Base model with no chat template — apply default ChatML\n            logger.info(f\"📝 No chat template found — applying default ChatML template (base model)\")\n            try:\n                tokenizer = get_chat_template(\n                    tokenizer,\n                    chat_template = \"chatml\",\n                )\n            except Exception as e:\n                logger.info(f\"⚠️ Failed to apply default ChatML template: {e}\")\n                logger.info(f\"   Falling back to tokenizer as-is\")\n\n    return tokenizer\n\n\ndef get_dataset_info_summary(dataset_info):\n    \"\"\"\n    Returns a human-readable summary for UI display.\n    \"\"\"\n    detected_format = dataset_info[\"detected_format\"]\n    final_format = dataset_info[\"final_format\"]\n\n    format_descriptions = {\n        \"alpaca\": \"Alpaca format (instruction/input/output)\",\n        \"sharegpt\": \"ShareGPT format (needs standardization)\",\n        \"chatml_messages\": \"ChatML format (messages column) - OpenAI compatible\",\n        \"chatml_conversations\": \"ChatML format (conversations column) - HuggingFace standard\",\n        \"unknown\": \"Unknown format\"\n    }\n\n    return {\n        \"detected_format\": detected_format,\n        \"final_format\": final_format,\n        \"detected_description\": format_descriptions.get(detected_format, \"Unknown\"),\n        \"final_description\": format_descriptions.get(final_format, \"Unknown\"),\n        \"chat_column\": dataset_info[\"chat_column\"],\n        \"is_standardized\": dataset_info[\"is_standardized\"],\n        \"warnings\": dataset_info.get(\"warnings\", []),\n        \"ready_for_training\": dataset_info[\"is_standardized\"] and final_format != \"unknown\"\n    }\n\n\ndef apply_chat_template_to_dataset(\n    dataset_info,\n    tokenizer,\n    model_name = None,\n    custom_prompt_template = None,\n    add_eos_token = False,\n    remove_bos_prefix = False,\n    custom_format_mapping = None,\n    auto_detect_mapping = True,\n    batch_size = 1000,\n    num_proc = None,\n    progress_callback = None,\n):\n    \"\"\"\n    Applies chat template to dataset based on its format.\n\n    Args:\n        dataset_info: Output from format_dataset() with metadata\n        tokenizer: Tokenizer with chat template\n        custom_prompt_template: Optional string template for custom formatting\n        add_eos_token: If True, appends tokenizer.eos_token to each text\n        remove_bos_prefix: If True, removes '<bos>' prefix (for Gemma, etc.)\n        custom_format_mapping: Dict mapping custom columns to standard format\n        batch_size: Batch size for processing\n        num_proc: Number of processes\n\n    Returns:\n        dict with dataset, success status, warnings, and errors\n    \"\"\"\n    dataset = dataset_info[\"dataset\"]\n    final_format = dataset_info[\"final_format\"]\n    chat_column = dataset_info[\"chat_column\"]\n    is_standardized = dataset_info[\"is_standardized\"]\n\n    warnings = list(dataset_info.get(\"warnings\", []))\n    errors = []\n\n    # Get EOS token if needed\n    eos_token = \"\"\n    if add_eos_token:\n        if hasattr(tokenizer, 'eos_token') and tokenizer.eos_token:\n            eos_token = tokenizer.eos_token\n        else:\n            warnings.append(\"add_eos_token=True but tokenizer has no eos_token\")\n\n    # CUSTOM FORMAT MAPPING (for non-standard datasets)\n    if final_format == \"unknown\":\n        # Try auto-detection if no custom mapping provided\n        if custom_format_mapping is None and auto_detect_mapping:\n            # Check if format_dataset already tried and failed\n            if not dataset_info.get(\"auto_detection_attempted\", False):\n                custom_format_mapping = detect_custom_format_heuristic(dataset)\n                if custom_format_mapping:\n                    warnings.append(f\"Auto-detected column mapping: {custom_format_mapping}\")\n                else:\n                    errors.append(\"Could not auto-detect format mapping\")\n                    return {\n                        \"dataset\": dataset,\n                        \"success\": False,\n                        \"warnings\": warnings,\n                        \"errors\": errors\n                    }\n            else:\n                # Already failed once in format_dataset, don't retry\n                errors.append(\n                    \"Format remains unknown after detection attempts. \"\n                    \"Please provide custom_format_mapping to specify column roles manually.\"\n                )\n                return {\n                    \"dataset\": dataset,\n                    \"success\": False,\n                    \"warnings\": warnings,\n                    \"errors\": errors\n                }\n\n        if custom_format_mapping:\n            warnings.append(f\"Applying custom format mapping: {custom_format_mapping}\")\n            is_user_provided = dataset_info.get(\"custom_format_mapping\") is not None\n\n            def _apply_custom_mapping(examples):\n                conversations = []\n                num_examples = len(examples[list(examples.keys())[0]])\n\n                # Only preserve unmapped columns if auto-detected\n                preserved_columns = {}\n                if not is_user_provided:\n                    all_columns = set(examples.keys())\n                    mapped_columns = set(custom_format_mapping.keys())\n                    non_mapped_columns = all_columns - mapped_columns\n\n                    for col in non_mapped_columns:\n                        preserved_columns[col] = examples[col]\n\n                for i in range(num_examples):\n                    convo = []\n                    role_order = ['system', 'user', 'assistant']\n\n                    for target_role in role_order:\n                        for col_name, role in custom_format_mapping.items():\n                            if role == target_role and col_name in examples:\n                                content = examples[col_name][i]\n\n                                if is_user_provided:\n                                    # User explicitly mapped - include even if empty\n                                    convo.append({\"role\": role, \"content\": str(content) if content else \"\"})\n                                else:\n                                    # Auto-detected - skip empty\n                                    if content and str(content).strip():\n                                        convo.append({\"role\": role, \"content\": str(content)})\n\n                    conversations.append(convo)\n\n                result = {\"conversations\": conversations}\n                if not is_user_provided:\n                    result.update(preserved_columns)\n                return result\n\n            try:\n                dataset = dataset.map(_apply_custom_mapping, batched = True, batch_size = batch_size)\n                # Update to use conversations format\n                final_format = \"chatml_conversations\"\n                chat_column = \"conversations\"\n                is_standardized = True\n                warnings.append(\"Successfully converted to ChatML format via custom mapping\")\n            except Exception as e:\n                errors.append(f\"Custom format mapping failed: {e}\")\n                return {\n                    \"dataset\": dataset,\n                    \"success\": False,\n                    \"warnings\": warnings,\n                    \"errors\": errors\n                }\n\n    # ALPACA FORMAT\n    if final_format == \"alpaca\":\n\n        # Set alpaca chat template on tokenizer for saving (if not already set)\n        # This ensures the template is saved with the model for inference\n        if not (hasattr(tokenizer, 'chat_template') and tokenizer.chat_template):\n            try:\n                from unsloth.chat_templates import get_chat_template\n                tokenizer = get_chat_template(tokenizer, chat_template = \"alpaca\")\n                logger.info(f\"📝 Set alpaca chat template on tokenizer for model saving\")\n            except Exception as e:\n                logger.info(f\"⚠️ Could not set alpaca template on tokenizer: {e}\")\n\n        # Use custom template if provided\n        def _format_alpaca_custom(examples):\n            texts = []\n            for i in range(len(examples[\"instruction\"])):\n                fields = {\n                    \"instruction\": examples[\"instruction\"][i],\n                    \"input\": examples.get(\"input\", [\"\"] * len(examples[\"instruction\"]))[i],\n                    \"output\": examples[\"output\"][i]\n                }\n\n                try:\n                    text = DEFAULT_ALPACA_TEMPLATE.format(fields[\"instruction\"], fields[\"input\"], fields[\"output\"])\n                    text += eos_token\n                    texts.append(text)\n                except KeyError as e:\n                    errors.append(f\"Custom template missing field: {e}\")\n                    texts.append(\"\")\n\n            return {\"text\": texts}\n\n        formatted_fn = _format_alpaca_custom\n\n        try:\n            dataset_map_kwargs = {\n                'batched': True,\n                'batch_size': batch_size,\n            }\n\n            if not isinstance(dataset, IterableDataset):\n                from utils.hardware import safe_num_proc\n                if num_proc is None or type(num_proc) is not int:\n                    num_proc = safe_num_proc()\n                else:\n                    num_proc = safe_num_proc(num_proc)\n                dataset_map_kwargs['num_proc'] = num_proc\n                dataset_map_kwargs['desc'] = \"Applying template to Alpaca format\"\n\n            formatted_dataset = dataset.map(formatted_fn, **dataset_map_kwargs)\n\n            return {\n                \"dataset\": formatted_dataset,\n                \"success\": True,\n                \"warnings\": warnings,\n                \"errors\": errors\n            }\n        except Exception as e:\n            errors.append(f\"Failed to format Alpaca dataset: {e}\")\n            return {\n                \"dataset\": dataset,\n                \"success\": False,\n                \"warnings\": warnings,\n                \"errors\": errors\n            }\n\n    # CHATML FORMATS\n    elif final_format in [\"chatml_messages\", \"chatml_conversations\"]:\n\n        if not is_standardized:\n            warnings.append(\"Dataset may not be fully standardized\")\n\n        # Apply Unsloth chat template if model matches\n        if model_name:\n            tokenizer = get_tokenizer_chat_template(tokenizer, model_name)\n\n        def _format_chatml(examples):\n            convos = examples[chat_column]\n            texts = []\n\n            for convo in convos:\n                try:\n                    text = tokenizer.apply_chat_template(\n                        convo,\n                        tokenize = False,\n                        add_generation_prompt = False\n                    )\n\n                    if remove_bos_prefix:\n                        text = text.removeprefix('<bos>')\n                    text += eos_token\n\n                    texts.append(text)\n                except Exception as e:\n                    if len(texts) == 0:\n                        warnings.append(f\"Chat template failed: {e}\")\n                    texts.append(\"\")\n\n            return {\"text\": texts}\n\n        try:\n            dataset_map_kwargs = {\n                'batched': True,\n                'batch_size': batch_size,\n            }\n\n            if not isinstance(dataset, IterableDataset):\n                from utils.hardware import safe_num_proc\n                if num_proc is None or type(num_proc) is not int:\n                    num_proc = safe_num_proc()\n                else:\n                    num_proc = safe_num_proc(num_proc)\n                dataset_map_kwargs['num_proc'] = num_proc\n                dataset_map_kwargs['desc'] = f\"Applying chat template to {final_format}\"\n\n            # Monitor tqdm progress from dataset.map() and relay to callback\n            _tqdm_monitor_stop = None\n            if progress_callback and not isinstance(dataset, IterableDataset):\n                import threading\n                from tqdm.auto import tqdm as _tqdm_cls\n\n                _tqdm_monitor_stop = threading.Event()\n                _total = len(dataset) if hasattr(dataset, \"__len__\") else 0\n                _desc = f\"Applying chat template to {final_format}\"\n\n                def _poll_tqdm():\n                    while not _tqdm_monitor_stop.is_set():\n                        for bar in list(getattr(_tqdm_cls, \"_instances\", set())):\n                            try:\n                                n = bar.n or 0\n                                total = bar.total or _total\n                                if total > 0 and n > 0:\n                                    pct = min(int(n * 100 / total), 100)\n                                    progress_callback(\n                                        status_message = f\"{_desc}... {pct}% ({n:,}/{total:,})\"\n                                    )\n                            except (AttributeError, ReferenceError):\n                                pass\n                        _tqdm_monitor_stop.wait(3)\n\n                threading.Thread(target = _poll_tqdm, daemon = True).start()\n\n            formatted_dataset = dataset.map(_format_chatml, **dataset_map_kwargs)\n\n            if _tqdm_monitor_stop is not None:\n                _tqdm_monitor_stop.set()\n\n            return {\n                \"dataset\": formatted_dataset,\n                \"success\": True,\n                \"warnings\": warnings,\n                \"errors\": errors\n            }\n        except Exception as e:\n            errors.append(f\"Failed to format ChatML dataset: {e}\")\n            return {\n                \"dataset\": dataset,\n                \"success\": False,\n                \"warnings\": warnings,\n                \"errors\": errors\n            }\n\n    # UNKNOWN FORMAT\n    else:\n        errors.append(\n            f\"Cannot apply chat template to format: {final_format}. \"\n            f\"This should not happen after custom mapping.\"\n        )\n        return {\n            \"dataset\": dataset,\n            \"success\": False,\n            \"warnings\": warnings,\n            \"errors\": errors\n        }\n"
  },
  {
    "path": "studio/backend/utils/datasets/data_collators.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nData collators for dataset processing.\n\nThis module contains custom data collators for training,\nparticularly for VLM/OCR processing.\n\"\"\"\n\nimport torch\nfrom dataclasses import dataclass\nfrom typing import Any, List, Optional, Union\nfrom loggers import get_logger\n\nlogger = get_logger(__name__)\n\n\n@dataclass\nclass DataCollatorSpeechSeq2SeqWithPadding:\n    \"\"\"\n    Data collator for Whisper speech-to-text training.\n\n    Pads input features (audio) and label sequences (text) separately,\n    masks padding in labels with -100, and strips leading BOS token.\n    Mirrors the collator from the Whisper.ipynb notebook.\n    \"\"\"\n\n    processor: Any\n\n    def __call__(self, features: List[dict]) -> dict:\n        input_features = [\n            {\"input_features\": feature[\"input_features\"]} for feature in features\n        ]\n        batch = self.processor.feature_extractor.pad(\n            input_features, return_tensors = \"pt\"\n        )\n\n        label_features = [{\"input_ids\": feature[\"labels\"]} for feature in features]\n        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors = \"pt\")\n\n        labels = labels_batch[\"input_ids\"].masked_fill(\n            labels_batch.attention_mask.ne(1), -100\n        )\n\n        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():\n            labels = labels[:, 1:]\n\n        batch[\"labels\"] = labels\n        return batch\n\n\n@dataclass\nclass DeepSeekOCRDataCollator:\n    \"\"\"\n    Data collator for DeepSeek OCR VLM training.\n\n    Handles:\n    - Image processing via processor\n    - Text tokenization\n    - Proper label masking for instruction fine-tuning\n    \"\"\"\n\n    processor: Any  # Qwen2VLProcessor or similar\n    max_length: int = 2048\n    ignore_index: int = -100\n\n    def __call__(self, batch: List[dict]) -> dict:\n        \"\"\"\n        Collate a batch of samples.\n\n        Args:\n            batch: List of dicts, each with 'messages' containing\n                   [{'role': 'user', 'content': [...]}, {'role': 'assistant', 'content': [...]}]\n\n        Returns:\n            dict with input_ids, attention_mask, labels, pixel_values, etc.\n        \"\"\"\n        from PIL import Image\n\n        # Extract messages and images\n        all_messages = []\n        all_images = []\n\n        for sample in batch:\n            messages = sample[\"messages\"]\n            all_messages.append(messages)\n\n            # Extract PIL images from content\n            for msg in messages:\n                content = msg.get(\"content\", [])\n                if isinstance(content, list):\n                    for item in content:\n                        if isinstance(item, dict) and item.get(\"type\") == \"image\":\n                            img = item.get(\"image\")\n                            if img is not None and hasattr(img, \"size\"):  # PIL Image\n                                all_images.append(img)\n\n        # Process with the VL processor\n        try:\n            # Qwen2VL style processing\n            texts = [\n                self.processor.apply_chat_template(\n                    msgs, tokenize = False, add_generation_prompt = False\n                )\n                for msgs in all_messages\n            ]\n\n            # Process with images\n            inputs = self.processor(\n                text = texts,\n                images = all_images if all_images else None,\n                return_tensors = \"pt\",\n                padding = True,\n                truncation = True,\n                max_length = self.max_length,\n            )\n\n            # Create labels (mask input, keep output)\n            labels = inputs[\"input_ids\"].clone()\n\n            # Simple masking: mask padding tokens\n            labels[labels == self.processor.tokenizer.pad_token_id] = self.ignore_index\n\n            inputs[\"labels\"] = labels\n\n            return inputs\n\n        except Exception as e:\n            logger.info(f\"⚠️ DeepSeekOCRDataCollator error: {e}\")\n            raise\n\n\n@dataclass\nclass VLMDataCollator:\n    \"\"\"\n    Generic VLM data collator that works with various processors.\n\n    Supports:\n    - Qwen2VL\n    - LLaVA\n    - Other VL models with compatible processors\n    \"\"\"\n\n    processor: Any\n    max_length: int = 2048\n    ignore_index: int = -100\n    mask_input_tokens: bool = True  # Whether to mask user tokens in labels\n\n    def __call__(self, batch: List[dict]) -> dict:\n        \"\"\"\n        Collate a batch of VLM samples.\n        \"\"\"\n        all_messages = []\n        all_images = []\n\n        for sample in batch:\n            messages = sample.get(\"messages\", [])\n            all_messages.append(messages)\n\n            # Extract images\n            for msg in messages:\n                content = msg.get(\"content\", [])\n                if isinstance(content, list):\n                    for item in content:\n                        if isinstance(item, dict):\n                            img = item.get(\"image\")\n                            if img is not None:\n                                all_images.append(img)\n\n        # Apply chat template\n        texts = [\n            self.processor.apply_chat_template(\n                msgs, tokenize = False, add_generation_prompt = False\n            )\n            for msgs in all_messages\n        ]\n\n        # Process inputs\n        inputs = self.processor(\n            text = texts,\n            images = all_images if all_images else None,\n            return_tensors = \"pt\",\n            padding = True,\n            truncation = True,\n            max_length = self.max_length,\n        )\n\n        # Create labels\n        labels = inputs[\"input_ids\"].clone()\n\n        # Mask padding\n        if hasattr(self.processor, \"tokenizer\"):\n            pad_token_id = self.processor.tokenizer.pad_token_id\n        else:\n            pad_token_id = self.processor.pad_token_id\n\n        if pad_token_id is not None:\n            labels[labels == pad_token_id] = self.ignore_index\n\n        inputs[\"labels\"] = labels\n\n        return inputs\n"
  },
  {
    "path": "studio/backend/utils/datasets/dataset_utils.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nDataset utilities for format detection, conversion, and template application.\n\nThis module provides the main entry points for dataset processing:\n- check_dataset_format: Lightweight check if manual mapping is needed (for frontend)\n- format_dataset: Detects and normalizes dataset formats\n- format_and_template_dataset: End-to-end processing with chat template application\n\nAll internal utilities have been moved to separate modules:\n- format_detection: detect_dataset_format, detect_multimodal_dataset, etc.\n- format_conversion: standardize_chat_format, convert_chatml_to_alpaca, etc.\n- chat_templates: apply_chat_template_to_dataset, get_tokenizer_chat_template, etc.\n- vlm_processing: generate_smart_vlm_instruction\n- data_collators: DeepSeekOCRDataCollator, VLMDataCollator\n- model_mappings: TEMPLATE_TO_MODEL_MAPPER\n\"\"\"\n\nimport json\n\n# Import from modular files\nfrom .format_detection import (\n    detect_dataset_format,\n    detect_multimodal_dataset,\n    detect_vlm_dataset_structure,\n    detect_custom_format_heuristic,\n)\nfrom .format_conversion import (\n    standardize_chat_format,\n    convert_chatml_to_alpaca,\n    convert_alpaca_to_chatml,\n    convert_to_vlm_format,\n    convert_llava_to_vlm_format,\n    convert_sharegpt_with_images_to_vlm_format,\n)\nfrom .chat_templates import (\n    apply_chat_template_to_dataset,\n    get_dataset_info_summary,\n    get_tokenizer_chat_template,\n    DEFAULT_ALPACA_TEMPLATE,\n)\nfrom .vlm_processing import generate_smart_vlm_instruction\nfrom .data_collators import DeepSeekOCRDataCollator, VLMDataCollator\nfrom .model_mappings import TEMPLATE_TO_MODEL_MAPPER\nfrom loggers import get_logger\n\nlogger = get_logger(__name__)\n\n\ndef check_dataset_format(dataset, is_vlm: bool = False) -> dict:\n    \"\"\"\n    Lightweight format check without processing - for frontend validation.\n\n    Use this to quickly determine if user needs to manually map columns\n    before calling the full format_and_template_dataset().\n\n    Args:\n        dataset: HuggingFace dataset\n        is_vlm: Whether this is a Vision-Language Model dataset\n\n    Returns:\n        dict: {\n            \"requires_manual_mapping\": bool - True if user must map columns,\n            \"detected_format\": str - The detected format,\n            \"columns\": list - Available column names for mapping UI,\n            \"suggested_mapping\": dict or None - Auto-detected mapping if available,\n            \"detected_image_column\": str or None - For VLM only,\n            \"detected_text_column\": str or None - For VLM only,\n        }\n    \"\"\"\n    columns = (\n        list(dataset.column_names)\n        if hasattr(dataset, \"column_names\")\n        else list(next(iter(dataset)).keys())\n    )\n\n    # Auto-detect multimodal data regardless of is_vlm flag\n    multimodal_info = detect_multimodal_dataset(dataset)\n    is_audio = multimodal_info.get(\"is_audio\", False)\n\n    # Common audio fields for all return paths\n    audio_fields = {\n        \"is_audio\": is_audio,\n        \"detected_audio_column\": multimodal_info.get(\"detected_audio_column\"),\n        \"detected_speaker_column\": multimodal_info.get(\"detected_speaker_column\"),\n    }\n\n    if is_vlm:\n        vlm_structure = detect_vlm_dataset_structure(dataset)\n        requires_mapping = vlm_structure[\"format\"] == \"unknown\"\n\n        warning = None\n        if requires_mapping:\n            img_col = vlm_structure.get(\"image_column\")\n            txt_col = vlm_structure.get(\"text_column\")\n            missing = []\n            if not img_col:\n                missing.append(\"image\")\n            if not txt_col:\n                missing.append(\"text\")\n            if missing:\n                warning = (\n                    f\"Could not auto-detect {' or '.join(missing)} column. \"\n                    \"Please assign image and text columns manually.\"\n                )\n\n        return {\n            \"requires_manual_mapping\": requires_mapping,\n            \"detected_format\": vlm_structure[\"format\"],\n            \"columns\": columns,\n            \"suggested_mapping\": None,\n            \"detected_image_column\": vlm_structure.get(\"image_column\"),\n            \"detected_text_column\": vlm_structure.get(\"text_column\"),\n            \"is_image\": multimodal_info[\"is_image\"],\n            \"multimodal_columns\": multimodal_info.get(\"multimodal_columns\"),\n            \"warning\": warning,\n            **audio_fields,\n        }\n\n    if is_audio:\n        # Audio dataset — require manual mapping only when columns can't be auto-detected\n        detected_audio = multimodal_info.get(\"detected_audio_column\")\n        detected_text = multimodal_info.get(\"detected_text_column\")\n        needs_mapping = not detected_audio or not detected_text\n        return {\n            \"requires_manual_mapping\": needs_mapping,\n            \"detected_format\": \"audio\",\n            \"columns\": columns,\n            \"suggested_mapping\": None,\n            \"detected_image_column\": None,\n            \"detected_text_column\": multimodal_info.get(\"detected_text_column\"),\n            \"is_image\": False,\n            \"multimodal_columns\": multimodal_info.get(\"audio_columns\"),\n            **audio_fields,\n        }\n\n    # Text / LLM flow\n    detected = detect_dataset_format(dataset)\n\n    # If format is unknown, try heuristic detection\n    if detected[\"format\"] == \"unknown\":\n        heuristic_mapping = detect_custom_format_heuristic(dataset)\n        if heuristic_mapping:\n            return {\n                \"requires_manual_mapping\": False,\n                \"detected_format\": \"custom_heuristic\",\n                \"columns\": columns,\n                \"suggested_mapping\": heuristic_mapping,\n                \"detected_image_column\": None,\n                \"detected_text_column\": None,\n                \"is_image\": multimodal_info[\"is_image\"],\n                \"multimodal_columns\": multimodal_info.get(\"multimodal_columns\"),\n                **audio_fields,\n            }\n        else:\n            # Heuristic failed — user must map manually (or use AI Assist)\n            return {\n                \"requires_manual_mapping\": True,\n                \"detected_format\": \"unknown\",\n                \"columns\": columns,\n                \"suggested_mapping\": None,\n                \"detected_image_column\": None,\n                \"detected_text_column\": None,\n                \"is_image\": multimodal_info[\"is_image\"],\n                \"multimodal_columns\": multimodal_info.get(\"multimodal_columns\"),\n                \"warning\": (\n                    f\"Could not auto-detect column roles for columns: {columns}. \"\n                    \"Please assign roles manually, or use AI Assist.\"\n                ),\n                **audio_fields,\n            }\n\n    # Known format detected\n    return {\n        \"requires_manual_mapping\": False,\n        \"detected_format\": detected[\"format\"],\n        \"columns\": columns,\n        \"suggested_mapping\": None,\n        \"detected_image_column\": None,\n        \"detected_text_column\": None,\n        \"is_image\": multimodal_info[\"is_image\"],\n        \"multimodal_columns\": multimodal_info.get(\"multimodal_columns\"),\n        **audio_fields,\n    }\n\n\n# Normalise any format-specific role to canonical chatml (user/assistant/system)\n_TO_CHATML = {\n    \"user\": \"user\",\n    \"human\": \"user\",\n    \"instruction\": \"user\",\n    \"assistant\": \"assistant\",\n    \"gpt\": \"assistant\",\n    \"output\": \"assistant\",\n    \"system\": \"system\",\n    \"input\": \"system\",\n}\n_CHATML_ROLE_ORDER = (\"system\", \"user\", \"assistant\")\n_CHATML_TO_ALPACA = {\"user\": \"instruction\", \"system\": \"input\", \"assistant\": \"output\"}\n\n\ndef _apply_user_mapping(dataset, mapping: dict, batch_size: int = 1000):\n    \"\"\"\n    Apply user-provided column mapping to convert dataset to conversations format.\n\n    Accepts chatml (user/assistant/system), sharegpt (human/gpt/system), and\n    alpaca (instruction/input/output) role names — all normalised to chatml output.\n\n    If the mapping contains ``__``-prefixed metadata keys (from the conversion\n    advisor), routes to template-based conversion instead of simple role mapping.\n\n    Returns:\n        Dataset with single 'conversations' column\n    \"\"\"\n    # Split metadata from column roles\n    meta = {k: v for k, v in mapping.items() if k.startswith(\"__\")}\n    column_roles = {k: v for k, v in mapping.items() if not k.startswith(\"__\")}\n\n    if meta:\n        return _apply_template_mapping(dataset, column_roles, meta, batch_size)\n\n    # ── Simple mode (original logic) ──\n    # Pre-compute: group columns by canonical chatml role\n    role_groups: dict[str, list[str]] = {r: [] for r in _CHATML_ROLE_ORDER}\n    for col_name, role in column_roles.items():\n        canonical = _TO_CHATML.get(role)\n        if canonical:\n            role_groups[canonical].append(col_name)\n\n    def _convert(examples):\n        num = len(next(iter(examples.values())))\n        conversations = []\n        for i in range(num):\n            convo = []\n            for chatml_role in _CHATML_ROLE_ORDER:\n                for col in role_groups[chatml_role]:\n                    if col in examples:\n                        content = examples[col][i]\n                        convo.append(\n                            {\n                                \"role\": chatml_role,\n                                \"content\": str(content) if content else \"\",\n                            }\n                        )\n            conversations.append(convo)\n        return {\"conversations\": conversations}\n\n    return dataset.map(\n        _convert,\n        batched = True,\n        batch_size = batch_size,\n        remove_columns = dataset.column_names,\n    )\n\n\ndef _extract_column_value(val, col: str, label_mapping: dict) -> str:\n    \"\"\"Extract a string value from a column, handling complex types and label mapping.\"\"\"\n    # Handle complex types (dicts, lists) — extract useful text instead of raw repr\n    if isinstance(val, dict):\n        # Common pattern: {\"text\": [...]} in QA datasets\n        if \"text\" in val:\n            inner = val[\"text\"]\n            str_val = inner[0] if isinstance(inner, list) and inner else str(inner)\n        else:\n            str_val = json.dumps(val, ensure_ascii = False)\n    elif isinstance(val, list):\n        str_val = val[0] if len(val) == 1 else \", \".join(str(v) for v in val)\n    else:\n        str_val = str(val) if val is not None else \"\"\n\n    # Apply label mapping if this column has one\n    if col in label_mapping and isinstance(label_mapping[col], dict):\n        str_val = label_mapping[col].get(str_val, str_val)\n\n    return str_val\n\n\ndef _apply_template_mapping(\n    dataset, column_roles: dict, meta: dict, batch_size: int = 1000\n):\n    \"\"\"\n    Apply advisor-driven mapping for non-conversational datasets.\n\n    Groups columns by their assigned role (user/assistant), concatenates\n    values within each role into a single message, and injects an optional\n    system prompt.  Label mapping is applied to convert integer labels\n    to human-readable strings.\n\n    Returns:\n        Dataset with single 'conversations' column\n    \"\"\"\n    system_prompt = meta.get(\"__system_prompt\", \"\")\n    label_mapping = meta.get(\"__label_mapping\", {})  # {col: {int_str: label_str}}\n\n    # Group columns by canonical chatml role\n    role_groups: dict[str, list[str]] = {\"user\": [], \"assistant\": []}\n    for col, role in column_roles.items():\n        canonical = _TO_CHATML.get(role, role)\n        if canonical in role_groups:\n            role_groups[canonical].append(col)\n\n    import logging as _log\n\n    _log.getLogger(__name__).info(\n        f\"Applying role mapping: sys={bool(system_prompt)}, \"\n        f\"user_cols={role_groups['user']}, asst_cols={role_groups['assistant']}, \"\n        f\"label_map={list(label_mapping.keys())}\"\n    )\n\n    def _convert(examples):\n        num = len(next(iter(examples.values())))\n        conversations = []\n        for i in range(num):\n            convo = []\n\n            # System prompt (generated, static across all rows)\n            if system_prompt:\n                convo.append({\"role\": \"system\", \"content\": system_prompt})\n\n            # User message: concatenate all user-role column values\n            user_parts = []\n            for col in role_groups[\"user\"]:\n                if col in examples:\n                    user_parts.append(\n                        _extract_column_value(examples[col][i], col, label_mapping)\n                    )\n            if user_parts:\n                convo.append({\"role\": \"user\", \"content\": \"\\n\".join(user_parts)})\n\n            # Assistant message: concatenate all assistant-role column values\n            asst_parts = []\n            for col in role_groups[\"assistant\"]:\n                if col in examples:\n                    asst_parts.append(\n                        _extract_column_value(examples[col][i], col, label_mapping)\n                    )\n            if asst_parts:\n                convo.append({\"role\": \"assistant\", \"content\": \"\\n\".join(asst_parts)})\n\n            conversations.append(convo)\n        return {\"conversations\": conversations}\n\n    return dataset.map(\n        _convert,\n        batched = True,\n        batch_size = batch_size,\n        remove_columns = dataset.column_names,\n    )\n\n\ndef _apply_user_mapping_alpaca(dataset, mapping: dict, batch_size: int = 1000):\n    \"\"\"\n    Apply user-provided column mapping to convert dataset to Alpaca format.\n\n    Accepts any format's role names — normalises via _TO_CHATML, then maps\n    user → instruction, system → input, assistant → output.\n\n    Returns:\n        Dataset with instruction/input/output columns\n    \"\"\"\n    col_for: dict[str, str | None] = {\n        \"instruction\": None,\n        \"input\": None,\n        \"output\": None,\n    }\n    for col_name, role in mapping.items():\n        canonical = _TO_CHATML.get(role)\n        alpaca_field = _CHATML_TO_ALPACA.get(canonical) if canonical else None\n        if alpaca_field:\n            col_for[alpaca_field] = col_name\n\n    def _convert(examples):\n        num = len(next(iter(examples.values())))\n        instructions, inputs, outputs = [], [], []\n        for i in range(num):\n            for field, dest in (\n                (\"instruction\", instructions),\n                (\"input\", inputs),\n                (\"output\", outputs),\n            ):\n                col = col_for[field]\n                val = (\n                    str(examples[col][i])\n                    if col and col in examples and examples[col][i]\n                    else \"\"\n                )\n                dest.append(val)\n        return {\"instruction\": instructions, \"input\": inputs, \"output\": outputs}\n\n    return dataset.map(\n        _convert,\n        batched = True,\n        batch_size = batch_size,\n        remove_columns = dataset.column_names,\n    )\n\n\ndef format_dataset(\n    dataset,\n    format_type = \"auto\",\n    tokenizer = None,\n    aliases_for_system = [\n        \"system\",\n    ],\n    aliases_for_user = [\n        \"user\",\n        \"human\",\n        \"input\",\n    ],\n    aliases_for_assistant = [\n        \"gpt\",\n        \"assistant\",\n        \"output\",\n    ],\n    batch_size = 1000,\n    num_proc = None,\n    auto_detect_custom = True,\n    custom_format_mapping = None,\n):\n    \"\"\"\n    Formats dataset and returns metadata.\n\n    Returns:\n        dict: {\n            \"dataset\": processed dataset,\n            \"detected_format\": original format detected,\n            \"final_format\": final format after processing,\n            \"chat_column\": column name with chat data,\n            \"is_standardized\": whether role names are standardized,\n            \"requires_manual_mapping\": True if format detection failed and user must map columns,\n            \"warnings\": list of warning messages\n        }\n    \"\"\"\n\n    # Detect multimodal first (needed for all flows)\n    multimodal_info = detect_multimodal_dataset(dataset)\n\n    # If user provided explicit mapping, skip detection and apply in the requested format\n    if custom_format_mapping:\n        try:\n            if format_type == \"alpaca\":\n                mapped_dataset = _apply_user_mapping_alpaca(\n                    dataset, custom_format_mapping, batch_size\n                )\n                final_format = \"alpaca\"\n                chat_column = None\n            else:\n                # auto / chatml / sharegpt / conversational — all produce chatml conversations\n                # (sharegpt is always standardized to role/content internally)\n                mapped_dataset = _apply_user_mapping(\n                    dataset, custom_format_mapping, batch_size\n                )\n                final_format = \"chatml_conversations\"\n                chat_column = \"conversations\"\n\n            return {\n                \"dataset\": mapped_dataset,\n                \"detected_format\": \"user_mapped\",\n                \"final_format\": final_format,\n                \"chat_column\": chat_column,\n                \"is_standardized\": True,\n                \"requires_manual_mapping\": False,\n                \"is_image\": multimodal_info[\"is_image\"],\n                \"multimodal_info\": multimodal_info,\n                \"warnings\": [\n                    f\"Applied user-provided column mapping ({format_type}): {custom_format_mapping}\"\n                ],\n            }\n        except Exception as e:\n            return {\n                \"dataset\": dataset,\n                \"detected_format\": \"user_mapped\",\n                \"final_format\": \"unknown\",\n                \"chat_column\": None,\n                \"is_standardized\": False,\n                \"requires_manual_mapping\": True,\n                \"is_image\": multimodal_info[\"is_image\"],\n                \"multimodal_info\": multimodal_info,\n                \"warnings\": [f\"Failed to apply user mapping: {e}\"],\n            }\n\n    # Detect current format\n    detected = detect_dataset_format(dataset)\n    warnings = []\n\n    # Add multimodal warning if detected\n    if multimodal_info[\"is_image\"]:\n        warnings.append(\n            f\"Multimodal dataset detected. Found columns: {multimodal_info['multimodal_columns']}\"\n        )\n\n    # AUTO MODE: Keep format but standardize if needed\n    if format_type == \"auto\":\n        # Alpaca - keep as is\n        if detected[\"format\"] == \"alpaca\":\n            return {\n                \"dataset\": dataset,\n                \"detected_format\": \"alpaca\",\n                \"final_format\": \"alpaca\",\n                \"chat_column\": None,\n                \"is_standardized\": True,\n                \"requires_manual_mapping\": False,\n                \"is_image\": multimodal_info[\"is_image\"],\n                \"multimodal_info\": multimodal_info,\n                \"warnings\": [],\n            }\n\n        # ShareGPT - needs standardization\n        elif detected[\"format\"] == \"sharegpt\":\n            try:\n                standardized = standardize_chat_format(\n                    dataset,\n                    tokenizer,\n                    aliases_for_system,\n                    aliases_for_user,\n                    aliases_for_assistant,\n                    batch_size,\n                    num_proc,\n                )\n                return {\n                    \"dataset\": standardized,\n                    \"detected_format\": \"sharegpt\",\n                    \"final_format\": f\"chatml_{detected['chat_column']}\",\n                    \"chat_column\": detected[\"chat_column\"],\n                    \"is_standardized\": True,\n                    \"requires_manual_mapping\": False,\n                    \"is_image\": multimodal_info[\"is_image\"],\n                    \"multimodal_info\": multimodal_info,\n                    \"warnings\": [],\n                }\n            except Exception as e:\n                warnings.append(f\"Failed to standardize ShareGPT format: {e}\")\n                return {\n                    \"dataset\": dataset,\n                    \"detected_format\": \"sharegpt\",\n                    \"final_format\": \"sharegpt\",\n                    \"chat_column\": detected[\"chat_column\"],\n                    \"is_standardized\": False,\n                    \"requires_manual_mapping\": True,\n                    \"is_image\": multimodal_info[\"is_image\"],\n                    \"multimodal_info\": multimodal_info,\n                    \"warnings\": warnings,\n                }\n\n        elif detected[\"format\"] == \"chatml\" and detected[\"chat_column\"] in [\n            \"conversations\",\n            \"messages\",\n            \"texts\",\n        ]:\n            return {\n                \"dataset\": dataset,\n                \"detected_format\": f\"chatml_{detected['chat_column']}\",\n                \"final_format\": f\"chatml_{detected['chat_column']}\",\n                \"chat_column\": detected[\"chat_column\"],\n                \"is_standardized\": True,\n                \"requires_manual_mapping\": False,\n                \"is_image\": multimodal_info[\"is_image\"],\n                \"multimodal_info\": multimodal_info,\n                \"warnings\": warnings,\n            }\n\n        # Unknown - try standardization, if fails pass as is\n        else:\n            warnings.append(\n                f\"Unknown format detected. Keys found: {detected['sample_keys']}\"\n            )\n\n            # NEW: Try heuristic detection\n            if auto_detect_custom:\n                custom_mapping = detect_custom_format_heuristic(dataset)\n                if custom_mapping:\n                    warnings.append(f\"Auto-detected column mapping: {custom_mapping}\")\n\n                    def _apply_auto_mapping(examples):\n                        conversations = []\n                        num_examples = len(examples[list(examples.keys())[0]])\n\n                        # Preserve non-mapped columns\n                        all_columns = set(examples.keys())\n                        mapped_columns = set(custom_mapping.keys())\n                        preserved_columns = {\n                            col: examples[col] for col in all_columns - mapped_columns\n                        }\n\n                        for i in range(num_examples):\n                            convo = []\n                            for target_role in [\"system\", \"user\", \"assistant\"]:\n                                for col_name, role in custom_mapping.items():\n                                    if role == target_role and col_name in examples:\n                                        content = examples[col_name][i]\n                                        if content and str(content).strip():\n                                            convo.append(\n                                                {\"role\": role, \"content\": str(content)}\n                                            )\n                            conversations.append(convo)\n\n                        return {\"conversations\": conversations, **preserved_columns}\n\n                    try:\n                        dataset = dataset.map(\n                            _apply_auto_mapping, batched = True, batch_size = batch_size\n                        )\n                        return {\n                            \"dataset\": dataset,\n                            \"detected_format\": \"unknown\",\n                            \"final_format\": \"chatml_conversations\",\n                            \"chat_column\": \"conversations\",\n                            \"is_standardized\": True,\n                            \"requires_manual_mapping\": False,\n                            \"is_image\": multimodal_info[\"is_image\"],\n                            \"multimodal_info\": multimodal_info,\n                            \"warnings\": warnings,\n                        }\n                    except Exception as e:\n                        warnings.append(f\"Auto-detection failed: {e}\")\n\n            # Try standardization as a last resort\n            if detected[\"chat_column\"]:\n                try:\n                    standardized = standardize_chat_format(\n                        dataset,\n                        tokenizer,\n                        aliases_for_system,\n                        aliases_for_user,\n                        aliases_for_assistant,\n                        batch_size,\n                        num_proc,\n                    )\n                    warnings.append(\"Successfully standardized unknown format\")\n                    return {\n                        \"dataset\": standardized,\n                        \"detected_format\": \"unknown\",\n                        \"final_format\": f\"chatml_{detected['chat_column']}\",\n                        \"chat_column\": detected[\"chat_column\"],\n                        \"is_standardized\": True,\n                        \"requires_manual_mapping\": False,\n                        \"is_image\": multimodal_info[\"is_image\"],\n                        \"multimodal_info\": multimodal_info,\n                        \"warnings\": warnings,\n                    }\n                except Exception as e:\n                    warnings.append(\n                        f\"Could not standardize: {e}. Passing dataset as-is.\"\n                    )\n\n            # Return as-is with warnings\n            return {\n                \"dataset\": dataset,\n                \"detected_format\": \"unknown\",\n                \"final_format\": \"unknown\",\n                \"chat_column\": detected[\"chat_column\"],\n                \"is_standardized\": False,\n                \"requires_manual_mapping\": True,\n                \"is_image\": multimodal_info[\"is_image\"],\n                \"multimodal_info\": multimodal_info,\n                \"warnings\": warnings,\n            }\n\n    # ALPACA MODE: Convert to Alpaca\n    elif format_type == \"alpaca\":\n        if detected[\"format\"] == \"alpaca\":\n            return {\n                \"dataset\": dataset,\n                \"detected_format\": \"alpaca\",\n                \"final_format\": \"alpaca\",\n                \"chat_column\": None,\n                \"is_standardized\": True,\n                \"requires_manual_mapping\": False,\n                \"is_image\": multimodal_info[\"is_image\"],\n                \"multimodal_info\": multimodal_info,\n                \"warnings\": [],\n            }\n\n        elif detected[\"format\"] in [\"sharegpt\", \"chatml\"]:\n            # First standardize if ShareGPT\n            if detected[\"format\"] == \"sharegpt\":\n                dataset = standardize_chat_format(\n                    dataset,\n                    tokenizer,\n                    aliases_for_system,\n                    aliases_for_user,\n                    aliases_for_assistant,\n                    batch_size,\n                    num_proc,\n                )\n\n            # Then convert to Alpaca\n            converted = convert_chatml_to_alpaca(dataset, batch_size, num_proc)\n            return {\n                \"dataset\": converted,\n                \"detected_format\": detected[\"format\"],\n                \"final_format\": \"alpaca\",\n                \"chat_column\": None,\n                \"is_standardized\": True,\n                \"requires_manual_mapping\": False,\n                \"is_image\": multimodal_info[\"is_image\"],\n                \"multimodal_info\": multimodal_info,\n                \"warnings\": [],\n            }\n\n        else:\n            warnings.append(f\"Cannot convert unknown format to Alpaca\")\n            return {\n                \"dataset\": dataset,\n                \"detected_format\": \"unknown\",\n                \"final_format\": \"unknown\",\n                \"chat_column\": detected[\"chat_column\"],\n                \"is_standardized\": False,\n                \"requires_manual_mapping\": True,\n                \"is_image\": multimodal_info[\"is_image\"],\n                \"multimodal_info\": multimodal_info,\n                \"warnings\": warnings,\n            }\n\n    # CHATML MODE: Convert to ChatML\n    elif format_type in [\"chatml\", \"conversational\", \"sharegpt\"]:\n        if detected[\"format\"] == \"alpaca\":\n            converted = convert_alpaca_to_chatml(dataset, batch_size, num_proc)\n            return {\n                \"dataset\": converted,\n                \"detected_format\": \"alpaca\",\n                \"final_format\": \"chatml_conversations\",\n                \"chat_column\": \"conversations\",\n                \"is_standardized\": True,\n                \"requires_manual_mapping\": False,\n                \"is_image\": multimodal_info[\"is_image\"],\n                \"multimodal_info\": multimodal_info,\n                \"warnings\": [],\n            }\n\n        elif detected[\"format\"] == \"sharegpt\":\n            standardized = standardize_chat_format(\n                dataset,\n                tokenizer,\n                aliases_for_system,\n                aliases_for_user,\n                aliases_for_assistant,\n                batch_size,\n                num_proc,\n            )\n            return {\n                \"dataset\": standardized,\n                \"detected_format\": \"sharegpt\",\n                \"final_format\": f\"chatml_{detected['chat_column']}\",\n                \"chat_column\": detected[\"chat_column\"],\n                \"is_standardized\": True,\n                \"requires_manual_mapping\": False,\n                \"is_image\": multimodal_info[\"is_image\"],\n                \"multimodal_info\": multimodal_info,\n                \"warnings\": [],\n            }\n\n        elif detected[\"format\"] == \"chatml\":\n            return {\n                \"dataset\": dataset,\n                \"detected_format\": f\"chatml_{detected['chat_column']}\",\n                \"final_format\": f\"chatml_{detected['chat_column']}\",\n                \"chat_column\": detected[\"chat_column\"],\n                \"is_standardized\": True,\n                \"requires_manual_mapping\": False,\n                \"is_image\": multimodal_info[\"is_image\"],\n                \"multimodal_info\": multimodal_info,\n                \"warnings\": [],\n            }\n\n        else:\n            warnings.append(f\"Unknown format, attempting standardization\")\n            if detected[\"chat_column\"]:\n                try:\n                    standardized = standardize_chat_format(\n                        dataset,\n                        tokenizer,\n                        aliases_for_system,\n                        aliases_for_user,\n                        aliases_for_assistant,\n                        batch_size,\n                        num_proc,\n                    )\n                    return {\n                        \"dataset\": standardized,\n                        \"detected_format\": \"unknown\",\n                        \"final_format\": f\"chatml_{detected['chat_column']}\",\n                        \"chat_column\": detected[\"chat_column\"],\n                        \"is_standardized\": True,\n                        \"requires_manual_mapping\": False,\n                        \"is_image\": multimodal_info[\"is_image\"],\n                        \"multimodal_info\": multimodal_info,\n                        \"warnings\": warnings,\n                    }\n                except Exception as e:\n                    warnings.append(f\"Standardization failed: {e}\")\n\n            return {\n                \"dataset\": dataset,\n                \"detected_format\": \"unknown\",\n                \"final_format\": \"unknown\",\n                \"chat_column\": detected[\"chat_column\"],\n                \"is_standardized\": False,\n                \"requires_manual_mapping\": True,\n                \"is_image\": multimodal_info[\"is_image\"],\n                \"multimodal_info\": multimodal_info,\n                \"warnings\": warnings,\n            }\n\n    else:\n        raise ValueError(f\"Unknown format_type: {format_type}\")\n\n\ndef format_and_template_dataset(\n    dataset,\n    model_name,\n    tokenizer,\n    is_vlm = False,\n    format_type = \"auto\",\n    # VLM-specific parameters\n    vlm_instruction = None,  # Now optional - will auto-generate\n    vlm_text_column = None,\n    vlm_image_column = None,\n    dataset_name = None,\n    custom_prompt_template = None,\n    add_eos_token = False,\n    remove_bos_prefix = False,\n    custom_format_mapping = None,\n    auto_detect_custom = True,\n    auto_detect_mapping = True,\n    aliases_for_system = [\n        \"system\",\n    ],\n    aliases_for_user = [\n        \"user\",\n        \"human\",\n        \"input\",\n    ],\n    aliases_for_assistant = [\n        \"gpt\",\n        \"assistant\",\n        \"output\",\n    ],\n    batch_size = 1000,\n    num_proc = None,\n    progress_callback = None,\n):\n    \"\"\"\n    Convenience function that combines format_dataset and apply_chat_template_to_dataset.\n    Perfect for UI workflows - one function does everything!\n\n    Returns:\n        dict: {\n            \"dataset\": Final dataset with 'text' column,\n            \"detected_format\": Original format,\n            \"final_format\": Format after processing,\n            \"success\": Whether template application succeeded,\n            \"requires_manual_mapping\": True if format detection failed and user must map columns,\n            \"warnings\": List of warnings,\n            \"errors\": List of errors,\n            \"summary\": Human-readable summary\n        }\n    \"\"\"\n\n    # VLM FLOW\n    if is_vlm:\n        warnings = []\n        errors = []\n\n        multimodal_info = detect_multimodal_dataset(dataset)\n\n        # NEW: If user provided explicit mapping for VLM, use it directly\n        if custom_format_mapping:\n            # Expect mapping like: {\"image_col\": \"image\", \"caption_col\": \"text\"}\n            user_vlm_image_column = None\n            user_vlm_text_column = None\n\n            for col, role in custom_format_mapping.items():\n                if role == \"image\":\n                    user_vlm_image_column = col\n                elif role in [\"text\", \"user\", \"caption\", \"assistant\"]:\n                    user_vlm_text_column = col\n\n            if user_vlm_image_column and user_vlm_text_column:\n                try:\n                    dataset = convert_to_vlm_format(\n                        dataset,\n                        instruction = vlm_instruction,\n                        text_column = user_vlm_text_column,\n                        image_column = user_vlm_image_column,\n                        dataset_name = dataset_name,\n                        progress_callback = progress_callback,\n                    )\n                    warnings.append(\n                        f\"Applied user VLM mapping: image='{user_vlm_image_column}', text='{user_vlm_text_column}'\"\n                    )\n\n                    return {\n                        \"dataset\": dataset,\n                        \"detected_format\": \"user_mapped\",\n                        \"final_format\": \"vlm_messages\",\n                        \"chat_column\": \"messages\",\n                        \"is_vlm\": True,\n                        \"is_image\": True,\n                        \"multimodal_info\": multimodal_info,\n                        \"success\": True,\n                        \"requires_manual_mapping\": False,\n                        \"warnings\": warnings,\n                        \"errors\": [],\n                    }\n                except Exception as e:\n                    # User mapping failed — fall back to auto-detection instead\n                    # of giving up (handles stale cached mappings gracefully)\n                    warnings.append(\n                        f\"User VLM mapping (image='{user_vlm_image_column}', \"\n                        f\"text='{user_vlm_text_column}') failed: {e} — \"\n                        f\"falling back to auto-detection\"\n                    )\n                    logger.info(\n                        f\"⚠️ User VLM mapping failed, falling back to auto-detection...\"\n                    )\n                    custom_format_mapping = None  # clear so auto-detection runs below\n            else:\n                errors.append(\n                    f\"Invalid VLM mapping: need 'image' and 'text' roles. Got: {custom_format_mapping}\"\n                )\n                return {\n                    \"dataset\": dataset,\n                    \"detected_format\": \"user_mapped\",\n                    \"final_format\": \"vlm_unknown\",\n                    \"is_vlm\": True,\n                    \"success\": False,\n                    \"requires_manual_mapping\": True,\n                    \"warnings\": warnings,\n                    \"errors\": errors,\n                }\n\n        # Auto-detect VLM structure\n        vlm_structure = detect_vlm_dataset_structure(dataset)\n\n        # Handle Llava format\n        if vlm_structure[\"format\"] == \"vlm_messages_llava\":\n            try:\n                dataset = convert_llava_to_vlm_format(dataset)\n                warnings.append(\n                    \"Converted from Llava format (image indices) to standard VLM format\"\n                )\n            except Exception as e:\n                errors.append(f\"Failed to convert Llava format: {e}\")\n                import traceback\n\n                traceback.print_exc()\n\n                return {\n                    \"dataset\": dataset,\n                    \"detected_format\": \"vlm_messages_llava\",\n                    \"final_format\": \"vlm_conversion_failed\",\n                    \"is_vlm\": True,\n                    \"success\": False,\n                    \"requires_manual_mapping\": True,\n                    \"warnings\": warnings,\n                    \"errors\": errors,\n                }\n\n        # Handle ShareGPT/ChatML + image column (e.g. ShareGPT4V, LLaVA-style)\n        elif vlm_structure[\"format\"] == \"sharegpt_with_images\":\n            try:\n                dataset = convert_sharegpt_with_images_to_vlm_format(\n                    dataset,\n                    image_column = vlm_structure[\"image_column\"],\n                    messages_column = vlm_structure[\"messages_column\"],\n                    dataset_name = dataset_name,\n                    progress_callback = progress_callback,\n                )\n                warnings.append(\n                    \"Converted from ShareGPT+image format to standard VLM format\"\n                )\n            except Exception as e:\n                errors.append(f\"Failed to convert ShareGPT+image format: {e}\")\n                import traceback\n\n                traceback.print_exc()\n\n                return {\n                    \"dataset\": dataset,\n                    \"detected_format\": \"sharegpt_with_images\",\n                    \"final_format\": \"vlm_conversion_failed\",\n                    \"is_vlm\": True,\n                    \"success\": False,\n                    \"requires_manual_mapping\": True,\n                    \"warnings\": warnings,\n                    \"errors\": errors,\n                }\n\n        # Handle simple format\n        elif vlm_structure[\"needs_conversion\"]:\n            if vlm_text_column is None:\n                vlm_text_column = vlm_structure[\"text_column\"]\n            if vlm_image_column is None:\n                vlm_image_column = vlm_structure[\"image_column\"]\n\n            if vlm_text_column is None or vlm_image_column is None:\n                columns = list(next(iter(dataset)).keys()) if dataset else []\n                issues = [\n                    f\"Could not auto-detect image and text columns from: {columns}\",\n                    f\"VLM structure detected: {vlm_structure.get('format', 'unknown')}\",\n                ]\n                friendly = None\n                try:\n                    from .llm_assist import llm_generate_dataset_warning\n\n                    friendly = llm_generate_dataset_warning(\n                        issues,\n                        dataset_name = dataset_name,\n                        modality = \"vision\",\n                        column_names = columns,\n                    )\n                except Exception:\n                    pass\n                errors.append(\n                    friendly\n                    or f\"Could not auto-detect image/text columns. Found: {vlm_structure}. \"\n                )\n                return {\n                    \"dataset\": dataset,\n                    \"detected_format\": \"vlm_unknown\",\n                    \"final_format\": \"vlm_unknown\",\n                    \"is_vlm\": True,\n                    \"success\": False,\n                    \"requires_manual_mapping\": True,\n                    \"warnings\": warnings,\n                    \"errors\": errors,\n                }\n\n            try:\n                dataset = convert_to_vlm_format(\n                    dataset,\n                    instruction = vlm_instruction,\n                    text_column = vlm_text_column,\n                    image_column = vlm_image_column,\n                    dataset_name = dataset_name,\n                    progress_callback = progress_callback,\n                )\n\n                if vlm_instruction:\n                    warnings.append(\n                        f\"Using user-provided instruction: '{vlm_instruction}'\"\n                    )\n                else:\n                    warnings.append(\n                        \"Auto-generated instruction based on dataset analysis\"\n                    )\n\n            except Exception as e:\n                errors.append(f\"Failed to convert to VLM format: {e}\")\n                import traceback\n\n                traceback.print_exc()\n\n                return {\n                    \"dataset\": dataset,\n                    \"detected_format\": vlm_structure[\"format\"],\n                    \"final_format\": \"vlm_conversion_failed\",\n                    \"is_vlm\": True,\n                    \"success\": False,\n                    \"requires_manual_mapping\": True,\n                    \"warnings\": warnings,\n                    \"errors\": errors,\n                }\n\n        # Already in standard VLM format\n        elif vlm_structure[\"format\"] == \"vlm_messages\":\n            dataset = [sample for sample in dataset]\n            warnings.append(\"Dataset already in standard VLM messages format\")\n\n        # Return as list\n        return {\n            \"dataset\": dataset,\n            \"detected_format\": vlm_structure[\"format\"],\n            \"final_format\": \"vlm_messages\",\n            \"chat_column\": \"messages\",\n            \"is_vlm\": True,\n            \"is_image\": multimodal_info[\"is_image\"],\n            \"multimodal_info\": multimodal_info,\n            \"vlm_structure\": vlm_structure,\n            \"success\": True,\n            \"requires_manual_mapping\": False,\n            \"warnings\": warnings,\n            \"errors\": errors,\n        }\n\n    # LLM FLOW (Existing code)\n    else:\n        # Step 1: Format the dataset\n        n_rows = len(dataset) if hasattr(dataset, \"__len__\") else None\n        if progress_callback and n_rows:\n            progress_callback(status_message = f\"Formatting dataset ({n_rows:,} rows)...\")\n        dataset_info = format_dataset(\n            dataset,\n            format_type = format_type,\n            tokenizer = tokenizer,\n            auto_detect_custom = auto_detect_custom,\n            custom_format_mapping = custom_format_mapping,\n            aliases_for_system = aliases_for_system,\n            aliases_for_user = aliases_for_user,\n            aliases_for_assistant = aliases_for_assistant,\n            batch_size = batch_size,\n            num_proc = num_proc,\n        )\n\n        # Step 2: Apply chat template\n        detected = dataset_info.get(\"detected_format\", \"unknown\")\n        if progress_callback and n_rows:\n            progress_callback(\n                status_message = f\"Applying chat template to {detected} ({n_rows:,} rows)...\"\n            )\n        # Gemma emits a leading <bos> that must be stripped for text-only chatml/sharegpt.\n        is_alpaca = format_type == \"alpaca\" or (\n            format_type == \"auto\" and dataset_info[\"detected_format\"] == \"alpaca\"\n        )\n        is_gemma = \"gemma\" in model_name.lower()\n        if is_gemma and not dataset_info[\"is_image\"] and not is_alpaca:\n            remove_bos_prefix = True\n        template_result = apply_chat_template_to_dataset(\n            dataset_info = dataset_info,\n            tokenizer = tokenizer,\n            model_name = model_name,\n            custom_prompt_template = custom_prompt_template,\n            add_eos_token = add_eos_token,\n            remove_bos_prefix = remove_bos_prefix,\n            custom_format_mapping = custom_format_mapping,\n            auto_detect_mapping = auto_detect_mapping,\n            batch_size = batch_size,\n            num_proc = num_proc,\n            progress_callback = progress_callback,\n        )\n\n        # Step 3: Generate summary\n        summary = get_dataset_info_summary(dataset_info)\n\n        # Combine results\n        all_warnings = dataset_info.get(\"warnings\", []) + template_result.get(\n            \"warnings\", []\n        )\n        all_errors = template_result.get(\"errors\", [])\n\n        # If format_dataset returned \"unknown\" but apply_chat_template rescued\n        # it via heuristic detection, update final_format to reflect reality.\n        final_format = dataset_info[\"final_format\"]\n        requires_manual = dataset_info.get(\"requires_manual_mapping\", False)\n        if final_format == \"unknown\" and template_result[\"success\"]:\n            out_ds = template_result[\"dataset\"]\n            if hasattr(out_ds, \"column_names\") and \"text\" in out_ds.column_names:\n                final_format = \"chatml_conversations\"\n                requires_manual = False\n\n        return {\n            \"dataset\": template_result[\"dataset\"],\n            \"detected_format\": dataset_info[\"detected_format\"],\n            \"final_format\": final_format,\n            \"chat_column\": dataset_info.get(\"chat_column\"),\n            \"is_vlm\": False,  # This is LLM flow\n            \"success\": template_result[\"success\"],\n            \"requires_manual_mapping\": requires_manual,\n            \"warnings\": all_warnings,\n            \"errors\": all_errors,\n            \"summary\": summary,\n        }\n"
  },
  {
    "path": "studio/backend/utils/datasets/format_conversion.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nFormat conversion utilities for dataset processing.\n\nThis module contains functions for converting between dataset formats\n(Alpaca, ShareGPT, ChatML) and standardizing chat formats.\n\"\"\"\n\nimport os\n\nfrom datasets import IterableDataset\nfrom loggers import get_logger\n\nlogger = get_logger(__name__)\n\n\ndef standardize_chat_format(\n    dataset,\n    tokenizer = None,\n    aliases_for_system = [\n        \"system\",\n    ],\n    aliases_for_user = [\n        \"user\",\n        \"human\",\n        \"input\",\n    ],\n    aliases_for_assistant = [\n        \"gpt\",\n        \"assistant\",\n        \"output\",\n    ],\n    batch_size = 1000,\n    num_proc = None,\n):\n    \"\"\"\n    Our own standardization function that handles BOTH messages and conversations.\n    Converts non-standard role names and keys to standard format.\n    \"\"\"\n    import collections\n    import itertools\n    from datasets import IterableDataset\n\n    # Check if vision tokenizer is used\n    is_vlm = False\n    if tokenizer is not None:\n        if hasattr(tokenizer, \"image_processor\") or hasattr(tokenizer, \"tokenizer\"):\n            is_vlm = True\n\n    column_names = set(next(iter(dataset)).keys())\n\n    #   Check for both 'conversations' and 'messages'\n    chat_column = None\n    if \"conversations\" in column_names:\n        chat_column = \"conversations\"\n    elif \"messages\" in column_names:\n        chat_column = \"messages\"\n    elif \"texts\" in column_names:\n        chat_column = \"texts\"\n    else:\n        return dataset  # No chat column found\n\n    # Inspect structure\n    examples = itertools.islice(dataset, 10)\n    uniques = collections.defaultdict(list)\n    for example in examples:\n        for message in example[chat_column]:\n            for key, value in message.items():\n                if type(value) is not str:\n                    continue  # Skip non-string values\n                uniques[key].append(value)\n\n    if len(uniques.keys()) != 2:\n        return dataset  # Unexpected structure\n\n    keys = list(uniques.keys())\n    length_first = len(set(uniques[keys[0]]))\n    length_second = len(set(uniques[keys[1]]))\n\n    # Determine which is role and which is content\n    if length_first < length_second:\n        role_key = keys[0]\n        content_key = keys[1]\n    else:\n        role_key = keys[1]\n        content_key = keys[0]\n\n    # Mapping for aliases\n    aliases_mapping = {}\n    for x in aliases_for_system:\n        aliases_mapping[x] = \"system\"\n    for x in aliases_for_user:\n        aliases_mapping[x] = \"user\"\n    for x in aliases_for_assistant:\n        aliases_mapping[x] = \"assistant\"\n\n    def _standardize_dataset(examples):\n        convos = examples[chat_column]\n        all_convos = []\n        for convo in convos:\n            new_convo = []\n            for message in convo:\n                # Get original role and content\n                original_role = message.get(role_key, \"\")\n                original_content = message.get(content_key, \"\")\n\n                # Map to standard role name\n                standard_role = aliases_mapping.get(original_role, original_role)\n\n                # Handle VLM format\n                if is_vlm:\n                    original_content = [{\"type\": \"text\", \"text\": original_content}]\n\n                # Create dict with EXPLICIT ORDER\n                new_message = {\"role\": standard_role, \"content\": original_content}\n                new_convo.append(new_message)\n\n            all_convos.append(new_convo)\n\n        return {chat_column: all_convos}\n\n    dataset_map_kwargs = {\n        \"batched\": True,\n        \"batch_size\": batch_size,\n    }\n\n    if not isinstance(dataset, IterableDataset):\n        from utils.hardware import safe_num_proc\n\n        if num_proc is None or type(num_proc) is not int:\n            num_proc = safe_num_proc()\n        else:\n            num_proc = safe_num_proc(num_proc)\n\n        dataset_map_kwargs[\"num_proc\"] = num_proc\n        dataset_map_kwargs[\"desc\"] = \"Standardizing chat format\"\n\n    return dataset.map(_standardize_dataset, **dataset_map_kwargs)\n\n\ndef convert_chatml_to_alpaca(dataset, batch_size = 1000, num_proc = None):\n    \"\"\"\n    Converts ChatML format (messages OR conversations) to Alpaca format.\n    Handles both standardized and ShareGPT formats.\n\n    Supports:\n    - \"messages\" or \"conversations\" column\n    - \"role\"/\"content\" (standard) or \"from\"/\"value\" (ShareGPT)\n    \"\"\"\n    from torch.utils.data import IterableDataset\n\n    def _convert(examples):\n        # Auto-detect which column name is used\n        chatml_data = (\n            examples.get(\"messages\")\n            or examples.get(\"conversations\")\n            or examples.get(\"texts\")\n        )\n\n        if chatml_data is None:\n            raise ValueError(\n                \"No 'messages' or 'conversations' or 'texts' column found.\"\n            )\n\n        instructions = []\n        outputs = []\n        inputs = []\n\n        for convo in chatml_data:\n            instruction = \"\"\n            output = \"\"\n\n            for msg in convo:\n                # Handle both standard and ShareGPT formats\n                role = msg.get(\"role\") or msg.get(\"from\")\n                content = msg.get(\"content\") or msg.get(\"value\")\n\n                # Get first user message as instruction\n                if role in [\"user\", \"human\", \"input\"] and not instruction:\n                    instruction = content\n                # Get first assistant message as output\n                elif role in [\"assistant\", \"gpt\", \"output\"] and not output:\n                    output = content\n                    break  # Stop after first assistant response\n\n            instructions.append(instruction)\n            inputs.append(\"\")  # Alpaca typically has empty input\n            outputs.append(output)\n\n        return {\"instruction\": instructions, \"input\": inputs, \"output\": outputs}\n\n    dataset_map_kwargs = {\n        \"batched\": True,\n        \"batch_size\": batch_size,\n    }\n\n    if not isinstance(dataset, IterableDataset):\n        from utils.hardware import safe_num_proc\n\n        if num_proc is None or type(num_proc) is not int:\n            num_proc = safe_num_proc()\n        else:\n            num_proc = safe_num_proc(num_proc)\n\n        dataset_map_kwargs[\"num_proc\"] = num_proc\n        dataset_map_kwargs[\"desc\"] = \"Converting ChatML to Alpaca format\"\n\n    return dataset.map(_convert, **dataset_map_kwargs)\n\n\ndef convert_alpaca_to_chatml(dataset, batch_size = 1000, num_proc = None):\n    \"\"\"\n    Converts Alpaca format to ChatML format.\n\n    Output format: Uses 'conversations' column with standard 'role'/'content' structure.\n    \"\"\"\n    from torch.utils.data import IterableDataset\n\n    def _convert(examples):\n        conversations = []\n\n        for i in range(len(examples[\"instruction\"])):\n            instruction = examples[\"instruction\"][i]\n            input_text = examples.get(\"input\", [\"\"] * len(examples[\"instruction\"]))[i]\n            output = examples[\"output\"][i]\n\n            # Combine instruction and input (if exists) for user message\n            if input_text and input_text.strip():\n                user_content = f\"{instruction}\\n\\n{input_text}\".strip()\n            else:\n                user_content = instruction\n\n            # Build conversation in standard ChatML format\n            convo = [\n                {\"role\": \"user\", \"content\": user_content},\n                {\"role\": \"assistant\", \"content\": output},\n            ]\n            conversations.append(convo)\n\n        return {\"conversations\": conversations}\n\n    dataset_map_kwargs = {\n        \"batched\": True,\n        \"batch_size\": batch_size,\n    }\n\n    if not isinstance(dataset, IterableDataset):\n        from utils.hardware import safe_num_proc\n\n        if num_proc is None or type(num_proc) is not int:\n            num_proc = safe_num_proc()\n        else:\n            num_proc = safe_num_proc(num_proc)\n\n        dataset_map_kwargs[\"num_proc\"] = num_proc\n        dataset_map_kwargs[\"desc\"] = \"Converting Alpaca to ChatML format\"\n\n    return dataset.map(_convert, **dataset_map_kwargs)\n\n\ndef _format_eta(seconds):\n    \"\"\"Format seconds into a human-readable ETA string.\"\"\"\n    if seconds < 60:\n        return f\"{seconds:.0f}s\"\n    elif seconds < 3600:\n        m, s = divmod(int(seconds), 60)\n        return f\"{m}m {s}s\"\n    else:\n        h, remainder = divmod(int(seconds), 3600)\n        m, _ = divmod(remainder, 60)\n        return f\"{h}h {m}m\"\n\n\ndef convert_to_vlm_format(\n    dataset,\n    instruction = None,\n    text_column = \"text\",\n    image_column = \"image\",\n    dataset_name = None,\n    progress_callback = None,\n):\n    \"\"\"\n    Converts simple {image, text} format to VLM messages format.\n\n    Returns a LIST, not a HuggingFace Dataset (to preserve PIL Images).\n\n    For URL-based image datasets, runs a 200-sample parallel probe first to\n    estimate download speed and failure rate, then reports time estimate or\n    warning through progress_callback before proceeding with the full conversion.\n\n    Args:\n        progress_callback: Optional callable(status_message=str) to report\n                          progress to the training overlay.\n\n    Returns:\n        list: List of dicts with 'messages' field\n    \"\"\"\n    from PIL import Image\n    from .vlm_processing import generate_smart_vlm_instruction\n\n    def _notify(msg):\n        \"\"\"Send status update to the training overlay if callback is available.\"\"\"\n        if progress_callback:\n            progress_callback(status_message = msg)\n\n    # Generate smart instruction if not provided\n    if instruction is None:\n        instruction_info = generate_smart_vlm_instruction(\n            dataset,\n            text_column = text_column,\n            image_column = image_column,\n            dataset_name = dataset_name,\n        )\n\n        instruction = instruction_info[\"instruction\"]\n        instruction_column = instruction_info.get(\"instruction_column\")\n        uses_dynamic = instruction_info[\"uses_dynamic_instruction\"]\n\n        logger.info(\n            f\"📝 Auto-detected instruction type: {instruction_info['instruction_type']}\"\n        )\n        logger.info(f\"📝 Confidence: {instruction_info['confidence']:.2f}\")\n        if not uses_dynamic:\n            logger.info(f\"📝 Using instruction: '{instruction}'\")\n        else:\n            logger.info(\n                f\"📝 Using dynamic instructions from column: '{instruction_column}'\"\n            )\n    else:\n        instruction_column = None\n        uses_dynamic = False\n\n    def _convert_single_sample(sample):\n        \"\"\"Convert a single sample to VLM format.\"\"\"\n        # Get image (might be PIL Image, local path, URL, or bare filename)\n        image_data = sample[image_column]\n\n        if isinstance(image_data, str):\n            if image_data.startswith((\"http://\", \"https://\")):\n                import fsspec\n                from io import BytesIO\n\n                with fsspec.open(image_data, \"rb\", expand = True) as f:\n                    image_data = Image.open(BytesIO(f.read())).convert(\"RGB\")\n            elif _image_lookup is not None and image_data in _image_lookup:\n                # Bare filename → resolve via HF repo lookup\n                from huggingface_hub import hf_hub_download\n\n                local_path = hf_hub_download(\n                    dataset_name,\n                    _image_lookup[image_data],\n                    repo_type = \"dataset\",\n                )\n                image_data = Image.open(local_path).convert(\"RGB\")\n            else:\n                image_data = Image.open(image_data).convert(\"RGB\")\n\n        # Get text (if list of strings, pick a random one — e.g. multiple captions)\n        text_data = sample[text_column]\n        if isinstance(text_data, list) and len(text_data) > 0:\n            import random\n\n            text_data = random.choice(text_data)\n\n        # Get instruction (static or dynamic)\n        if uses_dynamic and instruction_column:\n            current_instruction = sample[instruction_column]\n        else:\n            current_instruction = instruction\n\n        # Build VLM messages - simple structure\n        messages = [\n            {\n                \"role\": \"user\",\n                \"content\": [\n                    {\"type\": \"text\", \"text\": current_instruction},\n                    {\"type\": \"image\", \"image\": image_data},  # PIL object\n                ],\n            },\n            {\"role\": \"assistant\", \"content\": [{\"type\": \"text\", \"text\": text_data}]},\n        ]\n\n        # Return dict with messages\n        return {\"messages\": messages}\n\n    total = len(dataset)\n    first_image = next(iter(dataset))[image_column]\n    has_urls = isinstance(first_image, str) and first_image.startswith(\n        (\"http://\", \"https://\")\n    )\n\n    # ── Bare-filename detection: images stored as filenames (e.g. \"img_001.png\")\n    #    that don't exist locally.  Build a basename→repo_path lookup so we can\n    #    resolve them via hf_hub_download during conversion.\n    _image_lookup = None\n    _IMAGE_EXTS = (\".png\", \".jpg\", \".jpeg\", \".webp\", \".gif\", \".bmp\", \".tiff\")\n    if (\n        not has_urls\n        and isinstance(first_image, str)\n        and not os.path.exists(first_image)\n        and dataset_name\n    ):\n        try:\n            from huggingface_hub import HfApi\n\n            _notify(\"Resolving image filenames from HF repo...\")\n            logger.info(\n                f\"🔍 Image column contains bare filenames (e.g. '{first_image}') — building repo lookup...\"\n            )\n            repo_files = HfApi().list_repo_files(dataset_name, repo_type = \"dataset\")\n            _image_lookup = {\n                os.path.basename(f): f\n                for f in repo_files\n                if any(f.lower().endswith(ext) for ext in _IMAGE_EXTS)\n            }\n            if first_image in _image_lookup:\n                logger.info(\n                    f\"✅ Matched {len(_image_lookup)} image files in repo (e.g. '{first_image}' → '{_image_lookup[first_image]}')\"\n                )\n            else:\n                logger.info(\n                    f\"⚠️ Built lookup with {len(_image_lookup)} images but '{first_image}' not found — falling back to local open\"\n                )\n                _image_lookup = None\n        except Exception as e:\n            logger.info(f\"⚠️ Failed to build HF repo image lookup: {e}\")\n            _image_lookup = None\n\n    # ── URL probe: 200 samples with parallel workers to estimate speed + failure rate ──\n    PROBE_SIZE = 200\n    MAX_FAIL_RATE = 0.3\n\n    if has_urls and total > PROBE_SIZE:\n        import time\n        from concurrent.futures import ThreadPoolExecutor, as_completed\n        from utils.hardware import safe_num_proc\n\n        num_workers = safe_num_proc()\n        _notify(f\"Probing {PROBE_SIZE} image URLs with {num_workers} workers...\")\n        logger.info(\n            f\"🔍 Probing {PROBE_SIZE}/{total} image URLs with {num_workers} workers...\"\n        )\n\n        probe_samples = [dataset[i] for i in range(PROBE_SIZE)]\n        probe_ok = 0\n        probe_fail = 0\n        probe_start = time.time()\n\n        with ThreadPoolExecutor(max_workers = num_workers) as executor:\n            futures = {\n                executor.submit(_convert_single_sample, s): s for s in probe_samples\n            }\n            for future in as_completed(futures):\n                try:\n                    future.result()\n                    probe_ok += 1\n                except Exception:\n                    probe_fail += 1\n\n        probe_elapsed = time.time() - probe_start\n        probe_total = probe_ok + probe_fail\n        fail_rate = probe_fail / probe_total if probe_total > 0 else 0\n        throughput = probe_total / probe_elapsed if probe_elapsed > 0 else 0\n\n        if fail_rate >= MAX_FAIL_RATE:\n            issues = [\n                f\"{fail_rate:.0%} of the first {PROBE_SIZE} image URLs failed to download ({probe_fail}/{probe_total})\",\n                \"Images are external URLs, not embedded in the dataset\",\n            ]\n            # Try LLM-friendly warning\n            friendly = None\n            try:\n                from .llm_assist import llm_generate_dataset_warning\n\n                friendly = llm_generate_dataset_warning(\n                    issues,\n                    dataset_name = dataset_name,\n                    modality = \"vision\",\n                    column_names = [image_column, text_column],\n                )\n            except Exception:\n                pass\n            msg = friendly or (\n                f\"⚠️ {fail_rate:.0%} of the first {PROBE_SIZE} images failed to download \"\n                f\"({probe_fail}/{probe_total}). \"\n                \"This dataset has too many broken or unreachable image URLs. \"\n                \"Consider using a dataset with embedded images instead.\"\n            )\n            logger.info(msg)\n            _notify(msg)\n            raise ValueError(msg)\n\n        # Estimate total time for remaining samples\n        remaining = total - PROBE_SIZE\n        estimated_seconds = remaining / throughput if throughput > 0 else 0\n        eta_str = _format_eta(estimated_seconds)\n\n        info_msg = (\n            f\"Downloading {total:,} images ({num_workers} workers, ~{throughput:.1f} img/s). \"\n            f\"Estimated time: ~{eta_str}\"\n        )\n        if probe_fail > 0:\n            info_msg += f\" | {fail_rate:.0%} broken URLs will be skipped\"\n\n        logger.info(\n            f\"✅ Probe passed: {probe_ok}/{probe_total} ok, {probe_fail} failed ({fail_rate:.0%}), {throughput:.1f} img/s\"\n        )\n        logger.info(f\"⏱️ Estimated time for {total:,} samples: ~{eta_str}\")\n        _notify(info_msg)\n\n    # ── Full conversion with progress ──\n    from tqdm import tqdm\n\n    logger.info(f\"🔄 Converting {total} samples to VLM format...\")\n    converted_list = []\n    failed_count = 0\n\n    if has_urls:\n        # Parallel conversion for URL-based datasets\n        import time\n        from concurrent.futures import ThreadPoolExecutor, as_completed\n        from utils.hardware import safe_num_proc\n\n        num_workers = safe_num_proc()\n        batch_size = 500\n        start_time = time.time()\n\n        for batch_start in range(0, total, batch_size):\n            batch_end = min(batch_start + batch_size, total)\n            batch_samples = [dataset[i] for i in range(batch_start, batch_end)]\n\n            with ThreadPoolExecutor(max_workers = num_workers) as executor:\n                futures = {\n                    executor.submit(_convert_single_sample, s): i\n                    for i, s in enumerate(batch_samples)\n                }\n                batch_results = [None] * len(batch_samples)\n                for future in as_completed(futures):\n                    idx = futures[future]\n                    try:\n                        batch_results[idx] = future.result()\n                    except Exception as e:\n                        failed_count += 1\n                        if failed_count == 1:\n                            print(\n                                f\"⚠️ First VLM conversion failure: {type(e).__name__}: {e}\"\n                            )\n                        if failed_count == 1:\n                            logger.info(\n                                f\"⚠️ First VLM conversion failure: {type(e).__name__}: {e}\"\n                            )\n\n            converted_list.extend(r for r in batch_results if r is not None)\n\n            # Progress update every batch\n            elapsed = time.time() - start_time\n            done = batch_end\n            rate = done / elapsed if elapsed > 0 else 0\n            remaining_time = (total - done) / rate if rate > 0 else 0\n            eta_str = _format_eta(remaining_time)\n            progress_msg = f\"Downloading images: {done:,}/{total:,} ({done*100//total}%) | ~{eta_str} remaining | {failed_count} skipped\"\n            logger.info(\n                f\"  [{done}/{total}] {rate:.1f} img/s, {failed_count} failed, ETA {eta_str}\"\n            )\n            _notify(progress_msg)\n    else:\n        # Sequential conversion for local/embedded images (fast, no I/O bottleneck)\n        pbar = tqdm(dataset, total = total, desc = \"Converting VLM samples\", unit = \"sample\")\n        for sample in pbar:\n            try:\n                converted_list.append(_convert_single_sample(sample))\n            except Exception as e:\n                failed_count += 1\n                if failed_count == 1:\n                    # Log the first failure to aid debugging\n                    print(f\"⚠️ First VLM conversion failure: {type(e).__name__}: {e}\")\n                if failed_count == 1:\n                    # Log the first failure to aid debugging\n                    logger.info(\n                        f\"⚠️ First VLM conversion failure: {type(e).__name__}: {e}\"\n                    )\n            pbar.set_postfix(ok = len(converted_list), failed = failed_count, refresh = False)\n        pbar.close()\n\n    if failed_count > 0:\n        fail_rate = failed_count / total\n        logger.info(\n            f\"⚠️ Skipped {failed_count}/{total} ({fail_rate:.0%}) samples with broken/unreachable images\"\n        )\n        # For datasets that skipped the probe (small URL datasets), check fail rate now\n        if has_urls and fail_rate >= MAX_FAIL_RATE:\n            issues = [\n                f\"{fail_rate:.0%} of images failed to download ({failed_count}/{total})\",\n                \"Images are external URLs, not embedded in the dataset\",\n            ]\n            friendly = None\n            try:\n                from .llm_assist import llm_generate_dataset_warning\n\n                friendly = llm_generate_dataset_warning(\n                    issues,\n                    dataset_name = dataset_name,\n                    modality = \"vision\",\n                    column_names = [image_column, text_column],\n                )\n            except Exception:\n                pass\n            msg = friendly or (\n                f\"⚠️ {fail_rate:.0%} of images failed to download ({failed_count}/{total}). \"\n                \"This dataset has too many broken or unreachable image URLs. \"\n                \"Consider using a dataset with embedded images instead.\"\n            )\n            _notify(msg)\n            raise ValueError(msg)\n\n    if len(converted_list) == 0:\n        issues = [\n            f\"All {total} samples failed during VLM conversion — no usable images found\",\n            f\"Image column '{image_column}' may contain URLs that are no longer accessible, \"\n            \"or local file paths that don't exist\",\n        ]\n        friendly = None\n        try:\n            from .llm_assist import llm_generate_dataset_warning\n\n            friendly = llm_generate_dataset_warning(\n                issues,\n                dataset_name = dataset_name,\n                modality = \"vision\",\n                column_names = [image_column, text_column],\n            )\n        except Exception:\n            pass\n        raise ValueError(\n            friendly\n            or (\n                f\"All {total} samples failed during VLM conversion — no usable images found. \"\n                \"This dataset may contain only image URLs that are no longer accessible.\"\n            )\n        )\n\n    logger.info(f\"✅ Converted {len(converted_list)}/{total} samples\")\n    _notify(f\"Converted {len(converted_list):,}/{total:,} images successfully\")\n\n    # Return list, NOT Dataset\n    return converted_list\n\n\ndef convert_sharegpt_with_images_to_vlm_format(\n    dataset,\n    image_column = \"image\",\n    messages_column = \"conversations\",\n    dataset_name = None,\n    progress_callback = None,\n):\n    \"\"\"\n    Converts ShareGPT/ChatML datasets that have a separate image column and\n    ``<image>`` placeholders inside the conversation text.\n\n    Example input::\n\n        {\n            \"image\": \"sam/images/sa_545504.jpg\",\n            \"conversations\": [\n                {\"from\": \"human\", \"value\": \"<image>\\\\nWhat is this photo about?\"},\n                {\"from\": \"gpt\",   \"value\": \"The image captures...\"}\n            ]\n        }\n\n    Returns a list of dicts in standard VLM messages format (PIL Images inline).\n    \"\"\"\n    from PIL import Image\n    from tqdm import tqdm\n\n    _IMAGE_EXTS = (\".png\", \".jpg\", \".jpeg\", \".webp\", \".gif\", \".bmp\", \".tiff\")\n    _ROLE_MAP = {\n        \"human\": \"user\",\n        \"user\": \"user\",\n        \"gpt\": \"assistant\",\n        \"assistant\": \"assistant\",\n        \"system\": \"system\",\n    }\n\n    def _notify(msg):\n        if progress_callback:\n            progress_callback(status_message = msg)\n\n    # ── Resolve image loading strategy (same 3-tier as convert_to_vlm_format) ──\n    total = len(dataset)\n    first_image = next(iter(dataset))[image_column]\n\n    _image_lookup = None\n    if (\n        isinstance(first_image, str)\n        and not first_image.startswith((\"http://\", \"https://\"))\n        and not os.path.exists(first_image)\n        and dataset_name\n    ):\n        try:\n            from huggingface_hub import HfApi\n\n            _notify(\"Resolving image filenames from HF repo...\")\n            logger.info(\n                f\"🔍 Image column contains bare filenames (e.g. '{first_image}') — building repo lookup...\"\n            )\n            repo_files = HfApi().list_repo_files(dataset_name, repo_type = \"dataset\")\n            _image_lookup = {\n                os.path.basename(f): f\n                for f in repo_files\n                if any(f.lower().endswith(ext) for ext in _IMAGE_EXTS)\n            }\n            # Also add the full relative paths as keys (for paths like \"sam/images/sa_545504.jpg\")\n            for f in repo_files:\n                if any(f.lower().endswith(ext) for ext in _IMAGE_EXTS):\n                    _image_lookup[f] = f\n            if first_image in _image_lookup:\n                logger.info(\n                    f\"✅ Matched {len(_image_lookup)} image files in repo (e.g. '{first_image}' → '{_image_lookup[first_image]}')\"\n                )\n            else:\n                logger.info(\n                    f\"⚠️ Built lookup with {len(_image_lookup)} images but '{first_image}' not found — falling back to local open\"\n                )\n                _image_lookup = None\n        except Exception as e:\n            logger.info(f\"⚠️ Failed to build HF repo image lookup: {e}\")\n            _image_lookup = None\n\n    def _resolve_image(image_data):\n        \"\"\"Resolve image data to a PIL Image object.\"\"\"\n        if hasattr(image_data, \"size\") and hasattr(image_data, \"mode\"):\n            return image_data  # Already PIL\n        if isinstance(image_data, str):\n            if image_data.startswith((\"http://\", \"https://\")):\n                import fsspec\n                from io import BytesIO\n\n                with fsspec.open(image_data, \"rb\", expand = True) as f:\n                    return Image.open(BytesIO(f.read())).convert(\"RGB\")\n            elif _image_lookup is not None and image_data in _image_lookup:\n                from huggingface_hub import hf_hub_download\n\n                local_path = hf_hub_download(\n                    dataset_name,\n                    _image_lookup[image_data],\n                    repo_type = \"dataset\",\n                )\n                return Image.open(local_path).convert(\"RGB\")\n            else:\n                return Image.open(image_data).convert(\"RGB\")\n        if isinstance(image_data, dict) and (\n            \"bytes\" in image_data or \"path\" in image_data\n        ):\n            if image_data.get(\"bytes\"):\n                from io import BytesIO\n\n                return Image.open(BytesIO(image_data[\"bytes\"])).convert(\"RGB\")\n            if image_data.get(\"path\"):\n                return Image.open(image_data[\"path\"]).convert(\"RGB\")\n        raise ValueError(f\"Cannot resolve image: {type(image_data)}\")\n\n    def _convert_single_sample(sample):\n        \"\"\"Convert a single ShareGPT+image sample to standard VLM format.\"\"\"\n        pil_image = _resolve_image(sample[image_column])\n        conversation = sample[messages_column]\n\n        new_messages = []\n        for msg in conversation:\n            role_raw = msg.get(\"from\") or msg.get(\"role\", \"user\")\n            role = _ROLE_MAP.get(role_raw.lower(), role_raw.lower())\n            text = msg.get(\"value\") or msg.get(\"content\") or \"\"\n\n            # Split on <image> to interleave text and image content blocks\n            if \"<image>\" in text:\n                parts = text.split(\"<image>\")\n                content = []\n                for i, part in enumerate(parts):\n                    part = part.strip()\n                    if part:\n                        content.append({\"type\": \"text\", \"text\": part})\n                    if i < len(parts) - 1:\n                        content.append({\"type\": \"image\", \"image\": pil_image})\n                # If <image> was the entire text, content might just be the image\n                if not content:\n                    content.append({\"type\": \"image\", \"image\": pil_image})\n            else:\n                content = [{\"type\": \"text\", \"text\": text}]\n\n            new_messages.append({\"role\": role, \"content\": content})\n\n        return {\"messages\": new_messages}\n\n    # ── Full conversion with progress ──\n    logger.info(f\"🔄 Converting {total} samples from ShareGPT+image format...\")\n    converted_list = []\n    failed_count = 0\n\n    pbar = tqdm(dataset, total = total, desc = \"Converting ShareGPT+image\", unit = \"sample\")\n    for sample in pbar:\n        try:\n            converted_list.append(_convert_single_sample(sample))\n        except Exception as e:\n            failed_count += 1\n            if failed_count == 1:\n                logger.info(f\"⚠️ First conversion failure: {type(e).__name__}: {e}\")\n        pbar.set_postfix(ok = len(converted_list), failed = failed_count, refresh = False)\n    pbar.close()\n\n    if failed_count > 0:\n        logger.info(\n            f\"⚠️ Skipped {failed_count}/{total} ({failed_count*100//total}%) samples\"\n        )\n\n    if len(converted_list) == 0:\n        raise ValueError(\n            f\"All {total} samples failed during ShareGPT+image conversion — \"\n            \"no usable samples found.\"\n        )\n\n    logger.info(f\"✅ Converted {len(converted_list)}/{total} samples\")\n    _notify(f\"Converted {len(converted_list):,}/{total:,} samples successfully\")\n    return converted_list\n\n\ndef convert_llava_to_vlm_format(dataset):\n    \"\"\"\n    Converts Llava format to standard VLM format.\n\n    Llava format:\n    - messages: [{'content': [{'type': 'image', 'index': 0}, {'type': 'text', 'text': '...'}]}]\n    - images: [PIL_Image1, PIL_Image2, ...]\n\n    Standard VLM format:\n    - messages: [{'content': [{'type': 'image', 'image': PIL_Image}, {'type': 'text', 'text': '...'}]}]\n    \"\"\"\n    from PIL import Image\n\n    logger.info(\n        f\"🔄 Converting {len(dataset)} samples from Llava format to standard VLM format...\"\n    )\n\n    def _convert_single_sample(sample):\n        \"\"\"Convert a single llava sample to standard VLM format.\"\"\"\n        messages = sample[\"messages\"]\n        images = sample.get(\"images\", [])\n\n        # Process each message\n        new_messages = []\n        for msg in messages:\n            new_content = []\n\n            for item in msg[\"content\"]:\n                if item[\"type\"] == \"image\":\n                    # Replace index with actual PIL image\n                    if \"index\" in item and item[\"index\"] is not None:\n                        img_idx = item[\"index\"]\n                        if img_idx < len(images):\n                            pil_image = images[img_idx]\n                            # Ensure it's PIL\n                            if isinstance(pil_image, str):\n                                pil_image = Image.open(pil_image).convert(\"RGB\")\n\n                            new_content.append(\n                                {\n                                    \"type\": \"image\",\n                                    \"image\": pil_image,  # Actual PIL object\n                                }\n                            )\n                    else:\n                        # No index, try to use first image\n                        if len(images) > 0:\n                            pil_image = images[0]\n                            if isinstance(pil_image, str):\n                                pil_image = Image.open(pil_image).convert(\"RGB\")\n\n                            new_content.append({\"type\": \"image\", \"image\": pil_image})\n\n                elif item[\"type\"] == \"text\":\n                    # Keep text as-is (only type + text)\n                    new_content.append({\"type\": \"text\", \"text\": item.get(\"text\", \"\")})\n\n            new_messages.append({\"role\": msg[\"role\"], \"content\": new_content})\n\n        return {\"messages\": new_messages}\n\n    # Convert using list comprehension\n    converted_list = [_convert_single_sample(sample) for sample in dataset]\n\n    logger.info(f\"✅ Converted {len(converted_list)} samples\")\n    return converted_list\n"
  },
  {
    "path": "studio/backend/utils/datasets/format_detection.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nFormat detection utilities for dataset processing.\n\nThis module contains functions for detecting dataset formats (Alpaca, ShareGPT, ChatML),\ndetecting multimodal/VLM dataset structures, and heuristic-based column mapping.\n\"\"\"\n\nimport re\n\n\ndef _keyword_in_column(keyword: str, col_name: str) -> bool:\n    \"\"\"Word-boundary keyword match to avoid false positives like 'pic' in 'topic'.\"\"\"\n    return (\n        re.search(r\"\\b\" + re.escape(keyword) + r\"\\b\", col_name, re.IGNORECASE)\n        is not None\n    )\n\n\ndef detect_dataset_format(dataset):\n    \"\"\"\n    Detects dataset format by inspecting structure.\n\n    Returns:\n        dict: {\n            \"format\": \"alpaca\" | \"sharegpt\" | \"chatml\" | \"unknown\",\n            \"chat_column\": \"messages\" | \"conversations\" | None,\n            \"needs_standardization\": bool,\n            \"sample_keys\": list of keys found in messages (for debugging)\n        }\n    \"\"\"\n    column_names = set(next(iter(dataset)).keys())\n\n    # Check for Alpaca\n    alpaca_columns = {\"instruction\", \"output\"}\n    if alpaca_columns.issubset(column_names):\n        return {\n            \"format\": \"alpaca\",\n            \"chat_column\": None,\n            \"needs_standardization\": False,\n            \"sample_keys\": [],\n        }\n\n    # Check for chat-based formats (messages or conversations)\n    chat_column = None\n    if \"messages\" in column_names:\n        chat_column = \"messages\"\n    elif \"conversations\" in column_names:\n        chat_column = \"conversations\"\n    elif \"texts\" in column_names:\n        chat_column = \"texts\"\n\n    if chat_column:\n        # Inspect the structure to determine if ShareGPT or ChatML\n        try:\n            sample = next(iter(dataset))\n            chat_data = sample[chat_column]\n\n            if chat_data and len(chat_data) > 0:\n                first_msg = chat_data[0]\n                msg_keys = set(first_msg.keys())\n\n                # ShareGPT uses \"from\" and \"value\"\n                if \"from\" in msg_keys or \"value\" in msg_keys:\n                    return {\n                        \"format\": \"sharegpt\",\n                        \"chat_column\": chat_column,\n                        \"needs_standardization\": True,\n                        \"sample_keys\": list(msg_keys),\n                    }\n\n                # ChatML uses \"role\" and \"content\"\n                elif \"role\" in msg_keys and \"content\" in msg_keys:\n                    return {\n                        \"format\": \"chatml\",\n                        \"chat_column\": chat_column,\n                        \"needs_standardization\": False,\n                        \"sample_keys\": list(msg_keys),\n                    }\n\n                # Unknown structure but has chat column\n                else:\n                    return {\n                        \"format\": \"unknown\",\n                        \"chat_column\": chat_column,\n                        \"needs_standardization\": None,\n                        \"sample_keys\": list(msg_keys),\n                    }\n        except Exception as e:\n            return {\n                \"format\": \"unknown\",\n                \"chat_column\": chat_column,\n                \"needs_standardization\": None,\n                \"sample_keys\": [],\n                \"error\": str(e),\n            }\n\n    # No recognized format\n    return {\n        \"format\": \"unknown\",\n        \"chat_column\": None,\n        \"needs_standardization\": None,\n        \"sample_keys\": [],\n    }\n\n\ndef detect_custom_format_heuristic(dataset):\n    \"\"\"\n    Smart detection with priority scoring.\n\n    Strategy for ambiguous keywords like 'task':\n    1. Detect assistant first (unambiguous)\n    2. Detect user using high-priority keywords first\n    3. Check REMAINING columns for system keywords (including 'task')\n    4. Only if no system match, use 'task' as fallback user\n    \"\"\"\n    sample = next(iter(dataset))\n    all_columns = list(sample.keys())\n\n    mapping = {}\n\n    # Keywords\n    assistant_words = [\n        \"output\",\n        \"answer\",\n        \"response\",\n        \"assistant\",\n        \"completion\",\n        \"expected\",\n        \"recommendation\",\n        \"reply\",\n        \"result\",\n        \"target\",\n        \"solution\",\n        \"explanation\",\n        \"solve\",\n    ]\n\n    # Split into high/low priority\n    user_words_high_priority = [\n        \"input\",\n        \"question\",\n        \"query\",\n        \"prompt\",\n        \"instruction\",\n        \"request\",\n        \"snippet\",\n        \"user\",\n        \"text\",\n        \"problem\",\n        \"exercise\",\n    ]\n    user_words_low_priority = [\"task\"]  # Ambiguous - can be user OR system\n    user_words = user_words_high_priority + user_words_low_priority\n\n    system_words = [\n        \"system\",\n        \"context\",\n        \"description\",\n        \"persona\",\n        \"role\",\n        \"template\",\n        \"task\",  # Also in system\n    ]\n\n    # Metadata columns to ignore\n    metadata_exact_match = {\n        \"id\",\n        \"idx\",\n        \"index\",\n        \"key\",\n        \"timestamp\",\n        \"date\",\n        \"metadata\",\n        \"source\",\n        \"kind\",\n        \"type\",\n        \"category\",\n        \"score\",\n        \"label\",\n        \"tag\",\n        \"inference_mode\",\n    }\n\n    metadata_prefix_patterns = [\n        \"problem_type\",\n        \"problem_source\",\n        \"generation_model\",\n        \"pass_rate\",\n    ]\n\n    priority_patterns = {\n        \"generated\": 100,\n        \"gen_\": 90,\n        \"model_\": 80,\n        \"predicted\": 70,\n        \"completion\": 60,\n    }\n\n    def has_keyword(col_name, keywords):\n        \"\"\"Check if any keyword appears in column name.\"\"\"\n        col_lower = col_name.lower()\n        col_normalized = col_lower.replace(\"_\", \"\").replace(\"-\", \"\").replace(\" \", \"\")\n\n        for keyword in keywords:\n            if keyword in col_lower or keyword in col_normalized:\n                return True\n        return False\n\n    def is_metadata(col_name):\n        \"\"\"Check if column is likely metadata.\"\"\"\n        col_lower = col_name.lower()\n\n        if col_lower in metadata_exact_match:\n            return True\n\n        if col_lower in metadata_prefix_patterns:\n            return True\n\n        for pattern in metadata_prefix_patterns:\n            if (\n                col_lower.startswith(pattern.split(\"_\")[0] + \"_\")\n                and col_lower != pattern\n            ):\n                if \"_\" in col_lower:\n                    prefix = col_lower.split(\"_\")[0]\n                    if prefix in [\"generation\", \"pass\", \"inference\"]:\n                        return True\n\n        if len(col_lower) <= 2 and not col_lower in [\"qa\", \"q\", \"a\"]:\n            return True\n\n        return False\n\n    def get_priority_score(col_name):\n        \"\"\"Calculate priority score based on column name patterns.\"\"\"\n        col_lower = col_name.lower()\n        score = 0\n\n        for pattern, pattern_score in priority_patterns.items():\n            if pattern in col_lower:\n                score += pattern_score\n\n        return score\n\n    def get_content_length(col_name):\n        \"\"\"Get average content length for this column.\"\"\"\n        try:\n            if col_name in sample and sample[col_name]:\n                content = str(sample[col_name])\n                return len(content)\n            return 0\n        except:\n            return 0\n\n    def score_column(col_name, keywords, role_type, num_candidates):\n        \"\"\"Score a column for how likely it is to be a particular role.\"\"\"\n        if not has_keyword(col_name, keywords):\n            return 0\n\n        score = 0\n        score += 10\n\n        # Penalize ambiguous keywords when scoring for user\n        if role_type == \"user\":\n            col_lower = col_name.lower()\n            # If column is ONLY \"task\" (or task_xxx), give it lower priority for user role\n            if \"task\" in col_lower and not any(\n                kw in col_lower for kw in user_words_high_priority\n            ):\n                score -= 15  # Significant penalty so other user columns win\n\n        priority_bonus = get_priority_score(col_name)\n        score += priority_bonus\n\n        if role_type in [\"assistant\", \"user\"]:\n            avg_length = get_content_length(col_name)\n\n            if num_candidates > 1:\n                if avg_length > 1000:\n                    score += 50\n                elif avg_length > 200:\n                    score += 30\n                elif avg_length > 50:\n                    score += 10\n                elif avg_length < 50:\n                    score -= 20\n            else:\n                if avg_length > 1000:\n                    score += 50\n                elif avg_length > 200:\n                    score += 30\n                elif avg_length > 50:\n                    score += 10\n\n        return score\n\n    # Filter out metadata columns\n    content_columns = [col for col in all_columns if not is_metadata(col)]\n\n    # Count candidates first\n    assistant_potential = [\n        col for col in content_columns if has_keyword(col, assistant_words)\n    ]\n    user_potential = [col for col in content_columns if has_keyword(col, user_words)]\n\n    # STEP 1: Find best ASSISTANT column\n    assistant_candidates = []\n    for col in assistant_potential:\n        score = score_column(\n            col, assistant_words, \"assistant\", len(assistant_potential)\n        )\n        if score > 0:\n            assistant_candidates.append((col, score))\n\n    if assistant_candidates:\n        assistant_candidates.sort(key = lambda x: x[1], reverse = True)\n        assistant_col = assistant_candidates[0][0]\n        mapping[assistant_col] = \"assistant\"\n    else:\n        assistant_col = None\n\n    # STEP 2: Find best USER column (with penalty for ambiguous keywords)\n    user_candidates = []\n    for col in user_potential:\n        if col == assistant_col:\n            continue\n        score = score_column(col, user_words, \"user\", len(user_potential))\n        if score > 0:\n            user_candidates.append((col, score))\n\n    if user_candidates:\n        user_candidates.sort(key = lambda x: x[1], reverse = True)\n        user_col = user_candidates[0][0]\n        mapping[user_col] = \"user\"\n    else:\n        user_col = None\n\n    # STEP 3: Check ALL remaining columns for SYSTEM matches (priority check)\n    remaining_columns = [col for col in content_columns if col not in mapping]\n\n    system_col = None\n    for col in remaining_columns:\n        if has_keyword(col, system_words):\n            # Found a system match in remaining columns\n            mapping[col] = \"system\"\n            system_col = col\n            break\n\n    # STEP 4: Handle any additional remaining columns\n    if system_col:\n        remaining_columns = [col for col in remaining_columns if col != system_col]\n\n    if len(remaining_columns) >= 1:\n        remaining_col = remaining_columns[0]\n\n        # If no strong keyword match, decide based on what's missing\n        if not has_keyword(remaining_col, user_words + assistant_words):\n            mapping[remaining_col] = \"system\"\n        elif user_col is None:\n            # No user column yet, assign this as user\n            mapping[remaining_col] = \"user\"\n        else:\n            # Already have user + assistant, treat as system context\n            mapping[remaining_col] = \"system\"\n\n    # VALIDATION: Ensure we have at least user + assistant\n    has_user = any(role == \"user\" for role in mapping.values())\n    has_assistant = any(role == \"assistant\" for role in mapping.values())\n\n    if not has_user and len(remaining_columns) > 0:\n        for col in remaining_columns:\n            if col not in mapping:\n                mapping[col] = \"user\"\n                has_user = True\n                break\n\n    if has_user and has_assistant:\n        return mapping\n\n    return None\n\n\ndef detect_multimodal_dataset(dataset):\n    \"\"\"\n    Detects if dataset contains multimodal data (images and/or audio).\n\n    Two-pass approach for each modality:\n      1. Column-name heuristic (fast): checks for keywords.\n      2. Value-type inspection (reliable): checks actual sample values.\n\n    Returns:\n        dict: {\n            \"is_image\": bool,\n            \"multimodal_columns\": list of column names containing image data,\n            \"modality_types\": list of detected types (e.g., [\"image\", \"audio\"]),\n            \"is_audio\": bool,\n            \"audio_columns\": list of column names containing audio data,\n            \"detected_audio_column\": str or None,\n            \"detected_text_column\": str or None,\n        }\n    \"\"\"\n    sample = next(iter(dataset))\n    column_names = list(sample.keys())\n\n    # Keywords that indicate image data\n    image_keywords = [\n        \"image\",\n        \"img\",\n        \"pixel\",\n        \"jpg\",\n        \"jpeg\",\n        \"png\",\n        \"webp\",\n        \"bmp\",\n        \"gif\",\n        \"tiff\",\n        \"svg\",\n        \"photo\",\n        \"pic\",\n        \"picture\",\n        \"visual\",\n        \"file_name\",\n        \"filename\",\n    ]\n\n    # Keywords that indicate audio data\n    audio_keywords = [\"audio\", \"speech\", \"wav\", \"waveform\", \"sound\"]\n\n    multimodal_columns = []\n    audio_columns = []\n    modality_types = set()\n\n    # ── Image detection ─────────────────────────────────────\n    # Pass 1: column-name heuristic (word-boundary match to avoid\n    #          false positives like 'pic' in 'topic')\n    for col_name in column_names:\n        for keyword in image_keywords:\n            if _keyword_in_column(keyword, col_name):\n                multimodal_columns.append(col_name)\n                modality_types.add(keyword)\n                break\n\n    # Pass 2: inspect actual values\n    already_detected = set(multimodal_columns)\n    for col_name in column_names:\n        if col_name in already_detected:\n            continue\n        value = sample[col_name]\n        if _is_image_value(value):\n            multimodal_columns.append(col_name)\n            modality_types.add(\"image\")\n\n    # ── Audio detection ─────────────────────────────────────\n    # Pass 1: column-name heuristic (word-boundary match)\n    for col_name in column_names:\n        for keyword in audio_keywords:\n            if _keyword_in_column(keyword, col_name):\n                audio_columns.append(col_name)\n                modality_types.add(\"audio\")\n                break\n\n    # Pass 2: inspect actual values (catches non-obvious column names)\n    already_audio = set(audio_columns)\n    for col_name in column_names:\n        if col_name in already_audio:\n            continue\n        value = sample[col_name]\n        if _is_audio_value(value):\n            audio_columns.append(col_name)\n            modality_types.add(\"audio\")\n\n    # Filter out columns that are actually audio from the image list\n    # (e.g. a column named \"audio\" with {\"bytes\", \"path\"} could match _is_image_value)\n    if audio_columns:\n        audio_set = set(audio_columns)\n        multimodal_columns = [c for c in multimodal_columns if c not in audio_set]\n\n    # Detect text column for audio datasets\n    detected_text_col = None\n    if audio_columns:\n        text_keywords = [\"text\", \"sentence\", \"transcript\", \"transcription\", \"label\"]\n        for col_name in column_names:\n            if col_name.lower() in text_keywords:\n                detected_text_col = col_name\n                break\n\n    is_audio = len(audio_columns) > 0\n\n    # Detect speaker_id column for TTS datasets (CSM, Orpheus, Spark)\n    detected_speaker_col = None\n    if audio_columns:\n        speaker_keywords = [\"source\", \"speaker\", \"speaker_id\"]\n        for col_name in column_names:\n            if col_name.lower() in speaker_keywords:\n                detected_speaker_col = col_name\n                break\n\n    return {\n        \"is_image\": len(multimodal_columns) > 0,\n        \"multimodal_columns\": multimodal_columns,\n        \"modality_types\": list(modality_types),\n        \"is_audio\": is_audio,\n        \"audio_columns\": audio_columns,\n        \"detected_audio_column\": audio_columns[0] if audio_columns else None,\n        \"detected_text_column\": detected_text_col,\n        \"detected_speaker_column\": detected_speaker_col,\n    }\n\n\ndef _is_image_value(value) -> bool:\n    \"\"\"Check if a single sample value looks like image data.\"\"\"\n    if value is None:\n        return False\n\n    # PIL Image instance\n    try:\n        from PIL.Image import Image as PILImage\n\n        if isinstance(value, PILImage):\n            return True\n    except ImportError:\n        pass\n\n    # HF datasets Image feature stores decoded images as PIL or dicts with\n    # {\"bytes\": b\"...\", \"path\": \"...\"} when not yet decoded.\n    # Exclude audio dicts (decoded audio has \"array\" + \"sampling_rate\").\n    if isinstance(value, dict):\n        if \"array\" in value and \"sampling_rate\" in value:\n            return False  # This is audio, not image\n        if \"bytes\" in value and \"path\" in value:\n            # Check path extension to exclude audio files\n            path = value.get(\"path\") or \"\"\n            if isinstance(path, str) and any(\n                path.lower().endswith(ext) for ext in _AUDIO_EXTENSIONS\n            ):\n                return False\n            return True\n\n    # Raw bytes with a known image magic header\n    if isinstance(value, (bytes, bytearray)):\n        return _has_image_header(value)\n\n    # String that looks like an image file path or URL\n    _IMAGE_EXTS = (\".png\", \".jpg\", \".jpeg\", \".webp\", \".gif\", \".bmp\", \".tiff\", \".svg\")\n    if isinstance(value, str) and len(value) < 1000:\n        lower = value.strip().lower()\n        # Image URL (http://... ending in image extension)\n        if lower.startswith((\"http://\", \"https://\")) and any(\n            lower.split(\"?\")[0].endswith(ext) for ext in _IMAGE_EXTS\n        ):\n            return True\n        # Image file path (relative or absolute path ending in image extension)\n        if any(lower.endswith(ext) for ext in _IMAGE_EXTS):\n            return True\n\n    return False\n\n\n_AUDIO_EXTENSIONS = (\n    \".wav\",\n    \".mp3\",\n    \".flac\",\n    \".ogg\",\n    \".opus\",\n    \".m4a\",\n    \".aac\",\n    \".wma\",\n    \".webm\",\n)\n\n\ndef _is_audio_value(value) -> bool:\n    \"\"\"Check if a single sample value looks like audio data.\"\"\"\n    if value is None:\n        return False\n\n    # HF datasets Audio feature: decoded → {\"array\": np.ndarray, \"sampling_rate\": int}\n    if isinstance(value, dict):\n        if \"array\" in value and \"sampling_rate\" in value:\n            return True\n        # Undecoded/streaming → {\"bytes\": b\"...\", \"path\": \"some.wav\"}\n        if \"bytes\" in value or \"path\" in value:\n            path = value.get(\"path\") or \"\"\n            if isinstance(path, str) and any(\n                path.lower().endswith(ext) for ext in _AUDIO_EXTENSIONS\n            ):\n                return True\n\n    return False\n\n\ndef _has_image_header(data: bytes) -> bool:\n    \"\"\"Quick magic-byte check for common image formats.\"\"\"\n    if len(data) < 4:\n        return False\n    # JPEG\n    if data[:2] == b\"\\xff\\xd8\":\n        return True\n    # PNG\n    if data[:4] == b\"\\x89PNG\":\n        return True\n    # GIF\n    if data[:3] == b\"GIF\":\n        return True\n    # WebP\n    if data[:4] == b\"RIFF\" and len(data) >= 12 and data[8:12] == b\"WEBP\":\n        return True\n    # BMP\n    if data[:2] == b\"BM\":\n        return True\n    return False\n\n\ndef detect_vlm_dataset_structure(dataset):\n    \"\"\"\n    Detects if VLM dataset is:\n    - Standard VLM messages format (image objects in content)\n    - Llava format (image indices + separate images column)\n    - Simple format needing conversion (image + text columns)\n    \"\"\"\n    try:\n        sample = next(iter(dataset))\n    except StopIteration:\n        return {\n            \"format\": \"unknown\",\n            \"needs_conversion\": None,\n            \"image_column\": None,\n            \"text_column\": None,\n            \"messages_column\": None,\n        }\n\n    column_names = set(sample.keys())\n\n    # Check if has messages column\n    if \"messages\" in column_names:\n        messages = sample[\"messages\"]\n\n        if messages and len(messages) > 0:\n            first_msg = messages[0]\n            if \"content\" in first_msg:\n                content = first_msg[\"content\"]\n\n                if isinstance(content, list) and len(content) > 0:\n                    if isinstance(content[0], dict) and \"type\" in content[0]:\n                        # Check for llava format\n                        has_index = any(\n                            \"index\" in item\n                            for item in content\n                            if isinstance(item, dict)\n                        )\n                        has_images_column = \"images\" in column_names\n\n                        if has_index and has_images_column:\n                            return {\n                                \"format\": \"vlm_messages_llava\",\n                                \"needs_conversion\": True,\n                                \"messages_column\": \"messages\",\n                                \"image_column\": \"images\",\n                                \"text_column\": None,\n                            }\n\n                        # Standard VLM format\n                        has_image = any(\n                            \"image\" in item\n                            for item in content\n                            if isinstance(item, dict)\n                        )\n                        if has_image:\n                            return {\n                                \"format\": \"vlm_messages\",\n                                \"needs_conversion\": False,\n                                \"messages_column\": \"messages\",\n                                \"image_column\": None,\n                                \"text_column\": None,\n                            }\n\n    # Check for ShareGPT/ChatML conversations with <image> placeholder + companion image column\n    # (e.g. Lin-Chen/ShareGPT4V, LLaVA-style datasets)\n    for chat_col in (\"conversations\", \"messages\"):\n        if chat_col not in column_names:\n            continue\n        chat_data = sample[chat_col]\n        if not isinstance(chat_data, list) or len(chat_data) == 0:\n            continue\n        first_msg = chat_data[0]\n        if not isinstance(first_msg, dict):\n            continue\n        # Detect ShareGPT (from/value) or ChatML (role/content) keys\n        msg_text = first_msg.get(\"value\") or first_msg.get(\"content\")\n        if not isinstance(msg_text, str):\n            continue\n        # Check for <image> placeholder anywhere in the conversation\n        has_image_placeholder = any(\n            \"<image>\" in str(m.get(\"value\", \"\") or m.get(\"content\", \"\"))\n            for m in chat_data\n            if isinstance(m, dict)\n        )\n        if not has_image_placeholder:\n            continue\n        # Find companion image column\n        image_col = None\n        for col in column_names:\n            if col == chat_col:\n                continue\n            if _keyword_in_column(\"image\", col) or _keyword_in_column(\"img\", col):\n                image_col = col\n                break\n        if image_col:\n            return {\n                \"format\": \"sharegpt_with_images\",\n                \"needs_conversion\": True,\n                \"image_column\": image_col,\n                \"text_column\": None,\n                \"messages_column\": chat_col,\n            }\n\n    # Find image and text columns using metadata filtering\n\n    # Define metadata patterns to EXCLUDE\n    metadata_patterns = {\n        \"suffixes\": [\n            \"_id\",\n            \"_url\",\n            \"_name\",\n            \"_filename\",\n            \"_uri\",\n            \"_link\",\n            \"_key\",\n            \"_index\",\n        ],\n        \"prefixes\": [\n            \"id_\",\n            \"url_\",\n            \"name_\",\n            \"filename_\",\n            \"uri_\",\n            \"link_\",\n            \"key_\",\n            \"index_\",\n        ],\n    }\n\n    # Image-related keywords\n    image_keywords = [\n        \"image\",\n        \"img\",\n        \"photo\",\n        \"picture\",\n        \"pic\",\n        \"visual\",\n        \"scan\",\n        \"file_name\",\n        \"filename\",\n    ]\n\n    # Text-related keywords\n    text_keywords = [\n        \"text\",\n        \"caption\",\n        \"captions\",\n        \"description\",\n        \"answer\",\n        \"output\",\n        \"response\",\n        \"label\",\n    ]\n\n    def is_metadata_column(col_name):\n        \"\"\"Check if column name looks like metadata.\"\"\"\n        col_lower = col_name.lower()\n\n        # Check suffixes\n        if any(col_lower.endswith(suffix) for suffix in metadata_patterns[\"suffixes\"]):\n            return True\n\n        # Check prefixes\n        if any(\n            col_lower.startswith(prefix) for prefix in metadata_patterns[\"prefixes\"]\n        ):\n            return True\n\n        return False\n\n    def _score_image_candidate(col, sample_value):\n        \"\"\"Score a candidate image column by how resolvable its value is.\"\"\"\n        # PIL Image object (highest priority - already loaded)\n        if hasattr(sample_value, \"size\") and hasattr(sample_value, \"mode\"):\n            return 100\n\n        # Dict with image data (bytes/path from HF Image feature)\n        if isinstance(sample_value, dict) and (\n            \"bytes\" in sample_value or \"path\" in sample_value\n        ):\n            return 75\n\n        if isinstance(sample_value, str):\n            # URL strings\n            if sample_value.startswith((\"http://\", \"https://\")):\n                return 70 if not is_metadata_column(col) else 55\n            # Bare file path\n            if is_metadata_column(col):\n                return 30\n            return 50\n\n        return 0\n\n    def _probe_image_candidate(col, sample_value):\n        \"\"\"Quick probe to check if an image candidate is actually reachable.\n        Returns True if likely valid, False if definitely broken.\"\"\"\n        import os\n\n        # PIL / dict — already loaded, always valid\n        if not isinstance(sample_value, str):\n            return True\n\n        # Local file — check it exists\n        if not sample_value.startswith((\"http://\", \"https://\")):\n            return os.path.exists(\n                sample_value\n            )  # bare filenames return False here, that's OK\n\n        # URL — quick HEAD request with short timeout\n        try:\n            import urllib.request\n\n            req = urllib.request.Request(sample_value, method = \"HEAD\")\n            resp = urllib.request.urlopen(req, timeout = 3)\n            return resp.status < 400\n        except Exception:\n            return False\n\n    def find_image_column():\n        \"\"\"Find image column by keyword match + value-based fallback.\n        When multiple candidates exist, probes them to find one that works.\"\"\"\n        candidates = []\n\n        # Pass 1: keyword-matched columns\n        for col in column_names:\n            if any(_keyword_in_column(keyword, col) for keyword in image_keywords):\n                sample_value = sample[col]\n                score = _score_image_candidate(col, sample_value)\n                if score > 0:\n                    candidates.append((col, score))\n\n        # Pass 2: value-based fallback — find columns with image URLs/paths\n        # even if the column name doesn't match image keywords\n        already = {c[0] for c in candidates}\n        for col in column_names:\n            if col in already:\n                continue\n            sample_value = sample[col]\n            if _is_image_value(sample_value):\n                score = _score_image_candidate(col, sample_value)\n                # Slightly penalise non-keyword columns so keyword matches win on ties\n                candidates.append((col, max(score - 5, 1)))\n\n        if not candidates:\n            return None\n\n        candidates.sort(key = lambda x: x[1], reverse = True)\n\n        # Single candidate or top candidate is PIL/dict — no probing needed\n        if len(candidates) == 1 or candidates[0][1] >= 75:\n            return candidates[0][0]\n\n        # Multiple string-based candidates — probe to find one that actually works\n        for col, score in candidates:\n            sample_value = sample[col]\n            if _probe_image_candidate(col, sample_value):\n                return col\n\n        # Nothing probed successfully — return highest-scored anyway and let\n        # conversion handle the error (it may still resolve via hf_hub_download)\n        return candidates[0][0]\n\n    def find_text_column():\n        \"\"\"Find text column by filtering out metadata and checking keywords.\"\"\"\n        candidates = []\n\n        for col in column_names:\n            # Skip metadata columns\n            if is_metadata_column(col):\n                continue\n\n            # Check if contains text keywords (word-boundary match)\n            if any(_keyword_in_column(keyword, col) for keyword in text_keywords):\n                # Verify it's actually text\n                sample_value = sample[col]\n\n                if isinstance(sample_value, str) and len(sample_value) > 0:\n                    # Longer text = higher priority (likely content, not just a label)\n                    priority = min(len(sample_value), 1000)  # Cap at 1000\n                    candidates.append((col, priority))\n                elif (\n                    isinstance(sample_value, list)\n                    and len(sample_value) > 0\n                    and isinstance(sample_value[0], str)\n                ):\n                    # List of strings (e.g. captions list) — lower priority than plain strings\n                    priority = min(len(sample_value[0]), 1000) // 2\n                    candidates.append((col, priority))\n\n        # Return highest priority candidate\n        if candidates:\n            candidates.sort(key = lambda x: x[1], reverse = True)\n            return candidates[0][0]\n\n        return None\n\n    found_image = find_image_column()\n    found_text = find_text_column()\n\n    if found_image and found_text:\n        return {\n            \"format\": \"simple_image_text\",\n            \"needs_conversion\": True,\n            \"image_column\": found_image,\n            \"text_column\": found_text,\n            \"messages_column\": None,\n        }\n\n    return {\n        \"format\": \"unknown\",\n        \"needs_conversion\": None,\n        \"image_column\": found_image,\n        \"text_column\": found_text,\n        \"messages_column\": None,\n    }\n"
  },
  {
    "path": "studio/backend/utils/datasets/llm_assist.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nLLM-assisted dataset analysis using an ephemeral GGUF helper model.\n\nComplements heuristic-based detection in format_detection.py and\nvlm_processing.py.  Only invoked when heuristics are uncertain.\n\nArchitecture:\n  - Instantiates LlamaCppBackend, loads model, runs completion(s), unloads.\n  - Not kept warm — VRAM is freed immediately after use.\n  - Gracefully degrades: returns None when unavailable (no binary, OOM, disabled).\n\"\"\"\n\nimport json\nimport logging\nimport os\nimport re\nimport textwrap\nimport time\nfrom itertools import islice\nfrom typing import Any, Optional\n\nfrom loggers import get_logger\n\nlogger = get_logger(__name__)\n\nDEFAULT_HELPER_MODEL_REPO = \"unsloth/Qwen3.5-4B-GGUF\"\nDEFAULT_HELPER_MODEL_VARIANT = \"UD-Q4_K_XL\"\n\nREADME_MAX_CHARS = 1500\n\n\ndef _strip_think_tags(text: str) -> str:\n    \"\"\"Strip <think>...</think> reasoning blocks emitted by some models.\n\n    If the model places its actual answer OUTSIDE the think block, we\n    discard the think block and keep the rest.  If the entire response\n    is INSIDE a think block (nothing useful outside), we extract and\n    return the inner content instead of discarding everything.\n    \"\"\"\n    if \"<think>\" not in text:\n        return text\n\n    # Try stripping think blocks — keep content outside them\n    stripped = re.sub(r\"<think>.*?</think>\\s*\", \"\", text, flags = re.DOTALL).strip()\n    if stripped:\n        return stripped\n\n    # Everything was inside <think> tags — extract the inner content of the last block\n    matches = re.findall(r\"<think>(.*?)</think>\", text, flags = re.DOTALL)\n    if matches:\n        return matches[-1].strip()\n\n    return text\n\n\ndef precache_helper_gguf():\n    \"\"\"\n    Pre-download the helper GGUF to HF cache.\n\n    Called on FastAPI startup in a background thread so subsequent\n    ``_run_with_helper()`` calls skip the download and only pay for\n    llama-server startup.  No-op if already cached or disabled.\n    \"\"\"\n    if os.environ.get(\"UNSLOTH_HELPER_MODEL_DISABLE\", \"\").strip() in (\"1\", \"true\"):\n        return\n\n    repo = os.environ.get(\"UNSLOTH_HELPER_MODEL_REPO\", DEFAULT_HELPER_MODEL_REPO)\n    variant = os.environ.get(\n        \"UNSLOTH_HELPER_MODEL_VARIANT\", DEFAULT_HELPER_MODEL_VARIANT\n    )\n\n    try:\n        from huggingface_hub import HfApi, hf_hub_download\n        from huggingface_hub.utils import disable_progress_bars, enable_progress_bars\n\n        disable_progress_bars()\n        logging.getLogger(\"huggingface_hub\").setLevel(logging.WARNING)\n\n        # Find the GGUF file matching the variant\n        api = HfApi()\n        files = api.list_repo_files(repo, repo_type = \"model\")\n        gguf_files = [f for f in files if f.endswith(\".gguf\")]\n\n        # Find all GGUF files matching the variant (may be split into shards)\n        variant_lower = variant.lower().replace(\"-\", \"_\")\n        matching = sorted(\n            f for f in gguf_files if variant_lower in f.lower().replace(\"-\", \"_\")\n        )\n\n        if matching:\n            logger.info(\n                f\"Pre-caching helper GGUF: {repo}/{matching[0]}\"\n                + (f\" (+{len(matching) - 1} shards)\" if len(matching) > 1 else \"\")\n            )\n            for target in matching:\n                hf_hub_download(repo_id = repo, filename = target)\n            logger.info(f\"Helper GGUF cached: {len(matching)} file(s)\")\n        else:\n            logger.warning(f\"No GGUF matching variant '{variant}' in {repo}\")\n    except Exception as e:\n        logger.warning(f\"Failed to pre-cache helper GGUF: {e}\")\n    finally:\n        try:\n            enable_progress_bars()\n        except Exception as e:\n            pass\n\n\ndef _run_with_helper(prompt: str, max_tokens: int = 256) -> Optional[str]:\n    \"\"\"\n    Load helper model, run one chat completion, unload.\n\n    Returns the completion text, or None on any failure.\n    \"\"\"\n    if os.environ.get(\"UNSLOTH_HELPER_MODEL_DISABLE\", \"\").strip() in (\"1\", \"true\"):\n        return None\n\n    repo = os.environ.get(\"UNSLOTH_HELPER_MODEL_REPO\", DEFAULT_HELPER_MODEL_REPO)\n    variant = os.environ.get(\n        \"UNSLOTH_HELPER_MODEL_VARIANT\", DEFAULT_HELPER_MODEL_VARIANT\n    )\n\n    backend = None\n    try:\n        from core.inference.llama_cpp import LlamaCppBackend\n\n        backend = LlamaCppBackend()\n        logger.info(f\"Loading helper model: {repo} ({variant})\")\n\n        ok = backend.load_model(\n            hf_repo = repo,\n            hf_variant = variant,\n            model_identifier = f\"helper:{repo}:{variant}\",\n            is_vision = False,\n            n_ctx = 2048,\n            n_gpu_layers = -1,\n        )\n        if not ok:\n            logger.warning(\"Helper model failed to start\")\n            return None\n\n        messages = [{\"role\": \"user\", \"content\": prompt}]\n        logger.info(\n            \"Helper model request: enable_thinking=False (per-request override)\"\n        )\n        cumulative = \"\"\n        for text in backend.generate_chat_completion(\n            messages = messages,\n            temperature = 0.1,\n            top_p = 0.9,\n            top_k = 20,\n            max_tokens = max_tokens,\n            repetition_penalty = 1.0,\n            enable_thinking = False,  # Always disable thinking for AI Assist\n        ):\n            cumulative = text  # cumulative — last value is full text\n\n        result = cumulative.strip()\n        result = _strip_think_tags(result)\n        logger.info(f\"Helper model response ({len(result)} chars)\")\n        return result if result else None\n\n    except Exception as e:\n        logger.warning(f\"Helper model failed: {e}\")\n        return None\n\n    finally:\n        if backend is not None:\n            try:\n                backend.unload_model()\n                logger.info(\"Helper model unloaded\")\n            except Exception:\n                pass\n\n\n# ─── Public API ───────────────────────────────────────────────────────\n\n\ndef llm_generate_vlm_instruction(\n    column_names: list[str],\n    samples: list[dict],\n    dataset_name: Optional[str] = None,\n) -> Optional[dict]:\n    \"\"\"\n    Ask a helper LLM to generate a task-specific VLM instruction.\n\n    Called when heuristic instruction generation returns low confidence\n    or falls back to generic.\n\n    Args:\n        column_names: Column names in the dataset.\n        samples: 3-5 sample rows with text values (images replaced by \"<image>\").\n        dataset_name: Optional HF dataset identifier for context.\n\n    Returns:\n        {\"instruction\": str, \"confidence\": 0.85} or None.\n    \"\"\"\n    # Format samples for the prompt\n    formatted = \"\"\n    for i, row in enumerate(samples[:5], 1):\n        parts = []\n        for col in column_names:\n            val = str(row.get(col, \"\"))[:300]\n            parts.append(f\"  {col}: {val}\")\n        formatted += f\"Sample {i}:\\n\" + \"\\n\".join(parts) + \"\\n\\n\"\n\n    prompt = (\n        \"You are a dataset analyst. Given a vision-language dataset, generate ONE \"\n        \"instruction sentence that describes what the model should do with each image.\\n\\n\"\n        f\"Dataset: {dataset_name or 'unknown'}\\n\"\n        f\"Columns: {column_names}\\n\\n\"\n        f\"{formatted}\"\n        \"Write ONE instruction sentence. Examples:\\n\"\n        '- \"Solve the math problem shown in the image and explain your reasoning.\"\\n'\n        '- \"Transcribe all text visible in this image.\"\\n'\n        '- \"Answer the question about this image.\"\\n\\n'\n        \"Respond with ONLY the instruction sentence, nothing else.\"\n    )\n\n    result = _run_with_helper(prompt, max_tokens = 100)\n    if not result:\n        return None\n\n    # Clean up: strip quotes, ensure it's a single sentence\n    instruction = result.strip().strip('\"').strip(\"'\").strip()\n    # Reject obviously bad outputs (too short, too long, or multi-line)\n    if len(instruction) < 10 or len(instruction) > 200 or \"\\n\" in instruction:\n        logger.warning(f\"Helper model returned unusable instruction: {instruction!r}\")\n        return None\n\n    logger.info(f\"LLM-generated instruction: {instruction}\")\n    return {\n        \"instruction\": instruction,\n        \"confidence\": 0.85,\n    }\n\n\ndef llm_classify_columns(\n    column_names: list[str],\n    samples: list[dict],\n) -> Optional[dict[str, str]]:\n    \"\"\"\n    Ask a helper LLM to classify dataset columns into roles.\n\n    Called when heuristic column detection fails (returns None).\n\n    Args:\n        column_names: Column names in the dataset.\n        samples: 3-5 sample rows with values truncated to 200 chars.\n\n    Returns:\n        Dict mapping column_name → role (\"user\"|\"assistant\"|\"system\"|\"metadata\"),\n        or None on failure.\n    \"\"\"\n    formatted = \"\"\n    for i, row in enumerate(samples[:5], 1):\n        parts = []\n        for col in column_names:\n            val = str(row.get(col, \"\"))[:200]\n            parts.append(f\"  {col}: {val}\")\n        formatted += f\"Sample {i}:\\n\" + \"\\n\".join(parts) + \"\\n\\n\"\n\n    prompt = (\n        \"Classify each column in this dataset into one of these roles:\\n\"\n        \"- user: The input/question/prompt from the human\\n\"\n        \"- assistant: The expected output/answer/response from the AI\\n\"\n        \"- system: Context, persona, or task description\\n\"\n        \"- metadata: IDs, scores, labels, timestamps — not part of conversation\\n\\n\"\n        f\"Columns: {column_names}\\n\\n\"\n        f\"{formatted}\"\n        \"Respond with ONLY a JSON object mapping column names to roles.\\n\"\n        'Example: {\"question\": \"user\", \"answer\": \"assistant\", \"id\": \"metadata\"}'\n    )\n\n    result = _run_with_helper(prompt, max_tokens = 200)\n    if not result:\n        return None\n\n    # Parse JSON from response (may have markdown fences)\n    text = result.strip()\n    if text.startswith(\"```\"):\n        # Strip markdown code fence\n        lines = text.split(\"\\n\")\n        text = \"\\n\".join(lines[1:-1] if lines[-1].strip() == \"```\" else lines[1:])\n        text = text.strip()\n\n    try:\n        mapping = json.loads(text)\n    except json.JSONDecodeError:\n        # Try to find JSON object in the response\n        import re\n\n        match = re.search(r\"\\{[^}]+\\}\", text)\n        if match:\n            try:\n                mapping = json.loads(match.group())\n            except json.JSONDecodeError:\n                logger.warning(f\"Could not parse helper model JSON: {text!r}\")\n                return None\n        else:\n            logger.warning(f\"No JSON found in helper model response: {text!r}\")\n            return None\n\n    if not isinstance(mapping, dict):\n        return None\n\n    # Validate: all values must be valid roles\n    valid_roles = {\"user\", \"assistant\", \"system\", \"metadata\"}\n    cleaned = {}\n    for col, role in mapping.items():\n        if (\n            col in column_names\n            and isinstance(role, str)\n            and role.lower() in valid_roles\n        ):\n            cleaned[col] = role.lower()\n\n    if not cleaned:\n        return None\n\n    # Must have at least user + assistant\n    roles_present = set(cleaned.values())\n    if \"user\" not in roles_present or \"assistant\" not in roles_present:\n        logger.warning(f\"Helper model mapping missing user/assistant: {cleaned}\")\n        return None\n\n    logger.info(f\"LLM-classified columns: {cleaned}\")\n    return cleaned\n\n\ndef llm_generate_dataset_warning(\n    issues: list[str],\n    dataset_name: Optional[str] = None,\n    modality: str = \"text\",\n    column_names: Optional[list[str]] = None,\n) -> Optional[str]:\n    \"\"\"\n    Ask the helper LLM to turn technical dataset issues into a user-friendly warning.\n\n    Works for all modalities (text, vision, audio).\n\n    Args:\n        issues: List of technical issue descriptions found during analysis.\n        dataset_name: Optional HF dataset name.\n        modality: \"text\", \"vision\", or \"audio\".\n        column_names: Optional list of column names for context.\n\n    Returns:\n        A human-friendly warning string, or None on failure.\n    \"\"\"\n    if not issues:\n        return None\n\n    issues_text = \"\\n\".join(f\"- {issue}\" for issue in issues)\n    cols_text = f\"\\nColumns: {column_names}\" if column_names else \"\"\n\n    prompt = (\n        \"You are a helpful assistant. A user is trying to fine-tune a model on a dataset.\\n\"\n        \"The following issues were found during dataset analysis:\\n\\n\"\n        f\"{issues_text}\\n\\n\"\n        f\"Dataset: {dataset_name or 'unknown'}\\n\"\n        f\"Modality: {modality}\"\n        f\"{cols_text}\\n\\n\"\n        \"Write a brief, friendly explanation of what's wrong and what the user can do about it.\\n\"\n        \"Keep it under 3 sentences. Be specific about the dataset.\"\n    )\n\n    result = _run_with_helper(prompt, max_tokens = 200)\n    if not result:\n        return None\n\n    warning = result.strip()\n    # Reject obviously bad outputs\n    if len(warning) < 10 or len(warning) > 500:\n        return None\n\n    logger.info(f\"LLM-generated warning: {warning}\")\n    return warning\n\n\n# ─── Dataset Conversion Advisor ──────────────────────────────────────\n\n\ndef _parse_json_response(text: str) -> Optional[dict]:\n    \"\"\"Parse JSON from LLM response, handling markdown fences and noise.\"\"\"\n    if not text:\n        return None\n\n    cleaned = text.strip()\n\n    # Strip markdown code fences\n    if cleaned.startswith(\"```\"):\n        lines = cleaned.split(\"\\n\")\n        end = -1 if lines[-1].strip().startswith(\"```\") else len(lines)\n        cleaned = \"\\n\".join(lines[1:end]).strip()\n\n    # Try direct parse\n    try:\n        obj = json.loads(cleaned)\n        if isinstance(obj, dict):\n            return obj\n    except json.JSONDecodeError:\n        pass\n\n    # Greedy match for outermost {...}\n    match = re.search(r\"\\{.*\\}\", cleaned, re.DOTALL)\n    if match:\n        try:\n            obj = json.loads(match.group())\n            if isinstance(obj, dict):\n                return obj\n        except json.JSONDecodeError:\n            pass\n\n    return None\n\n\ndef _generate_with_backend(backend, messages: list[dict], max_tokens: int = 512) -> str:\n    \"\"\"Run one chat completion on an already-loaded backend. Returns raw text.\"\"\"\n    logger.info(\"Advisor request: enable_thinking=False (per-request override)\")\n    cumulative = \"\"\n    for text in backend.generate_chat_completion(\n        messages = messages,\n        temperature = 0.1,\n        top_p = 0.9,\n        top_k = 20,\n        max_tokens = max_tokens,\n        repetition_penalty = 1.0,\n        enable_thinking = False,  # Always disable thinking for AI Assist\n    ):\n        cumulative = text\n    result = cumulative.strip()\n    result = _strip_think_tags(result)\n    return result\n\n\ndef fetch_hf_dataset_card(\n    dataset_name: str, hf_token: Optional[str] = None\n) -> tuple[Optional[str], Optional[dict]]:\n    \"\"\"\n    Fetch HF dataset card (README) and metadata.\n\n    Returns:\n        (readme_text, metadata_dict) or (None, None) on failure.\n    \"\"\"\n    try:\n        from huggingface_hub import DatasetCard\n\n        card = DatasetCard.load(dataset_name, token = hf_token)\n        readme = card.text or \"\"\n\n        # Truncate at sentence boundary\n        if len(readme) > README_MAX_CHARS:\n            cut = readme[:README_MAX_CHARS].rfind(\".\")\n            if cut > README_MAX_CHARS // 2:\n                readme = readme[: cut + 1] + \"\\n[...truncated]\"\n            else:\n                readme = readme[:README_MAX_CHARS] + \"\\n[...truncated]\"\n\n        # Extract metadata from YAML frontmatter\n        metadata = {}\n        if card.data:\n            for key in (\n                \"task_categories\",\n                \"task_ids\",\n                \"language\",\n                \"size_categories\",\n                \"tags\",\n                \"license\",\n                \"pretty_name\",\n            ):\n                val = getattr(card.data, key, None)\n                if val is not None:\n                    metadata[key] = val\n\n        logger.info(\n            f\"Fetched dataset card: {len(readme)} chars, {len(metadata)} metadata fields\"\n        )\n        return readme, metadata\n\n    except Exception as e:\n        logger.warning(f\"Could not fetch dataset card for {dataset_name}: {e}\")\n        return None, None\n\n\ndef _run_multi_pass_advisor(\n    columns: list[str],\n    samples: list[dict],\n    dataset_name: Optional[str] = None,\n    dataset_card: Optional[str] = None,\n    dataset_metadata: Optional[dict] = None,\n    model_name: Optional[str] = None,\n    model_type: Optional[str] = None,\n    hf_token: Optional[str] = None,\n) -> Optional[dict[str, Any]]:\n    \"\"\"\n    Multi-pass LLM analysis: classify → convert → validate.\n\n    Keeps model loaded across all passes. Returns combined result dict or None.\n    \"\"\"\n    if os.environ.get(\"UNSLOTH_HELPER_MODEL_DISABLE\", \"\").strip() in (\"1\", \"true\"):\n        return None\n\n    repo = os.environ.get(\"UNSLOTH_HELPER_MODEL_REPO\", DEFAULT_HELPER_MODEL_REPO)\n    variant = os.environ.get(\n        \"UNSLOTH_HELPER_MODEL_VARIANT\", DEFAULT_HELPER_MODEL_VARIANT\n    )\n\n    backend = None\n    try:\n        from core.inference.llama_cpp import LlamaCppBackend\n\n        backend = LlamaCppBackend()\n        logger.info(f\"Loading advisor model: {repo} ({variant})\")\n        t0 = time.monotonic()\n\n        ok = backend.load_model(\n            hf_repo = repo,\n            hf_variant = variant,\n            model_identifier = f\"advisor:{repo}:{variant}\",\n            is_vision = False,\n            n_ctx = 2048,\n            n_gpu_layers = -1,\n        )\n        if not ok:\n            logger.warning(\"Advisor model failed to start\")\n            return None\n\n        logger.info(f\"Advisor model loaded in {time.monotonic() - t0:.1f}s\")\n        # ── Format samples ──\n        samples_text = \"\"\n        for i, row in enumerate(samples[:5], 1):\n            parts = [f\"  {col}: {str(row.get(col, ''))[:200]}\" for col in columns]\n            samples_text += f\"Row {i}:\\n\" + \"\\n\".join(parts) + \"\\n\"\n\n        metadata_str = (\n            json.dumps(dataset_metadata, indent = 2, default = str)[:500]\n            if dataset_metadata\n            else \"N/A\"\n        )\n        card_excerpt = (dataset_card or \"\")[:1200] or \"N/A\"\n\n        # ── Target Model Hints ──\n        target_hints = \"\"\n        is_gemma_3n = False\n        if model_name:\n            try:\n                from utils.models.model_config import load_model_config\n\n                config = load_model_config(\n                    model_name,\n                    use_auth = True,\n                    token = hf_token,\n                    trust_remote_code = False,\n                )\n                archs = getattr(config, \"architectures\", [])\n                if archs and \"Gemma3nForConditionalGeneration\" in archs:\n                    is_gemma_3n = True\n            except Exception:\n                is_gemma_3n = \"gemma-3n\" in model_name.lower()\n\n        if model_type == \"audio\" and not is_gemma_3n:\n            target_hints = (\n                \"\\n\\nHINT: The user is training an AUDIO model. The dataset MUST contain \"\n                \"a column with audio files/paths. Ensure one such column is selected \"\n                \"as part of the input.\"\n            )\n        elif model_type == \"embeddings\":\n            target_hints = (\n                \"\\n\\nHINT: The user is training an EMBEDDING model. These models typically \"\n                \"do not use standard conversational input/output formats but instead use \"\n                \"specific formats like:\\n\"\n                \"- Pairs of texts for Semantic Textual Similarity (STS)\\n\"\n                \"- Premise, hypothesis, and label for Natural Language Inference (NLI)\\n\"\n                \"- Queries and positive/negative documents for information retrieval\\n\"\n                \"Ensure the dataset format mapped reflects these specialized tasks.\"\n            )\n\n        # ── Pass 1: Classify ──\n        logger.info(\"Pass 1: Classifying dataset...\")\n        t1 = time.monotonic()\n        messages1 = [\n            {\n                \"role\": \"system\",\n                \"content\": (\n                    \"You are a dataset analyst. Your job is to look at a HuggingFace dataset \"\n                    \"and figure out what kind of data it contains and whether it is already in \"\n                    \"a conversational format suitable for LLM fine-tuning. A dataset is \"\n                    '\"conversational\" if it already has columns like \"messages\", \"conversations\", '\n                    'or multiturn \"user\"/\"assistant\" pairs. Some datasets are NOT conversational '\n                    \"— they are things like summarization, question answering, translation, \"\n                    \"classification, etc. Those need conversion. You must respond with ONLY a \"\n                    \"valid JSON object. Do not write any explanation before or after the JSON.\"\n                    f\"{target_hints}\"\n                ),\n            },\n            {\n                \"role\": \"user\",\n                \"content\": textwrap.dedent(f\"\"\"\\\n                    Look at this HuggingFace dataset and classify it.\n\n                    DATASET CARD (excerpt):\n                    {card_excerpt}\n\n                    METADATA:\n                    {metadata_str}\n\n                    COLUMNS: {columns}\n\n                    SAMPLE DATA (first 3 rows):\n                    {samples_text}\n\n                    Based on the above, respond with this exact JSON structure:\n                    {{\n                        \"dataset_type\": \"<one of: summarization, question_answering, translation, classification, natural_language_inference, instruction_following, conversational, code_generation, other>\",\n                        \"is_conversational\": <true if the dataset already has message/conversation columns, false otherwise>,\n                        \"needs_conversion\": <true if it needs to be converted into user/assistant turns, false if it is already conversational>,\n                        \"description\": \"<one sentence describing what this dataset contains>\",\n                        \"task_description\": \"<one sentence describing the task: what input goes in and what output comes out>\"\n                    }}\n\n                    Respond with ONLY the JSON object. No markdown, no explanation.\"\"\"),\n            },\n        ]\n        raw1 = _generate_with_backend(backend, messages1, max_tokens = 256)\n        pass1 = _parse_json_response(raw1)\n        logger.info(f\"Pass 1 done ({time.monotonic() - t1:.1f}s): {pass1}\")\n\n        if not pass1:\n            logger.warning(f\"Advisor Pass 1 failed to produce JSON: {raw1[:200]}\")\n            return None\n\n        # If dataset is already conversational, skip passes 2-3\n        if pass1.get(\"is_conversational\") and not pass1.get(\"needs_conversion\"):\n            return {\n                \"success\": True,\n                \"dataset_type\": pass1.get(\"dataset_type\"),\n                \"is_conversational\": True,\n                \"user_notification\": (\n                    \"This dataset is already in conversational format. \"\n                    \"No conversion needed — columns can be mapped directly.\"\n                ),\n            }\n\n        # ── Pass 2: Map columns to roles ──\n        logger.info(\"Pass 2: Mapping columns to roles...\")\n\n        t2 = time.monotonic()\n        messages2 = [\n            {\n                \"role\": \"system\",\n                \"content\": (\n                    \"You are a data preparation assistant. Your job is to assign each column \"\n                    \"in a dataset to a conversation role for LLM fine-tuning. There are exactly \"\n                    \"two roles:\\n\"\n                    '- \"user\" = This column contains INPUT that the model will receive as a prompt.\\n'\n                    '- \"assistant\" = This column contains OUTPUT that the model should learn to generate.\\n\\n'\n                    \"CRITICAL RULES:\\n\"\n                    '1. There MUST be at least one column assigned to \"user\" AND at least one '\n                    'column assigned to \"assistant\". Never assign all columns to the same role.\\n'\n                    \"2. The column that contains the TARGET or OUTPUT or ANSWER or LABEL must \"\n                    'ALWAYS be assigned to \"assistant\". This is the thing the model should learn '\n                    \"to produce.\\n\"\n                    \"3. The columns that contain the SOURCE or INPUT or CONTEXT or QUESTION must \"\n                    'be assigned to \"user\". This is what the model receives.\\n'\n                    '4. Metadata columns like \"id\", \"index\", \"source\", \"url\", \"date\" should be '\n                    'set to \"skip\".\\n\\n'\n                    \"You must respond with ONLY a valid JSON object.\"\n                    f\"{target_hints}\"\n                ),\n            },\n            {\n                \"role\": \"user\",\n                \"content\": textwrap.dedent(f\"\"\"\\\n                    Here is a dataset that has been classified:\n\n                    CLASSIFICATION:\n                    {json.dumps(pass1, indent = 2)}\n\n                    COLUMNS AVAILABLE: {columns}\n\n                    SAMPLE DATA (first 3 rows):\n                    {samples_text}\n\n                    Your task: assign each column to either \"user\", \"assistant\", or \"skip\".\n\n                    Here are worked examples to guide you:\n\n                    Example 1 — Summarization dataset with columns [\"document\", \"summary\"]:\n                      \"document\" is the input text → \"user\"\n                      \"summary\" is the output the model should generate → \"assistant\"\n                      Result: {{\"document\": \"user\", \"summary\": \"assistant\"}}\n\n                    Example 2 — Question answering dataset with columns [\"context\", \"question\", \"answer\"]:\n                      \"context\" is input → \"user\"\n                      \"question\" is input → \"user\"\n                      \"answer\" is what the model should generate → \"assistant\"\n                      Result: {{\"context\": \"user\", \"question\": \"user\", \"answer\": \"assistant\"}}\n\n                    Example 3 — Classification dataset with columns [\"text\", \"label\"]:\n                      \"text\" is input → \"user\"\n                      \"label\" is the output the model should predict → \"assistant\"\n                      Result: {{\"text\": \"user\", \"label\": \"assistant\"}}\n\n                    Example 4 — Translation dataset with columns [\"en\", \"fr\"]:\n                      \"en\" is the source language (input) → \"user\"\n                      \"fr\" is the target language (output) → \"assistant\"\n                      Result: {{\"en\": \"user\", \"fr\": \"assistant\"}}\n\n                    Now apply this logic to the actual dataset columns listed above.\n\n                    Respond with this exact JSON structure:\n                    {{\n                        \"column_roles\": {{\n                            \"<column_name>\": \"<user|assistant|skip>\"\n                        }},\n                        \"label_mapping\": <if any column contains integer labels (like 0, 1, 2), provide a mapping like {{\"label\": {{\"0\": \"entailment\", \"1\": \"neutral\", \"2\": \"contradiction\"}}}}, otherwise null>,\n                        \"notes\": \"<brief explanation of why you assigned roles this way>\"\n                    }}\n\n                    REMEMBER: There must be at least one \"user\" column AND at least one \"assistant\" column. If all columns are \"user\", you made a mistake — the output/target column should be \"assistant\".\n\n                    Respond with ONLY the JSON object.\"\"\"),\n            },\n        ]\n        raw2 = _generate_with_backend(backend, messages2, max_tokens = 512)\n        pass2 = _parse_json_response(raw2)\n        logger.info(f\"Pass 2 done ({time.monotonic() - t2:.1f}s): {pass2}\")\n\n        if not pass2:\n            logger.warning(f\"Advisor Pass 2 failed to produce JSON: {raw2[:200]}\")\n            return None\n\n        # ── Extract and validate column roles from Pass 2 ──\n        column_roles = pass2.get(\"column_roles\", {})\n        label_map = pass2.get(\"label_mapping\") or {}  # may be null\n\n        # Validate: must have at least one user AND one assistant\n        roles_present = set(column_roles.values())\n        if \"user\" not in roles_present or \"assistant\" not in roles_present:\n            logger.warning(\n                f\"Pass 2 sanity fail: missing user or assistant role: {column_roles}\"\n            )\n            return None  # triggers fallback to simple classification\n\n        # ── Pass 3: System prompt (non-conversational datasets only) ──\n        sys_prompt = \"\"\n        dtype = pass1.get(\"dataset_type\", \"unknown\")\n        is_conv = pass1.get(\"is_conversational\", False)\n\n        if not is_conv:\n            logger.info(\"Pass 3: Generating system prompt...\")\n            t3 = time.monotonic()\n\n            # Format label mapping info for the prompt\n            label_info = \"\"\n            if label_map:\n                for col, mapping in label_map.items():\n                    if isinstance(mapping, dict) and mapping:\n                        pairs = \", \".join(f\"{k} = {v}\" for k, v in mapping.items())\n                        label_info += f\"\\nLabel mapping for '{col}': {pairs}\"\n\n            # Describe the role assignments for context\n            user_cols = [c for c, r in column_roles.items() if r == \"user\"]\n            asst_cols = [c for c, r in column_roles.items() if r == \"assistant\"]\n            task_desc = pass1.get(\"task_description\") or pass1.get(\"description\", \"\")\n\n            messages3 = [\n                {\n                    \"role\": \"user\",\n                    \"content\": textwrap.dedent(f\"\"\"\\\n                        I am building a fine-tuning dataset for an LLM. I need you to write a \\\n                        system prompt that will be included in every training example to tell \\\n                        the model what task it is performing.\n\n                        Here is the task information:\n                        - Dataset type: {dtype}\n                        - Task description: {task_desc}\n                        - The USER (input) columns are: {user_cols}\n                        - The ASSISTANT (output) columns are: {asst_cols}\n                        {label_info}\n\n                        Write a system prompt that:\n                        1. Explains what task the model is performing in plain language\n                        2. Describes what input it will receive\n                        3. Describes what output it should produce\n                        4. Is 2-4 sentences long\n\n                        Write ONLY the system prompt text. No quotes, no labels, no explanation around it.\"\"\"),\n                },\n            ]\n            raw3 = _generate_with_backend(backend, messages3, max_tokens = 256)\n            logger.info(\n                f\"Pass 3 done ({time.monotonic() - t3:.1f}s): {raw3[:200] if raw3 else None}\"\n            )\n\n            if raw3:\n                # Pass 3 returns raw text, not JSON — clean it up\n                cleaned = raw3.strip().strip('\"').strip(\"'\").strip()\n                if len(cleaned) >= 20 and cleaned.lower() not in (\"null\", \"none\", \"\"):\n                    sys_prompt = cleaned\n\n        # Build suggested_mapping (column → role, for the frontend dropdowns)\n        suggested_mapping = {}\n        for col, role in column_roles.items():\n            if col in columns and role in (\"user\", \"assistant\", \"system\"):\n                suggested_mapping[col] = role\n\n        # Build user notification from Pass 1 classification\n        desc = pass1.get(\"task_description\") or pass1.get(\"description\", \"\")\n        note_parts = [f\"This is a {dtype} dataset (not conversational).\"]\n        if desc:\n            note_parts.append(desc)\n        note_parts.append(\n            \"Columns have been mapped to conversation roles. You can adjust the mapping if needed.\"\n        )\n        user_notification = \" \".join(note_parts)\n\n        total_time = time.monotonic() - t0\n        logger.info(\n            f\"Advisor complete ({total_time:.1f}s): type={dtype}, mapping={suggested_mapping}, sys_prompt={bool(sys_prompt)}, label_map={bool(label_map)}\"\n        )\n\n        return {\n            \"success\": True,\n            \"suggested_mapping\": suggested_mapping,\n            \"system_prompt\": sys_prompt,\n            \"label_mapping\": label_map if label_map else None,\n            \"dataset_type\": dtype,\n            \"is_conversational\": is_conv,\n            \"user_notification\": user_notification,\n        }\n\n    except Exception as e:\n        logger.warning(f\"Advisor multi-pass failed: {e}\")\n        return None\n\n    finally:\n        if backend is not None:\n            try:\n                backend.unload_model()\n                logger.info(\"Advisor model unloaded\")\n            except Exception:\n                pass\n\n\ndef llm_conversion_advisor(\n    column_names: list[str],\n    samples: list[dict],\n    dataset_name: Optional[str] = None,\n    hf_token: Optional[str] = None,\n    model_name: Optional[str] = None,\n    model_type: Optional[str] = None,\n) -> Optional[dict[str, Any]]:\n    \"\"\"\n    Full conversion advisor: fetch HF card → multi-pass LLM analysis.\n\n    Falls back to simple llm_classify_columns() if the multi-pass advisor fails.\n\n    Returns:\n        Dict with keys: success, suggested_mapping, system_prompt, user_template,\n        assistant_template, label_mapping, dataset_type, is_conversational,\n        user_notification. Or None on complete failure.\n    \"\"\"\n    # Fetch HF dataset card if this looks like a HF dataset (has a slash)\n    dataset_card = None\n    dataset_metadata = None\n    if dataset_name and \"/\" in dataset_name:\n        dataset_card, dataset_metadata = fetch_hf_dataset_card(dataset_name, hf_token)\n\n    # Try multi-pass advisor\n    result = _run_multi_pass_advisor(\n        columns = column_names,\n        samples = samples,\n        dataset_name = dataset_name,\n        dataset_card = dataset_card,\n        dataset_metadata = dataset_metadata,\n        model_name = model_name,\n        model_type = model_type,\n        hf_token = hf_token,\n    )\n\n    if result and result.get(\"success\"):\n        logger.info(f\"Conversion advisor succeeded: type={result.get('dataset_type')}\")\n        return result\n\n    # Fallback: simple column classification\n    logger.info(\"Advisor failed, falling back to simple column classification\")\n    simple_mapping = llm_classify_columns(column_names, samples)\n    if simple_mapping:\n        return {\n            \"success\": True,\n            \"suggested_mapping\": {\n                col: role\n                for col, role in simple_mapping.items()\n                if role in (\"user\", \"assistant\", \"system\")\n            },\n            \"dataset_type\": None,\n            \"is_conversational\": None,\n            \"user_notification\": None,\n        }\n\n    return None\n"
  },
  {
    "path": "studio/backend/utils/datasets/model_mappings.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nModel and template mappings for dataset processing.\n\nThis module contains the mapping dictionaries that associate model names\nwith their corresponding chat templates and response markers.\n\"\"\"\n\nTEMPLATE_TO_MODEL_MAPPER = {\n    \"phi-3.5\": (\n        \"unsloth/Phi-3.5-mini-instruct-bnb-4bit\",\n        \"unsloth/Phi-3.5-mini-instruct\",\n        \"microsoft/Phi-3.5-mini-instruct\",\n    ),\n    \"phi-3\": (\n        \"unsloth/Phi-3-mini-4k-instruct-bnb-4bit\",\n        \"unsloth/Phi-3-mini-4k-instruct\",\n        \"microsoft/Phi-3-mini-4k-instruct\",\n        \"unsloth/Phi-3-medium-4k-instruct-bnb-4bit\",\n        \"unsloth/Phi-3-medium-4k-instruct\",\n        \"microsoft/Phi-3-medium-4k-instruct\",\n        \"unsloth/Phi-3-mini-4k-instruct-v0-bnb-4bit\",\n        \"unsloth/Phi-3-mini-4k-instruct-v0\",\n    ),\n    \"phi-4\": (\n        \"unsloth/phi-4-unsloth-bnb-4bit\",\n        \"unsloth/phi-4\",\n        \"microsoft/phi-4\",\n        \"unsloth/phi-4-bnb-4bit\",\n        \"unsloth/phi-4-reasoning-unsloth-bnb-4bit\",\n        \"unsloth/phi-4-reasoning\",\n        \"microsoft/Phi-4-reasoning\",\n        \"unsloth/phi-4-reasoning-bnb-4bit\",\n        \"unsloth/phi-4-reasoning-plus-unsloth-bnb-4bit\",\n        \"unsloth/phi-4-reasoning-plus\",\n        \"microsoft/Phi-4-reasoning-plus\",\n        \"unsloth/phi-4-reasoning-plus-bnb-4bit\",\n        \"unsloth/phi-4-mini-reasoning-unsloth-bnb-4bit\",\n        \"unsloth/phi-4-mini-reasoning\",\n        \"microsoft/Phi-4-mini-reasoning\",\n        \"unsloth/phi-4-mini-reasoning-bnb-4bit\",\n        \"unsloth/Phi-4-mini-instruct-unsloth-bnb-4bit\",\n        \"unsloth/Phi-4-mini-instruct\",\n        \"microsoft/Phi-4-mini-instruct\",\n        \"unsloth/Phi-4-mini-instruct-bnb-4bit\",\n    ),\n    \"mistral\": (\n        \"unsloth/mistral-7b-instruct-v0.1-bnb-4bit\",\n        \"unsloth/mistral-7b-instruct-v0.1\",\n        \"mistralai/Mistral-7B-Instruct-v0.1\",\n        \"unsloth/mistral-7b-instruct-v0.2-bnb-4bit\",\n        \"unsloth/mistral-7b-instruct-v0.2\",\n        \"mistralai/Mistral-7B-Instruct-v0.2\",\n        \"unsloth/mistral-7b-instruct-v0.3-bnb-4bit\",\n        \"unsloth/mistral-7b-instruct-v0.3\",\n        \"mistralai/Mistral-7B-Instruct-v0.3\",\n        \"unsloth/Mixtral-8x7B-Instruct-v0.1-unsloth-bnb-4bit\",\n        \"unsloth/Mixtral-8x7B-Instruct-v0.1\",\n        \"mistralai/Mixtral-8x7B-Instruct-v0.1\",\n        \"unsloth/Mixtral-8x7B-Instruct-v0.1-bnb-4bit\",\n        \"unsloth/Mistral-Nemo-Instruct-2407-bnb-4bit\",\n        \"unsloth/Mistral-Nemo-Instruct-2407\",\n        \"mistralai/Mistral-Nemo-Instruct-2407\",\n        \"unsloth/Mistral-Large-Instruct-2407-bnb-4bit\",\n        \"mistralai/Mistral-Large-Instruct-2407\",\n        \"unsloth/Mistral-Small-Instruct-2409-bnb-4bit\",\n        \"unsloth/Mistral-Small-Instruct-2409\",\n        \"mistralai/Mistral-Small-Instruct-2409\",\n        \"unsloth/Mistral-Small-24B-Instruct-2501-unsloth-bnb-4bit\",\n        \"unsloth/Mistral-Small-24B-Instruct-2501\",\n        \"mistralai/Mistral-Small-24B-Instruct-2501\",\n        \"unsloth/Mistral-Small-24B-Instruct-2501-bnb-4bit\",\n        \"unsloth/Mistral-Small-3.1-24B-Instruct-2503-unsloth-bnb-4bit\",\n        \"unsloth/Mistral-Small-3.1-24B-Instruct-2503\",\n        \"mistralai/Mistral-Small-3.1-24B-Instruct-2503\",\n        \"unsloth/Mistral-Small-3.1-24B-Instruct-2503-bnb-4bit\",\n        \"unsloth/Mistral-Small-3.2-24B-Instruct-2506-unsloth-bnb-4bit\",\n        \"unsloth/Mistral-Small-3.2-24B-Instruct-2506\",\n        \"mistralai/Mistral-Small-3.2-24B-Instruct-2506\",\n        \"unsloth/Mistral-Small-3.2-24B-Instruct-2506-bnb-4bit\",\n    ),\n    \"llama\": (\n        \"meta-llama/Llama-2-13b-chat-hf\",\n        \"unsloth/llama-2-7b-chat-bnb-4bit\",\n        \"unsloth/llama-2-7b-chat\",\n        \"meta-llama/Llama-2-7b-chat-hf\",\n    ),\n    \"llama3\": (\n        \"unsloth/llama-3-8b-Instruct-bnb-4bit\",\n        \"unsloth/llama-3-8b-Instruct\",\n        \"meta-llama/Meta-Llama-3-8B-Instruct\",\n        \"unsloth/llama-3-70b-Instruct-bnb-4bit\",\n        \"meta-llama/Meta-Llama-3-70B-Instruct\",\n    ),\n    \"llama-3.1\": (\n        \"unsloth/Meta-Llama-3.1-8B-Instruct-unsloth-bnb-4bit\",\n        \"unsloth/Meta-Llama-3.1-8B-Instruct\",\n        \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n        \"unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit\",\n        \"unsloth/Llama-3.1-8B-Instruct-unsloth-bnb-4bit\",\n        \"unsloth/Llama-3.1-8B-Instruct\",\n        \"meta-llama/Llama-3.1-8B-Instruct\",\n        \"unsloth/Llama-3.1-8B-Instruct-bnb-4bit\",\n        \"unsloth/Meta-Llama-3.1-405B-Instruct-bnb-4bit\",\n        \"meta-llama/Meta-Llama-3.1-405B-Instruct\",\n        \"unsloth/Meta-Llama-3.1-70B-Instruct-bnb-4bit\",\n        \"unsloth/Meta-Llama-3.1-70B-Instruct\",\n        \"meta-llama/Meta-Llama-3.1-70B-Instruct\",\n        \"unsloth/Llama-3.1-Storm-8B-bnb-4bit\",\n        \"unsloth/Llama-3.1-Storm-8B\",\n        \"akjindal53244/Llama-3.1-Storm-8B\",\n        \"unsloth/Hermes-3-Llama-3.1-8B-bnb-4bit\",\n        \"unsloth/Hermes-3-Llama-3.1-8B\",\n        \"NousResearch/Hermes-3-Llama-3.1-8B\",\n        \"unsloth/Hermes-3-Llama-3.1-70B-bnb-4bit\",\n        \"unsloth/Hermes-3-Llama-3.1-70B\",\n        \"NousResearch/Hermes-3-Llama-3.1-70B\",\n        \"unsloth/Hermes-3-Llama-3.1-405B-bnb-4bit\",\n        \"NousResearch/Hermes-3-Llama-3.1-405B\",\n        \"unsloth/Llama-3.1-Nemotron-70B-Instruct-bnb-4bit\",\n        \"unsloth/Llama-3.1-Nemotron-70B-Instruct\",\n        \"nvidia/Llama-3.1-Nemotron-70B-Instruct-HF\",\n        \"unsloth/Llama-3.1-Tulu-3-8B-bnb-4bit\",\n        \"unsloth/Llama-3.1-Tulu-3-8B\",\n        \"allenai/Llama-3.1-Tulu-3-8B\",\n        \"unsloth/Llama-3.1-Tulu-3-70B-bnb-4bit\",\n        \"unsloth/Llama-3.1-Tulu-3-70B\",\n        \"allenai/Llama-3.1-Tulu-3-70B\",\n    ),\n    \"llama-3.2\": (\n        \"unsloth/Llama-3.2-1B-Instruct-unsloth-bnb-4bit\",\n        \"unsloth/Llama-3.2-1B-Instruct\",\n        \"meta-llama/Llama-3.2-1B-Instruct\",\n        \"unsloth/Llama-3.2-1B-Instruct-bnb-4bit\",\n        \"unsloth/Llama-3.2-3B-Instruct-unsloth-bnb-4bit\",\n        \"unsloth/Llama-3.2-3B-Instruct\",\n        \"meta-llama/Llama-3.2-3B-Instruct\",\n        \"unsloth/Llama-3.2-3B-Instruct-bnb-4bit\",\n        \"unsloth/Llama-3.2-11B-Vision-Instruct-unsloth-bnb-4bit\",\n        \"unsloth/Llama-3.2-11B-Vision-Instruct\",\n        \"meta-llama/Llama-3.2-11B-Vision-Instruct\",\n        \"unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit\",\n        \"unsloth/Llama-3.2-90B-Vision-Instruct-bnb-4bit\",\n        \"unsloth/Llama-3.2-90B-Vision-Instruct\",\n        \"meta-llama/Llama-3.2-90B-Vision-Instruct\",\n    ),\n    \"llama-3.3\": (\n        \"unsloth/Llama-3.3-70B-Instruct-bnb-4bit\",\n        \"unsloth/Llama-3.3-70B-Instruct\",\n        \"meta-llama/Llama-3.3-70B-Instruct\",\n    ),\n    \"gemma\": (\n        \"unsloth/gemma-7b-it-bnb-4bit\",\n        \"unsloth/gemma-7b-it\",\n        \"google/gemma-7b-it\",\n        \"google/gemma-2b-it\",\n        \"unsloth/gemma-1.1-2b-it-bnb-4bit\",\n        \"unsloth/gemma-1.1-2b-it\",\n        \"google/gemma-1.1-2b-it\",\n        \"unsloth/gemma-1.1-7b-it-bnb-4bit\",\n        \"unsloth/gemma-1.1-7b-it\",\n        \"google/gemma-1.1-7b-it\",\n    ),\n    \"gemma2\": (\n        \"unsloth/gemma-2-9b-it-bnb-4bit\",\n        \"unsloth/gemma-2-9b-it\",\n        \"google/gemma-2-9b-it\",\n        \"unsloth/gemma-2-27b-it-bnb-4bit\",\n        \"unsloth/gemma-2-27b-it\",\n        \"google/gemma-2-27b-it\",\n        \"unsloth/gemma-2-2b-it-bnb-4bit\",\n        \"unsloth/gemma-2-2b-it\",\n        \"google/gemma-2-2b-it\",\n    ),\n    \"gemma-3\": (\n        \"unsloth/gemma-3-1b-it-unsloth-bnb-4bit\",\n        \"unsloth/gemma-3-1b-it\",\n        \"google/gemma-3-1b-it\",\n        \"unsloth/gemma-3-1b-it-bnb-4bit\",\n        \"unsloth/gemma-3-4b-it-unsloth-bnb-4bit\",\n        \"unsloth/gemma-3-4b-it\",\n        \"google/gemma-3-4b-it\",\n        \"unsloth/gemma-3-4b-it-bnb-4bit\",\n        \"unsloth/gemma-3-12b-it-unsloth-bnb-4bit\",\n        \"unsloth/gemma-3-12b-it\",\n        \"google/gemma-3-12b-it\",\n        \"unsloth/gemma-3-12b-it-bnb-4bit\",\n        \"unsloth/gemma-3-27b-it-unsloth-bnb-4bit\",\n        \"unsloth/gemma-3-27b-it\",\n        \"google/gemma-3-27b-it\",\n        \"unsloth/gemma-3-27b-it-bnb-4bit\",\n        \"unsloth/gemma-3-270m-it-unsloth-bnb-4bit\",\n        \"unsloth/gemma-3-270m-it\",\n        \"google/gemma-3-270m-it\",\n        \"unsloth/gemma-3-270m-it-bnb-4bit\",\n        \"unsloth/gemma-3-270m-unsloth-bnb-4bit\",\n        \"unsloth/medgemma-4b-it-unsloth-bnb-4bit\",\n        \"unsloth/medgemma-4b-it\",\n        \"google/medgemma-4b-it\",\n        \"unsloth/medgemma-4b-it-bnb-4bit\",\n        \"unsloth/medgemma-27b-text-it-unsloth-bnb-4bit\",\n        \"unsloth/medgemma-27b-text-it\",\n        \"google/medgemma-27b-text-it\",\n        \"unsloth/medgemma-27b-text-it-bnb-4bit\",\n    ),\n    \"gemma3n\": (\n        \"unsloth/gemma-3n-E4B-it-unsloth-bnb-4bit\",\n        \"unsloth/gemma-3n-E4B-it\",\n        \"google/gemma-3n-E4B-it\",\n        \"unsloth/gemma-3n-E4B-it-unsloth-bnb-4bit\",\n        \"unsloth/gemma-3n-E2B-it-unsloth-bnb-4bit\",\n        \"unsloth/gemma-3n-E2B-it\",\n        \"google/gemma-3n-E2B-it\",\n        \"unsloth/gemma-3n-E2B-it-unsloth-bnb-4bit\",\n    ),\n    \"qwen2.5\": (\n        \"unsloth/Qwen2.5-0.5B-Instruct-unsloth-bnb-4bit\",\n        \"unsloth/Qwen2.5-0.5B-Instruct\",\n        \"Qwen/Qwen2.5-0.5B-Instruct\",\n        \"unsloth/Qwen2.5-0.5B-Instruct-bnb-4bit\",\n        \"unsloth/Qwen2.5-1.5B-Instruct-unsloth-bnb-4bit\",\n        \"unsloth/Qwen2.5-1.5B-Instruct\",\n        \"Qwen/Qwen2.5-1.5B-Instruct\",\n        \"unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit\",\n        \"unsloth/Qwen2.5-3B-Instruct-unsloth-bnb-4bit\",\n        \"unsloth/Qwen2.5-3B-Instruct\",\n        \"Qwen/Qwen2.5-3B-Instruct\",\n        \"unsloth/Qwen2.5-3B-Instruct-bnb-4bit\",\n        \"unsloth/Qwen2.5-7B-Instruct-unsloth-bnb-4bit\",\n        \"unsloth/Qwen2.5-7B-Instruct\",\n        \"Qwen/Qwen2.5-7B-Instruct\",\n        \"unsloth/Qwen2.5-7B-Instruct-bnb-4bit\",\n        \"unsloth/Qwen2.5-14B-Instruct-unsloth-bnb-4bit\",\n        \"unsloth/Qwen2.5-14B-Instruct\",\n        \"Qwen/Qwen2.5-14B-Instruct\",\n        \"unsloth/Qwen2.5-14B-Instruct-bnb-4bit\",\n        \"unsloth/Qwen2.5-32B-Instruct-bnb-4bit\",\n        \"unsloth/Qwen2.5-32B-Instruct\",\n        \"Qwen/Qwen2.5-32B-Instruct\",\n        \"unsloth/Qwen2.5-72B-Instruct-bnb-4bit\",\n        \"unsloth/Qwen2.5-72B-Instruct\",\n        \"Qwen/Qwen2.5-72B-Instruct\",\n        \"unsloth/Qwen2.5-0.5B-unsloth-bnb-4bit\",\n        \"unsloth/Qwen2.5-Math-1.5B-Instruct-bnb-4bit\",\n        \"unsloth/Qwen2.5-Math-1.5B-Instruct\",\n        \"Qwen/Qwen2.5-Math-1.5B-Instruct\",\n        \"unsloth/Qwen2.5-Math-7B-Instruct-bnb-4bit\",\n        \"unsloth/Qwen2.5-Math-7B-Instruct\",\n        \"Qwen/Qwen2.5-Math-7B-Instruct\",\n        \"unsloth/Qwen2.5-Math-72B-Instruct-bnb-4bit\",\n        \"unsloth/Qwen2.5-Math-72B-Instruct\",\n        \"Qwen/Qwen2.5-Math-72B-Instruct\",\n        \"unsloth/Qwen2.5-Coder-0.5B-Instruct-bnb-4bit\",\n        \"unsloth/Qwen2.5-Coder-0.5B-Instruct\",\n        \"Qwen/Qwen2.5-Coder-0.5B-Instruct\",\n        \"unsloth/Qwen2.5-Coder-1.5B-Instruct-bnb-4bit\",\n        \"unsloth/Qwen2.5-Coder-1.5B-Instruct\",\n        \"Qwen/Qwen2.5-Coder-1.5B-Instruct\",\n        \"unsloth/Qwen2.5-Coder-3B-Instruct-bnb-4bit\",\n        \"unsloth/Qwen2.5-Coder-3B-Instruct\",\n        \"Qwen/Qwen2.5-Coder-3B-Instruct\",\n        \"unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit\",\n        \"unsloth/Qwen2.5-Coder-7B-Instruct\",\n        \"Qwen/Qwen2.5-Coder-7B-Instruct\",\n        \"unsloth/Qwen2.5-Coder-14B-Instruct-bnb-4bit\",\n        \"unsloth/Qwen2.5-Coder-14B-Instruct\",\n        \"Qwen/Qwen2.5-Coder-14B-Instruct\",\n        \"unsloth/Qwen2.5-Coder-32B-Instruct-bnb-4bit\",\n        \"unsloth/Qwen2.5-Coder-32B-Instruct\",\n        \"Qwen/Qwen2.5-Coder-32B-Instruct\",\n        \"unsloth/Qwen2.5-VL-3B-Instruct-unsloth-bnb-4bit\",\n        \"unsloth/Qwen2.5-VL-3B-Instruct\",\n        \"Qwen/Qwen2.5-VL-3B-Instruct\",\n        \"unsloth/Qwen2.5-VL-3B-Instruct-bnb-4bit\",\n        \"unsloth/Qwen2.5-VL-7B-Instruct-unsloth-bnb-4bit\",\n        \"unsloth/Qwen2.5-VL-7B-Instruct\",\n        \"Qwen/Qwen2.5-VL-7B-Instruct\",\n        \"unsloth/Qwen2.5-VL-7B-Instruct-bnb-4bit\",\n        \"unsloth/Qwen2.5-VL-32B-Instruct-unsloth-bnb-4bit\",\n        \"unsloth/Qwen2.5-VL-32B-Instruct\",\n        \"Qwen/Qwen2.5-VL-32B-Instruct\",\n        \"unsloth/Qwen2.5-VL-32B-Instruct-bnb-4bit\",\n        \"unsloth/Qwen2.5-VL-72B-Instruct-unsloth-bnb-4bit\",\n        \"unsloth/Qwen2.5-VL-72B-Instruct\",\n        \"Qwen/Qwen2.5-VL-72B-Instruct\",\n        \"unsloth/Qwen2.5-VL-72B-Instruct-bnb-4bit\",\n        \"unsloth/OpenThinker-7B-unsloth-bnb-4bit\",\n        \"unsloth/OpenThinker-7B\",\n        \"open-thoughts/OpenThinker-7B\",\n        \"unsloth/OpenThinker-7B-bnb-4bit\",\n    ),\n    \"qwen3\": (\n        \"unsloth/Qwen3-0.6B-unsloth-bnb-4bit\",\n        \"unsloth/Qwen3-0.6B\",\n        \"Qwen/Qwen3-0.6B\",\n        \"unsloth/Qwen3-0.6B-bnb-4bit\",\n        \"unsloth/Qwen3-1.7B-unsloth-bnb-4bit\",\n        \"unsloth/Qwen3-1.7B\",\n        \"Qwen/Qwen3-1.7B\",\n        \"unsloth/Qwen3-1.7B-bnb-4bit\",\n        \"unsloth/Qwen3-4B-unsloth-bnb-4bit\",\n        \"unsloth/Qwen3-4B\",\n        \"Qwen/Qwen3-4B\",\n        \"unsloth/Qwen3-4B-bnb-4bit\",\n        \"unsloth/Qwen3-8B-unsloth-bnb-4bit\",\n        \"unsloth/Qwen3-8B\",\n        \"Qwen/Qwen3-8B\",\n        \"unsloth/Qwen3-8B-bnb-4bit\",\n        \"unsloth/Qwen3-14B-unsloth-bnb-4bit\",\n        \"unsloth/Qwen3-14B\",\n        \"Qwen/Qwen3-14B\",\n        \"unsloth/Qwen3-14B-bnb-4bit\",\n        \"unsloth/Qwen3-32B-unsloth-bnb-4bit\",\n        \"unsloth/Qwen3-32B\",\n        \"Qwen/Qwen3-32B\",\n        \"unsloth/Qwen3-32B-bnb-4bit\",\n        \"unsloth/Qwen3-30B-A3B-unsloth-bnb-4bit\",\n        \"unsloth/Qwen3-30B-A3B\",\n        \"Qwen/Qwen3-30B-A3B\",\n        \"unsloth/Qwen3-30B-A3B-bnb-4bit\",\n    ),\n    \"qwen3-instruct\": (\n        \"unsloth/Qwen3-4B-Instruct-2507-unsloth-bnb-4bit\",\n        \"unsloth/Qwen3-4B-Instruct-2507\",\n        \"Qwen/Qwen3-4B-Instruct-2507\",\n        \"unsloth/Qwen3-4B-Instruct-2507-bnb-4bit\",\n        \"unsloth/Qwen3-30B-A3B-Instruct-2507\",\n        \"Qwen/Qwen3-30B-A3B-Instruct-2507\",\n        \"unsloth/Qwen3-Coder-30B-A3B-Instruct\",\n        \"Qwen/Qwen3-Coder-30B-A3B-Instruct\",\n        \"unsloth/Qwen3-4B-Instruct-2507-unsloth-bnb-4bit\",\n        \"unsloth/Qwen3-4B-Instruct-2507\",\n        \"Qwen/Qwen3-4B-Instruct-2507\",\n        \"unsloth/Qwen3-4B-Instruct-2507-bnb-4bit\",\n    ),\n    \"qwen3-thinking\": (\n        \"unsloth/QwQ-32B-Preview-bnb-4bit\",\n        \"unsloth/QwQ-32B-Preview\",\n        \"Qwen/QwQ-32B-Preview\",\n        \"unsloth/QwQ-32B-unsloth-bnb-4bit\",\n        \"unsloth/QwQ-32B\",\n        \"Qwen/QwQ-32B\",\n        \"unsloth/QwQ-32B-bnb-4bit\",\n        \"unsloth/Qwen3-4B-Thinking-2507-unsloth-bnb-4bit\",\n        \"unsloth/Qwen3-4B-Thinking-2507\",\n        \"Qwen/Qwen3-4B-Thinking-2507\",\n        \"unsloth/Qwen3-4B-Thinking-2507-bnb-4bit\",\n        \"unsloth/Qwen3-30B-A3B-Thinking-2507\",\n        \"Qwen/Qwen3-30B-A3B-Thinking-2507\",\n    ),\n    \"qwen3.5\": (\n        \"unsloth/Qwen3.5-0.8B\",\n        \"unsloth/Qwen3.5-2B\",\n        \"unsloth/Qwen3.5-4B\",\n        \"unsloth/Qwen3.5-27B\",\n        \"unsloth/Qwen3.5-35B-A3B\",\n    ),\n    \"zephyr\": (\n        \"unsloth/zephyr-sft-bnb-4bit\",\n        \"unsloth/zephyr-sft\",\n        \"HuggingFaceH4/mistral-7b-sft-beta\",\n    ),\n    \"chatml\": (\n        \"unsloth/yi-6b-bnb-4bit\",\n        \"unsloth/yi-6b\",\n        \"01-ai/Yi-6B\",\n        \"unsloth/Hermes-2-Pro-Mistral-7B-bnb-4bit\",\n        \"unsloth/Hermes-2-Pro-Mistral-7B\",\n        \"NousResearch/Hermes-2-Pro-Mistral-7B\",\n        \"unsloth/OpenHermes-2.5-Mistral-7B-bnb-4bit\",\n        \"unsloth/OpenHermes-2.5-Mistral-7B\",\n        \"teknium/OpenHermes-2.5-Mistral-7B\",\n    ),\n    \"gpt-oss\": (\n        \"unsloth/gpt-oss-20b-unsloth-bnb-4bit\",\n        \"unsloth/gpt-oss-20b\",\n        \"openai/gpt-oss-20b\",\n        \"unsloth/gpt-oss-20b-unsloth-bnb-4bit\",\n        \"unsloth/gpt-oss-120b-unsloth-bnb-4bit\",\n        \"unsloth/gpt-oss-120b\",\n        \"openai/gpt-oss-120b\",\n        \"unsloth/gpt-oss-120b-unsloth-bnb-4bit\",\n    ),\n    \"starling\": (\n        \"unsloth/Starling-LM-7B-beta-bnb-4bit\",\n        \"unsloth/Starling-LM-7B-beta\",\n        \"Nexusflow/Starling-LM-7B-beta\",\n    ),\n    \"yi-chat\": (\n        \"unsloth/yi-34b-chat-bnb-4bit\",\n        \"01-ai/Yi-6B-Chat\",\n        \"01-ai/Yi-34B-Chat\",\n    ),\n    \"glm\": (\n        \"unsloth/GLM-4.7-Flash-unsloth-bnb-4bit\",\n        \"unsloth/GLM-4.7-Flash\",\n        \"THUDM/GLM-4.7-Flash\",\n        \"unsloth/GLM-4.7-Flash-bnb-4bit\",\n    ),\n}\n\nMODEL_TO_TEMPLATE_MAPPER = {}\n\nfor key, values in TEMPLATE_TO_MODEL_MAPPER.items():\n    for value in values:\n        MODEL_TO_TEMPLATE_MAPPER[value] = key\n\n    # Get lowercased\n    lowered_key = key.lower()\n    for value in values:\n        MODEL_TO_TEMPLATE_MAPPER[value.lower()] = lowered_key\n\n\nTEMPLATE_TO_RESPONSES_MAPPER = {\n    \"gemma-3\": {\n        \"instruction\": \"<start_of_turn>user\\n\",\n        \"response\": \"<start_of_turn>model\\n\",\n    },\n    \"gemma3n\": {\n        \"instruction\": \"<start_of_turn>user\\n\",\n        \"response\": \"<start_of_turn>model\\n\",\n    },\n    \"qwen3.5\": {\n        \"instruction\": \"<|im_start|>user\\n\",\n        \"response\": \"<|im_start|>assistant\\n\",\n    },\n    \"qwen3-instruct\": {\n        \"instruction\": \"<|im_start|>user\\n\",\n        \"response\": \"<|im_start|>assistant\\n\",\n    },\n    \"qwen3-thinking\": {\n        \"instruction\": \"<|im_start|>user\\n\",\n        \"response\": \"<|im_start|>assistant\\n<think>\\n\",\n    },\n    \"qwen3\": {\n        \"instruction\": \"<|im_start|>user\\n\",\n        \"response\": \"<|im_start|>assistant\\n\",\n    },\n    \"qwen2.5\": {\n        \"instruction\": \"<|im_start|>user\\n\",\n        \"response\": \"<|im_start|>assistant\\n\",\n    },\n    \"llama-3.2\": {\n        \"instruction\": \"<|start_header_id|>user<|end_header_id|>\\n\\n\",\n        \"response\": \"<|start_header_id|>assistant<|end_header_id|>\\n\\n\",\n    },\n    \"llama-3.3\": {\n        \"instruction\": \"<|start_header_id|>user<|end_header_id|>\\n\\n\",\n        \"response\": \"<|start_header_id|>assistant<|end_header_id|>\\n\\n\",\n    },\n    \"llama-3.1\": {\n        \"instruction\": \"<|start_header_id|>user<|end_header_id|>\\n\\n\",\n        \"response\": \"<|start_header_id|>assistant<|end_header_id|>\\n\\n\",\n    },\n    \"llama3\": {\n        \"instruction\": \"<|start_header_id|>user<|end_header_id|>\\n\\n\",\n        \"response\": \"<|start_header_id|>assistant<|end_header_id|>\\n\\n\",\n    },\n    \"phi-3\": {\n        \"instruction\": \"<|user|>\\n\",\n        \"response\": \"<|assistant|>\\n\",\n    },\n    \"phi-3.5\": {\n        \"instruction\": \"<|user|>\\n\",\n        \"response\": \"<|assistant|>\\n\",\n    },\n    \"phi-4\": {\n        \"instruction\": \"<|im_start|>user<|im_sep|>\",\n        \"response\": \"<|im_start|>assistant<|im_sep|>\",\n    },\n    \"mistral\": {\n        \"instruction\": \"[INST] \",\n        \"response\": \" [/INST]\",\n    },\n    \"llama\": {\n        \"instruction\": \"[INST] \",\n        \"response\": \" [/INST]\",\n    },\n    \"chatml\": {\n        \"instruction\": \"<|im_start|>user\\n\",\n        \"response\": \"<|im_start|>assistant\\n\",\n    },\n    \"zephyr\": {\n        \"instruction\": \"<|user|>\\n\",\n        \"response\": \"<|assistant|>\\n\",\n    },\n    \"unsloth\": {\n        \"instruction\": \">>> User: \",\n        \"response\": \">>> Assistant: \",\n    },\n    \"vicuna\": {\n        \"instruction\": \"USER: \",\n        \"response\": \"ASSISTANT: \",\n    },\n    \"alpaca\": {\n        \"instruction\": \"### Instruction:\\n\",\n        \"response\": \"### Response:\\n\",\n    },\n    \"gemma\": {\n        \"instruction\": \"<start_of_turn>user\\n\",\n        \"response\": \"<start_of_turn>model\\n\",\n    },\n    \"gemma2\": {\n        \"instruction\": \"<start_of_turn>user\\n\",\n        \"response\": \"<start_of_turn>model\\n\",\n    },\n    \"gpt-oss\": {\n        \"instruction\": \"<|start|>user<|message|>\",\n        \"response\": \"<|start|>assistant<|channel|>final<|message|>\",\n    },\n    \"lfm-2\": {\n        \"instruction\": \"<|im_start|>user\\n\",\n        \"response\": \"<|im_start|>assistant\\n\",\n    },\n    \"starling\": {\n        \"instruction\": \"GPT4 Correct User: \",\n        \"response\": \"GPT4 Correct Assistant: \",\n    },\n    \"yi-chat\": {\n        \"instruction\": \"<|im_start|>user\\n\",\n        \"response\": \"<|im_start|>assistant\\n\",\n    },\n    \"glm\": {\n        \"instruction\": \"[gMASK]<sop><|user|>\",\n        \"response\": \"<|assistant|><think>\",\n    },\n}\n"
  },
  {
    "path": "studio/backend/utils/datasets/vlm_processing.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nVLM (Vision-Language Model) processing utilities.\n\nThis module contains functions for generating smart instructions\nfor VLM datasets based on content analysis and heuristics.\n\"\"\"\n\nimport re\nfrom itertools import islice\n\n\ndef generate_smart_vlm_instruction(\n    dataset,\n    text_column = \"text\",\n    image_column = \"image\",\n    dataset_name = None,\n):\n    \"\"\"\n    Generate smart, context-aware instruction for VLM datasets using heuristics.\n\n    Strategy:\n    1. Check for explicit question/instruction columns → use that\n    2. Infer from text column name + sample content\n    3. Analyze dataset name for task hints\n    4. Fall back to generic instruction\n\n    Returns:\n        dict: {\n            \"instruction\": str or None,  # None means use column content\n            \"instruction_type\": \"explicit\" | \"inferred\" | \"generic\",\n            \"uses_dynamic_instruction\": bool,  # True if instruction varies per sample\n            \"confidence\": float,  # 0.0 to 1.0\n        }\n    \"\"\"\n    column_names = set(next(iter(dataset)).keys())\n    sample = next(iter(dataset))\n\n    # ===== LEVEL 1: Explicit Instruction Columns =====\n    # Check for columns that contain per-sample instructions\n    question_columns = [\"question\", \"query\", \"prompt\", \"instruction\", \"user_prompt\"]\n\n    for col in question_columns:\n        if col in column_names:\n            # Check if this column has varied content (not just empty/same)\n            sample_content = sample[col]\n            if sample_content and str(sample_content).strip():\n                return {\n                    \"instruction\": None,  # Signal to use column content\n                    \"instruction_column\": col,\n                    \"instruction_type\": \"explicit\",\n                    \"uses_dynamic_instruction\": True,\n                    \"confidence\": 1.0,\n                }\n\n    # ===== LEVEL 2: Infer from Column Names + Content =====\n    text_col_lower = text_column.lower()\n\n    # Sample the text content to detect patterns\n    text_sample = str(sample.get(text_column, \"\"))[:500]  # First 500 chars\n\n    # Task-specific keywords and their instructions\n    task_patterns = {\n        # OCR / Transcription\n        \"ocr\": {\n            \"keywords\": [\"ocr\", \"transcribe\", \"transcript\"],\n            \"content_hints\": [\n                r\"[A-Za-z\\u0600-\\u06FF]{10,}\"\n            ],  # Long text passages (Latin/Arabic)\n            \"instruction\": \"Transcribe all the text shown in this image.\",\n            \"confidence\": 0.9,\n        },\n        # LaTeX / Math\n        \"latex\": {\n            \"keywords\": [\"latex\", \"math\", \"formula\", \"equation\"],\n            \"content_hints\": [r\"\\\\[a-z]+\\{\", r\"\\^\", r\"_\", r\"\\\\frac\"],  # LaTeX commands\n            \"instruction\": \"Convert this image to LaTeX notation.\",\n            \"confidence\": 0.95,\n        },\n        # Caption / Description\n        \"caption\": {\n            \"keywords\": [\"caption\", \"description\", \"describe\"],\n            \"content_hints\": [],\n            \"instruction\": \"Provide a detailed description of this image.\",\n            \"confidence\": 0.85,\n        },\n        # Medical / Radiology\n        \"medical\": {\n            \"keywords\": [\n                \"medical\",\n                \"radiology\",\n                \"xray\",\n                \"ct\",\n                \"mri\",\n                \"scan\",\n                \"diagnosis\",\n            ],\n            \"content_hints\": [r\"\\b(lesion|radiograph|patient|diagnosis|findings)\\b\"],\n            \"instruction\": \"Analyze this medical image and describe the key findings.\",\n            \"confidence\": 0.9,\n        },\n        # Code / Programming\n        \"code\": {\n            \"keywords\": [\"code\", \"program\", \"function\", \"algorithm\"],\n            \"content_hints\": [r\"def |class |function|import |return \"],\n            \"instruction\": \"Explain what this code visualization shows.\",\n            \"confidence\": 0.85,\n        },\n        # Chart / Graph\n        \"chart\": {\n            \"keywords\": [\"chart\", \"graph\", \"plot\", \"visualization\", \"diagram\"],\n            \"content_hints\": [r\"\\b(axis|legend|bar|line|pie|scatter)\\b\"],\n            \"instruction\": \"Describe this chart or graph, including key data points and trends.\",\n            \"confidence\": 0.85,\n        },\n        # Document / Text Recognition\n        \"document\": {\n            \"keywords\": [\"document\", \"page\", \"paragraph\", \"article\"],\n            \"content_hints\": [r\"\\n.*\\n.*\\n\"],  # Multi-line text\n            \"instruction\": \"Extract and transcribe the text from this document image.\",\n            \"confidence\": 0.85,\n        },\n    }\n\n    # Check column name matches\n    best_match = None\n    best_score = 0.0\n\n    for task_name, task_info in task_patterns.items():\n        score = 0.0\n\n        # Check column name\n        if any(keyword in text_col_lower for keyword in task_info[\"keywords\"]):\n            score += 0.5\n\n        # Check dataset name if provided\n        if dataset_name and any(\n            keyword in dataset_name.lower() for keyword in task_info[\"keywords\"]\n        ):\n            score += 0.3\n\n        # Check content patterns\n        for pattern in task_info[\"content_hints\"]:\n            if re.search(pattern, text_sample, re.IGNORECASE):\n                score += 0.4\n                break\n\n        if score > best_score:\n            best_score = score\n            best_match = task_info\n\n    if best_match and best_score > 0.5:  # Confidence threshold\n        return {\n            \"instruction\": best_match[\"instruction\"],\n            \"instruction_column\": None,\n            \"instruction_type\": \"inferred\",\n            \"uses_dynamic_instruction\": False,\n            \"confidence\": min(best_score, best_match[\"confidence\"]),\n        }\n\n    # ===== LEVEL 3: Analyze Dataset Name =====\n    if dataset_name:\n        name_lower = dataset_name.lower()\n\n        # Common dataset name patterns\n        if \"vqa\" in name_lower or \"question\" in name_lower:\n            return {\n                \"instruction\": \"Answer the question about this image.\",\n                \"instruction_column\": None,\n                \"instruction_type\": \"inferred\",\n                \"uses_dynamic_instruction\": False,\n                \"confidence\": 0.75,\n            }\n\n        if \"coco\" in name_lower or \"flickr\" in name_lower:\n            return {\n                \"instruction\": \"Provide a detailed caption for this image.\",\n                \"instruction_column\": None,\n                \"instruction_type\": \"inferred\",\n                \"uses_dynamic_instruction\": False,\n                \"confidence\": 0.75,\n            }\n\n    # ===== LEVEL 4: LLM-Assisted Instruction Generation =====\n    try:\n        from .llm_assist import llm_generate_vlm_instruction\n\n        sample_rows = []\n        for s in islice(dataset, 5):\n            row = {}\n            for col in s:\n                val = s[col]\n                if hasattr(val, \"size\") and hasattr(val, \"mode\"):  # PIL Image\n                    row[col] = \"<image>\"\n                elif isinstance(val, list):\n                    row[col] = str(val)[:300]\n                else:\n                    row[col] = str(val)[:300]\n            sample_rows.append(row)\n\n        llm_result = llm_generate_vlm_instruction(\n            column_names = list(column_names),\n            samples = sample_rows,\n            dataset_name = dataset_name,\n        )\n        if llm_result and llm_result.get(\"instruction\"):\n            print(\n                f\"\\n[DEBUG] LLM-assisted VLM instruction generated: \"\n                f\"'{llm_result['instruction']}' (confidence={llm_result.get('confidence', 'N/A')})\\n\",\n                flush = True,\n            )\n            return {\n                \"instruction\": llm_result[\"instruction\"],\n                \"instruction_column\": None,\n                \"instruction_type\": \"llm_assisted\",\n                \"uses_dynamic_instruction\": False,\n                \"confidence\": llm_result.get(\"confidence\", 0.85),\n            }\n    except Exception as e:\n        import logging\n\n        logging.getLogger(__name__).debug(f\"LLM-assisted instruction skipped: {e}\")\n\n    # ===== LEVEL 5: Generic Fallback =====\n    return {\n        \"instruction\": \"Describe this image in detail.\",\n        \"instruction_column\": None,\n        \"instruction_type\": \"generic\",\n        \"uses_dynamic_instruction\": False,\n        \"confidence\": 0.5,\n    }\n"
  },
  {
    "path": "studio/backend/utils/hardware/__init__.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nHardware detection and GPU utilities\n\"\"\"\n\nfrom .hardware import (\n    DeviceType,\n    DEVICE,\n    CHAT_ONLY,\n    detect_hardware,\n    get_device,\n    is_apple_silicon,\n    clear_gpu_cache,\n    get_gpu_memory_info,\n    log_gpu_memory,\n    get_gpu_summary,\n    get_package_versions,\n    get_gpu_utilization,\n    get_physical_gpu_count,\n    get_visible_gpu_count,\n    safe_num_proc,\n)\n\n__all__ = [\n    \"DeviceType\",\n    \"DEVICE\",\n    \"CHAT_ONLY\",\n    \"detect_hardware\",\n    \"get_device\",\n    \"is_apple_silicon\",\n    \"clear_gpu_cache\",\n    \"get_gpu_memory_info\",\n    \"log_gpu_memory\",\n    \"get_gpu_summary\",\n    \"get_package_versions\",\n    \"get_gpu_utilization\",\n    \"get_physical_gpu_count\",\n    \"get_visible_gpu_count\",\n    \"safe_num_proc\",\n]\n"
  },
  {
    "path": "studio/backend/utils/hardware/hardware.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nHardware detection — run once at startup, read everywhere.\n\nUsage:\n    # At FastAPI lifespan startup:\n    from utils.hardware import detect_hardware\n    detect_hardware()\n\n    # Anywhere else:\n    from utils.hardware import DEVICE, DeviceType, is_apple_silicon\n    if DEVICE == DeviceType.CUDA:\n        import torch\n        ...\n\"\"\"\n\nimport platform\nimport structlog\nfrom loggers import get_logger\nfrom enum import Enum\nfrom typing import Optional, Dict, Any\n\nlogger = get_logger(__name__)\n\n\n# ========== Device Enum ==========\n\n\nclass DeviceType(str, Enum):\n    \"\"\"Supported compute backends. Inherits from str so it serializes cleanly in JSON.\"\"\"\n\n    CUDA = \"cuda\"\n    MLX = \"mlx\"\n    CPU = \"cpu\"\n\n\n# ========== Global State (set once by detect_hardware) ==========\n\nDEVICE: Optional[DeviceType] = None\nCHAT_ONLY: bool = True  # No CUDA GPU -> GGUF chat only (Mac, CPU-only, etc.)\n\n\n# ========== Detection ==========\n\n\ndef is_apple_silicon() -> bool:\n    \"\"\"Check if running on Apple Silicon hardware (pure platform check, no ML imports).\"\"\"\n    return platform.system() == \"Darwin\" and platform.machine() == \"arm64\"\n\n\ndef _has_torch() -> bool:\n    \"\"\"Check if PyTorch is importable.\"\"\"\n    try:\n        import torch\n\n        return True\n    except ImportError:\n        return False\n\n\ndef _has_mlx() -> bool:\n    \"\"\"Check if MLX is importable.\"\"\"\n    try:\n        import mlx.core\n\n        return True\n    except ImportError:\n        return False\n\n\ndef detect_hardware() -> DeviceType:\n    \"\"\"\n    Detect the best available compute device and set the module-level DEVICE global.\n\n    Should be called exactly once during FastAPI lifespan startup.\n    Safe to call multiple times (idempotent).\n\n    Detection order:\n      1. CUDA  (NVIDIA GPU, requires torch)\n      2. MLX   (Apple Silicon via MLX framework)\n      3. CPU   (fallback)\n    \"\"\"\n    global DEVICE, CHAT_ONLY\n    CHAT_ONLY = True  # reset -- only CUDA sets it to False\n\n    # --- CUDA: try PyTorch ---\n    if _has_torch():\n        import torch\n\n        if torch.cuda.is_available():\n            DEVICE = DeviceType.CUDA\n            CHAT_ONLY = False\n            device_name = torch.cuda.get_device_properties(0).name\n            print(f\"Hardware detected: CUDA — {device_name}\")\n            return DEVICE\n\n    # --- MLX: Apple Silicon ---\n    if is_apple_silicon() and _has_mlx():\n        DEVICE = DeviceType.MLX\n        chip = platform.processor() or platform.machine()\n        print(f\"Hardware detected: MLX — Apple Silicon ({chip})\")\n        return DEVICE\n\n    # --- Fallback ---\n    DEVICE = DeviceType.CPU\n    print(\"Hardware detected: CPU (no GPU backend available)\")\n    return DEVICE\n\n\n# ========== Convenience helpers ==========\n\n\ndef get_device() -> DeviceType:\n    \"\"\"\n    Return the detected device. Auto-detects if detect_hardware() hasn't been called yet.\n    Prefer calling detect_hardware() explicitly at startup instead.\n    \"\"\"\n    global DEVICE\n    if DEVICE is None:\n        detect_hardware()\n    return DEVICE\n\n\ndef clear_gpu_cache():\n    \"\"\"\n    Clear GPU memory cache for the current device.\n    Safe to call on any platform — no-ops gracefully.\n    \"\"\"\n    import gc\n\n    gc.collect()\n\n    device = get_device()\n\n    if device == DeviceType.CUDA:\n        import torch\n\n        torch.cuda.synchronize()\n        torch.cuda.empty_cache()\n        torch.cuda.ipc_collect()\n    elif device == DeviceType.MLX:\n        # MLX manages memory automatically; no explicit cache clear needed.\n        # mlx.core has no empty_cache equivalent — gc.collect() above is enough.\n        pass\n\n\ndef get_gpu_memory_info() -> Dict[str, Any]:\n    \"\"\"\n    Get GPU memory information.\n    Supports CUDA (NVIDIA), MLX (Apple Silicon), and CPU-only environments.\n    \"\"\"\n    device = get_device()\n\n    # ---- CUDA path ----\n    if device == DeviceType.CUDA:\n        try:\n            import torch\n\n            idx = torch.cuda.current_device()\n            props = torch.cuda.get_device_properties(idx)\n\n            total = props.total_memory\n            allocated = torch.cuda.memory_allocated(idx)\n            reserved = torch.cuda.memory_reserved(idx)\n\n            return {\n                \"available\": True,\n                \"backend\": device.value,\n                \"device\": idx,\n                \"device_name\": props.name,\n                \"total_gb\": total / (1024**3),\n                \"allocated_gb\": allocated / (1024**3),\n                \"reserved_gb\": reserved / (1024**3),\n                \"free_gb\": (total - allocated) / (1024**3),\n                \"utilization_pct\": (allocated / total) * 100,\n            }\n        except Exception as e:\n            logger.error(f\"Error getting CUDA GPU info: {e}\")\n            return {\"available\": False, \"backend\": device.value, \"error\": str(e)}\n\n    # ---- MLX path (Apple Silicon) ----\n    if device == DeviceType.MLX:\n        try:\n            import mlx.core as mx\n            import psutil\n\n            # MLX uses unified memory — report system memory as the pool\n            total = psutil.virtual_memory().total\n            # MLX doesn't expose per-process GPU allocation; report 0 as allocated\n            allocated = 0\n\n            return {\n                \"available\": True,\n                \"backend\": device.value,\n                \"device\": 0,\n                \"device_name\": f\"Apple Silicon ({platform.processor() or platform.machine()})\",\n                \"total_gb\": total / (1024**3),\n                \"allocated_gb\": allocated / (1024**3),\n                \"reserved_gb\": 0,\n                \"free_gb\": (total - allocated) / (1024**3),\n                \"utilization_pct\": (allocated / total) * 100 if total else 0,\n            }\n        except Exception as e:\n            logger.error(f\"Error getting MLX GPU info: {e}\")\n            return {\"available\": False, \"backend\": device.value, \"error\": str(e)}\n\n    # ---- CPU-only ----\n    return {\"available\": False, \"backend\": \"cpu\"}\n\n\ndef log_gpu_memory(context: str):\n    \"\"\"Log GPU memory usage with context.\"\"\"\n    memory_info = get_gpu_memory_info()\n    if memory_info.get(\"available\"):\n        backend = memory_info.get(\"backend\", \"unknown\").upper()\n        device_name = memory_info.get(\"device_name\", \"\")\n        label = f\"{backend}\" + (f\" ({device_name})\" if device_name else \"\")\n        logger.info(\n            f\"GPU Memory [{context}] {label}: \"\n            f\"{memory_info['allocated_gb']:.2f}GB/{memory_info['total_gb']:.2f}GB \"\n            f\"({memory_info['utilization_pct']:.1f}% used, \"\n            f\"{memory_info['free_gb']:.2f}GB free)\"\n        )\n    else:\n        logger.info(f\"GPU Memory [{context}]: No GPU available (CPU-only)\")\n\n\n# ========== GPU Summary & Package Versions ==========\n\n\ndef get_gpu_summary() -> Dict[str, Any]:\n    \"\"\"\n    Return a compact summary of the primary GPU.\n\n    Returns dict with keys:\n        gpu_name      – e.g. \"NVIDIA L4\" (or None)\n        vram_total_gb – e.g. 22.17       (or None)\n    \"\"\"\n    mem = get_gpu_memory_info()\n    if mem.get(\"available\"):\n        return {\n            \"gpu_name\": mem.get(\"device_name\"),\n            \"vram_total_gb\": round(mem.get(\"total_gb\", 0), 2),\n            \"vram_free_gb\": round(mem.get(\"free_gb\", 0), 2),\n        }\n    return {\"gpu_name\": None, \"vram_total_gb\": None, \"vram_free_gb\": None}\n\n\ndef get_package_versions() -> Dict[str, Optional[str]]:\n    \"\"\"\n    Return the installed versions of key ML packages.\n\n    Uses importlib.metadata (stdlib) so no subprocess is needed.\n    CUDA version comes from torch.version.cuda.\n\n    Returns dict with keys: unsloth, torch, transformers, cuda.\n    Missing packages yield None.\n    \"\"\"\n    from importlib.metadata import version as pkg_version, PackageNotFoundError\n\n    packages = (\"unsloth\", \"torch\", \"transformers\")\n    versions: Dict[str, Optional[str]] = {}\n\n    for name in packages:\n        try:\n            versions[name] = pkg_version(name)\n        except PackageNotFoundError:\n            versions[name] = None\n\n    # CUDA toolkit version bundled with torch\n    try:\n        import torch\n\n        versions[\"cuda\"] = getattr(torch.version, \"cuda\", None)\n    except Exception:\n        versions[\"cuda\"] = None\n\n    return versions\n\n\n# ========== Live GPU Utilization (nvidia-smi) ==========\n\n\ndef get_gpu_utilization() -> Dict[str, Any]:\n    \"\"\"\n    Return a live snapshot of GPU utilization via ``nvidia-smi``.\n\n    Designed to be polled by the frontend during training (not streaming).\n    Uses ``nvidia-smi --query-gpu`` which is the most accurate source for\n    utilization %, temperature, and power draw – stats that PyTorch does\n    not expose.\n\n    Returns dict with keys:\n        available          – bool, whether stats could be retrieved\n        gpu_utilization_pct – GPU core utilization %\n        temperature_c      – GPU temperature in °C\n        vram_used_gb       – VRAM currently used (GiB)\n        vram_total_gb      – VRAM total (GiB)\n        vram_utilization_pct – VRAM used / total * 100\n        power_draw_w       – current power draw (W)\n        power_limit_w      – power limit (W)\n        power_utilization_pct – power draw / limit * 100\n    \"\"\"\n    device = get_device()\n\n    if device != DeviceType.CUDA:\n        return {\"available\": False, \"backend\": device.value}\n\n    def _parse_smi_value(raw: str):\n        \"\"\"Parse a single nvidia-smi CSV value. Returns float or None for [N/A].\"\"\"\n        raw = raw.strip()\n        if not raw or raw == \"[N/A]\":\n            return None\n        try:\n            return float(raw)\n        except (ValueError, TypeError):\n            return None\n\n    # ── nvidia-smi (most complete source) ───────────────────────\n    smi_data = {}\n    try:\n        import subprocess\n\n        result = subprocess.run(\n            [\n                \"nvidia-smi\",\n                \"--query-gpu=utilization.gpu,temperature.gpu,\"\n                \"memory.used,memory.total,power.draw,power.limit\",\n                \"--format=csv,noheader,nounits\",\n            ],\n            capture_output = True,\n            text = True,\n            timeout = 5,\n        )\n\n        if result.returncode == 0 and result.stdout.strip():\n            # nvidia-smi outputs one line per GPU; take GPU 0\n            first_line = result.stdout.strip().splitlines()[0]\n            parts = [p.strip() for p in first_line.split(\",\")]\n            if len(parts) >= 6:\n                smi_data = {\n                    \"gpu_util\": _parse_smi_value(parts[0]),\n                    \"temp\": _parse_smi_value(parts[1]),\n                    \"vram_used_mb\": _parse_smi_value(parts[2]),\n                    \"vram_total_mb\": _parse_smi_value(parts[3]),\n                    \"power_draw\": _parse_smi_value(parts[4]),\n                    \"power_limit\": _parse_smi_value(parts[5]),\n                }\n\n    except FileNotFoundError:\n        logger.debug(\"nvidia-smi not found, falling back to torch.cuda\")\n    except Exception as e:\n        logger.warning(f\"nvidia-smi query failed: {e}\")\n\n    # ── Backfill VRAM from torch.cuda if nvidia-smi returned [N/A] ──\n    vram_used_mb = smi_data.get(\"vram_used_mb\")\n    vram_total_mb = smi_data.get(\"vram_total_mb\")\n\n    if vram_used_mb is None or vram_total_mb is None:\n        try:\n            import torch\n\n            idx = torch.cuda.current_device()\n            props = torch.cuda.get_device_properties(idx)\n            if vram_total_mb is None:\n                vram_total_mb = props.total_memory / (1024**2)  # bytes → MiB\n            if vram_used_mb is None:\n                vram_used_mb = torch.cuda.memory_allocated(idx) / (1024**2)\n        except Exception as e:\n            logger.debug(f\"torch.cuda VRAM backfill failed: {e}\")\n\n    # ── Build response ──────────────────────────────────────────\n    gpu_util = smi_data.get(\"gpu_util\")\n    temp = smi_data.get(\"temp\")\n    power_draw = smi_data.get(\"power_draw\")\n    power_limit = smi_data.get(\"power_limit\")\n\n    vram_used_gb = round(vram_used_mb / 1024, 2) if vram_used_mb is not None else None\n    vram_total_gb = (\n        round(vram_total_mb / 1024, 2) if vram_total_mb is not None else None\n    )\n    vram_pct = (\n        round((vram_used_mb / vram_total_mb) * 100, 1)\n        if vram_used_mb is not None and vram_total_mb and vram_total_mb > 0\n        else None\n    )\n    power_pct = (\n        round((power_draw / power_limit) * 100, 1)\n        if power_draw is not None and power_limit and power_limit > 0\n        else None\n    )\n\n    # If we got at least something useful, report available\n    has_any = any(v is not None for v in [gpu_util, temp, vram_used_gb, power_draw])\n    if not has_any:\n        return {\"available\": False, \"backend\": device.value}\n\n    return {\n        \"available\": True,\n        \"backend\": device.value,\n        \"gpu_utilization_pct\": gpu_util,\n        \"temperature_c\": temp,\n        \"vram_used_gb\": vram_used_gb,\n        \"vram_total_gb\": vram_total_gb,\n        \"vram_utilization_pct\": vram_pct,\n        \"power_draw_w\": power_draw,\n        \"power_limit_w\": power_limit,\n        \"power_utilization_pct\": power_pct,\n    }\n\n\n# ========== Multi-GPU Detection & Safe num_proc ==========\n\n_physical_gpu_count: Optional[int] = None\n_visible_gpu_count: Optional[int] = None\n\n\ndef get_physical_gpu_count() -> int:\n    \"\"\"\n    Return the number of physical NVIDIA GPUs on the machine.\n\n    Uses ``nvidia-smi -L`` which is NOT affected by CUDA_VISIBLE_DEVICES,\n    so it always reflects the true hardware count.\n    Result is cached after the first call.\n    \"\"\"\n    global _physical_gpu_count\n    if _physical_gpu_count is not None:\n        return _physical_gpu_count\n\n    try:\n        import subprocess\n\n        result = subprocess.run(\n            [\"nvidia-smi\", \"-L\"],\n            capture_output = True,\n            text = True,\n            timeout = 5,\n        )\n        if result.returncode == 0 and result.stdout.strip():\n            _physical_gpu_count = len(result.stdout.strip().splitlines())\n        else:\n            _physical_gpu_count = 1\n    except Exception:\n        _physical_gpu_count = 1\n\n    return _physical_gpu_count\n\n\ndef get_visible_gpu_count() -> int:\n    \"\"\"\n    Return the number of GPUs visible to this process.\n\n    Respects ``CUDA_VISIBLE_DEVICES`` -- if set, only those GPUs count.\n    Falls back to physical count if the env var is unset or torch is\n    unavailable.  Result is cached after the first call.\n    \"\"\"\n    global _visible_gpu_count\n    if _visible_gpu_count is not None:\n        return _visible_gpu_count\n\n    import os\n\n    cuda_visible = os.environ.get(\"CUDA_VISIBLE_DEVICES\")\n    if cuda_visible is not None:\n        # \"\" means zero GPUs, \"0\" means 1, \"0,1,2\" means 3\n        cuda_visible = cuda_visible.strip()\n        if cuda_visible == \"\" or cuda_visible == \"-1\":\n            _visible_gpu_count = 0\n        else:\n            _visible_gpu_count = len([x for x in cuda_visible.split(\",\") if x.strip()])\n        return _visible_gpu_count\n\n    # CUDA_VISIBLE_DEVICES not set -- try torch, fall back to physical count\n    try:\n        import torch\n\n        _visible_gpu_count = torch.cuda.device_count()\n    except Exception:\n        _visible_gpu_count = get_physical_gpu_count()\n\n    return _visible_gpu_count\n\n\ndef safe_num_proc(desired: Optional[int] = None) -> int:\n    \"\"\"\n    Return a safe ``num_proc`` for ``dataset.map()`` calls.\n\n    On Windows, always returns 1 because Python uses ``spawn`` instead of\n    ``fork`` for multiprocessing -- the overhead of re-importing torch,\n    transformers, unsloth etc. per worker is typically slower than\n    single-process for normal dataset sizes.\n\n    On multi-GPU machines (where multiple GPUs are *visible* to this\n    process) the NVIDIA driver spawns extra background threads, making\n    ``os.fork()`` prone to deadlocks when many workers are created.\n    This helper caps ``num_proc`` to 4 on such machines.\n\n    When ``CUDA_VISIBLE_DEVICES`` restricts to a single GPU, the cap\n    does not apply.\n\n    Args:\n        desired: The num_proc you *want*. If None, auto-computes from\n                 ``os.cpu_count()``.\n\n    Returns:\n        A safe integer ≥ 1.\n    \"\"\"\n    import os\n    import sys\n\n    # Windows uses 'spawn' for multiprocessing -- the overhead of re-importing\n    # torch/transformers/unsloth per worker is typically slower than single-process.\n    if sys.platform == \"win32\":\n        return 1\n\n    if desired is None or not isinstance(desired, int):\n        desired = max(1, os.cpu_count() // 3)\n\n    visible = get_visible_gpu_count()\n    if visible > 1:\n        capped = min(4, desired)\n        logger.info(\n            f\"Multi-GPU detected ({visible} visible GPUs) -- \"\n            f\"capping num_proc {desired} -> {capped} to avoid fork deadlocks\"\n        )\n        return capped\n\n    return desired\n"
  },
  {
    "path": "studio/backend/utils/inference/__init__.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nInference utility functions\n\"\"\"\n\nfrom utils.inference.inference_config import load_inference_config\n\n__all__ = [\"load_inference_config\"]\n"
  },
  {
    "path": "studio/backend/utils/inference/inference_config.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nInference configuration loading utilities.\n\nThis module provides functions to load inference parameters (temperature, top_p, top_k, min_p)\nfrom model YAML configuration files, with fallback to default.yaml.\nIncludes family-based lookup from inference_defaults.json for GGUF models.\n\"\"\"\n\nfrom pathlib import Path\nfrom typing import Dict, Any, Optional\nimport json\nimport yaml\nimport structlog\nfrom loggers import get_logger\n\nfrom utils.models.model_config import load_model_defaults\n\nlogger = get_logger(__name__)\n\n# ── Family-based inference defaults (loaded once, cached) ──────────────\n\n_FAMILY_DEFAULTS: Optional[Dict[str, Any]] = None\n_FAMILY_PATTERNS: Optional[list] = None\n\n\ndef _load_family_defaults():\n    \"\"\"Load and cache inference_defaults.json.\"\"\"\n    global _FAMILY_DEFAULTS, _FAMILY_PATTERNS\n    if _FAMILY_DEFAULTS is not None:\n        return\n\n    json_path = (\n        Path(__file__).parent.parent.parent\n        / \"assets\"\n        / \"configs\"\n        / \"inference_defaults.json\"\n    )\n    try:\n        with open(json_path, \"r\", encoding = \"utf-8\") as f:\n            data = json.load(f)\n        _FAMILY_DEFAULTS = data.get(\"families\", {})\n        _FAMILY_PATTERNS = data.get(\"patterns\", [])\n    except Exception as e:\n        logger.warning(f\"Failed to load inference_defaults.json: {e}\")\n        _FAMILY_DEFAULTS = {}\n        _FAMILY_PATTERNS = []\n\n\ndef get_family_inference_params(model_id: str) -> Dict[str, Any]:\n    \"\"\"\n    Look up recommended inference parameters by model family.\n\n    Extracts the model family from the identifier (e.g. \"unsloth/Qwen3.5-9B-GGUF\" -> \"qwen3.5\")\n    and returns the matching parameters from inference_defaults.json.\n\n    Args:\n        model_id: Model identifier (e.g. \"unsloth/Qwen3.5-9B-GGUF\")\n\n    Returns:\n        Dict with inference params, or empty dict if no family match.\n    \"\"\"\n    _load_family_defaults()\n\n    if not _FAMILY_PATTERNS or not _FAMILY_DEFAULTS:\n        return {}\n\n    # Normalize: lowercase, strip org prefix\n    normalized = model_id.lower()\n    if \"/\" in normalized:\n        normalized = normalized.split(\"/\", 1)[1]\n\n    # Match against patterns (ordered longest-match-first in the JSON)\n    for pattern in _FAMILY_PATTERNS:\n        if pattern in normalized:\n            params = _FAMILY_DEFAULTS.get(pattern, {})\n            if params:\n                return dict(params)\n\n    return {}\n\n\ndef _has_specific_yaml(model_identifier: str) -> bool:\n    \"\"\"Check if a model has its own YAML config (not just default.yaml).\"\"\"\n    from utils.models.model_config import _REVERSE_MODEL_MAPPING\n\n    script_dir = Path(__file__).parent.parent.parent\n    defaults_dir = script_dir / \"assets\" / \"configs\" / \"model_defaults\"\n\n    # Check the mapping\n    if model_identifier.lower() in _REVERSE_MODEL_MAPPING:\n        return True\n\n    # Check for exact filename match\n    model_filename = model_identifier.replace(\"/\", \"_\") + \".yaml\"\n    for config_path in defaults_dir.rglob(model_filename):\n        if config_path.is_file():\n            return True\n\n    return False\n\n\ndef load_inference_config(model_identifier: str) -> Dict[str, Any]:\n    \"\"\"\n    Load inference configuration parameters for a model.\n\n    Priority chain:\n    1. Model-specific YAML (if it exists and has inference params)\n    2. Family-based defaults from inference_defaults.json\n    3. default.yaml fallback\n\n    Args:\n        model_identifier: Model identifier (e.g., \"unsloth/llama-3-8b-bnb-4bit\")\n\n    Returns:\n        Dictionary containing inference parameters:\n        {\n            \"temperature\": float,\n            \"top_p\": float,\n            \"top_k\": int,\n            \"min_p\": float\n        }\n    \"\"\"\n    # Load model defaults to get inference parameters\n    model_defaults = load_model_defaults(model_identifier)\n\n    # Load default.yaml for fallback values\n    script_dir = Path(__file__).parent.parent.parent\n    defaults_dir = script_dir / \"assets\" / \"configs\" / \"model_defaults\"\n    default_config_path = defaults_dir / \"default.yaml\"\n\n    default_inference = {}\n    if default_config_path.exists():\n        try:\n            with open(default_config_path, \"r\", encoding = \"utf-8\") as f:\n                default_config = yaml.safe_load(f) or {}\n                default_inference = default_config.get(\"inference\", {})\n        except Exception as e:\n            logger.warning(f\"Failed to load default.yaml: {e}\")\n\n    # Family-based defaults from inference_defaults.json\n    family_params = get_family_inference_params(model_identifier)\n\n    model_inference = model_defaults.get(\"inference\", {})\n\n    # If the model has its own YAML config, those values take priority over family defaults.\n    # If it only fell back to default.yaml, family defaults take priority.\n    has_own_yaml = _has_specific_yaml(model_identifier)\n\n    def _get_param(key, hardcoded_default):\n        if has_own_yaml:\n            # Model-specific YAML wins, then family fills gaps, then default.yaml\n            val = model_inference.get(key)\n            if val is not None and isinstance(val, (int, float)):\n                return val\n            if key in family_params:\n                return family_params[key]\n            return default_inference.get(key, hardcoded_default)\n        else:\n            # No model-specific YAML: family wins, then default.yaml\n            if key in family_params:\n                return family_params[key]\n            return default_inference.get(key, hardcoded_default)\n\n    inference_config = {\n        \"temperature\": _get_param(\"temperature\", 0.7),\n        \"top_p\": _get_param(\"top_p\", 0.95),\n        \"top_k\": _get_param(\"top_k\", -1),\n        \"min_p\": _get_param(\"min_p\", 0.01),\n        \"presence_penalty\": _get_param(\"presence_penalty\", 0.0),\n        \"trust_remote_code\": model_inference.get(\n            \"trust_remote_code\", default_inference.get(\"trust_remote_code\", False)\n        ),\n    }\n\n    return inference_config\n"
  },
  {
    "path": "studio/backend/utils/models/__init__.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nModel and LoRA configuration handling\n\"\"\"\n\nfrom .model_config import (\n    ModelConfig,\n    GgufVariantInfo,\n    is_vision_model,\n    is_embedding_model,\n    detect_audio_type,\n    is_audio_input_type,\n    VALID_AUDIO_TYPES,\n    scan_trained_loras,\n    scan_exported_models,\n    load_model_defaults,\n    get_base_model_from_lora,\n    load_model_config,\n    list_gguf_variants,\n    MODEL_NAME_MAPPING,\n    UI_STATUS_INDICATORS,\n)\nfrom .checkpoints import scan_checkpoints\n\n__all__ = [\n    \"ModelConfig\",\n    \"GgufVariantInfo\",\n    \"is_vision_model\",\n    \"is_embedding_model\",\n    \"detect_audio_type\",\n    \"is_audio_input_type\",\n    \"VALID_AUDIO_TYPES\",\n    \"scan_trained_loras\",\n    \"scan_exported_models\",\n    \"load_model_defaults\",\n    \"get_base_model_from_lora\",\n    \"load_model_config\",\n    \"list_gguf_variants\",\n    \"MODEL_NAME_MAPPING\",\n    \"UI_STATUS_INDICATORS\",\n    \"scan_checkpoints\",\n]\n"
  },
  {
    "path": "studio/backend/utils/models/checkpoints.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nCheckpoint scanning utilities for discovering training runs and their checkpoints.\n\"\"\"\n\nimport json\nimport structlog\nfrom loggers import get_logger\nfrom pathlib import Path\nfrom typing import List, Optional, Tuple\nfrom utils.paths import outputs_root, resolve_output_dir\n\nlogger = get_logger(__name__)\n\n\ndef _read_checkpoint_loss(checkpoint_path: Path) -> Optional[float]:\n    \"\"\"\n    Read the training loss from a checkpoint's trainer_state.json.\n\n    Returns the loss from the last log_history entry, or None if unavailable.\n    \"\"\"\n    trainer_state = checkpoint_path / \"trainer_state.json\"\n    if not trainer_state.exists():\n        return None\n    try:\n        with open(trainer_state) as f:\n            state = json.load(f)\n        log_history = state.get(\"log_history\", [])\n        if log_history:\n            return log_history[-1].get(\"loss\")\n    except Exception as e:\n        logger.debug(f\"Could not read loss from {trainer_state}: {e}\")\n    return None\n\n\ndef scan_checkpoints(\n    outputs_dir: str = str(outputs_root()),\n) -> List[Tuple[str, List[Tuple[str, str, Optional[float]]], dict]]:\n    \"\"\"\n    Scan outputs folder for training runs and their checkpoints.\n\n    Returns:\n        List of tuples: [(model_name, [(display_name, checkpoint_path, loss), ...], metadata), ...]\n        metadata keys: base_model, peft_type, lora_rank (all optional)\n        The first entry in each checkpoint list is the main adapter; its loss is\n        set to the loss of the last (highest-step) intermediate checkpoint.\n    \"\"\"\n    models = []\n    outputs_path = resolve_output_dir(outputs_dir)\n\n    if not outputs_path.exists():\n        logger.warning(f\"Outputs directory not found: {outputs_dir}\")\n        return models\n\n    try:\n        for item in outputs_path.iterdir():\n            if not item.is_dir():\n                continue\n\n            config_file = item / \"config.json\"\n            adapter_config = item / \"adapter_config.json\"\n\n            if not (config_file.exists() or adapter_config.exists()):\n                continue\n\n            # Extract training metadata from adapter_config.json / config.json\n            metadata: dict = {}\n            try:\n                if adapter_config.exists():\n                    cfg = json.loads(adapter_config.read_text())\n                    metadata[\"base_model\"] = cfg.get(\"base_model_name_or_path\")\n                    metadata[\"peft_type\"] = cfg.get(\"peft_type\")\n                    metadata[\"lora_rank\"] = cfg.get(\"r\")\n                elif config_file.exists():\n                    cfg = json.loads(config_file.read_text())\n                    metadata[\"base_model\"] = cfg.get(\"_name_or_path\")\n            except Exception:\n                pass\n\n            # Fallback: extract base model name from folder name\n            # e.g. \"unsloth_Llama-3.2-3B-Instruct_1771227800\" → \"unsloth/Llama-3.2-3B-Instruct\"\n            if not metadata.get(\"base_model\"):\n                parts = item.name.rsplit(\"_\", 1)\n                if len(parts) == 2 and parts[1].isdigit():\n                    name_part = parts[0]\n                    idx = name_part.find(\"_\")\n                    if idx > 0:\n                        metadata[\"base_model\"] = (\n                            name_part[:idx] + \"/\" + name_part[idx + 1 :]\n                        )\n                    else:\n                        metadata[\"base_model\"] = name_part\n\n            # This is a valid training run\n            checkpoints = []\n\n            # Placeholder for the main adapter — loss filled from last checkpoint below\n            checkpoints.append((item.name, str(item), None))\n\n            # Scan for intermediate checkpoints (checkpoint-N subdirs)\n            for sub in sorted(item.iterdir()):\n                if not sub.is_dir() or not sub.name.startswith(\"checkpoint-\"):\n                    continue\n                sub_config = sub / \"config.json\"\n                sub_adapter = sub / \"adapter_config.json\"\n                if sub_config.exists() or sub_adapter.exists():\n                    loss = _read_checkpoint_loss(sub)\n                    checkpoints.append((sub.name, str(sub), loss))\n\n            # Assign the last checkpoint's loss to the main adapter entry\n            if len(checkpoints) > 1:\n                last_checkpoint_loss = checkpoints[-1][2]\n                checkpoints[0] = (\n                    checkpoints[0][0],\n                    checkpoints[0][1],\n                    last_checkpoint_loss,\n                )\n\n            models.append((item.name, checkpoints, metadata))\n            logger.debug(\n                f\"Found model: {item.name} with {len(checkpoints)} checkpoint(s)\"\n            )\n\n        # Sort by modification time (newest first)\n        models.sort(key = lambda x: Path(x[1][0][1]).stat().st_mtime, reverse = True)\n\n        logger.info(f\"Found {len(models)} training runs in {outputs_dir}\")\n        return models\n\n    except Exception as e:\n        logger.error(f\"Error scanning checkpoints: {e}\")\n        return []\n"
  },
  {
    "path": "studio/backend/utils/models/model_config.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nModel and LoRA configuration handling\n\"\"\"\n\nfrom transformers import AutoConfig\nfrom dataclasses import dataclass\nfrom typing import Optional, Dict, Any\nfrom utils.paths import (\n    normalize_path,\n    is_local_path,\n    is_model_cached,\n    outputs_root,\n    exports_root,\n    resolve_output_dir,\n    resolve_export_dir,\n)\nfrom utils.utils import without_hf_auth\nimport structlog\nfrom loggers import get_logger\nimport os\nimport subprocess\nimport sys\nfrom pathlib import Path\nfrom typing import List, Tuple\nimport json\nimport yaml\n\n\nlogger = get_logger(__name__)\n\n# Model name mapping: maps all equivalent model names to their canonical YAML config file\n# Format: \"canonical_model_name.yaml\": [list of all equivalent model names]\n# Based on the model mapper provided - canonical filename is based on the first model name in the mapper\nMODEL_NAME_MAPPING = {\n    # ── Embedding models ──\n    \"unsloth_all-MiniLM-L6-v2.yaml\": [\n        \"unsloth/all-MiniLM-L6-v2\",\n        \"sentence-transformers/all-MiniLM-L6-v2\",\n    ],\n    \"unsloth_bge-m3.yaml\": [\n        \"unsloth/bge-m3\",\n        \"BAAI/bge-m3\",\n    ],\n    \"unsloth_embeddinggemma-300m.yaml\": [\n        \"unsloth/embeddinggemma-300m\",\n        \"google/embeddinggemma-300m\",\n    ],\n    \"unsloth_gte-modernbert-base.yaml\": [\n        \"unsloth/gte-modernbert-base\",\n        \"Alibaba-NLP/gte-modernbert-base\",\n    ],\n    \"unsloth_Qwen3-Embedding-0.6B.yaml\": [\n        \"unsloth/Qwen3-Embedding-0.6B\",\n        \"Qwen/Qwen3-Embedding-0.6B\",\n        \"unsloth/Qwen3-Embedding-4B\",\n        \"Qwen/Qwen3-Embedding-4B\",\n    ],\n    # ── Other models ──\n    \"unsloth_answerdotai_ModernBERT-large.yaml\": [\n        \"answerdotai/ModernBERT-large\",\n    ],\n    \"unsloth_Qwen2.5-Coder-7B-Instruct-bnb-4bit.yaml\": [\n        \"unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit\",\n        \"unsloth/Qwen2.5-Coder-7B-Instruct\",\n        \"Qwen/Qwen2.5-Coder-7B-Instruct\",\n    ],\n    \"unsloth_codegemma-7b-bnb-4bit.yaml\": [\n        \"unsloth/codegemma-7b-bnb-4bit\",\n        \"unsloth/codegemma-7b\",\n        \"google/codegemma-7b\",\n    ],\n    \"unsloth_ERNIE-4.5-21B-A3B-PT.yaml\": [\n        \"unsloth/ERNIE-4.5-21B-A3B-PT\",\n    ],\n    \"unsloth_ERNIE-4.5-VL-28B-A3B-PT.yaml\": [\n        \"unsloth/ERNIE-4.5-VL-28B-A3B-PT\",\n    ],\n    \"tiiuae_Falcon-H1-0.5B-Instruct.yaml\": [\n        \"tiiuae/Falcon-H1-0.5B-Instruct\",\n        \"unsloth/Falcon-H1-0.5B-Instruct\",\n    ],\n    \"unsloth_functiongemma-270m-it.yaml\": [\n        \"unsloth/functiongemma-270m-it-unsloth-bnb-4bit\",\n        \"google/functiongemma-270m-it\",\n        \"unsloth/functiongemma-270m-it-unsloth-bnb-4bit\",\n    ],\n    \"unsloth_gemma-2-2b.yaml\": [\n        \"unsloth/gemma-2-2b-bnb-4bit\",\n        \"google/gemma-2-2b\",\n    ],\n    \"unsloth_gemma-2-27b-bnb-4bit.yaml\": [\n        \"unsloth/gemma-2-9b-bnb-4bit\",\n        \"unsloth/gemma-2-9b\",\n        \"google/gemma-2-9b\",\n        \"unsloth/gemma-2-27b\",\n        \"google/gemma-2-27b\",\n    ],\n    \"unsloth_gemma-3-4b-pt.yaml\": [\n        \"unsloth/gemma-3-4b-pt-unsloth-bnb-4bit\",\n        \"google/gemma-3-4b-pt\",\n        \"unsloth/gemma-3-4b-pt-bnb-4bit\",\n    ],\n    \"unsloth_gemma-3-4b-it.yaml\": [\n        \"unsloth/gemma-3-4b-it-unsloth-bnb-4bit\",\n        \"google/gemma-3-4b-it\",\n        \"unsloth/gemma-3-4b-it-bnb-4bit\",\n    ],\n    \"unsloth_gemma-3-27b-it.yaml\": [\n        \"unsloth/gemma-3-27b-it-unsloth-bnb-4bit\",\n        \"google/gemma-3-27b-it\",\n        \"unsloth/gemma-3-27b-it-bnb-4bit\",\n    ],\n    \"unsloth_gemma-3-270m-it.yaml\": [\n        \"unsloth/gemma-3-270m-it-unsloth-bnb-4bit\",\n        \"google/gemma-3-270m-it\",\n        \"unsloth/gemma-3-270m-it-bnb-4bit\",\n    ],\n    \"unsloth_gemma-3n-E4B-it.yaml\": [\n        \"unsloth/gemma-3n-E4B-it-unsloth-bnb-4bit\",\n        \"google/gemma-3n-E4B-it\",\n        \"unsloth/gemma-3n-E4B-it-unsloth-bnb-4bit\",\n    ],\n    \"unsloth_gemma-3n-E4B.yaml\": [\n        \"unsloth/gemma-3n-E4B-unsloth-bnb-4bit\",\n        \"google/gemma-3n-E4B\",\n    ],\n    \"unsloth_gpt-oss-20b.yaml\": [\n        \"openai/gpt-oss-20b\",\n        \"unsloth/gpt-oss-20b-unsloth-bnb-4bit\",\n        \"unsloth/gpt-oss-20b-BF16\",\n    ],\n    \"unsloth_gpt-oss-120b.yaml\": [\n        \"openai/gpt-oss-120b\",\n        \"unsloth/gpt-oss-120b-unsloth-bnb-4bit\",\n    ],\n    \"unsloth_granite-4.0-350m-unsloth-bnb-4bit.yaml\": [\n        \"unsloth/granite-4.0-350m\",\n        \"ibm-granite/granite-4.0-350m\",\n        \"unsloth/granite-4.0-350m-bnb-4bit\",\n    ],\n    \"unsloth_granite-4.0-h-micro.yaml\": [\n        \"ibm-granite/granite-4.0-h-micro\",\n        \"unsloth/granite-4.0-h-micro-bnb-4bit\",\n        \"unsloth/granite-4.0-h-micro-unsloth-bnb-4bit\",\n    ],\n    \"unsloth_LFM2-1.2B.yaml\": [\n        \"unsloth/LFM2-1.2B\",\n    ],\n    \"unsloth_llama-3-8b-bnb-4bit.yaml\": [\n        \"unsloth/llama-3-8b\",\n        \"meta-llama/Meta-Llama-3-8B\",\n    ],\n    \"unsloth_llama-3-8b-Instruct-bnb-4bit.yaml\": [\n        \"unsloth/llama-3-8b-Instruct\",\n        \"meta-llama/Meta-Llama-3-8B-Instruct\",\n    ],\n    \"unsloth_Meta-Llama-3.1-70B-bnb-4bit.yaml\": [\n        \"unsloth/Meta-Llama-3.1-8B-bnb-4bit\",\n        \"unsloth/Meta-Llama-3.1-8B-unsloth-bnb-4bit\",\n        \"meta-llama/Meta-Llama-3.1-8B\",\n        \"unsloth/Meta-Llama-3.1-70B-bnb-4bit\",\n        \"unsloth/Meta-Llama-3.1-8B\",\n        \"unsloth/Meta-Llama-3.1-70B\",\n        \"meta-llama/Meta-Llama-3.1-70B\",\n        \"unsloth/Meta-Llama-3.1-405B-bnb-4bit\",\n        \"meta-llama/Meta-Llama-3.1-405B\",\n    ],\n    \"unsloth_Meta-Llama-3.1-8B-Instruct-bnb-4bit.yaml\": [\n        \"unsloth/Meta-Llama-3.1-8B-Instruct-unsloth-bnb-4bit\",\n        \"unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit\",\n        \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n        \"unsloth/Meta-Llama-3.1-8B-Instruct\",\n        \"RedHatAI/Llama-3.1-8B-Instruct-FP8\",\n        \"unsloth/Llama-3.1-8B-Instruct-FP8-Block\",\n        \"unsloth/Llama-3.1-8B-Instruct-FP8-Dynamic\",\n    ],\n    \"unsloth_Llama-3.2-3B-Instruct.yaml\": [\n        \"unsloth/Llama-3.2-3B-Instruct-unsloth-bnb-4bit\",\n        \"meta-llama/Llama-3.2-3B-Instruct\",\n        \"unsloth/Llama-3.2-3B-Instruct-bnb-4bit\",\n        \"RedHatAI/Llama-3.2-3B-Instruct-FP8\",\n        \"unsloth/Llama-3.2-3B-Instruct-FP8-Block\",\n        \"unsloth/Llama-3.2-3B-Instruct-FP8-Dynamic\",\n    ],\n    \"unsloth_Llama-3.2-1B-Instruct.yaml\": [\n        \"unsloth/Llama-3.2-1B-Instruct-unsloth-bnb-4bit\",\n        \"meta-llama/Llama-3.2-1B-Instruct\",\n        \"unsloth/Llama-3.2-1B-Instruct-bnb-4bit\",\n        \"RedHatAI/Llama-3.2-1B-Instruct-FP8\",\n        \"unsloth/Llama-3.2-1B-Instruct-FP8-Block\",\n        \"unsloth/Llama-3.2-1B-Instruct-FP8-Dynamic\",\n    ],\n    \"unsloth_Llama-3.2-11B-Vision-Instruct.yaml\": [\n        \"unsloth/Llama-3.2-11B-Vision-Instruct-unsloth-bnb-4bit\",\n        \"meta-llama/Llama-3.2-11B-Vision-Instruct\",\n        \"unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit\",\n    ],\n    \"unsloth_Llama-3.3-70B-Instruct.yaml\": [\n        \"unsloth/Llama-3.3-70B-Instruct-unsloth-bnb-4bit\",\n        \"meta-llama/Llama-3.3-70B-Instruct\",\n        \"unsloth/Llama-3.3-70B-Instruct-bnb-4bit\",\n        \"RedHatAI/Llama-3.3-70B-Instruct-FP8\",\n        \"unsloth/Llama-3.3-70B-Instruct-FP8-Block\",\n        \"unsloth/Llama-3.3-70B-Instruct-FP8-Dynamic\",\n    ],\n    \"unsloth_Llasa-3B.yaml\": [\n        \"HKUSTAudio/Llasa-1B\",\n        \"unsloth/Llasa-3B\",\n    ],\n    \"unsloth_Magistral-Small-2509-unsloth-bnb-4bit.yaml\": [\n        \"unsloth/Magistral-Small-2509\",\n        \"mistralai/Magistral-Small-2509\",\n        \"unsloth/Magistral-Small-2509-bnb-4bit\",\n    ],\n    \"unsloth_Ministral-3-3B-Instruct-2512.yaml\": [\n        \"unsloth/Ministral-3-3B-Instruct-2512\",\n    ],\n    \"unsloth_mistral-7b-v0.3-bnb-4bit.yaml\": [\n        \"unsloth/mistral-7b-v0.3-bnb-4bit\",\n        \"unsloth/mistral-7b-v0.3\",\n        \"mistralai/Mistral-7B-v0.3\",\n    ],\n    \"unsloth_Mistral-Nemo-Base-2407-bnb-4bit.yaml\": [\n        \"unsloth/Mistral-Nemo-Base-2407-bnb-4bit\",\n        \"unsloth/Mistral-Nemo-Base-2407\",\n        \"mistralai/Mistral-Nemo-Base-2407\",\n        \"unsloth/Mistral-Nemo-Instruct-2407-bnb-4bit\",\n        \"unsloth/Mistral-Nemo-Instruct-2407\",\n        \"mistralai/Mistral-Nemo-Instruct-2407\",\n    ],\n    \"unsloth_Mistral-Small-Instruct-2409.yaml\": [\n        \"unsloth/Mistral-Small-Instruct-2409-bnb-4bit\",\n        \"mistralai/Mistral-Small-Instruct-2409\",\n    ],\n    \"unsloth_mistral-7b-instruct-v0.3-bnb-4bit.yaml\": [\n        \"unsloth/mistral-7b-instruct-v0.3-bnb-4bit\",\n        \"unsloth/mistral-7b-instruct-v0.3\",\n        \"mistralai/Mistral-7B-Instruct-v0.3\",\n    ],\n    \"unsloth_Qwen2.5-1.5B-Instruct.yaml\": [\n        \"unsloth/Qwen2.5-1.5B-Instruct-unsloth-bnb-4bit\",\n        \"Qwen/Qwen2.5-1.5B-Instruct\",\n        \"unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit\",\n    ],\n    \"unsloth_Nemotron-3-Nano-30B-A3B.yaml\": [\n        \"unsloth/Nemotron-3-Nano-30B-A3B\",\n    ],\n    \"unsloth_orpheus-3b-0.1-ft.yaml\": [\n        \"unsloth/orpheus-3b-0.1-ft\",\n        \"unsloth/orpheus-3b-0.1-ft-unsloth-bnb-4bit\",\n        \"canopylabs/orpheus-3b-0.1-ft\",\n        \"unsloth/orpheus-3b-0.1-ft-bnb-4bit\",\n    ],\n    \"OuteAI_Llama-OuteTTS-1.0-1B.yaml\": [\n        \"OuteAI/Llama-OuteTTS-1.0-1B\",\n        \"unsloth/Llama-OuteTTS-1.0-1B\",\n        \"unsloth/llama-outetts-1.0-1b\",\n        \"OuteAI/OuteTTS-1.0-0.6B\",\n        \"unsloth/OuteTTS-1.0-0.6B\",\n        \"unsloth/outetts-1.0-0.6b\",\n    ],\n    \"unsloth_PaddleOCR-VL.yaml\": [\n        \"unsloth/PaddleOCR-VL\",\n    ],\n    \"unsloth_Phi-3-medium-4k-instruct.yaml\": [\n        \"unsloth/Phi-3-medium-4k-instruct-bnb-4bit\",\n        \"microsoft/Phi-3-medium-4k-instruct\",\n    ],\n    \"unsloth_Phi-3.5-mini-instruct.yaml\": [\n        \"unsloth/Phi-3.5-mini-instruct-bnb-4bit\",\n        \"microsoft/Phi-3.5-mini-instruct\",\n    ],\n    \"unsloth_Phi-4.yaml\": [\n        \"unsloth/phi-4-unsloth-bnb-4bit\",\n        \"microsoft/phi-4\",\n        \"unsloth/phi-4-bnb-4bit\",\n    ],\n    \"unsloth_Pixtral-12B-2409.yaml\": [\n        \"unsloth/Pixtral-12B-2409-unsloth-bnb-4bit\",\n        \"mistralai/Pixtral-12B-2409\",\n        \"unsloth/Pixtral-12B-2409-bnb-4bit\",\n    ],\n    \"unsloth_Qwen2-7B.yaml\": [\n        \"unsloth/Qwen2-7B-bnb-4bit\",\n        \"Qwen/Qwen2-7B\",\n    ],\n    \"unsloth_Qwen2-VL-7B-Instruct.yaml\": [\n        \"unsloth/Qwen2-VL-7B-Instruct-unsloth-bnb-4bit\",\n        \"Qwen/Qwen2-VL-7B-Instruct\",\n        \"unsloth/Qwen2-VL-7B-Instruct-bnb-4bit\",\n    ],\n    \"unsloth_Qwen2.5-7B.yaml\": [\n        \"unsloth/Qwen2.5-7B-unsloth-bnb-4bit\",\n        \"Qwen/Qwen2.5-7B\",\n        \"unsloth/Qwen2.5-7B-bnb-4bit\",\n    ],\n    \"unsloth_Qwen2.5-Coder-1.5B-Instruct.yaml\": [\n        \"unsloth/Qwen2.5-Coder-1.5B-Instruct-bnb-4bit\",\n        \"Qwen/Qwen2.5-Coder-1.5B-Instruct\",\n    ],\n    \"unsloth_Qwen2.5-Coder-14B-Instruct.yaml\": [\n        \"unsloth/Qwen2.5-Coder-14B-Instruct-bnb-4bit\",\n        \"Qwen/Qwen2.5-Coder-14B-Instruct\",\n    ],\n    \"unsloth_Qwen2.5-VL-7B-Instruct-bnb-4bit.yaml\": [\n        \"unsloth/Qwen2.5-VL-7B-Instruct\",\n        \"Qwen/Qwen2.5-VL-7B-Instruct\",\n        \"unsloth/Qwen2.5-VL-7B-Instruct-unsloth-bnb-4bit\",\n    ],\n    \"unsloth_Qwen3-0.6B.yaml\": [\n        \"unsloth/Qwen3-0.6B-unsloth-bnb-4bit\",\n        \"Qwen/Qwen3-0.6B\",\n        \"unsloth/Qwen3-0.6B-bnb-4bit\",\n        \"Qwen/Qwen3-0.6B-FP8\",\n        \"unsloth/Qwen3-0.6B-FP8\",\n    ],\n    \"unsloth_Qwen3-4B-Instruct-2507.yaml\": [\n        \"unsloth/Qwen3-4B-Instruct-2507-unsloth-bnb-4bit\",\n        \"Qwen/Qwen3-4B-Instruct-2507\",\n        \"unsloth/Qwen3-4B-Instruct-2507-bnb-4bit\",\n        \"Qwen/Qwen3-4B-Instruct-2507-FP8\",\n        \"unsloth/Qwen3-4B-Instruct-2507-FP8\",\n    ],\n    \"unsloth_Qwen3-4B-Thinking-2507.yaml\": [\n        \"unsloth/Qwen3-4B-Thinking-2507-unsloth-bnb-4bit\",\n        \"Qwen/Qwen3-4B-Thinking-2507\",\n        \"unsloth/Qwen3-4B-Thinking-2507-bnb-4bit\",\n        \"Qwen/Qwen3-4B-Thinking-2507-FP8\",\n        \"unsloth/Qwen3-4B-Thinking-2507-FP8\",\n    ],\n    \"unsloth_Qwen3-14B-Base-unsloth-bnb-4bit.yaml\": [\n        \"unsloth/Qwen3-14B-Base\",\n        \"Qwen/Qwen3-14B-Base\",\n        \"unsloth/Qwen3-14B-Base-bnb-4bit\",\n    ],\n    \"unsloth_Qwen3-14B.yaml\": [\n        \"unsloth/Qwen3-14B-unsloth-bnb-4bit\",\n        \"Qwen/Qwen3-14B\",\n        \"unsloth/Qwen3-14B-bnb-4bit\",\n        \"Qwen/Qwen3-14B-FP8\",\n        \"unsloth/Qwen3-14B-FP8\",\n    ],\n    \"unsloth_Qwen3-32B.yaml\": [\n        \"unsloth/Qwen3-32B-unsloth-bnb-4bit\",\n        \"Qwen/Qwen3-32B\",\n        \"unsloth/Qwen3-32B-bnb-4bit\",\n        \"Qwen/Qwen3-32B-FP8\",\n        \"unsloth/Qwen3-32B-FP8\",\n    ],\n    \"unsloth_Qwen3-VL-8B-Instruct-unsloth-bnb-4bit.yaml\": [\n        \"Qwen/Qwen3-VL-8B-Instruct-FP8\",\n        \"unsloth/Qwen3-VL-8B-Instruct-FP8\",\n        \"unsloth/Qwen3-VL-8B-Instruct\",\n        \"Qwen/Qwen3-VL-8B-Instruct\",\n        \"unsloth/Qwen3-VL-8B-Instruct-bnb-4bit\",\n    ],\n    \"sesame_csm-1b.yaml\": [\n        \"sesame/csm-1b\",\n        \"unsloth/csm-1b\",\n    ],\n    \"Spark-TTS-0.5B_LLM.yaml\": [\n        \"Spark-TTS-0.5B/LLM\",\n        \"unsloth/Spark-TTS-0.5B\",\n    ],\n    \"unsloth_tinyllama-bnb-4bit.yaml\": [\n        \"unsloth/tinyllama\",\n        \"TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T\",\n    ],\n    \"unsloth_whisper-large-v3.yaml\": [\n        \"unsloth/whisper-large-v3\",\n        \"openai/whisper-large-v3\",\n    ],\n}\n\n# Reverse mapping for quick lookup: model_name -> canonical_filename\n_REVERSE_MODEL_MAPPING = {}\nfor canonical_file, model_names in MODEL_NAME_MAPPING.items():\n    for model_name in model_names:\n        _REVERSE_MODEL_MAPPING[model_name.lower()] = canonical_file\n\n\ndef load_model_config(\n    model_name: str,\n    use_auth: bool = False,\n    token: Optional[str] = None,\n    trust_remote_code: bool = True,\n):\n    \"\"\"\n    Load model config with optional authentication control.\n    \"\"\"\n\n    if token:\n        # Explicit token provided - use it\n        return AutoConfig.from_pretrained(\n            model_name, trust_remote_code = trust_remote_code, token = token\n        )\n\n    if not use_auth:\n        # Load without any authentication (for public model checks)\n        with without_hf_auth():\n            return AutoConfig.from_pretrained(\n                model_name,\n                trust_remote_code = trust_remote_code,\n                token = None,\n            )\n\n    # Use default authentication (cached tokens)\n    return AutoConfig.from_pretrained(\n        model_name,\n        trust_remote_code = trust_remote_code,\n    )\n\n\n# VLM architecture suffixes and known VLM model_type values.\n_VLM_ARCH_SUFFIXES = (\"ForConditionalGeneration\", \"ForVisionText2Text\")\n_VLM_MODEL_TYPES = {\n    \"phi3_v\",\n    \"llava\",\n    \"llava_next\",\n    \"llava_onevision\",\n    \"internvl_chat\",\n    \"cogvlm2\",\n    \"minicpmv\",\n}\n\n# Pre-computed .venv_t5 path and backend dir for subprocess version switching.\n_VENV_T5_DIR = str(Path.home() / \".unsloth\" / \"studio\" / \".venv_t5\")\n_BACKEND_DIR = str(Path(__file__).resolve().parent.parent.parent)\n\n# Inline script executed in a subprocess with transformers 5.x activated.\n# Receives model_name and token via argv, prints JSON result to stdout.\n_VISION_CHECK_SCRIPT = r\"\"\"\nimport sys, os, json\nos.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n\n# Activate transformers 5.x\nvenv_t5 = sys.argv[1]\nbackend_dir = sys.argv[2]\nmodel_name = sys.argv[3]\ntoken = sys.argv[4] if len(sys.argv) > 4 and sys.argv[4] != \"\" else None\n\nsys.path.insert(0, venv_t5)\nif backend_dir not in sys.path:\n    sys.path.insert(0, backend_dir)\n\ntry:\n    from transformers import AutoConfig\n    kwargs = {\"trust_remote_code\": True}\n    if token:\n        kwargs[\"token\"] = token\n    config = AutoConfig.from_pretrained(model_name, **kwargs)\n\n    is_vlm = False\n    if hasattr(config, \"architectures\"):\n        is_vlm = any(\n            x.endswith((\"ForConditionalGeneration\", \"ForVisionText2Text\"))\n            for x in config.architectures\n        )\n    if not is_vlm and hasattr(config, \"vision_config\"):\n        is_vlm = True\n    if not is_vlm and hasattr(config, \"img_processor\"):\n        is_vlm = True\n    if not is_vlm and hasattr(config, \"image_token_index\"):\n        is_vlm = True\n    if not is_vlm and hasattr(config, \"model_type\"):\n        vlm_types = {\"phi3_v\",\"llava\",\"llava_next\",\"llava_onevision\",\n                      \"internvl_chat\",\"cogvlm2\",\"minicpmv\"}\n        if config.model_type in vlm_types:\n            is_vlm = True\n\n    model_type = getattr(config, \"model_type\", \"unknown\")\n    archs = getattr(config, \"architectures\", [])\n    print(json.dumps({\"is_vision\": is_vlm, \"model_type\": model_type,\n                       \"architectures\": archs}))\nexcept Exception as exc:\n    print(json.dumps({\"error\": str(exc)}))\n    sys.exit(1)\n\"\"\"\n\n\ndef _is_vision_model_subprocess(\n    model_name: str, hf_token: Optional[str] = None\n) -> bool:\n    \"\"\"Run is_vision_model check in a subprocess with transformers 5.x.\n\n    Same pattern as training/inference workers: spawn a clean subprocess\n    with .venv_t5/ prepended to sys.path so AutoConfig recognizes newer\n    architectures (glm4_moe_lite, etc.).\n    \"\"\"\n    token_arg = hf_token or \"\"\n\n    try:\n        result = subprocess.run(\n            [\n                sys.executable,\n                \"-c\",\n                _VISION_CHECK_SCRIPT,\n                _VENV_T5_DIR,\n                _BACKEND_DIR,\n                model_name,\n                token_arg,\n            ],\n            capture_output = True,\n            text = True,\n            timeout = 60,\n        )\n\n        if result.returncode != 0:\n            stderr = result.stderr.strip()\n            logger.warning(\n                \"Vision check subprocess failed for '%s': %s\",\n                model_name,\n                stderr or result.stdout.strip(),\n            )\n            return False\n\n        data = json.loads(result.stdout.strip())\n        if \"error\" in data:\n            logger.warning(\n                \"Vision check subprocess error for '%s': %s\",\n                model_name,\n                data[\"error\"],\n            )\n            return False\n\n        is_vlm = data[\"is_vision\"]\n        logger.info(\n            \"Vision check (subprocess, transformers 5.x) for '%s': \"\n            \"model_type=%s, architectures=%s, is_vision=%s\",\n            model_name,\n            data.get(\"model_type\"),\n            data.get(\"architectures\"),\n            is_vlm,\n        )\n        return is_vlm\n\n    except subprocess.TimeoutExpired:\n        logger.warning(\"Vision check subprocess timed out for '%s'\", model_name)\n        return False\n    except Exception as exc:\n        logger.warning(\"Vision check subprocess failed for '%s': %s\", model_name, exc)\n        return False\n\n\ndef is_vision_model(model_name: str, hf_token: Optional[str] = None) -> bool:\n    \"\"\"\n    Detect vision-language models (VLMs) by checking architecture in config.\n    Works for fine-tuned models since they inherit the base architecture.\n\n    For models that require transformers 5.x (e.g. GLM-4.7-Flash), the check\n    runs in a subprocess with .venv_t5/ activated — same pattern as the\n    training and inference workers.\n\n    Args:\n        model_name: Model identifier (HF repo or local path)\n        hf_token: Optional HF token for accessing gated/private models\n    \"\"\"\n    # Models that need transformers 5.x must be checked in a subprocess\n    # because AutoConfig in the main process (transformers 4.57.x) doesn't\n    # recognize their architectures.\n    from utils.transformers_version import needs_transformers_5\n\n    if needs_transformers_5(model_name):\n        logger.info(\n            \"Model '%s' needs transformers 5.x — checking vision via subprocess\",\n            model_name,\n        )\n        return _is_vision_model_subprocess(model_name, hf_token = hf_token)\n\n    try:\n        config = load_model_config(model_name, use_auth = True, token = hf_token)\n\n        # Exclude audio-only models that share ForConditionalGeneration suffix\n        # (e.g. CsmForConditionalGeneration, WhisperForConditionalGeneration)\n        _audio_only_model_types = {\"csm\", \"whisper\"}\n        model_type = getattr(config, \"model_type\", None)\n        if model_type in _audio_only_model_types:\n            return False\n\n        # Check 1: Architecture class name patterns\n        if hasattr(config, \"architectures\"):\n            is_vlm = any(x.endswith(_VLM_ARCH_SUFFIXES) for x in config.architectures)\n            if is_vlm:\n                logger.info(\n                    f\"Model {model_name} detected as VLM: architecture {config.architectures}\"\n                )\n                return True\n\n        # Check 2: Has vision_config (most VLMs: LLaVA, Gemma-3, Qwen2-VL, etc.)\n        if hasattr(config, \"vision_config\"):\n            logger.info(f\"Model {model_name} detected as VLM: has vision_config\")\n            return True\n\n        # Check 3: Has img_processor (Phi-3.5 Vision uses this instead of vision_config)\n        if hasattr(config, \"img_processor\"):\n            logger.info(f\"Model {model_name} detected as VLM: has img_processor\")\n            return True\n\n        # Check 4: Has image_token_index (common in VLMs for image placeholder tokens)\n        if hasattr(config, \"image_token_index\"):\n            logger.info(f\"Model {model_name} detected as VLM: has image_token_index\")\n            return True\n\n        # Check 5: Known VLM model_type values that may not match above checks\n        if hasattr(config, \"model_type\"):\n            if config.model_type in _VLM_MODEL_TYPES:\n                logger.info(\n                    f\"Model {model_name} detected as VLM: model_type={config.model_type}\"\n                )\n                return True\n\n        return False\n\n    except Exception as e:\n        logger.warning(f\"Could not determine if {model_name} is vision model: {e}\")\n        return False\n\n\nVALID_AUDIO_TYPES = (\"snac\", \"csm\", \"bicodec\", \"dac\", \"whisper\", \"audio_vlm\")\n\n# Cache detection results per session to avoid repeated API calls\n_audio_detection_cache: Dict[str, Optional[str]] = {}\n\n# Tokenizer token patterns → audio_type (all 6 types detected from tokenizer_config.json)\n_AUDIO_TOKEN_PATTERNS = {\n    \"csm\": lambda tokens: \"<|AUDIO|>\" in tokens and \"<|audio_eos|>\" in tokens,\n    \"whisper\": lambda tokens: \"<|startoftranscript|>\" in tokens,\n    \"audio_vlm\": lambda tokens: \"<audio_soft_token>\" in tokens,\n    \"bicodec\": lambda tokens: any(t.startswith(\"<|bicodec_\") for t in tokens),\n    \"dac\": lambda tokens: \"<|audio_start|>\" in tokens\n    and \"<|audio_end|>\" in tokens\n    and \"<|text_start|>\" in tokens\n    and \"<|text_end|>\" in tokens,\n    \"snac\": lambda tokens: sum(1 for t in tokens if t.startswith(\"<custom_token_\"))\n    > 10000,\n}\n\n\ndef detect_audio_type(model_name: str, hf_token: Optional[str] = None) -> Optional[str]:\n    \"\"\"\n    Dynamically detect if a model is an audio model and return its type.\n\n    Fully dynamic — works for any model, not just known ones.\n    Uses tokenizer_config.json special tokens to detect all 6 audio types.\n\n    Returns: audio_type string ('snac', 'csm', 'bicodec', 'dac', 'whisper', 'audio_vlm') or None.\n    \"\"\"\n    if model_name in _audio_detection_cache:\n        return _audio_detection_cache[model_name]\n\n    result = _detect_audio_from_tokenizer(model_name, hf_token)\n\n    _audio_detection_cache[model_name] = result\n    if result:\n        logger.info(f\"Model {model_name} detected as audio model: audio_type={result}\")\n    return result\n\n\ndef _detect_audio_from_tokenizer(\n    model_name: str, hf_token: Optional[str] = None\n) -> Optional[str]:\n    \"\"\"Detect audio type from tokenizer special tokens (for LLM-based audio models).\n\n    First checks local HF cache, then fetches tokenizer_config.json from HuggingFace.\n    Checks added_tokens_decoder for distinctive patterns.\n    \"\"\"\n\n    def _check_token_patterns(tok_config: dict) -> Optional[str]:\n        added = tok_config.get(\"added_tokens_decoder\", {})\n        if not added:\n            return None\n        token_contents = [v.get(\"content\", \"\") for v in added.values()]\n        for audio_type, check_fn in _AUDIO_TOKEN_PATTERNS.items():\n            if check_fn(token_contents):\n                return audio_type\n        return None\n\n    # 1) Check local HF cache first (works for gated/offline models)\n    try:\n        from huggingface_hub.constants import HF_HUB_CACHE\n\n        cache_dir = Path(HF_HUB_CACHE)\n        repo_dir_name = f\"models--{model_name.replace('/', '--')}\"\n        repo_dir = cache_dir / repo_dir_name\n        if repo_dir.exists():\n            snapshots_dir = repo_dir / \"snapshots\"\n            if snapshots_dir.exists():\n                for snapshot in snapshots_dir.iterdir():\n                    for tok_path in [\n                        \"tokenizer_config.json\",\n                        \"LLM/tokenizer_config.json\",\n                    ]:\n                        tok_file = snapshot / tok_path\n                        if tok_file.exists():\n                            tok_config = json.loads(tok_file.read_text())\n                            result = _check_token_patterns(tok_config)\n                            if result:\n                                return result\n    except Exception as e:\n        logger.debug(f\"Could not check local cache for {model_name}: {e}\")\n\n    # 2) Fall back to HuggingFace API\n    try:\n        import requests\n        import os\n\n        paths_to_try = [\"tokenizer_config.json\", \"LLM/tokenizer_config.json\"]\n        # Use provided token, or fall back to env\n        token = hf_token or os.environ.get(\"HF_TOKEN\")\n        headers = {}\n        if token:\n            headers[\"Authorization\"] = f\"Bearer {token}\"\n\n        for tok_path in paths_to_try:\n            url = f\"https://huggingface.co/{model_name}/resolve/main/{tok_path}\"\n            resp = requests.get(url, headers = headers, timeout = 15)\n            if not resp.ok:\n                continue\n\n            tok_config = resp.json()\n            result = _check_token_patterns(tok_config)\n            if result:\n                return result\n\n        return None\n    except Exception as e:\n        logger.debug(\n            f\"Could not detect audio type from tokenizer for {model_name}: {e}\"\n        )\n        return None\n\n\ndef is_audio_input_type(audio_type: Optional[str]) -> bool:\n    \"\"\"Check if an audio_type accepts audio input (ASR/speech understanding).\n\n    Whisper (ASR) and audio_vlm (Gemma3n) accept audio input.\n    \"\"\"\n    return audio_type in (\"whisper\", \"audio_vlm\")\n\n\ndef _is_mmproj(filename: str) -> bool:\n    \"\"\"Check if a GGUF filename is a vision projection (mmproj) file.\"\"\"\n    return \"mmproj\" in filename.lower()\n\n\ndef detect_mmproj_file(path: str) -> Optional[str]:\n    \"\"\"\n    Find the mmproj (vision projection) GGUF file in a directory.\n\n    Args:\n        path: Directory to search — or a .gguf file (uses its parent dir).\n\n    Returns:\n        Full path to the mmproj .gguf file, or None if not found.\n    \"\"\"\n    p = Path(path)\n    search_dir = p.parent if p.is_file() else p\n    if not search_dir.is_dir():\n        return None\n\n    for f in search_dir.glob(\"*.gguf\"):\n        if _is_mmproj(f.name):\n            return str(f.resolve())\n    return None\n\n\ndef detect_gguf_model(path: str) -> Optional[str]:\n    \"\"\"\n    Check if the given local path is or contains a GGUF model file.\n\n    Handles two cases:\n    1. path is a direct .gguf file path\n    2. path is a directory containing .gguf files\n\n    Skips mmproj (vision projection) files — those must be passed via\n    ``--mmproj``, not ``-m``.  Use :func:`detect_mmproj_file` instead.\n\n    Returns the full path to the .gguf file if found, None otherwise.\n    For HuggingFace repo detection, use detect_gguf_model_remote() instead.\n    \"\"\"\n    p = Path(path)\n\n    # Case 1: direct .gguf file\n    if p.suffix == \".gguf\" and p.is_file():\n        if _is_mmproj(p.name):\n            return None\n        return str(p.resolve())\n\n    # Case 2: directory containing .gguf files (skip mmproj)\n    if p.is_dir():\n        gguf_files = sorted(\n            (f for f in p.glob(\"*.gguf\") if not _is_mmproj(f.name)),\n            key = lambda f: f.stat().st_size,\n            reverse = True,\n        )\n        if gguf_files:\n            return str(gguf_files[0].resolve())\n\n    return None\n\n\n# Preferred GGUF quantization levels, in descending priority.\n# Q4_K_M is a good default: small, fast, acceptable quality.\n# UD (Unsloth Dynamic) variants are always preferred over standard quants\n# because they provide better quality per bit. If the repo has no UD variants\n# (e.g., bartowski repos), the standard quants are used as fallback.\n# Ordered by best size/quality tradeoff, not raw quality.\n_GGUF_QUANT_PREFERENCE = [\n    # UD variants (best quality per bit) -- Q4 is the sweet spot\n    \"UD-Q4_K_XL\",\n    \"UD-Q4_K_L\",\n    \"UD-Q5_K_XL\",\n    \"UD-Q3_K_XL\",\n    \"UD-Q6_K_XL\",\n    \"UD-Q6_K_S\",\n    \"UD-Q8_K_XL\",\n    \"UD-Q2_K_XL\",\n    \"UD-IQ4_NL\",\n    \"UD-IQ4_XS\",\n    \"UD-IQ3_S\",\n    \"UD-IQ3_XXS\",\n    \"UD-IQ2_M\",\n    \"UD-IQ2_XXS\",\n    \"UD-IQ1_M\",\n    \"UD-IQ1_S\",\n    # Standard quants (fallback for non-Unsloth repos)\n    \"Q4_K_M\",\n    \"Q4_K_S\",\n    \"Q5_K_M\",\n    \"Q5_K_S\",\n    \"Q6_K\",\n    \"Q8_0\",\n    \"Q3_K_M\",\n    \"Q3_K_L\",\n    \"Q3_K_S\",\n    \"Q2_K\",\n    \"Q2_K_L\",\n    \"IQ4_NL\",\n    \"IQ4_XS\",\n    \"IQ3_M\",\n    \"IQ3_XXS\",\n    \"IQ2_M\",\n    \"IQ1_M\",\n    \"F16\",\n    \"BF16\",\n    \"F32\",\n]\n\n\ndef _pick_best_gguf(filenames: list[str]) -> Optional[str]:\n    \"\"\"\n    Pick the best GGUF file from a list of filenames.\n\n    Prefers quantization levels in _GGUF_QUANT_PREFERENCE order.\n    Falls back to the first .gguf file found.\n    \"\"\"\n    gguf_files = [f for f in filenames if f.endswith(\".gguf\")]\n    if not gguf_files:\n        return None\n\n    # Try preferred quantization levels\n    for quant in _GGUF_QUANT_PREFERENCE:\n        for f in gguf_files:\n            if quant in f:\n                return f\n\n    # Fallback: first GGUF file\n    return gguf_files[0]\n\n\n@dataclass\nclass GgufVariantInfo:\n    \"\"\"A single GGUF quantization variant from a HuggingFace repo.\"\"\"\n\n    filename: str  # e.g., \"gemma-3-4b-it-Q4_K_M.gguf\"\n    quant: str  # e.g., \"Q4_K_M\" (extracted from filename)\n    size_bytes: int  # file size\n\n\ndef _extract_quant_label(filename: str) -> str:\n    \"\"\"\n    Extract quantization label like Q4_K_M, IQ4_XS, BF16 from a GGUF filename.\n\n    Examples:\n        \"gemma-3-4b-it-Q4_K_M.gguf\"          → \"Q4_K_M\"\n        \"model-IQ4_NL.gguf\"                   → \"IQ4_NL\"\n        \"model-BF16.gguf\"                     → \"BF16\"\n        \"model-UD-IQ1_S.gguf\"                 → \"UD-IQ1_S\"\n        \"model-UD-TQ1_0.gguf\"                 → \"UD-TQ1_0\"\n        \"MXFP4_MOE/model-MXFP4_MOE-0001.gguf\"→ \"MXFP4_MOE\"\n    \"\"\"\n    import re\n\n    # Use only the basename (rfilename may include directory)\n    basename = filename.rsplit(\"/\", 1)[-1]\n    # Strip .gguf and any shard suffix (-00001-of-00010)\n    stem = re.sub(r\"-\\d{3,}-of-\\d{3,}\", \"\", basename.rsplit(\".\", 1)[0])\n    # Match known quantization patterns\n    match = re.search(\n        r\"(UD-)?\"  # Optional UD- prefix (Ultra Discrete)\n        r\"(MXFP[0-9]+(?:_[A-Z0-9]+)*\"  # MXFP variants: MXFP4, MXFP4_MOE\n        r\"|IQ[0-9]+_[A-Z]+(?:_[A-Z0-9]+)?\"  # IQ variants: IQ4_XS, IQ4_NL, IQ1_S\n        r\"|TQ[0-9]+_[0-9]+\"  # Ternary quant: TQ1_0, TQ2_0\n        r\"|Q[0-9]+_K_[A-Z]+\"  # K-quant: Q4_K_M, Q3_K_S\n        r\"|Q[0-9]+_[0-9]+\"  # Standard: Q8_0, Q5_1\n        r\"|Q[0-9]+_K\"  # Short K-quant: Q6_K\n        r\"|BF16|F16|F32)\",  # Full precision\n        stem,\n        re.IGNORECASE,\n    )\n    if match:\n        prefix = match.group(1) or \"\"\n        return f\"{prefix}{match.group(2)}\"\n    # Fallback: last segment after hyphen\n    return stem.split(\"-\")[-1]\n\n\ndef list_gguf_variants(\n    repo_id: str,\n    hf_token: Optional[str] = None,\n) -> tuple[list[GgufVariantInfo], bool]:\n    \"\"\"\n    List all GGUF quantization variants in a HuggingFace repo.\n\n    Separates main model files from mmproj (vision projection) files.\n    The presence of mmproj files indicates a vision-capable model.\n\n    Returns:\n        (variants, has_vision): list of non-mmproj GGUF variants + vision flag.\n    \"\"\"\n    from huggingface_hub import model_info as hf_model_info\n\n    info = hf_model_info(repo_id, token = hf_token, files_metadata = True)\n    variants: list[GgufVariantInfo] = []\n    has_vision = False\n\n    quant_totals: dict[str, int] = {}  # quant -> total bytes\n    quant_first_file: dict[str, str] = {}  # quant -> first filename (for display)\n\n    for sibling in info.siblings:\n        fname = sibling.rfilename\n        if not fname.endswith(\".gguf\"):\n            continue\n        size = sibling.size or 0\n\n        # mmproj files are vision projection models, not main model files\n        if \"mmproj\" in fname.lower():\n            has_vision = True\n            continue\n\n        quant = _extract_quant_label(fname)\n        quant_totals[quant] = quant_totals.get(quant, 0) + size\n        if quant not in quant_first_file:\n            quant_first_file[quant] = fname\n\n    for quant, total_size in quant_totals.items():\n        variants.append(\n            GgufVariantInfo(\n                filename = quant_first_file[quant],\n                quant = quant,\n                size_bytes = total_size,\n            )\n        )\n\n    # Sort by size descending (largest = best quality first).\n    # Recommended pinning and OOM demotion are handled client-side\n    # where GPU VRAM info is available.\n    variants.sort(key = lambda v: -v.size_bytes)\n\n    return variants, has_vision\n\n\ndef detect_gguf_model_remote(\n    repo_id: str,\n    hf_token: Optional[str] = None,\n) -> Optional[str]:\n    \"\"\"\n    Check if a HuggingFace repo contains GGUF files.\n\n    Returns the filename of the best GGUF file in the repo, or None.\n    \"\"\"\n    try:\n        from huggingface_hub import model_info as hf_model_info\n\n        info = hf_model_info(repo_id, token = hf_token)\n        repo_files = [s.rfilename for s in info.siblings]\n        return _pick_best_gguf(repo_files)\n    except Exception as e:\n        logger.debug(f\"Could not check GGUF files for '{repo_id}': {e}\")\n        return None\n\n\ndef download_gguf_file(\n    repo_id: str,\n    filename: str,\n    hf_token: Optional[str] = None,\n) -> str:\n    \"\"\"\n    Download a specific GGUF file from a HuggingFace repo.\n\n    Returns the local path to the downloaded file.\n    \"\"\"\n    from huggingface_hub import hf_hub_download\n\n    local_path = hf_hub_download(\n        repo_id = repo_id,\n        filename = filename,\n        token = hf_token,\n    )\n    return local_path\n\n\n# Cache embedding detection results per session to avoid repeated HF API calls\n_embedding_detection_cache: Dict[tuple, bool] = {}\n\n\ndef is_embedding_model(model_name: str, hf_token: Optional[str] = None) -> bool:\n    \"\"\"\n    Detect embedding/sentence-transformer models using HuggingFace model metadata.\n\n    Uses a belt-and-suspenders approach combining three signals:\n      1. \"sentence-transformers\" in model tags\n      2. \"feature-extraction\" in model tags\n      3. pipeline_tag is \"sentence-similarity\" or \"feature-extraction\"\n\n    This catches all known embedding models including those like gte-modernbert\n    whose library_name is \"transformers\" rather than \"sentence-transformers\".\n\n    Args:\n        model_name: Model identifier (HF repo or local path)\n        hf_token: Optional HF token for accessing gated/private models\n\n    Returns:\n        True if the model is an embedding model, False otherwise.\n        Defaults to False for local paths or on errors.\n    \"\"\"\n    cache_key = (model_name, hf_token)\n    if cache_key in _embedding_detection_cache:\n        return _embedding_detection_cache[cache_key]\n\n    # Local paths: check for sentence-transformer marker file (modules.json)\n    if is_local_path(model_name):\n        local_dir = normalize_path(model_name)\n        is_emb = os.path.isfile(os.path.join(local_dir, \"modules.json\"))\n        _embedding_detection_cache[cache_key] = is_emb\n        return is_emb\n\n    try:\n        from huggingface_hub import model_info as hf_model_info\n\n        info = hf_model_info(model_name, token = hf_token)\n        tags = set(info.tags or [])\n        pipeline_tag = info.pipeline_tag or \"\"\n\n        is_emb = (\n            \"sentence-transformers\" in tags\n            or \"feature-extraction\" in tags\n            or pipeline_tag in (\"sentence-similarity\", \"feature-extraction\")\n        )\n\n        _embedding_detection_cache[cache_key] = is_emb\n        if is_emb:\n            logger.info(\n                f\"Model {model_name} detected as embedding model: \"\n                f\"pipeline_tag={pipeline_tag}, \"\n                f\"sentence-transformers in tags={('sentence-transformers' in tags)}, \"\n                f\"feature-extraction in tags={('feature-extraction' in tags)}\"\n            )\n        return is_emb\n\n    except Exception as e:\n        logger.warning(f\"Could not determine if {model_name} is embedding model: {e}\")\n        _embedding_detection_cache[cache_key] = False\n        return False\n\n\ndef scan_trained_loras(outputs_dir: str = str(outputs_root())) -> List[Tuple[str, str]]:\n    \"\"\"\n    Scan outputs folder for trained LoRA adapters.\n\n    Returns:\n        List of tuples: [(display_name, adapter_path), ...]\n\n    Example:\n        [\n            (\"unsloth_Meta-Llama-3.1_...\", \"./outputs/unsloth_Meta-Llama-3.1_.../\"),\n            (\"my_finetuned_model\", \"./outputs/my_finetuned_model/\"),\n        ]\n    \"\"\"\n    trained_loras = []\n    outputs_path = resolve_output_dir(outputs_dir)\n\n    if not outputs_path.exists():\n        logger.warning(f\"Outputs directory not found: {outputs_dir}\")\n        return trained_loras\n\n    try:\n        for item in outputs_path.iterdir():\n            if item.is_dir():\n                # Check if this directory contains a LoRA adapter\n                adapter_config = item / \"adapter_config.json\"\n                adapter_model = item / \"adapter_model.safetensors\"\n\n                if adapter_config.exists() or adapter_model.exists():\n                    display_name = item.name\n                    adapter_path = str(item)\n                    trained_loras.append((display_name, adapter_path))\n                    logger.debug(f\"Found trained LoRA: {display_name}\")\n\n        # Sort by modification time (newest first)\n        trained_loras.sort(key = lambda x: Path(x[1]).stat().st_mtime, reverse = True)\n\n        logger.info(\n            f\"Found {len(trained_loras)} trained LoRA adapters in {outputs_dir}\"\n        )\n        return trained_loras\n\n    except Exception as e:\n        logger.error(f\"Error scanning outputs folder: {e}\")\n        return []\n\n\ndef scan_exported_models(\n    exports_dir: str = str(exports_root()),\n) -> List[Tuple[str, str, str, Optional[str]]]:\n    \"\"\"\n    Scan exports folder for exported models (merged, LoRA, GGUF).\n\n    Supports two directory layouts:\n      - Two-level: {run}/{checkpoint}/  (merged & LoRA exports)\n      - Flat:      {name}-finetune-gguf/  (GGUF exports)\n\n    Returns:\n        List of tuples: [(display_name, model_path, export_type, base_model), ...]\n        export_type: \"lora\" | \"merged\" | \"gguf\"\n    \"\"\"\n    results = []\n    exports_path = resolve_export_dir(exports_dir)\n\n    if not exports_path.exists():\n        return results\n\n    try:\n        for run_dir in exports_path.iterdir():\n            if not run_dir.is_dir():\n                continue\n\n            # Check for flat GGUF export (e.g. exports/gemma-3-4b-it-finetune-gguf/)\n            # Filter out mmproj (vision projection) files — they aren't loadable as main models\n            gguf_files = [f for f in run_dir.glob(\"*.gguf\") if not _is_mmproj(f.name)]\n            if gguf_files:\n                base_model = None\n                export_meta = run_dir / \"export_metadata.json\"\n                try:\n                    if export_meta.exists():\n                        meta = json.loads(export_meta.read_text())\n                        base_model = meta.get(\"base_model\")\n                except Exception:\n                    pass\n\n                display_name = run_dir.name\n                model_path = str(gguf_files[0])  # path to the .gguf file\n                results.append((display_name, model_path, \"gguf\", base_model))\n                logger.debug(f\"Found GGUF export: {display_name}\")\n                continue\n\n            # Two-level: {run}/{checkpoint}/\n            for checkpoint_dir in run_dir.iterdir():\n                if not checkpoint_dir.is_dir():\n                    continue\n\n                adapter_config = checkpoint_dir / \"adapter_config.json\"\n                config_file = checkpoint_dir / \"config.json\"\n                has_weights = any(checkpoint_dir.glob(\"*.safetensors\")) or any(\n                    checkpoint_dir.glob(\"*.bin\")\n                )\n                has_gguf = any(checkpoint_dir.glob(\"*.gguf\"))\n\n                base_model = None\n                export_type = None\n\n                if adapter_config.exists():\n                    export_type = \"lora\"\n                    try:\n                        cfg = json.loads(adapter_config.read_text())\n                        base_model = cfg.get(\"base_model_name_or_path\")\n                    except Exception:\n                        pass\n                elif config_file.exists() and has_weights:\n                    export_type = \"merged\"\n                    export_meta = checkpoint_dir / \"export_metadata.json\"\n                    try:\n                        if export_meta.exists():\n                            meta = json.loads(export_meta.read_text())\n                            base_model = meta.get(\"base_model\")\n                    except Exception:\n                        pass\n                elif has_gguf:\n                    export_type = \"gguf\"\n                    gguf_list = list(checkpoint_dir.glob(\"*.gguf\"))\n                    # Check checkpoint_dir first, then fall back to parent run_dir\n                    # (export.py writes metadata to the top-level export directory)\n                    for meta_dir in (checkpoint_dir, run_dir):\n                        export_meta = meta_dir / \"export_metadata.json\"\n                        try:\n                            if export_meta.exists():\n                                meta = json.loads(export_meta.read_text())\n                                base_model = meta.get(\"base_model\")\n                                if base_model:\n                                    break\n                        except Exception:\n                            pass\n\n                    display_name = f\"{run_dir.name} / {checkpoint_dir.name}\"\n                    model_path = str(gguf_list[0]) if gguf_list else str(checkpoint_dir)\n                    results.append((display_name, model_path, export_type, base_model))\n                    logger.debug(f\"Found GGUF export: {display_name}\")\n                    continue\n                else:\n                    continue\n\n                # Fallback: read base model from the original training run's\n                # adapter_config.json in ./outputs/{run_name}/\n                if not base_model:\n                    outputs_adapter_cfg = (\n                        resolve_output_dir(run_dir.name) / \"adapter_config.json\"\n                    )\n                    try:\n                        if outputs_adapter_cfg.exists():\n                            cfg = json.loads(outputs_adapter_cfg.read_text())\n                            base_model = cfg.get(\"base_model_name_or_path\")\n                    except Exception:\n                        pass\n\n                display_name = f\"{run_dir.name} / {checkpoint_dir.name}\"\n                model_path = str(checkpoint_dir)\n                results.append((display_name, model_path, export_type, base_model))\n                logger.debug(f\"Found exported model: {display_name} ({export_type})\")\n\n        results.sort(key = lambda x: Path(x[1]).stat().st_mtime, reverse = True)\n        logger.info(f\"Found {len(results)} exported models in {exports_dir}\")\n        return results\n\n    except Exception as e:\n        logger.error(f\"Error scanning exports folder: {e}\")\n        return []\n\n\ndef get_base_model_from_lora(lora_path: str) -> Optional[str]:\n    \"\"\"\n    Read the base model name from a LoRA adapter's config.\n\n    Args:\n        lora_path: Path to the LoRA adapter directory\n\n    Returns:\n        Base model identifier (e.g., \"unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit\")\n        or None if not found\n\n    Example:\n        >>> get_base_model_from_lora(\"./outputs/unsloth_Meta-Llama-3.1_.../\")\n        \"unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit\"\n    \"\"\"\n    try:\n        lora_path_obj = Path(lora_path)\n\n        # Try adapter_config.json first\n        adapter_config_path = lora_path_obj / \"adapter_config.json\"\n        if adapter_config_path.exists():\n            with open(adapter_config_path, \"r\") as f:\n                config = json.load(f)\n                base_model = config.get(\"base_model_name_or_path\")\n                if base_model:\n                    logger.info(\n                        f\"Detected base model from adapter_config.json: {base_model}\"\n                    )\n                    return base_model\n\n        # Fallback: try training_args.bin (requires torch)\n        training_args_path = lora_path_obj / \"training_args.bin\"\n        if training_args_path.exists():\n            try:\n                import torch\n\n                training_args = torch.load(training_args_path)\n                if hasattr(training_args, \"model_name_or_path\"):\n                    base_model = training_args.model_name_or_path\n                    logger.info(\n                        f\"Detected base model from training_args.bin: {base_model}\"\n                    )\n                    return base_model\n            except Exception as e:\n                logger.warning(f\"Could not load training_args.bin: {e}\")\n\n        # Last resort: parse from directory name\n        # Format: unsloth_Meta-Llama-3.1-8B-Instruct-bnb-4bit_timestamp\n        dir_name = lora_path_obj.name\n        if dir_name.startswith(\"unsloth_\"):\n            # Remove timestamp suffix (usually _1234567890)\n            parts = dir_name.split(\"_\")\n            # Reconstruct model name\n            if len(parts) >= 2:\n                model_parts = parts[1:-1]  # Skip \"unsloth\" and timestamp\n                base_model = \"unsloth/\" + \"_\".join(model_parts)\n                logger.info(f\"Detected base model from directory name: {base_model}\")\n                return base_model\n\n        logger.warning(f\"Could not detect base model for LoRA: {lora_path}\")\n        return None\n\n    except Exception as e:\n        logger.error(f\"Error reading base model from LoRA config: {e}\")\n        return None\n\n\n# Status indicators that appear in UI dropdowns\nUI_STATUS_INDICATORS = [\" (Ready)\", \" (Loading...)\", \" (Active)\", \"↓ \"]\n\n\ndef load_model_defaults(model_name: str) -> Dict[str, Any]:\n    \"\"\"\n    Load default training parameters for a model from YAML file.\n\n    Args:\n        model_name: Model identifier (e.g., \"unsloth/Meta-Llama-3.1-8B-bnb-4bit\")\n\n    Returns:\n        Dictionary with default parameters from YAML file, or empty dict if not found\n\n    The function looks for a YAML file in configs/model_defaults/ (including subfolders)\n    based on the model name or its aliases from MODEL_NAME_MAPPING.\n    If no specific file exists, it falls back to default.yaml.\n    \"\"\"\n    try:\n        # Get the script directory to locate configs\n        script_dir = Path(__file__).parent.parent.parent\n        defaults_dir = script_dir / \"assets\" / \"configs\" / \"model_defaults\"\n\n        # First, check if model is in the mapping\n        if model_name.lower() in _REVERSE_MODEL_MAPPING:\n            canonical_file = _REVERSE_MODEL_MAPPING[model_name.lower()]\n            # Search in subfolders and root\n            for config_path in defaults_dir.rglob(canonical_file):\n                if config_path.is_file():\n                    with open(config_path, \"r\", encoding = \"utf-8\") as f:\n                        config = yaml.safe_load(f) or {}\n                        logger.info(\n                            f\"Loaded model defaults from {config_path} (via mapping)\"\n                        )\n                        return config\n\n        # If model_name is a local path (e.g. /home/.../Spark-TTS-0.5B/LLM from\n        # adapter_config.json), try matching the last 1-2 path components against\n        # the registry (e.g. \"Spark-TTS-0.5B/LLM\").\n        if model_name not in _REVERSE_MODEL_MAPPING and (\n            model_name.startswith(\"/\") or model_name.startswith(\".\")\n        ):\n            parts = Path(model_name).parts\n            for depth in [2, 1]:\n                if len(parts) >= depth:\n                    suffix = \"/\".join(parts[-depth:])\n                    if suffix in _REVERSE_MODEL_MAPPING:\n                        canonical_file = _REVERSE_MODEL_MAPPING[suffix]\n                        for config_path in defaults_dir.rglob(canonical_file):\n                            if config_path.is_file():\n                                with open(config_path, \"r\", encoding = \"utf-8\") as f:\n                                    config = yaml.safe_load(f) or {}\n                                    logger.info(\n                                        f\"Loaded model defaults from {config_path} (via path suffix '{suffix}')\"\n                                    )\n                                    return config\n\n        # Try exact model name match (for backward compatibility)\n        model_filename = model_name.replace(\"/\", \"_\") + \".yaml\"\n        # Search in subfolders and root\n        for config_path in defaults_dir.rglob(model_filename):\n            if config_path.is_file():\n                with open(config_path, \"r\", encoding = \"utf-8\") as f:\n                    config = yaml.safe_load(f) or {}\n                    logger.info(f\"Loaded model defaults from {config_path}\")\n                    return config\n\n        # Fall back to default.yaml\n        default_config_path = defaults_dir / \"default.yaml\"\n        if default_config_path.exists():\n            with open(default_config_path, \"r\", encoding = \"utf-8\") as f:\n                config = yaml.safe_load(f) or {}\n                logger.info(f\"Loaded default model defaults from {default_config_path}\")\n                return config\n\n        logger.warning(f\"No default config found for model {model_name}\")\n        return {}\n\n    except Exception as e:\n        logger.error(f\"Error loading model defaults for {model_name}: {e}\")\n        return {}\n\n\n@dataclass\nclass ModelConfig:\n    \"\"\"Configuration for a model to load\"\"\"\n\n    identifier: str  # Clean model identifier (org/name or path)\n    display_name: str  # Original UI display name\n    path: str  # Normalized filesystem path\n    is_local: bool  # Is this a local file vs HF model?\n    is_cached: bool  # Is this already in HF cache?\n    is_vision: bool  # Is this a vision model?\n    is_lora: bool  # Is this a lora adapter?\n    is_gguf: bool = False  # Is this a GGUF model?\n    is_audio: bool = False  # Is this a TTS audio model?\n    audio_type: Optional[str] = (\n        None  # Audio codec type: 'snac', 'csm', 'bicodec', 'dac'\n    )\n    has_audio_input: bool = False  # Accepts audio input (ASR/speech understanding)\n    gguf_file: Optional[str] = None  # Full path to the .gguf file (local mode)\n    gguf_mmproj_file: Optional[str] = (\n        None  # Full path to the mmproj .gguf file (vision projection)\n    )\n    gguf_hf_repo: Optional[str] = (\n        None  # HF repo ID for -hf mode (e.g. \"unsloth/gemma-3-4b-it-GGUF\")\n    )\n    gguf_variant: Optional[str] = None  # Quantization variant (e.g. \"Q4_K_M\")\n    base_model: Optional[str] = None  # Base model (for LoRAs)\n\n    @classmethod\n    def from_lora_path(\n        cls, lora_path: str, hf_token: Optional[str] = None\n    ) -> Optional[\"ModelConfig\"]:\n        \"\"\"\n        Create ModelConfig from a local LoRA adapter path.\n\n        Automatically detects the base model from adapter config.\n\n        Args:\n            lora_path: Path to LoRA adapter (e.g., \"./outputs/unsloth_Meta-Llama-3.1_.../\")\n            hf_token: HF token for vision detection\n\n        Returns:\n            ModelConfig for the LoRA adapter\n        \"\"\"\n        try:\n            lora_path_obj = Path(lora_path)\n\n            if not lora_path_obj.exists():\n                logger.error(f\"LoRA path does not exist: {lora_path}\")\n                return None\n\n            # Get base model\n            base_model = get_base_model_from_lora(lora_path)\n            if not base_model:\n                logger.error(f\"Could not determine base model for LoRA: {lora_path}\")\n                return None\n\n            # Check if base model is vision\n            is_vision = is_vision_model(base_model, hf_token = hf_token)\n\n            # Check if base model is audio\n            audio_type = detect_audio_type(base_model, hf_token = hf_token)\n\n            display_name = lora_path_obj.name\n            identifier = lora_path  # Use path as identifier for local LoRAs\n\n            return cls(\n                identifier = identifier,\n                display_name = display_name,\n                path = lora_path,\n                is_local = True,\n                is_cached = True,  # Local LoRAs are always \"cached\"\n                is_vision = is_vision,\n                is_lora = True,\n                is_audio = audio_type is not None and audio_type != \"audio_vlm\",\n                audio_type = audio_type,\n                has_audio_input = is_audio_input_type(audio_type),\n                base_model = base_model,\n            )\n\n        except Exception as e:\n            logger.error(f\"Error creating ModelConfig from LoRA path: {e}\")\n            return None\n\n    @classmethod\n    def from_identifier(\n        cls,\n        model_id: str,\n        hf_token: Optional[str] = None,\n        is_lora: bool = False,\n        gguf_variant: Optional[str] = None,\n    ) -> Optional[\"ModelConfig\"]:\n        \"\"\"\n        Create ModelConfig from a clean model identifier.\n\n        For FastAPI routes where the frontend sends sanitized model paths.\n        No Gradio dropdown parsing - expects clean identifiers like:\n        - \"unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit\"\n        - \"./outputs/my_lora_adapter\"\n        - \"/absolute/path/to/model\"\n\n        Args:\n            model_id: Clean model identifier (HF repo name or local path)\n            hf_token: Optional HF token for vision detection on gated models\n            is_lora: Whether this is a LoRA adapter\n            gguf_variant: Optional GGUF quantization variant (e.g. \"Q4_K_M\").\n                For remote GGUF repos, specifies which quant to load via -hf.\n                If None, auto-selects using _pick_best_gguf().\n\n        Returns:\n            ModelConfig or None if configuration cannot be created\n        \"\"\"\n        if not model_id or not model_id.strip():\n            return None\n\n        identifier = model_id.strip()\n        is_local = is_local_path(identifier)\n        path = normalize_path(identifier) if is_local else identifier\n\n        # Add unsloth/ prefix for shorthand HF models\n        if not is_local and \"/\" not in identifier:\n            identifier = f\"unsloth/{identifier}\"\n            path = identifier\n\n        # Enforce lowercase for remote Hugging Face identifiers to prevent cache duplication\n        # Hugging Face Hub APIs are case-insensitive remotely, but case-sensitive locally (repo_folder_name).\n        if not is_local:\n            identifier = identifier.lower()\n            path = path.lower()\n\n        # Auto-detect GGUF models (check before LoRA/vision detection)\n        if is_local:\n            gguf_file = detect_gguf_model(path)\n            if gguf_file:\n                display_name = Path(gguf_file).stem\n                logger.info(f\"Detected local GGUF model: {gguf_file}\")\n\n                # Detect vision: check if base model is vision, then look for mmproj\n                mmproj_file = None\n                gguf_is_vision = False\n                gguf_dir = Path(gguf_file).parent\n\n                # Determine if this is a vision model from export metadata\n                base_is_vision = False\n                meta_path = gguf_dir / \"export_metadata.json\"\n                if meta_path.exists():\n                    try:\n                        meta = json.loads(meta_path.read_text())\n                        base = meta.get(\"base_model\")\n                        if base and is_vision_model(base, hf_token = hf_token):\n                            base_is_vision = True\n                            logger.info(f\"GGUF base model '{base}' is a vision model\")\n                    except Exception as e:\n                        logger.debug(f\"Could not read export metadata: {e}\")\n\n                # If vision (or mmproj happens to exist), find the mmproj file\n                mmproj_file = detect_mmproj_file(gguf_file)\n                if mmproj_file:\n                    gguf_is_vision = True\n                    logger.info(f\"Detected mmproj for vision: {mmproj_file}\")\n                elif base_is_vision:\n                    logger.warning(\n                        f\"Base model is vision but no mmproj file found in {gguf_dir}\"\n                    )\n\n                return cls(\n                    identifier = identifier,\n                    display_name = display_name,\n                    path = path,\n                    is_local = True,\n                    is_cached = True,\n                    is_vision = gguf_is_vision,\n                    is_lora = False,\n                    is_gguf = True,\n                    gguf_file = gguf_file,\n                    gguf_mmproj_file = mmproj_file,\n                )\n        else:\n            # Check if the HF repo contains GGUF files\n            gguf_filename = detect_gguf_model_remote(identifier, hf_token = hf_token)\n            if gguf_filename:\n                # Preflight: verify llama-server binary exists BEFORE user waits\n                # for a multi-GB download that llama-server handles natively\n                from core.inference.llama_cpp import LlamaCppBackend\n\n                if not LlamaCppBackend._find_llama_server_binary():\n                    raise RuntimeError(\n                        \"llama-server binary not found — cannot load GGUF models. \"\n                        \"Run setup.sh to build it, or set LLAMA_SERVER_PATH.\"\n                    )\n\n                # Use list_gguf_variants() to detect vision & resolve variant\n                variants, has_vision = list_gguf_variants(identifier, hf_token = hf_token)\n                variant = gguf_variant\n                if not variant:\n                    # Auto-select best quantization\n                    variant_filenames = [v.filename for v in variants]\n                    best = _pick_best_gguf(variant_filenames)\n                    if best:\n                        variant = _extract_quant_label(best)\n                    else:\n                        variant = \"Q4_K_M\"  # Fallback — llama-server's own default\n\n                display_name = f\"{identifier.split('/')[-1]} ({variant})\"\n                logger.info(\n                    f\"Detected remote GGUF repo '{identifier}', \"\n                    f\"variant={variant}, vision={has_vision}\"\n                )\n                return cls(\n                    identifier = identifier,\n                    display_name = display_name,\n                    path = identifier,\n                    is_local = False,\n                    is_cached = False,\n                    is_vision = has_vision,\n                    is_lora = False,\n                    is_gguf = True,\n                    gguf_file = None,\n                    gguf_hf_repo = identifier,\n                    gguf_variant = variant,\n                )\n\n        # Auto-detect LoRA for local paths (check adapter_config.json on disk)\n        if not is_lora and is_local:\n            detected_base = get_base_model_from_lora(path)\n            if detected_base:\n                is_lora = True\n                logger.info(\n                    f\"Auto-detected local LoRA adapter at '{path}' (base: {detected_base})\"\n                )\n\n        # Auto-detect LoRA for remote HF models (check repo file listing)\n        if not is_lora and not is_local:\n            try:\n                from huggingface_hub import model_info as hf_model_info\n\n                info = hf_model_info(identifier, token = hf_token)\n                repo_files = [s.rfilename for s in info.siblings]\n                if \"adapter_config.json\" in repo_files:\n                    is_lora = True\n                    logger.info(f\"Auto-detected remote LoRA adapter: '{identifier}'\")\n            except Exception as e:\n                logger.debug(\n                    f\"Could not check remote LoRA status for '{identifier}': {e}\"\n                )\n\n        # Handle LoRA adapters\n        base_model = None\n        if is_lora:\n            if is_local:\n                # Local LoRA: read adapter_config.json from disk\n                base_model = get_base_model_from_lora(path)\n            else:\n                # Remote LoRA: download adapter_config.json from HF\n                try:\n                    from huggingface_hub import hf_hub_download\n\n                    config_path = hf_hub_download(\n                        identifier, \"adapter_config.json\", token = hf_token\n                    )\n                    with open(config_path, \"r\") as f:\n                        adapter_config = json.load(f)\n                    base_model = adapter_config.get(\"base_model_name_or_path\")\n                    if base_model:\n                        logger.info(f\"Resolved remote LoRA base model: '{base_model}'\")\n                except Exception as e:\n                    logger.warning(\n                        f\"Could not download adapter_config.json for '{identifier}': {e}\"\n                    )\n\n            if not base_model:\n                logger.warning(f\"Could not determine base model for LoRA '{path}'\")\n                return None\n            check_model = base_model\n        else:\n            check_model = identifier\n\n        vision = is_vision_model(check_model, hf_token = hf_token)\n        audio_type_val = detect_audio_type(check_model, hf_token = hf_token)\n        has_audio_in = is_audio_input_type(audio_type_val)\n\n        display_name = Path(path).name if is_local else identifier.split(\"/\")[-1]\n\n        return cls(\n            identifier = identifier,\n            display_name = display_name,\n            path = path,\n            is_local = is_local,\n            is_cached = is_model_cached(identifier) if not is_local else True,\n            is_vision = vision,\n            is_lora = is_lora,\n            is_audio = audio_type_val is not None and audio_type_val != \"audio_vlm\",\n            audio_type = audio_type_val,\n            has_audio_input = has_audio_in,\n            base_model = base_model,\n        )\n\n    @classmethod\n    def from_ui_selection(\n        cls,\n        dropdown_value: Optional[str],\n        search_value: Optional[str],\n        local_models: list = None,\n        hf_token: Optional[str] = None,\n        is_lora: bool = False,\n    ) -> Optional[\"ModelConfig\"]:\n        \"\"\"\n        Create a universal ModelConfig from UI dropdown/search selections.\n        Handles base models and LoRA adapters.\n        \"\"\"\n        selected = None\n        if search_value and search_value.strip():\n            selected = search_value.strip()\n        elif dropdown_value:\n            selected = dropdown_value\n\n        if not selected:\n            return None\n\n        display_name = selected\n\n        #  Use the correct 'local_models' parameter to resolve display names\n        if \" (Active)\" in selected or \" (Ready)\" in selected:\n            clean_display_name = selected.replace(\" (Active)\", \"\").replace(\n                \" (Ready)\", \"\"\n            )\n            if local_models:\n                for local_display, local_path in local_models:\n                    if local_display == clean_display_name:\n                        selected = local_path\n                        break\n\n        # Clean all UI status indicators to get the final identifier\n        identifier = selected\n        for status in UI_STATUS_INDICATORS:\n            identifier = identifier.replace(status, \"\")\n        identifier = identifier.strip()\n\n        is_local = is_local_path(identifier)\n        path = normalize_path(identifier) if is_local else identifier\n\n        # Add unsloth/ prefix for shorthand HF models\n        if not is_local and \"/\" not in identifier:\n            identifier = f\"unsloth/{identifier}\"\n            path = identifier\n\n        # --- Logic for Base Model and Vision Detection ---\n        base_model = None\n        is_vision = False\n\n        if is_lora:\n            # For a LoRA, we MUST find its base model.\n            base_model = get_base_model_from_lora(path)\n            if not base_model:\n                logger.warning(\n                    f\"Could not determine base model for LoRA '{path}'. Cannot create config.\"\n                )\n                return None  # Cannot proceed without a base model\n\n            # A LoRA's vision capability is determined by its base model.\n            is_vision = is_vision_model(base_model, hf_token = hf_token)\n        else:\n            # For a base model, just check its own vision status.\n            is_vision = is_vision_model(identifier, hf_token = hf_token)\n\n        from utils.paths import is_model_cached\n\n        is_cached = is_model_cached(identifier) if not is_local else True\n\n        return cls(\n            identifier = identifier,\n            display_name = display_name,\n            path = path,\n            is_local = is_local,\n            is_cached = is_cached,\n            is_vision = is_vision,\n            is_lora = is_lora,\n            base_model = base_model,  # This will be None for base models, and populated for LoRAs\n        )\n"
  },
  {
    "path": "studio/backend/utils/paths/__init__.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nPath utilities for model and dataset handling\n\"\"\"\n\nfrom .path_utils import normalize_path, is_local_path, is_model_cached, get_cache_path\nfrom .storage_roots import (\n    studio_root,\n    assets_root,\n    datasets_root,\n    dataset_uploads_root,\n    recipe_datasets_root,\n    outputs_root,\n    exports_root,\n    auth_root,\n    auth_db_path,\n    tmp_root,\n    seed_uploads_root,\n    unstructured_seed_cache_root,\n    oxc_validator_tmp_root,\n    tensorboard_root,\n    ensure_dir,\n    ensure_studio_directories,\n    resolve_under_root,\n    resolve_output_dir,\n    resolve_export_dir,\n    resolve_tensorboard_dir,\n    resolve_dataset_path,\n)\n\n__all__ = [\n    \"normalize_path\",\n    \"is_local_path\",\n    \"is_model_cached\",\n    \"get_cache_path\",\n    \"studio_root\",\n    \"assets_root\",\n    \"datasets_root\",\n    \"dataset_uploads_root\",\n    \"recipe_datasets_root\",\n    \"outputs_root\",\n    \"exports_root\",\n    \"auth_root\",\n    \"auth_db_path\",\n    \"tmp_root\",\n    \"seed_uploads_root\",\n    \"unstructured_seed_cache_root\",\n    \"oxc_validator_tmp_root\",\n    \"tensorboard_root\",\n    \"ensure_dir\",\n    \"ensure_studio_directories\",\n    \"resolve_under_root\",\n    \"resolve_output_dir\",\n    \"resolve_export_dir\",\n    \"resolve_tensorboard_dir\",\n    \"resolve_dataset_path\",\n]\n"
  },
  {
    "path": "studio/backend/utils/paths/path_utils.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nPath utilities for model and dataset handling\n\"\"\"\n\nimport os\nfrom pathlib import Path\nfrom typing import Optional\nimport structlog\nfrom loggers import get_logger\n\nlogger = get_logger(__name__)\n\n\ndef normalize_path(path: str) -> str:\n    \"\"\"\n    Convert Windows paths to WSL format if needed.\n\n    Examples:\n        C:\\\\Users\\\\... -> /mnt/c/Users/...\n        /home/user/... -> /home/user/... (unchanged)\n    \"\"\"\n    if not path:\n        return path\n\n    # Handle Windows drive letters (C:\\\\ or c:\\\\)\n    if len(path) >= 3 and path[1] == \":\" and path[2] in (\"\\\\\", \"/\"):\n        drive = path[0].lower()\n        rest = path[3:].replace(\"\\\\\", \"/\")\n        return f\"/mnt/{drive}/{rest}\"\n\n    # Already Unix-style or relative\n    return path.replace(\"\\\\\", \"/\")\n\n\ndef is_local_path(path: str) -> bool:\n    \"\"\"\n    Check if path is a local filesystem path vs HuggingFace model identifier.\n\n    Examples:\n        True: /home/user/model, C:\\\\models, ./model, ~/model\n        False: unsloth/llama-3.1-8b, microsoft/phi-2\n    \"\"\"\n    if not path:\n        return False\n\n    # If it exists on disk, treat as local (covers relative paths like \"outputs/foo\").\n    try:\n        if Path(normalize_path(path)).expanduser().exists():\n            return True\n    except Exception:\n        pass\n\n    # Obvious HF patterns\n    if path.count(\"/\") == 1 and not path.startswith((\"/\", \".\", \"~\")):\n        return False  # Looks like org/model format\n\n    # Filesystem indicators\n    return (\n        path.startswith((\"/\", \".\", \"~\"))  # Unix absolute/relative\n        or \":\" in path  # Windows drive or URL\n        or \"\\\\\" in path  # Windows separator\n        or os.path.isabs(path)  # System-absolute\n    )\n\n\ndef get_cache_path(model_name: str) -> Optional[Path]:\n    \"\"\"Get HuggingFace cache path for a model if it exists.\"\"\"\n    cache_dir = Path.home() / \".cache\" / \"huggingface\" / \"hub\"\n    model_cache_name = model_name.replace(\"/\", \"--\")\n    model_cache_path = cache_dir / f\"models--{model_cache_name}\"\n\n    return model_cache_path if model_cache_path.exists() else None\n\n\ndef is_model_cached(model_name: str) -> bool:\n    \"\"\"Check if model is downloaded in HuggingFace cache.\"\"\"\n    cache_path = get_cache_path(model_name)\n    if not cache_path:\n        return False\n\n    # Check for actual model files\n    for suffix in [\".safetensors\", \".bin\", \".json\"]:\n        if list(cache_path.rglob(f\"*{suffix}\")):\n            return True\n\n    return False\n"
  },
  {
    "path": "studio/backend/utils/paths/storage_roots.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nfrom __future__ import annotations\n\nimport os\nfrom pathlib import Path\nimport tempfile\n\n\ndef studio_root() -> Path:\n    return Path.home() / \".unsloth\" / \"studio\"\n\n\ndef cache_root() -> Path:\n    \"\"\"Central cache directory for all studio downloads (models, datasets, etc.).\"\"\"\n    return Path.home() / \".unsloth\" / \"studio\" / \"cache\"\n\n\ndef assets_root() -> Path:\n    return studio_root() / \"assets\"\n\n\ndef datasets_root() -> Path:\n    return assets_root() / \"datasets\"\n\n\ndef dataset_uploads_root() -> Path:\n    return datasets_root() / \"uploads\"\n\n\ndef recipe_datasets_root() -> Path:\n    return datasets_root() / \"recipes\"\n\n\ndef outputs_root() -> Path:\n    return studio_root() / \"outputs\"\n\n\ndef exports_root() -> Path:\n    return studio_root() / \"exports\"\n\n\ndef auth_root() -> Path:\n    return studio_root() / \"auth\"\n\n\ndef auth_db_path() -> Path:\n    return auth_root() / \"auth.db\"\n\n\ndef tmp_root() -> Path:\n    return Path(tempfile.gettempdir()) / \"unsloth-studio\"\n\n\ndef seed_uploads_root() -> Path:\n    return tmp_root() / \"seed-uploads\"\n\n\ndef unstructured_seed_cache_root() -> Path:\n    return tmp_root() / \"unstructured-seed-cache\"\n\n\ndef oxc_validator_tmp_root() -> Path:\n    return tmp_root() / \"oxc-validator\"\n\n\ndef tensorboard_root() -> Path:\n    return studio_root() / \"runs\"\n\n\ndef ensure_dir(path: Path) -> Path:\n    path.mkdir(parents = True, exist_ok = True)\n    return path\n\n\ndef _setup_cache_env() -> None:\n    \"\"\"Set cache environment variables for HuggingFace, uv, and vLLM.\n\n    Only sets variables that are not already set by the user, so\n    explicit overrides (e.g. HF_HOME=/data/hf) are respected.\n    Works on Linux, macOS, and Windows.\n    \"\"\"\n    root = cache_root()\n    hf_dir = root / \"huggingface\"\n    defaults = {\n        \"HF_HOME\": str(hf_dir),\n        \"HF_HUB_CACHE\": str(hf_dir / \"hub\"),\n        \"HF_XET_CACHE\": str(hf_dir / \"xet\"),\n        \"UV_CACHE_DIR\": str(root / \"uv\"),\n        \"VLLM_CACHE_ROOT\": str(root / \"vllm\"),\n    }\n    for key, value in defaults.items():\n        if key not in os.environ:\n            os.environ[key] = value\n            Path(value).mkdir(parents = True, exist_ok = True)\n\n\ndef ensure_studio_directories() -> None:\n    \"\"\"Create all standard studio directories on startup.\"\"\"\n    for dir_fn in (\n        studio_root,\n        assets_root,\n        datasets_root,\n        dataset_uploads_root,\n        recipe_datasets_root,\n        outputs_root,\n        exports_root,\n        auth_root,\n        tensorboard_root,\n    ):\n        ensure_dir(dir_fn())\n    _setup_cache_env()\n\n\ndef _clean_relative_path(\n    path_value: str, *, strip_prefixes: tuple[str, ...] = ()\n) -> Path:\n    path = Path(path_value).expanduser()\n    parts = [part for part in path.parts if part not in (\"\", \".\")]\n    while parts and parts[0] in strip_prefixes:\n        parts = parts[1:]\n    return Path(*parts) if parts else Path()\n\n\ndef resolve_under_root(\n    path_value: str | None,\n    *,\n    root: Path,\n    strip_prefixes: tuple[str, ...] = (),\n) -> Path:\n    if not path_value or not str(path_value).strip():\n        return root\n\n    path = Path(str(path_value).strip()).expanduser()\n    if path.is_absolute():\n        return path\n\n    cleaned = _clean_relative_path(str(path), strip_prefixes = strip_prefixes)\n    return root / cleaned\n\n\ndef resolve_output_dir(path_value: str | None = None) -> Path:\n    return resolve_under_root(\n        path_value,\n        root = outputs_root(),\n        strip_prefixes = (\"outputs\",),\n    )\n\n\ndef resolve_export_dir(path_value: str | None = None) -> Path:\n    return resolve_under_root(\n        path_value,\n        root = exports_root(),\n        strip_prefixes = (\"exports\",),\n    )\n\n\ndef resolve_tensorboard_dir(path_value: str | None = None) -> Path:\n    return resolve_under_root(\n        path_value,\n        root = tensorboard_root(),\n        strip_prefixes = (\"runs\", \"tensorboard\"),\n    )\n\n\ndef resolve_dataset_path(path_value: str) -> Path:\n    path = Path(path_value).expanduser()\n    if path.is_absolute():\n        return path\n\n    parts = [part for part in Path(path_value).parts if part not in (\"\", \".\")]\n    if parts[:2] == [\"assets\", \"datasets\"]:\n        parts = parts[2:]\n    if parts and parts[0] == \"uploads\":\n        cleaned = Path(*parts[1:]) if len(parts) > 1 else Path()\n        return dataset_uploads_root() / cleaned\n    if parts and parts[0] == \"recipes\":\n        cleaned = Path(*parts[1:]) if len(parts) > 1 else Path()\n        return recipe_datasets_root() / cleaned\n\n    cleaned = Path(*parts) if parts else Path()\n    candidates = [\n        dataset_uploads_root() / cleaned,\n        recipe_datasets_root() / cleaned,\n        datasets_root() / cleaned,\n        dataset_uploads_root() / cleaned.name,\n        recipe_datasets_root() / cleaned.name,\n    ]\n    for candidate in candidates:\n        if candidate.exists():\n            return candidate\n    return candidates[0]\n"
  },
  {
    "path": "studio/backend/utils/transformers_version.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nAutomatic transformers version switching.\n\nSome newer model architectures (Ministral-3, GLM-4.7-Flash, Qwen3-30B-A3B MoE,\ntiny_qwen3_moe) require transformers>=5.3.0, while everything else needs the\ndefault 4.57.x that ships with Unsloth.\n\nWhen loading a LoRA adapter with a custom name, we resolve the base model from\n``adapter_config.json`` and check *that* against the model list.\n\nStrategy:\n  Training and inference run in subprocesses that activate the correct version\n  via sys.path (prepending .venv_t5/ for 5.x models). See:\n    - core/training/worker.py\n    - core/inference/worker.py\n\n  For export (still in-process), ensure_transformers_version() does a lightweight\n  sys.path swap using the same .venv_t5/ directory pre-installed by setup.sh.\n\"\"\"\n\nimport importlib\nimport json\nimport structlog\nfrom loggers import get_logger\nimport os\nimport shutil\nimport subprocess\nimport sys\nfrom pathlib import Path\n\nlogger = get_logger(__name__)\n\n\n# ---------------------------------------------------------------------------\n# Detection\n# ---------------------------------------------------------------------------\n\n# Lowercase substrings — if ANY appears anywhere in the lowered model name,\n# we need transformers 5.x.\nTRANSFORMERS_5_MODEL_SUBSTRINGS: tuple[str, ...] = (\n    \"ministral-3-\",  # Ministral-3-{3,8,14}B-{Instruct,Reasoning,Base}-2512\n    \"glm-4.7-flash\",  # GLM-4.7-Flash\n    \"qwen3-30b-a3b\",  # Qwen3-30B-A3B-Instruct-2507 and variants\n    \"qwen3.5\",  # Qwen3.5 family (35B-A3B, etc.)\n    \"qwen3-next\",  # Qwen3-Next and variants\n    \"tiny_qwen3_moe\",  # imdatta0/tiny_qwen3_moe_2.8B_0.7B\n)\n\n# Tokenizer classes that only exist in transformers>=5.x\n_TRANSFORMERS_5_TOKENIZER_CLASSES: set[str] = {\n    \"TokenizersBackend\",\n}\n\n# Cache for dynamic tokenizer_config.json lookups to avoid repeated fetches\n_tokenizer_class_cache: dict[str, bool] = {}\n\n# Versions\nTRANSFORMERS_5_VERSION = \"5.3.0\"\nTRANSFORMERS_DEFAULT_VERSION = \"4.57.6\"\n\n# Pre-installed directory for transformers 5.x — created by setup.sh / setup.ps1\n_VENV_T5_DIR = str(Path.home() / \".unsloth\" / \"studio\" / \".venv_t5\")\n\n\ndef _resolve_base_model(model_name: str) -> str:\n    \"\"\"If *model_name* points to a LoRA adapter, return its base model.\n\n    Checks for ``adapter_config.json`` locally first.  Only calls the heavier\n    ``get_base_model_from_lora`` for paths that are actual local directories\n    (avoids noisy warnings for plain HF model IDs).\n\n    Returns the original *model_name* unchanged if it is not a LoRA adapter.\n    \"\"\"\n    # --- Fast local check ---------------------------------------------------\n    local_path = Path(model_name)\n    adapter_cfg_path = local_path / \"adapter_config.json\"\n    if adapter_cfg_path.is_file():\n        try:\n            with open(adapter_cfg_path) as f:\n                cfg = json.load(f)\n            base = cfg.get(\"base_model_name_or_path\")\n            if base:\n                logger.info(\n                    \"Resolved LoRA adapter '%s' → base model '%s'\",\n                    model_name,\n                    base,\n                )\n                return base\n        except Exception as exc:\n            logger.debug(\"Could not read %s: %s\", adapter_cfg_path, exc)\n\n    # --- Only try the heavier fallback for local directories ----------------\n    if local_path.is_dir():\n        try:\n            from utils.models import get_base_model_from_lora\n\n            base = get_base_model_from_lora(model_name)\n            if base:\n                logger.info(\n                    \"Resolved LoRA adapter '%s' → base model '%s' \"\n                    \"(via get_base_model_from_lora)\",\n                    model_name,\n                    base,\n                )\n                return base\n        except Exception as exc:\n            logger.debug(\n                \"get_base_model_from_lora failed for '%s': %s\",\n                model_name,\n                exc,\n            )\n\n    return model_name\n\n\ndef _check_tokenizer_config_needs_v5(model_name: str) -> bool:\n    \"\"\"Fetch tokenizer_config.json from HuggingFace and check if the\n    tokenizer_class requires transformers 5.x.\n\n    Results are cached in ``_tokenizer_class_cache`` to avoid repeated fetches.\n    Returns False on any network/parse error (fail-open to default version).\n    \"\"\"\n    if model_name in _tokenizer_class_cache:\n        return _tokenizer_class_cache[model_name]\n\n    import urllib.request\n\n    url = f\"https://huggingface.co/{model_name}/raw/main/tokenizer_config.json\"\n    try:\n        req = urllib.request.Request(url, headers = {\"User-Agent\": \"unsloth-studio\"})\n        with urllib.request.urlopen(req, timeout = 10) as resp:\n            data = json.loads(resp.read().decode())\n        tokenizer_class = data.get(\"tokenizer_class\", \"\")\n        result = tokenizer_class in _TRANSFORMERS_5_TOKENIZER_CLASSES\n        if result:\n            logger.info(\n                \"Dynamic check: %s uses tokenizer_class=%s (requires transformers 5.x)\",\n                model_name,\n                tokenizer_class,\n            )\n        _tokenizer_class_cache[model_name] = result\n        return result\n    except Exception as exc:\n        logger.debug(\n            \"Could not fetch tokenizer_config.json for '%s': %s\", model_name, exc\n        )\n        _tokenizer_class_cache[model_name] = False\n        return False\n\n\ndef needs_transformers_5(model_name: str) -> bool:\n    \"\"\"Return True if *model_name* belongs to an architecture that requires\n    ``transformers>=5.3.0``.\n\n    First checks the hardcoded substring list for known models, then\n    dynamically fetches ``tokenizer_config.json`` from HuggingFace to check\n    if the tokenizer_class (e.g. ``TokenizersBackend``) requires v5.\n    \"\"\"\n    lowered = model_name.lower()\n    if any(sub in lowered for sub in TRANSFORMERS_5_MODEL_SUBSTRINGS):\n        return True\n    return _check_tokenizer_config_needs_v5(model_name)\n\n\n# ---------------------------------------------------------------------------\n# Version switching (in-process — used only by export)\n# ---------------------------------------------------------------------------\n\n\ndef _get_in_memory_version() -> str | None:\n    \"\"\"Return the transformers version currently loaded in this process.\"\"\"\n    tf = sys.modules.get(\"transformers\")\n    if tf is not None:\n        return getattr(tf, \"__version__\", None)\n    return None\n\n\n# All top-level prefixes that hold references to transformers internals.\n_PURGE_PREFIXES = (\n    \"transformers\",\n    \"huggingface_hub\",\n    \"unsloth\",\n    \"unsloth_zoo\",\n    \"peft\",\n    \"trl\",\n    \"accelerate\",\n    \"auto_gptq\",\n    # NOTE: bitsandbytes is intentionally EXCLUDED — it registers torch custom\n    # operators at import time via torch.library.define(). Those registrations\n    # live in torch's global operator registry which survives module purge.\n    # Re-importing bitsandbytes after purge → duplicate registration → crash.\n    # Our own modules that import from transformers at module level\n    # (e.g. model_config.py: `from transformers import AutoConfig`)\n    \"utils.models\",\n    \"core.training\",\n    \"core.inference\",\n    \"core.export\",\n)\n\n\ndef _purge_modules() -> int:\n    \"\"\"Remove all cached modules for transformers and its dependents.\n\n    Returns the number of modules purged.\n    \"\"\"\n    importlib.invalidate_caches()\n    to_remove = [\n        k\n        for k in list(sys.modules.keys())\n        if any(k == p or k.startswith(p + \".\") for p in _PURGE_PREFIXES)\n    ]\n    for key in to_remove:\n        del sys.modules[key]\n    return len(to_remove)\n\n\n_VENV_T5_PACKAGES = (\n    f\"transformers=={TRANSFORMERS_5_VERSION}\",\n    \"huggingface_hub==1.7.1\",\n    \"hf_xet==1.4.2\",\n    \"tiktoken\",\n)\n\n\ndef _venv_t5_is_valid() -> bool:\n    \"\"\"Return True if .venv_t5/ has all required packages at the correct versions.\"\"\"\n    if not os.path.isdir(_VENV_T5_DIR) or not os.listdir(_VENV_T5_DIR):\n        return False\n    # Check that the key package directories exist AND match the required version\n    for pkg_spec in _VENV_T5_PACKAGES:\n        parts = pkg_spec.split(\"==\")\n        pkg_name = parts[0]\n        pkg_version = parts[1] if len(parts) > 1 else None\n        pkg_name_norm = pkg_name.replace(\"-\", \"_\")\n        # Check directory exists\n        if not any(\n            (Path(_VENV_T5_DIR) / d).is_dir()\n            for d in (pkg_name_norm, pkg_name_norm.replace(\"_\", \"-\"))\n        ):\n            return False\n        # For unpinned packages, existence is enough\n        if pkg_version is None:\n            continue\n        # Check version via .dist-info metadata\n        dist_info_found = False\n        for di in Path(_VENV_T5_DIR).glob(f\"{pkg_name_norm}-*.dist-info\"):\n            metadata = di / \"METADATA\"\n            if not metadata.is_file():\n                continue\n            for line in metadata.read_text(errors = \"replace\").splitlines():\n                if line.startswith(\"Version:\"):\n                    installed_ver = line.split(\":\", 1)[1].strip()\n                    if installed_ver != pkg_version:\n                        logger.info(\n                            \".venv_t5 has %s==%s but need %s\",\n                            pkg_name,\n                            installed_ver,\n                            pkg_version,\n                        )\n                        return False\n                    dist_info_found = True\n                    break\n            if dist_info_found:\n                break\n        if not dist_info_found:\n            return False\n    return True\n\n\ndef _install_to_venv_t5(pkg: str) -> bool:\n    \"\"\"Install a single package into .venv_t5/, preferring uv then pip.\"\"\"\n    # Try uv first (faster) if already on PATH -- do NOT install uv at runtime\n    if shutil.which(\"uv\"):\n        result = subprocess.run(\n            [\n                \"uv\",\n                \"pip\",\n                \"install\",\n                \"--python\",\n                sys.executable,\n                \"--target\",\n                _VENV_T5_DIR,\n                \"--no-deps\",\n                \"--upgrade\",\n                pkg,\n            ],\n            stdout = subprocess.PIPE,\n            stderr = subprocess.STDOUT,\n            text = True,\n        )\n        if result.returncode == 0:\n            return True\n        logger.warning(\"uv install of %s failed, falling back to pip\", pkg)\n\n    # Fallback to pip\n    result = subprocess.run(\n        [\n            sys.executable,\n            \"-m\",\n            \"pip\",\n            \"install\",\n            \"--target\",\n            _VENV_T5_DIR,\n            \"--no-deps\",\n            \"--upgrade\",\n            pkg,\n        ],\n        stdout = subprocess.PIPE,\n        stderr = subprocess.STDOUT,\n        text = True,\n    )\n    if result.returncode != 0:\n        logger.error(\"install failed:\\n%s\", result.stdout)\n        return False\n    return True\n\n\ndef _ensure_venv_t5_exists() -> bool:\n    \"\"\"Ensure .venv_t5/ exists with all required packages. Install if missing.\"\"\"\n    if _venv_t5_is_valid():\n        return True\n\n    logger.warning(\n        \".venv_t5 not found or incomplete at %s -- installing at runtime\", _VENV_T5_DIR\n    )\n    shutil.rmtree(_VENV_T5_DIR, ignore_errors = True)\n    os.makedirs(_VENV_T5_DIR, exist_ok = True)\n    for pkg in _VENV_T5_PACKAGES:\n        if not _install_to_venv_t5(pkg):\n            return False\n    logger.info(\"Installed transformers 5.x to %s\", _VENV_T5_DIR)\n    return True\n\n\ndef _activate_5x() -> None:\n    \"\"\"Prepend .venv_t5/ to sys.path, purge stale modules, reimport.\"\"\"\n    if not _ensure_venv_t5_exists():\n        raise RuntimeError(\n            f\"Cannot activate transformers 5.x: .venv_t5 missing at {_VENV_T5_DIR}\"\n        )\n\n    if _VENV_T5_DIR not in sys.path:\n        sys.path.insert(0, _VENV_T5_DIR)\n        logger.info(\"Prepended %s to sys.path\", _VENV_T5_DIR)\n\n    count = _purge_modules()\n    logger.info(\"Purged %d cached modules\", count)\n\n    import transformers\n\n    logger.info(\"Loaded transformers %s\", transformers.__version__)\n\n\ndef _deactivate_5x() -> None:\n    \"\"\"Remove .venv_t5/ from sys.path, purge stale modules, reimport.\"\"\"\n    while _VENV_T5_DIR in sys.path:\n        sys.path.remove(_VENV_T5_DIR)\n    logger.info(\"Removed %s from sys.path\", _VENV_T5_DIR)\n\n    count = _purge_modules()\n    logger.info(\"Purged %d cached modules\", count)\n\n    import transformers\n\n    logger.info(\"Reverted to transformers %s\", transformers.__version__)\n\n\ndef ensure_transformers_version(model_name: str) -> None:\n    \"\"\"Ensure the correct ``transformers`` version is active for *model_name*.\n\n    Uses sys.path with .venv_t5/ (pre-installed by setup.sh):\n      • Need 5.x → prepend .venv_t5/ to sys.path, purge modules.\n      • Need 4.x → remove .venv_t5/ from sys.path, purge modules.\n\n    For LoRA adapters with custom names, the base model is resolved from\n    ``adapter_config.json`` before checking.\n\n    NOTE: Training and inference use subprocess isolation instead of this\n    function. This is only used by the export path (routes/export.py).\n    \"\"\"\n    # Resolve LoRA adapters to their base model for accurate detection\n    resolved = _resolve_base_model(model_name)\n    want_5 = needs_transformers_5(resolved)\n    target_version = TRANSFORMERS_5_VERSION if want_5 else TRANSFORMERS_DEFAULT_VERSION\n    target_major = int(target_version.split(\".\")[0])\n\n    # Check what's actually loaded in memory\n    in_memory = _get_in_memory_version()\n\n    logger.info(\n        \"Version check for '%s' (resolved: '%s'): need=%s, in_memory=%s\",\n        model_name,\n        resolved,\n        target_version,\n        in_memory,\n    )\n\n    # --- Already correct? ---------------------------------------------------\n    if in_memory is not None:\n        in_memory_major = int(in_memory.split(\".\")[0])\n        if in_memory_major == target_major:\n            logger.info(\n                \"transformers %s already loaded — correct for '%s'\",\n                in_memory,\n                model_name,\n            )\n            return\n\n    # --- Switch version -----------------------------------------------------\n    if want_5:\n        logger.info(\"Activating transformers %s via .venv_t5…\", TRANSFORMERS_5_VERSION)\n        _activate_5x()\n    else:\n        logger.info(\n            \"Reverting to default transformers %s…\", TRANSFORMERS_DEFAULT_VERSION\n        )\n        _deactivate_5x()\n\n    final = _get_in_memory_version()\n    logger.info(\"✓ transformers version is now %s\", final)\n"
  },
  {
    "path": "studio/backend/utils/utils.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"\nShared backend utilities\n\"\"\"\n\nimport os\nimport structlog\nfrom loggers import get_logger\nfrom contextlib import contextmanager\nfrom pathlib import Path\nimport shutil\nimport tempfile\n\n\nlogger = get_logger(__name__)\n\n\n@contextmanager\ndef without_hf_auth():\n    \"\"\"\n    Context manager to temporarily disable HuggingFace authentication.\n\n    Usage:\n        with without_hf_auth():\n            # Code that should run without cached tokens\n            model_info(model_name, token=None)\n    \"\"\"\n    # Save environment variables\n    saved_env = {}\n    env_vars = [\"HF_TOKEN\", \"HUGGINGFACE_HUB_TOKEN\", \"HF_HOME\"]\n    for var in env_vars:\n        if var in os.environ:\n            saved_env[var] = os.environ[var]\n            del os.environ[var]\n\n    # Save disable flag\n    saved_disable = os.environ.get(\"HF_HUB_DISABLE_IMPLICIT_TOKEN\")\n    os.environ[\"HF_HUB_DISABLE_IMPLICIT_TOKEN\"] = \"1\"\n\n    # Move token files temporarily\n    token_files = []\n    token_locations = [\n        Path.home() / \".cache\" / \"huggingface\" / \"token\",\n        Path.home() / \".huggingface\" / \"token\",\n    ]\n\n    for token_loc in token_locations:\n        if token_loc.exists():\n            temp = tempfile.NamedTemporaryFile(delete = False)\n            temp.close()\n            shutil.move(str(token_loc), temp.name)\n            token_files.append((token_loc, temp.name))\n\n    try:\n        yield\n    finally:\n        # Restore tokens\n        for original, temp in token_files:\n            try:\n                original.parent.mkdir(parents = True, exist_ok = True)\n                shutil.move(temp, str(original))\n            except Exception as e:\n                logger.error(f\"Failed to restore token {original}: {e}\")\n\n        # Restore environment\n        for var, value in saved_env.items():\n            os.environ[var] = value\n\n        if saved_disable is not None:\n            os.environ[\"HF_HUB_DISABLE_IMPLICIT_TOKEN\"] = saved_disable\n        else:\n            os.environ.pop(\"HF_HUB_DISABLE_IMPLICIT_TOKEN\", None)\n\n\ndef format_error_message(error: Exception, model_name: str) -> str:\n    \"\"\"\n    Format user-friendly error messages for common issues.\n\n    Args:\n        error: The exception that occurred\n        model_name: Name of the model being loaded\n\n    Returns:\n        User-friendly error string\n    \"\"\"\n    error_str = str(error).lower()\n    model_short = model_name.split(\"/\")[-1] if \"/\" in model_name else model_name\n\n    if \"repository not found\" in error_str or \"404\" in error_str:\n        return f\"Model '{model_short}' not found. Check the model name.\"\n\n    if \"401\" in error_str or \"unauthorized\" in error_str:\n        return f\"Authentication failed for '{model_short}'. Please provide a valid HF token.\"\n\n    if \"gated\" in error_str or \"access to model\" in error_str:\n        return f\"Model '{model_short}' requires authentication. Please provide a valid HF token.\"\n\n    if \"invalid user token\" in error_str:\n        return \"Invalid HF token. Please check your token and try again.\"\n\n    if (\n        \"memory\" in error_str\n        or \"cuda\" in error_str\n        or \"mlx\" in error_str\n        or \"out of memory\" in error_str\n    ):\n        from utils.hardware import get_device\n\n        device = get_device()\n        device_label = {\"cuda\": \"GPU\", \"mlx\": \"Apple Silicon GPU\", \"cpu\": \"system\"}.get(\n            device.value, \"GPU\"\n        )\n        return f\"Not enough {device_label} memory to load '{model_short}'. Try a smaller model or free memory.\"\n\n    # Generic fallback\n    return str(error)\n"
  },
  {
    "path": "studio/frontend/.gitignore",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n# Logs\nlogs\n*.log\nnpm-debug.log*\nyarn-debug.log*\nyarn-error.log*\npnpm-debug.log*\nlerna-debug.log*\n\nnode_modules\ndist\ndist-ssr\ntest/\n*.local\n.env\n.env.*\n.omx/\n\n# Editor directories and files\n.vscode/*\n!.vscode/extensions.json\n.idea\n.DS_Store\n._*\n*.suo\n*.ntvs*\n*.njsproj\n*.sln\n*.sw?\n/src/features/recipe-studio/AGENTS.md\n/docs\n"
  },
  {
    "path": "studio/frontend/.gitkeep",
    "content": ""
  },
  {
    "path": "studio/frontend/biome.json",
    "content": "{\n  \"$schema\": \"https://biomejs.dev/schemas/1.9.4/schema.json\",\n  \"files\": {\n    \"ignore\": [\n      \"dist\",\n      \"node_modules\",\n      \"test\",\n      \"test/**\",\n      \"**/._*\",\n      \"._*\",\n      \"**/.DS_Store\",\n      \"tsconfig*.json\"\n    ]\n  },\n  \"formatter\": {\n    \"enabled\": true,\n    \"indentStyle\": \"space\",\n    \"indentWidth\": 2\n  },\n  \"organizeImports\": {\n    \"enabled\": true\n  },\n  \"linter\": {\n    \"enabled\": true,\n    \"rules\": {\n      \"recommended\": true,\n      \"a11y\": { \"all\": true },\n      \"complexity\": { \"all\": true },\n      \"correctness\": { \"all\": true, \"useImportExtensions\": \"off\" },\n      \"performance\": { \"all\": true },\n      \"security\": { \"all\": true },\n      \"style\": {\n        \"all\": true,\n        \"useNamingConvention\": { \"options\": { \"strictCase\": false } }\n      },\n      \"suspicious\": { \"all\": true, \"noReactSpecificProps\": \"off\" }\n    }\n  },\n  \"overrides\": [\n    {\n      \"include\": [\"vite.config.ts\", \"eslint.config.js\"],\n      \"linter\": {\n        \"rules\": {\n          \"correctness\": { \"noNodejsModules\": \"off\" },\n          \"style\": { \"noDefaultExport\": \"off\" }\n        }\n      }\n    },\n    {\n      \"include\": [\"src/components/assistant-ui/reasoning.tsx\"],\n      \"linter\": {\n        \"rules\": {\n          \"style\": { \"useNamingConvention\": \"off\" }\n        }\n      }\n    },\n    {\n      \"include\": [\"src/components/assistant-ui/attachment.tsx\"],\n      \"linter\": {\n        \"rules\": {\n          \"style\": { \"useNamingConvention\": \"off\" }\n        }\n      }\n    },\n    {\n      \"include\": [\"src/components/assistant-ui/tool-fallback.tsx\"],\n      \"linter\": {\n        \"rules\": {\n          \"style\": { \"useNamingConvention\": \"off\" }\n        }\n      }\n    },\n    {\n      \"include\": [\"src/components/component-example.tsx\"],\n      \"linter\": {\n        \"rules\": {\n          \"style\": { \"noNamespaceImport\": \"off\" }\n        }\n      }\n    },\n    {\n      \"include\": [\"src/config/env.ts\"],\n      \"linter\": {\n        \"rules\": {\n          \"style\": { \"useNamingConvention\": \"off\" }\n        }\n      }\n    },\n    {\n      \"include\": [\"src/components/layout/index.ts\"],\n      \"linter\": {\n        \"rules\": {\n          \"performance\": { \"noBarrelFile\": \"off\" }\n        }\n      }\n    },\n    {\n      \"include\": [\"src/features/**/index.ts\"],\n      \"linter\": {\n        \"rules\": {\n          \"performance\": { \"noBarrelFile\": \"off\" }\n        }\n      }\n    },\n    {\n      \"include\": [\"src/features/chat/thread-sidebar.tsx\"],\n      \"linter\": {\n        \"rules\": {\n          \"a11y\": { \"useSemanticElements\": \"off\" }\n        }\n      }\n    },\n    {\n      \"include\": [\"src/features/chat/runtime-provider.tsx\"],\n      \"linter\": {\n        \"rules\": {\n          \"style\": { \"useNamingConvention\": \"off\" }\n        }\n      }\n    },\n    {\n      \"include\": [\"src/components/assistant-ui/thread.tsx\"],\n      \"linter\": {\n        \"rules\": {\n          \"style\": { \"useNamingConvention\": \"off\" }\n        }\n      }\n    },\n    {\n      \"include\": [\"src/features/onboarding/components/steps/summary-step.tsx\"],\n      \"linter\": {\n        \"rules\": {\n          \"style\": { \"useExplicitLengthCheck\": \"off\" }\n        }\n      }\n    },\n    {\n      \"include\": [\"src/components/ui/**\"],\n      \"linter\": {\n        \"enabled\": false\n      },\n      \"formatter\": {\n        \"enabled\": false\n      },\n      \"organizeImports\": {\n        \"enabled\": false\n      }\n    }\n  ]\n}\n"
  },
  {
    "path": "studio/frontend/components.json",
    "content": "{\n  \"$schema\": \"https://ui.shadcn.com/schema.json\",\n  \"style\": \"radix-maia\",\n  \"rsc\": false,\n  \"tsx\": true,\n  \"tailwind\": {\n    \"config\": \"\",\n    \"css\": \"src/index.css\",\n    \"baseColor\": \"neutral\",\n    \"cssVariables\": true,\n    \"prefix\": \"\"\n  },\n  \"iconLibrary\": \"hugeicons\",\n  \"menuColor\": \"default\",\n  \"menuAccent\": \"subtle\",\n  \"aliases\": {\n    \"components\": \"@/components\",\n    \"utils\": \"@/lib/utils\",\n    \"ui\": \"@/components/ui\",\n    \"lib\": \"@/lib\",\n    \"hooks\": \"@/hooks\"\n  },\n  \"registries\": {\n    \"@magicui\": \"https://magicui.design/r/{name}\"\n  }\n}\n"
  },
  {
    "path": "studio/frontend/data-designer.openapi (1).yaml",
    "content": "openapi: 3.1.0\ninfo:\n  title: NeMo Data Designer Microservice\n  description: Service for generating synthetic data.\n  version: 1.5.0\npaths:\n  /v1/data-designer/jobs:\n    post:\n      tags:\n      - Data Designer\n      summary: Create Job\n      operationId: create_job_v1_data_designer_jobs_post\n      requestBody:\n        required: true\n        content:\n          application/json:\n            schema:\n              $ref: '#/components/schemas/DataDesignerJobRequest'\n      responses:\n        '200':\n          description: Successful Response\n          content:\n            application/json:\n              schema:\n                $ref: '#/components/schemas/DataDesignerJob'\n        '422':\n          description: Validation Error\n          content:\n            application/json:\n              schema:\n                $ref: '#/components/schemas/HTTPValidationError'\n    get:\n      tags:\n      - Data Designer\n      summary: List Jobs\n      operationId: list_jobs_v1_data_designer_jobs_get\n      parameters:\n      - name: page\n        in: query\n        required: false\n        schema:\n          type: integer\n          exclusiveMinimum: 0\n          description: Page number.\n          default: 1\n          title: Page\n        description: Page number.\n      - name: page_size\n        in: query\n        required: false\n        schema:\n          type: integer\n          exclusiveMinimum: 0\n          description: Page size.\n          default: 10\n          title: Page Size\n        description: Page size.\n      - name: sort\n        in: query\n        required: false\n        schema:\n          allOf:\n          - $ref: '#/components/schemas/DataDesignerJobsSortField'\n          description: The field to sort by. To sort in decreasing order, use `-`\n            in front of the field name.\n          default: -created_at\n        description: The field to sort by. To sort in decreasing order, use `-` in\n          front of the field name.\n      - in: query\n        name: filter\n        style: deepObject\n        required: false\n        explode: true\n        schema:\n          $ref: '#/components/schemas/DataDesignerJobsListFilter'\n        description: Filter jobs on various criteria.\n      - in: query\n        name: search\n        style: deepObject\n        required: false\n        explode: true\n        schema:\n          $ref: '#/components/schemas/DataDesignerJobsSearch'\n        description: \"\\nSearch jobs using substring matching.\\nYou can combine multiple\\\n          \\ search fields and filters.\\n\\nFor example:\\n- `?search[name]=training`:\\\n          \\ searches all jobs with 'training' in the name.\\n- `?search[project]=my-project`:\\\n          \\ searches all jobs with 'my-project'\\n  in the project field.\\n- `?search[name]=training&search[name]=eval`:\\\n          \\ searches all jobs with\\n  'training' OR 'eval' in the name.\\n- `?search[name]=training&search[project]=my-project`:\\\n          \\ searches all\\n  jobs with 'training' in the name AND 'my-project' in the\\\n          \\ project.\\n\"\n      responses:\n        '200':\n          description: Successful Response\n          content:\n            application/json:\n              schema:\n                $ref: '#/components/schemas/DataDesignerJobsPage'\n        '422':\n          description: Validation Error\n          content:\n            application/json:\n              schema:\n                $ref: '#/components/schemas/HTTPValidationError'\n  /v1/data-designer/jobs/{job_id}:\n    get:\n      tags:\n      - Data Designer\n      summary: Get Job\n      operationId: get_job_v1_data_designer_jobs__job_id__get\n      parameters:\n      - name: job_id\n        in: path\n        required: true\n        schema:\n          type: string\n          title: Job Id\n      responses:\n        '200':\n          description: Successful Response\n          content:\n            application/json:\n              schema:\n                $ref: '#/components/schemas/DataDesignerJob'\n        '422':\n          description: Validation Error\n          content:\n            application/json:\n              schema:\n                $ref: '#/components/schemas/HTTPValidationError'\n    delete:\n      tags:\n      - Data Designer\n      summary: Delete Job\n      operationId: delete_job_v1_data_designer_jobs__job_id__delete\n      parameters:\n      - name: job_id\n        in: path\n        required: true\n        schema:\n          type: string\n          title: Job Id\n      responses:\n        '200':\n          description: Successful Response\n          content:\n            application/json:\n              schema: {}\n        '422':\n          description: Validation Error\n          content:\n            application/json:\n              schema:\n                $ref: '#/components/schemas/HTTPValidationError'\n  /v1/data-designer/jobs/{job_id}/cancel:\n    post:\n      tags:\n      - Data Designer\n      summary: Cancel Job\n      operationId: cancel_job_v1_data_designer_jobs__job_id__cancel_post\n      parameters:\n      - name: job_id\n        in: path\n        required: true\n        schema:\n          type: string\n          title: Job Id\n      responses:\n        '200':\n          description: Successful Response\n          content:\n            application/json:\n              schema:\n                $ref: '#/components/schemas/DataDesignerJob'\n        '422':\n          description: Validation Error\n          content:\n            application/json:\n              schema:\n                $ref: '#/components/schemas/HTTPValidationError'\n  /v1/data-designer/jobs/{job_id}/logs:\n    get:\n      tags:\n      - Data Designer\n      summary: Get Job Logs\n      operationId: get_job_logs_v1_data_designer_jobs__job_id__logs_get\n      parameters:\n      - name: job_id\n        in: path\n        required: true\n        schema:\n          type: string\n          title: Job Id\n      - name: limit\n        in: query\n        required: false\n        schema:\n          anyOf:\n          - type: integer\n          - type: 'null'\n          title: Limit\n      - name: page_cursor\n        in: query\n        required: false\n        schema:\n          anyOf:\n          - type: string\n          - type: 'null'\n          title: Page Cursor\n      responses:\n        '200':\n          description: Successful Response\n          content:\n            application/json:\n              schema:\n                $ref: '#/components/schemas/PlatformJobLogPage'\n        '422':\n          description: Validation Error\n          content:\n            application/json:\n              schema:\n                $ref: '#/components/schemas/HTTPValidationError'\n  /v1/data-designer/jobs/{job_id}/results:\n    get:\n      tags:\n      - Data Designer\n      summary: List Job Results\n      operationId: list_job_results_v1_data_designer_jobs__job_id__results_get\n      parameters:\n      - name: job_id\n        in: path\n        required: true\n        schema:\n          type: string\n          title: Job Id\n      responses:\n        '200':\n          description: Successful Response\n          content:\n            application/json:\n              schema:\n                $ref: '#/components/schemas/PlatformJobListResultResponse'\n        '422':\n          description: Validation Error\n          content:\n            application/json:\n              schema:\n                $ref: '#/components/schemas/HTTPValidationError'\n  /v1/data-designer/jobs/{job_id}/results/analysis/download:\n    get:\n      tags:\n      - Data Designer\n      summary: Download Job Result Analysis\n      operationId: download_job_result_analysis_v1_data_designer_jobs__job_id__results_analysis_download_get\n      parameters:\n      - name: job_id\n        in: path\n        required: true\n        schema:\n          type: string\n          title: Job Id\n      responses:\n        '200':\n          description: Successful Response\n          content:\n            application/json:\n              schema: {}\n        '404':\n          description: Not Found\n        '422':\n          description: Validation Error\n          content:\n            application/json:\n              schema:\n                $ref: '#/components/schemas/HTTPValidationError'\n  /v1/data-designer/jobs/{job_id}/results/dataset/download:\n    get:\n      tags:\n      - Data Designer\n      summary: Download Job Result Dataset\n      operationId: download_job_result_dataset_v1_data_designer_jobs__job_id__results_dataset_download_get\n      parameters:\n      - name: job_id\n        in: path\n        required: true\n        schema:\n          type: string\n          title: Job Id\n      responses:\n        '200':\n          description: Successful Response\n          content:\n            application/octet-stream:\n              schema:\n                type: string\n                format: binary\n        '404':\n          description: Not Found\n        '422':\n          description: Validation Error\n          content:\n            application/json:\n              schema:\n                $ref: '#/components/schemas/HTTPValidationError'\n  /v1/data-designer/jobs/{job_id}/results/{result_name}:\n    get:\n      tags:\n      - Data Designer\n      summary: Get Job Result\n      operationId: get_job_result_v1_data_designer_jobs__job_id__results__result_name__get\n      parameters:\n      - name: job_id\n        in: path\n        required: true\n        schema:\n          type: string\n          title: Job Id\n      - name: result_name\n        in: path\n        required: true\n        schema:\n          type: string\n          title: Result Name\n      responses:\n        '200':\n          description: Successful Response\n          content:\n            application/json:\n              schema:\n                $ref: '#/components/schemas/PlatformJobResultResponse'\n        '422':\n          description: Validation Error\n          content:\n            application/json:\n              schema:\n                $ref: '#/components/schemas/HTTPValidationError'\n  /v1/data-designer/jobs/{job_id}/results/{result_name}/download:\n    get:\n      tags:\n      - Data Designer\n      summary: Download Job Result\n      operationId: download_job_result_v1_data_designer_jobs__job_id__results__result_name__download_get\n      parameters:\n      - name: job_id\n        in: path\n        required: true\n        schema:\n          type: string\n          title: Job Id\n      - name: result_name\n        in: path\n        required: true\n        schema:\n          type: string\n          title: Result Name\n      responses:\n        '200':\n          description: Successful Response\n          content:\n            application/octet-stream:\n              schema:\n                type: string\n                format: binary\n        '404':\n          description: Not Found\n        '422':\n          description: Validation Error\n          content:\n            application/json:\n              schema:\n                $ref: '#/components/schemas/HTTPValidationError'\n  /v1/data-designer/jobs/{job_id}/status:\n    get:\n      tags:\n      - Data Designer\n      summary: Get Job Status\n      operationId: get_job_status_v1_data_designer_jobs__job_id__status_get\n      parameters:\n      - name: job_id\n        in: path\n        required: true\n        schema:\n          type: string\n          title: Job Id\n      responses:\n        '200':\n          description: Successful Response\n          content:\n            application/json:\n              schema:\n                $ref: '#/components/schemas/PlatformJobStatusResponse'\n        '422':\n          description: Validation Error\n          content:\n            application/json:\n              schema:\n                $ref: '#/components/schemas/HTTPValidationError'\n  /v1/data-designer/preview:\n    post:\n      tags:\n      - Data Designer\n      summary: Generate preview Data Designer\n      operationId: preview_v1_data_designer_preview_post\n      requestBody:\n        content:\n          application/json:\n            schema:\n              $ref: '#/components/schemas/PreviewRequest'\n        required: true\n      responses:\n        '200':\n          description: Successful Response\n          content:\n            application/jsonl:\n              schema:\n                $ref: '#/components/schemas/PreviewMessage'\n        '422':\n          description: Validation Error\n          content:\n            application/json:\n              schema:\n                $ref: '#/components/schemas/HTTPValidationError'\n  /v1/data-designer/settings:\n    get:\n      tags:\n      - Data Designer\n      summary: Get Data Designer settings\n      description: Returns the settings available for Data Designer.\n      operationId: get_settings_v1_data_designer_settings_get\n      responses:\n        '200':\n          description: Successful Response\n          content:\n            application/json:\n              schema:\n                $ref: '#/components/schemas/SettingsResponse'\ncomponents:\n  schemas:\n    BernoulliMixtureSamplerParams:\n      properties:\n        p:\n          type: number\n          maximum: 1.0\n          minimum: 0.0\n          title: P\n          description: Bernoulli distribution probability of success.\n        dist_name:\n          type: string\n          title: Dist Name\n          description: Mixture distribution name. Samples will be equal to the distribution\n            sample with probability `p`, otherwise equal to 0. Must be a valid scipy.stats\n            distribution name.\n        dist_params:\n          additionalProperties: true\n          type: object\n          title: Dist Params\n          description: Parameters of the scipy.stats distribution given in `dist_name`.\n        sampler_type:\n          type: string\n          const: bernoulli_mixture\n          title: Sampler Type\n          default: bernoulli_mixture\n      additionalProperties: false\n      type: object\n      required:\n      - p\n      - dist_name\n      - dist_params\n      title: BernoulliMixtureSamplerParams\n      description: \"Parameters for sampling from a Bernoulli mixture distribution.\\n\\\n        \\nCombines a Bernoulli distribution with another continuous distribution,\\\n        \\ creating a mixture\\nwhere values are either 0 (with probability 1-p) or\\\n        \\ sampled from the specified distribution\\n(with probability p). This is useful\\\n        \\ for modeling scenarios with many zero values mixed with\\na continuous distribution\\\n        \\ of non-zero values.\\n\\nCommon use cases include modeling sparse events,\\\n        \\ zero-inflated data, or situations where\\nan outcome either doesn't occur\\\n        \\ (0) or follows a specific distribution when it does occur.\\n\\nAttributes:\\n\\\n        \\    p: Probability of sampling from the mixture distribution (non-zero outcome).\\n\\\n        \\        Must be between 0.0 and 1.0 (inclusive). With probability 1-p, the\\\n        \\ sample is 0.\\n    dist_name: Name of the scipy.stats distribution to sample\\\n        \\ from when outcome is non-zero.\\n        Must be a valid scipy.stats distribution\\\n        \\ name (e.g., \\\"norm\\\", \\\"gamma\\\", \\\"expon\\\").\\n    dist_params: Parameters\\\n        \\ for the specified scipy.stats distribution.\"\n    BernoulliSamplerParams:\n      properties:\n        p:\n          type: number\n          maximum: 1.0\n          minimum: 0.0\n          title: P\n          description: Probability of success.\n        sampler_type:\n          type: string\n          const: bernoulli\n          title: Sampler Type\n          default: bernoulli\n      additionalProperties: false\n      type: object\n      required:\n      - p\n      title: BernoulliSamplerParams\n      description: \"Parameters for sampling from a Bernoulli distribution.\\n\\nSamples\\\n        \\ binary values (0 or 1) representing the outcome of a single trial with a\\\n        \\ fixed\\nprobability of success. This is the simplest discrete probability\\\n        \\ distribution, useful for\\nmodeling binary outcomes like success/failure,\\\n        \\ yes/no, or true/false.\\n\\nAttributes:\\n    p: Probability of success (sampling\\\n        \\ 1). Must be between 0.0 and 1.0 (inclusive).\\n        The probability of\\\n        \\ failure (sampling 0) is automatically 1 - p.\"\n    BinomialSamplerParams:\n      properties:\n        n:\n          type: integer\n          title: N\n          description: Number of trials.\n        p:\n          type: number\n          maximum: 1.0\n          minimum: 0.0\n          title: P\n          description: Probability of success on each trial.\n        sampler_type:\n          type: string\n          const: binomial\n          title: Sampler Type\n          default: binomial\n      additionalProperties: false\n      type: object\n      required:\n      - n\n      - p\n      title: BinomialSamplerParams\n      description: \"Parameters for sampling from a Binomial distribution.\\n\\nSamples\\\n        \\ integer values representing the number of successes in a fixed number of\\\n        \\ independent\\nBernoulli trials, each with the same probability of success.\\\n        \\ Commonly used to model the number\\nof successful outcomes in repeated experiments.\\n\\\n        \\nAttributes:\\n    n: Number of independent trials. Must be a positive integer.\\n\\\n        \\    p: Probability of success on each trial. Must be between 0.0 and 1.0\\\n        \\ (inclusive).\"\n    BuildStage:\n      type: string\n      enum:\n      - pre_batch\n      - post_batch\n      - pre_generation\n      - post_generation\n      title: BuildStage\n    CategorySamplerParams:\n      properties:\n        values:\n          items:\n            anyOf:\n            - type: string\n            - type: integer\n            - type: number\n          type: array\n          minItems: 1\n          title: Values\n          description: List of possible categorical values that can be sampled from.\n        weights:\n          type: array\n          items:\n            type: number\n          title: Weights\n          description: List of unnormalized probability weights to assigned to each\n            value, in order. Larger values will be sampled with higher probability.\n        sampler_type:\n          type: string\n          const: category\n          title: Sampler Type\n          default: category\n      additionalProperties: false\n      type: object\n      required:\n      - values\n      title: CategorySamplerParams\n      description: \"Parameters for categorical sampling with optional probability\\\n        \\ weighting.\\n\\nSamples values from a discrete set of categories. When weights\\\n        \\ are provided, values are\\nsampled according to their assigned probabilities.\\\n        \\ Without weights, uniform sampling is used.\\n\\nAttributes:\\n    values: List\\\n        \\ of possible categorical values to sample from. Can contain strings, integers,\\n\\\n        \\        or floats. Must contain at least one value.\\n    weights: Optional\\\n        \\ unnormalized probability weights for each value. If provided, must be\\n\\\n        \\        the same length as `values`. Weights are automatically normalized\\\n        \\ to sum to 1.0.\\n        Larger weights result in higher sampling probability\\\n        \\ for the corresponding value.\"\n    CodeLang:\n      type: string\n      enum:\n      - go\n      - javascript\n      - java\n      - kotlin\n      - python\n      - ruby\n      - rust\n      - scala\n      - swift\n      - typescript\n      - sql:sqlite\n      - sql:tsql\n      - sql:bigquery\n      - sql:mysql\n      - sql:postgres\n      - sql:ansi\n      title: CodeLang\n    CodeValidatorParams:\n      properties:\n        code_lang:\n          allOf:\n          - $ref: '#/components/schemas/CodeLang'\n          description: The language of the code to validate\n      additionalProperties: false\n      type: object\n      required:\n      - code_lang\n      title: CodeValidatorParams\n      description: \"Configuration for code validation. Supports Python and SQL code\\\n        \\ validation.\\n\\nAttributes:\\n    code_lang: The language of the code to validate.\\\n        \\ Supported values include: `python`,\\n        `sql:sqlite`, `sql:postgres`,\\\n        \\ `sql:mysql`, `sql:tsql`, `sql:bigquery`, `sql:ansi`.\"\n    ColumnInequalityConstraint:\n      properties:\n        target_column:\n          type: string\n          title: Target Column\n        rhs:\n          type: string\n          title: Rhs\n        operator:\n          $ref: '#/components/schemas/InequalityOperator'\n      additionalProperties: false\n      type: object\n      required:\n      - target_column\n      - rhs\n      - operator\n      title: ColumnInequalityConstraint\n    DataDesignerConfig:\n      properties:\n        columns:\n          items:\n            oneOf:\n            - $ref: '#/components/schemas/ExpressionColumnConfig'\n            - $ref: '#/components/schemas/LLMCodeColumnConfig'\n            - $ref: '#/components/schemas/LLMJudgeColumnConfig'\n            - $ref: '#/components/schemas/LLMStructuredColumnConfig'\n            - $ref: '#/components/schemas/LLMTextColumnConfig'\n            - $ref: '#/components/schemas/SamplerColumnConfig'\n            - $ref: '#/components/schemas/SeedDatasetColumnConfig'\n            - $ref: '#/components/schemas/ValidationColumnConfig'\n            discriminator:\n              propertyName: column_type\n              mapping:\n                expression: '#/components/schemas/ExpressionColumnConfig'\n                llm-code: '#/components/schemas/LLMCodeColumnConfig-Input'\n                llm-judge: '#/components/schemas/LLMJudgeColumnConfig-Input'\n                llm-structured: '#/components/schemas/LLMStructuredColumnConfig-Input'\n                llm-text: '#/components/schemas/LLMTextColumnConfig-Input'\n                sampler: '#/components/schemas/SamplerColumnConfig'\n                seed-dataset: '#/components/schemas/SeedDatasetColumnConfig'\n                validation: '#/components/schemas/ValidationColumnConfig-Input'\n          type: array\n          minItems: 1\n          title: Columns\n        model_configs:\n          type: array\n          items:\n            $ref: '#/components/schemas/ModelConfigInput'\n          title: Model Configs\n        seed_config:\n          $ref: '#/components/schemas/SeedConfig'\n        constraints:\n          type: array\n          items:\n            anyOf:\n            - $ref: '#/components/schemas/ScalarInequalityConstraint'\n            - $ref: '#/components/schemas/ColumnInequalityConstraint'\n          title: Constraints\n        profilers:\n          type: array\n          items:\n            $ref: '#/components/schemas/JudgeScoreProfilerConfig'\n          title: Profilers\n        processors:\n          type: array\n          items:\n            $ref: '#/components/schemas/ProcessorConfig'\n          title: Processors\n      additionalProperties: false\n      type: object\n      required:\n      - columns\n      title: DataDesignerConfig\n      description: \"Configuration for NeMo Data Designer.\\n\\nThis class defines the\\\n        \\ main configuration structure for NeMo Data Designer,\\nwhich orchestrates\\\n        \\ the generation of synthetic data.\\n\\nAttributes:\\n    columns: Required\\\n        \\ list of column configurations defining how each column\\n        should be\\\n        \\ generated. Must contain at least one column.\\n    model_configs: Optional\\\n        \\ list of model configurations for LLM-based generation.\\n        Each model\\\n        \\ config defines the model, provider, and inference parameters.\\n    seed_config:\\\n        \\ Optional seed dataset settings to use for generation.\\n    constraints:\\\n        \\ Optional list of column constraints.\\n    profilers: Optional list of column\\\n        \\ profilers for analyzing generated data characteristics.\"\n    DataDesignerJob:\n      properties:\n        id:\n          type: string\n          title: Id\n        name:\n          type: string\n          title: Name\n        description:\n          type: string\n          title: Description\n        project:\n          type: string\n          title: Project\n        namespace:\n          type: string\n          title: Namespace\n        created_at:\n          type: string\n          title: Created At\n        updated_at:\n          type: string\n          title: Updated At\n        spec:\n          $ref: '#/components/schemas/DataDesignerJobConfig'\n        status:\n          $ref: '#/components/schemas/PlatformJobStatus'\n        status_details:\n          type: object\n          additionalProperties: true\n          title: Status Details\n        error_details:\n          type: object\n          additionalProperties: true\n          title: Error Details\n        ownership:\n          type: object\n          additionalProperties: true\n          title: Ownership\n        custom_fields:\n          type: object\n          additionalProperties: true\n          title: Custom Fields\n      type: object\n      required:\n      - name\n      - spec\n      title: DataDesignerJob\n    DataDesignerJobConfig:\n      properties:\n        num_records:\n          type: integer\n          title: Num Records\n        config:\n          $ref: '#/components/schemas/DataDesignerConfig'\n      type: object\n      required:\n      - num_records\n      - config\n      title: DataDesignerJobConfig\n    DataDesignerJobRequest:\n      properties:\n        name:\n          type: string\n          title: Name\n        description:\n          type: string\n          title: Description\n        namespace:\n          type: string\n          title: Namespace\n        project:\n          type: string\n          title: Project\n        spec:\n          $ref: '#/components/schemas/DataDesignerJobConfig'\n        ownership:\n          type: object\n          additionalProperties: true\n          title: Ownership\n        custom_fields:\n          type: object\n          additionalProperties: true\n          title: Custom Fields\n      type: object\n      required:\n      - spec\n      title: DataDesignerJobRequest\n    DataDesignerJobsListFilter:\n      properties:\n        created_at:\n          allOf:\n          - $ref: '#/components/schemas/DatetimeFilter'\n          description: Jobs created at 'gte' datetime or 'lte' datetime.\n        name:\n          type: string\n          title: Name\n          description: Name of the job.\n        namespace:\n          type: string\n          title: Namespace\n          description: Namespace of the job.\n        project:\n          type: string\n          title: Project\n          description: Project containing the job.\n        status:\n          allOf:\n          - $ref: '#/components/schemas/PlatformJobStatus'\n          description: The current status.\n        updated_at:\n          allOf:\n          - $ref: '#/components/schemas/DatetimeFilter'\n          description: Jobs updated at 'gte' datetime or 'lte' datetime.\n      additionalProperties: false\n      type: object\n      title: DataDesignerJobsListFilter\n    DataDesignerJobsPage:\n      properties:\n        object:\n          type: string\n          title: Object\n          description: The type of object being returned.\n          default: list\n        data:\n          items:\n            $ref: '#/components/schemas/DataDesignerJob'\n          type: array\n          title: Data\n        pagination:\n          allOf:\n          - $ref: '#/components/schemas/PaginationData'\n          description: Pagination information.\n        sort:\n          type: string\n          title: Sort\n          description: The field on which the results are sorted.\n        filter:\n          allOf:\n          - $ref: '#/components/schemas/DataDesignerJobsListFilter'\n          description: Filtering information.\n        search:\n          allOf:\n          - $ref: '#/components/schemas/DataDesignerJobsSearch'\n          description: Search information.\n      type: object\n      required:\n      - data\n      title: DataDesignerJobsPage\n    DataDesignerJobsSearch:\n      properties:\n        name:\n          type: array\n          items:\n            type: string\n          title: Name\n          description: Search jobs where name contains any of these strings.\n        project:\n          type: array\n          items:\n            type: string\n          title: Project\n          description: Search jobs where project contains any of these strings.\n      type: object\n      title: DataDesignerJobsSearch\n    DataDesignerJobsSortField:\n      type: string\n      enum:\n      - created_at\n      - -created_at\n      - updated_at\n      - -updated_at\n      title: DataDesignerJobsSortField\n    DatetimeFilter:\n      properties:\n        gte:\n          type: string\n          title: Gte\n          description: Filter for results greater than or equal to this datetime.\n        lte:\n          type: string\n          title: Lte\n          description: Filter for results less than or equal to this datetime.\n      additionalProperties: false\n      type: object\n      title: DatetimeFilter\n    DatetimeSamplerParams:\n      properties:\n        start:\n          type: string\n          title: Start\n          description: Earliest possible datetime for sampling range, inclusive.\n        end:\n          type: string\n          title: End\n          description: Latest possible datetime for sampling range, inclusive.\n        unit:\n          type: string\n          enum:\n          - Y\n          - M\n          - D\n          - h\n          - m\n          - s\n          title: Unit\n          description: Sampling units, e.g. the smallest possible time interval between\n            samples.\n          default: D\n        sampler_type:\n          type: string\n          const: datetime\n          title: Sampler Type\n          default: datetime\n      additionalProperties: false\n      type: object\n      required:\n      - start\n      - end\n      title: DatetimeSamplerParams\n      description: \"Parameters for uniform datetime sampling within a specified range.\\n\\\n        \\nSamples datetime values uniformly between a start and end date with a specified\\\n        \\ granularity.\\nThe sampling unit determines the smallest possible time interval\\\n        \\ between consecutive samples.\\n\\nAttributes:\\n    start: Earliest possible\\\n        \\ datetime for the sampling range (inclusive). Must be a valid\\n        datetime\\\n        \\ string parseable by pandas.to_datetime().\\n    end: Latest possible datetime\\\n        \\ for the sampling range (inclusive). Must be a valid\\n        datetime string\\\n        \\ parseable by pandas.to_datetime().\\n    unit: Time unit for sampling granularity.\\\n        \\ Options:\\n        - \\\"Y\\\": Years\\n        - \\\"M\\\": Months\\n        - \\\"\\\n        D\\\": Days (default)\\n        - \\\"h\\\": Hours\\n        - \\\"m\\\": Minutes\\n  \\\n        \\      - \\\"s\\\": Seconds\"\n    DisplayModelProvider:\n      properties:\n        name:\n          type: string\n          title: Name\n        provider_type:\n          type: string\n          title: Provider Type\n          default: openai\n        extra_body:\n          type: object\n          additionalProperties: true\n          title: Extra Body\n        allowed_models:\n          type: array\n          items:\n            type: string\n          title: Allowed Models\n      additionalProperties: false\n      type: object\n      required:\n      - name\n      title: DisplayModelProvider\n    DistributionType:\n      type: string\n      enum:\n      - uniform\n      - manual\n      title: DistributionType\n    ExpressionColumnConfig:\n      properties:\n        name:\n          type: string\n          title: Name\n        drop:\n          type: boolean\n          title: Drop\n          default: false\n        column_type:\n          type: string\n          const: expression\n          title: Column Type\n          default: expression\n        expr:\n          type: string\n          title: Expr\n        dtype:\n          type: string\n          enum:\n          - int\n          - float\n          - str\n          - bool\n          title: Dtype\n          default: str\n      additionalProperties: false\n      type: object\n      required:\n      - name\n      - expr\n      title: ExpressionColumnConfig\n      description: \"Configuration for derived columns using Jinja2 expressions.\\n\\n\\\n        Expression columns compute values by evaluating Jinja2 templates that reference\\\n        \\ other\\ncolumns. Useful for transformations, concatenations, conditional\\\n        \\ logic, and derived\\nfeatures without requiring LLM generation. The expression\\\n        \\ is evaluated row-by-row.\\n\\nAttributes:\\n    expr: Jinja2 expression to\\\n        \\ evaluate. Can reference other column values using\\n        {{ column_name\\\n        \\ }} syntax. Supports filters, conditionals, and arithmetic.\\n        Must\\\n        \\ be a valid, non-empty Jinja2 template.\\n    dtype: Data type to cast the\\\n        \\ result to. Must be one of \\\"int\\\", \\\"float\\\", \\\"str\\\", or \\\"bool\\\".\\n  \\\n        \\      Defaults to \\\"str\\\". Type conversion is applied after expression evaluation.\\n\\\n        \\    column_type: Discriminator field, always \\\"expression\\\" for this configuration\\\n        \\ type.\"\n    FileStorageType:\n      type: string\n      enum:\n      - nds\n      title: FileStorageType\n    GaussianSamplerParams:\n      properties:\n        mean:\n          type: number\n          title: Mean\n          description: Mean of the Gaussian distribution\n        stddev:\n          type: number\n          title: Stddev\n          description: Standard deviation of the Gaussian distribution\n        decimal_places:\n          type: integer\n          title: Decimal Places\n          description: Number of decimal places to round the sampled values to.\n        sampler_type:\n          type: string\n          const: gaussian\n          title: Sampler Type\n          default: gaussian\n      additionalProperties: false\n      type: object\n      required:\n      - mean\n      - stddev\n      title: GaussianSamplerParams\n      description: \"Parameters for sampling from a Gaussian (Normal) distribution.\\n\\\n        \\nSamples continuous values from a normal distribution characterized by its\\\n        \\ mean and standard\\ndeviation. The Gaussian distribution is one of the most\\\n        \\ commonly used probability distributions,\\nappearing naturally in many real-world\\\n        \\ phenomena due to the Central Limit Theorem.\\n\\nAttributes:\\n    mean: Mean\\\n        \\ (center) of the Gaussian distribution. This is the expected value and the\\n\\\n        \\        location of the distribution's peak.\\n    stddev: Standard deviation\\\n        \\ of the Gaussian distribution. Controls the spread or width\\n        of the\\\n        \\ distribution. Must be positive.\\n    decimal_places: Optional number of\\\n        \\ decimal places to round sampled values to. If None,\\n        values are\\\n        \\ not rounded.\"\n    HTTPValidationError:\n      properties:\n        detail:\n          items:\n            $ref: '#/components/schemas/ValidationError'\n          type: array\n          title: Detail\n      type: object\n      title: HTTPValidationError\n    ImageContext:\n      properties:\n        modality:\n          allOf:\n          - $ref: '#/components/schemas/Modality'\n          default: image\n        column_name:\n          type: string\n          title: Column Name\n        data_type:\n          $ref: '#/components/schemas/ModalityDataType'\n        image_format:\n          $ref: '#/components/schemas/ImageFormat'\n      type: object\n      required:\n      - column_name\n      - data_type\n      title: ImageContext\n    ImageFormat:\n      type: string\n      enum:\n      - png\n      - jpg\n      - jpeg\n      - gif\n      - webp\n      title: ImageFormat\n    IndexRange:\n      properties:\n        start:\n          type: integer\n          minimum: 0.0\n          title: Start\n          description: The start index of the index range (inclusive)\n        end:\n          type: integer\n          minimum: 0.0\n          title: End\n          description: The end index of the index range (inclusive)\n      additionalProperties: false\n      type: object\n      required:\n      - start\n      - end\n      title: IndexRange\n    InequalityOperator:\n      type: string\n      enum:\n      - lt\n      - le\n      - gt\n      - ge\n      title: InequalityOperator\n    InferenceParametersInput:\n      properties:\n        temperature:\n          anyOf:\n          - type: number\n          - $ref: '#/components/schemas/UniformDistribution'\n          - $ref: '#/components/schemas/ManualDistribution'\n          - type: 'null'\n          title: Temperature\n        top_p:\n          anyOf:\n          - type: number\n          - $ref: '#/components/schemas/UniformDistribution'\n          - $ref: '#/components/schemas/ManualDistribution'\n          - type: 'null'\n          title: Top P\n        max_tokens:\n          type: integer\n          title: Max Tokens\n        max_parallel_requests:\n          type: integer\n          minimum: 1.0\n          title: Max Parallel Requests\n          default: 4\n        timeout:\n          type: integer\n          title: Timeout\n        extra_body:\n          type: object\n          additionalProperties: true\n          title: Extra Body\n      additionalProperties: false\n      type: object\n      title: InferenceParametersInput\n    InferenceParametersOutput:\n      properties:\n        temperature:\n          anyOf:\n          - type: number\n          - $ref: '#/components/schemas/UniformDistribution'\n          - $ref: '#/components/schemas/ManualDistribution'\n          - type: 'null'\n          title: Temperature\n        top_p:\n          anyOf:\n          - type: number\n          - $ref: '#/components/schemas/UniformDistribution'\n          - $ref: '#/components/schemas/ManualDistribution'\n          - type: 'null'\n          title: Top P\n        max_tokens:\n          type: integer\n          title: Max Tokens\n        max_parallel_requests:\n          type: integer\n          minimum: 1.0\n          title: Max Parallel Requests\n          default: 4\n        timeout:\n          type: integer\n          title: Timeout\n        extra_body:\n          type: object\n          additionalProperties: true\n          title: Extra Body\n      additionalProperties: false\n      type: object\n      title: InferenceParametersOutput\n    JudgeScoreProfilerConfig:\n      properties:\n        model_alias:\n          type: string\n          title: Model Alias\n        summary_score_sample_size:\n          type: integer\n          title: Summary Score Sample Size\n          default: 20\n      additionalProperties: false\n      type: object\n      required:\n      - model_alias\n      title: JudgeScoreProfilerConfig\n    LLMCodeColumnConfig:\n      properties:\n        name:\n          type: string\n          title: Name\n        drop:\n          type: boolean\n          title: Drop\n          default: false\n        column_type:\n          type: string\n          const: llm-code\n          title: Column Type\n          default: llm-code\n        prompt:\n          type: string\n          title: Prompt\n        model_alias:\n          type: string\n          title: Model Alias\n        system_prompt:\n          type: string\n          title: System Prompt\n        multi_modal_context:\n          type: array\n          items:\n            $ref: '#/components/schemas/ImageContext'\n          title: Multi Modal Context\n        code_lang:\n          $ref: '#/components/schemas/CodeLang'\n      additionalProperties: false\n      type: object\n      required:\n      - name\n      - prompt\n      - model_alias\n      - code_lang\n      title: LLMCodeColumnConfig\n      description: \"Configuration for code generation columns using Large Language\\\n        \\ Models.\\n\\nExtends LLMTextColumnConfig to generate code snippets in specific\\\n        \\ programming languages\\nor SQL dialects. The generated code is automatically\\\n        \\ extracted from markdown code blocks\\nfor the specified language. Inherits\\\n        \\ all prompt templating capabilities.\\n\\nAttributes:\\n    code_lang: Programming\\\n        \\ language or SQL dialect for code generation. Supported\\n        values include:\\\n        \\ \\\"python\\\", \\\"javascript\\\", \\\"typescript\\\", \\\"java\\\", \\\"kotlin\\\", \\\"go\\\"\\\n        ,\\n        \\\"rust\\\", \\\"ruby\\\", \\\"scala\\\", \\\"swift\\\", \\\"sql:sqlite\\\", \\\"sql:postgres\\\"\\\n        , \\\"sql:mysql\\\",\\n        \\\"sql:tsql\\\", \\\"sql:bigquery\\\", \\\"sql:ansi\\\". See\\\n        \\ CodeLang enum for complete list.\\n    column_type: Discriminator field,\\\n        \\ always \\\"llm-code\\\" for this configuration type.\"\n    LLMJudgeColumnConfig:\n      properties:\n        name:\n          type: string\n          title: Name\n        drop:\n          type: boolean\n          title: Drop\n          default: false\n        column_type:\n          type: string\n          const: llm-judge\n          title: Column Type\n          default: llm-judge\n        prompt:\n          type: string\n          title: Prompt\n        model_alias:\n          type: string\n          title: Model Alias\n        system_prompt:\n          type: string\n          title: System Prompt\n        multi_modal_context:\n          type: array\n          items:\n            $ref: '#/components/schemas/ImageContext'\n          title: Multi Modal Context\n        scores:\n          items:\n            $ref: '#/components/schemas/Score'\n          type: array\n          minItems: 1\n          title: Scores\n      additionalProperties: false\n      type: object\n      required:\n      - name\n      - prompt\n      - model_alias\n      - scores\n      title: LLMJudgeColumnConfig\n      description: \"Configuration for LLM-as-a-judge quality assessment and scoring\\\n        \\ columns.\\n\\nExtends LLMTextColumnConfig to create judge columns that evaluate\\\n        \\ and score other\\ngenerated content based on the defined criteria. Useful\\\n        \\ for quality assessment, preference\\nranking, and multi-dimensional evaluation\\\n        \\ of generated data.\\n\\nAttributes:\\n    scores: List of Score objects defining\\\n        \\ the evaluation dimensions. Each score\\n        represents a different aspect\\\n        \\ to evaluate (e.g., accuracy, relevance, fluency).\\n        Must contain\\\n        \\ at least one score.\\n    column_type: Discriminator field, always \\\"llm-judge\\\"\\\n        \\ for this configuration type.\"\n    LLMStructuredColumnConfig:\n      properties:\n        name:\n          type: string\n          title: Name\n        drop:\n          type: boolean\n          title: Drop\n          default: false\n        column_type:\n          type: string\n          const: llm-structured\n          title: Column Type\n          default: llm-structured\n        prompt:\n          type: string\n          title: Prompt\n        model_alias:\n          type: string\n          title: Model Alias\n        system_prompt:\n          type: string\n          title: System Prompt\n        multi_modal_context:\n          type: array\n          items:\n            $ref: '#/components/schemas/ImageContext'\n          title: Multi Modal Context\n        output_format:\n          anyOf:\n          - additionalProperties: true\n            type: object\n          - {}\n          title: Output Format\n      additionalProperties: false\n      type: object\n      required:\n      - name\n      - prompt\n      - model_alias\n      - output_format\n      title: LLMStructuredColumnConfig\n      description: \"Configuration for structured JSON generation columns using Large\\\n        \\ Language Models.\\n\\nExtends LLMTextColumnConfig to generate structured data\\\n        \\ conforming to a specified schema.\\nUses JSON schema or Pydantic models to\\\n        \\ define the expected output structure, enabling\\ntype-safe and validated\\\n        \\ structured output generation. Inherits prompt templating capabilities.\\n\\\n        \\nAttributes:\\n    output_format: The schema defining the expected output\\\n        \\ structure. Can be either:\\n        - A Pydantic BaseModel class (recommended)\\n\\\n        \\        - A JSON schema dictionary\\n    column_type: Discriminator field,\\\n        \\ always \\\"llm-structured\\\" for this configuration type.\"\n    LLMTextColumnConfig:\n      properties:\n        name:\n          type: string\n          title: Name\n        drop:\n          type: boolean\n          title: Drop\n          default: false\n        column_type:\n          type: string\n          const: llm-text\n          title: Column Type\n          default: llm-text\n        prompt:\n          type: string\n          title: Prompt\n        model_alias:\n          type: string\n          title: Model Alias\n        system_prompt:\n          type: string\n          title: System Prompt\n        multi_modal_context:\n          type: array\n          items:\n            $ref: '#/components/schemas/ImageContext'\n          title: Multi Modal Context\n      additionalProperties: false\n      type: object\n      required:\n      - name\n      - prompt\n      - model_alias\n      title: LLMTextColumnConfig\n      description: \"Configuration for text generation columns using Large Language\\\n        \\ Models.\\n\\nLLM text columns generate free-form text content using language\\\n        \\ models via LiteLLM.\\nPrompts support Jinja2 templating to reference values\\\n        \\ from other columns, enabling\\ncontext-aware generation. The generated text\\\n        \\ can optionally include reasoning traces\\nwhen models support extended thinking.\\n\\\n        \\nAttributes:\\n    prompt: Prompt template for text generation. Supports Jinja2\\\n        \\ syntax to\\n        reference other columns (e.g., \\\"Write a story about\\\n        \\ {{ character_name }}\\\").\\n        Must be a valid Jinja2 template.\\n   \\\n        \\ model_alias: Alias of the model configuration to use for generation.\\n \\\n        \\       Must match a model alias defined when initializing the DataDesignerConfigBuilder.\\n\\\n        \\    system_prompt: Optional system prompt to set model behavior and constraints.\\n\\\n        \\        Also supports Jinja2 templating. If provided, must be a valid Jinja2\\\n        \\ template.\\n        Do not put any output parsing instructions in the system\\\n        \\ prompt. Instead,\\n        use the appropriate column type for the output\\\n        \\ you want to generate - e.g.,\\n        `LLMStructuredColumnConfig` for structured\\\n        \\ output, `LLMCodeColumnConfig` for code.\\n    multi_modal_context: Optional\\\n        \\ list of image contexts for multi-modal generation.\\n        Enables vision-capable\\\n        \\ models to generate text based on image inputs.\\n    column_type: Discriminator\\\n        \\ field, always \\\"llm-text\\\" for this configuration type.\"\n    LocalCallableValidatorParams:\n      properties:\n        validation_function:\n          title: Validation Function\n          description: Function (Callable[[pd.DataFrame], pd.DataFrame]) to validate\n            the data\n        output_schema:\n          type: object\n          additionalProperties: true\n          title: Output Schema\n          description: Expected schema for local callable validator's output\n      additionalProperties: false\n      type: object\n      required:\n      - validation_function\n      title: LocalCallableValidatorParams\n      description: \"Configuration for local callable validation. Expects a function\\\n        \\ to be passed that validates the data.\\n\\nAttributes:\\n    validation_function:\\\n        \\ Function (`Callable[[pd.DataFrame], pd.DataFrame]`) to validate the\\n  \\\n        \\      data. Output must contain a column `is_valid` of type `bool`.\\n   \\\n        \\ output_schema: The JSON schema for the local callable validator's output.\\\n        \\ If not provided,\\n        the output will not be validated.\"\n    ManualDistribution:\n      properties:\n        distribution_type:\n          allOf:\n          - $ref: '#/components/schemas/DistributionType'\n          default: manual\n        params:\n          $ref: '#/components/schemas/ManualDistributionParams'\n      additionalProperties: false\n      type: object\n      required:\n      - params\n      title: ManualDistribution\n    ManualDistributionParams:\n      properties:\n        values:\n          items:\n            type: number\n          type: array\n          minItems: 1\n          title: Values\n        weights:\n          type: array\n          items:\n            type: number\n          title: Weights\n      additionalProperties: false\n      type: object\n      required:\n      - values\n      title: ManualDistributionParams\n    MessageType:\n      type: string\n      enum:\n      - analysis\n      - dataset\n      - heartbeat\n      - log\n      title: MessageType\n    Modality:\n      type: string\n      enum:\n      - image\n      title: Modality\n    ModalityDataType:\n      type: string\n      enum:\n      - url\n      - base64\n      title: ModalityDataType\n    ModelConfigInput:\n      properties:\n        alias:\n          type: string\n          title: Alias\n        model:\n          type: string\n          title: Model\n        inference_parameters:\n          $ref: '#/components/schemas/InferenceParametersInput'\n        provider:\n          type: string\n          title: Provider\n      additionalProperties: false\n      type: object\n      required:\n      - alias\n      - model\n      title: ModelConfigInput\n    ModelConfigOutput:\n      properties:\n        alias:\n          type: string\n          title: Alias\n        model:\n          type: string\n          title: Model\n        inference_parameters:\n          $ref: '#/components/schemas/InferenceParametersOutput'\n        provider:\n          type: string\n          title: Provider\n      additionalProperties: false\n      type: object\n      required:\n      - alias\n      - model\n      title: ModelConfigOutput\n    PaginationData:\n      properties:\n        page:\n          type: integer\n          title: Page\n          description: The current page number.\n        page_size:\n          type: integer\n          title: Page Size\n          description: The page size used for the query.\n        current_page_size:\n          type: integer\n          title: Current Page Size\n          description: The size for the current page.\n        total_pages:\n          type: integer\n          title: Total Pages\n          description: The total number of pages.\n        total_results:\n          type: integer\n          title: Total Results\n          description: The total number of results.\n      type: object\n      required:\n      - page\n      - page_size\n      - current_page_size\n      - total_pages\n      - total_results\n      title: PaginationData\n    PartitionBlock:\n      properties:\n        index:\n          type: integer\n          minimum: 0.0\n          title: Index\n          description: The index of the partition to sample from\n          default: 0\n        num_partitions:\n          type: integer\n          minimum: 1.0\n          title: Num Partitions\n          description: The total number of partitions in the dataset\n          default: 1\n      additionalProperties: false\n      type: object\n      title: PartitionBlock\n    PersonFromFakerSamplerParams:\n      properties:\n        locale:\n          type: string\n          title: Locale\n          description: Locale string, determines the language and geographic locale\n            that a synthetic person will be sampled from. E.g, en_US, en_GB, fr_FR,\n            ...\n          default: en_US\n        sex:\n          type: string\n          title: Sex\n          description: If specified, then only synthetic people of the specified sex\n            will be sampled.\n        city:\n          anyOf:\n          - type: string\n          - items:\n              type: string\n            type: array\n          title: City\n          description: If specified, then only synthetic people from these cities\n            will be sampled.\n        age_range:\n          items:\n            type: integer\n          type: array\n          maxItems: 2\n          minItems: 2\n          title: Age Range\n          description: If specified, then only synthetic people within this age range\n            will be sampled.\n          default:\n          - 18\n          - 114\n        sampler_type:\n          type: string\n          const: person_from_faker\n          title: Sampler Type\n          default: person_from_faker\n      additionalProperties: false\n      type: object\n      title: PersonFromFakerSamplerParams\n    PersonSamplerParams:\n      properties:\n        locale:\n          type: string\n          title: Locale\n          description: 'Locale that determines the language and geographic location\n            that a synthetic person will be sampled from. Must be a locale supported\n            by a managed Nemotron Personas dataset. Managed datasets exist for the\n            following locales: en_US, ja_JP, en_IN, hi_IN.'\n          default: en_US\n        sex:\n          type: string\n          title: Sex\n          description: If specified, then only synthetic people of the specified sex\n            will be sampled.\n        city:\n          anyOf:\n          - type: string\n          - items:\n              type: string\n            type: array\n          title: City\n          description: If specified, then only synthetic people from these cities\n            will be sampled.\n        age_range:\n          items:\n            type: integer\n          type: array\n          maxItems: 2\n          minItems: 2\n          title: Age Range\n          description: If specified, then only synthetic people within this age range\n            will be sampled.\n          default:\n          - 18\n          - 114\n        select_field_values:\n          type: object\n          additionalProperties:\n            items:\n              type: string\n            type: array\n          title: Select Field Values\n          description: Sample synthetic people with the specified field values. This\n            is meant to be a flexible argument for selecting a subset of the population\n            from the managed dataset. Note that this sampler does not support rare\n            combinations of field values and will likely fail if your desired subset\n            is not well-represented in the managed Nemotron Personas dataset. We generally\n            recommend using the `sex`, `city`, and `age_range` arguments to filter\n            the population when possible.\n          examples:\n          - education_level:\n            - high_school\n            - some_college\n            - bachelors\n            state:\n            - NY\n            - CA\n            - OH\n            - TX\n            - NV\n        with_synthetic_personas:\n          type: boolean\n          title: With Synthetic Personas\n          description: If True, then append synthetic persona columns to each generated\n            person.\n          default: false\n        sampler_type:\n          type: string\n          const: person\n          title: Sampler Type\n          default: person\n      additionalProperties: false\n      type: object\n      title: PersonSamplerParams\n      description: \"Parameters for sampling synthetic person data with demographic\\\n        \\ attributes.\\n\\nGenerates realistic synthetic person data including names,\\\n        \\ addresses, phone numbers, and other\\ndemographic information. Data can be\\\n        \\ sampled from managed datasets (when available) or generated\\nusing Faker.\\\n        \\ The sampler supports filtering by locale, sex, age, geographic location,\\\n        \\ and can\\noptionally include synthetic persona descriptions.\\n\\nAttributes:\\n\\\n        \\    locale: Locale string determining the language and geographic region\\\n        \\ for synthetic people.\\n        Format: language_COUNTRY (e.g., \\\"en_US\\\"\\\n        , \\\"en_GB\\\", \\\"fr_FR\\\", \\\"de_DE\\\", \\\"es_ES\\\", \\\"ja_JP\\\").\\n        Defaults\\\n        \\ to \\\"en_US\\\".\\n    sex: If specified, filters to only sample people of the\\\n        \\ specified sex. Options: \\\"Male\\\" or\\n        \\\"Female\\\". If None, samples\\\n        \\ both sexes.\\n    city: If specified, filters to only sample people from\\\n        \\ the specified city or cities. Can be\\n        a single city name (string)\\\n        \\ or a list of city names.\\n    age_range: Two-element list [min_age, max_age]\\\n        \\ specifying the age range to sample from\\n        (inclusive). Defaults to\\\n        \\ a standard age range. Both values must be between minimum and\\n        maximum\\\n        \\ allowed ages.\\n    with_synthetic_personas: If True, appends additional\\\n        \\ synthetic persona columns including\\n        personality traits, interests,\\\n        \\ and background descriptions. Only supported for certain\\n        locales\\\n        \\ with managed datasets.\\n    sample_dataset_when_available: If True, samples\\\n        \\ from curated managed datasets when available\\n        for the specified\\\n        \\ locale. If False or unavailable, falls back to Faker-generated data.\\n \\\n        \\       Managed datasets typically provide more realistic and diverse synthetic\\\n        \\ people.\"\n    PlatformJobListResultResponse:\n      properties:\n        object:\n          type: string\n          title: Object\n          description: The type of object being returned.\n          default: list\n        data:\n          items:\n            $ref: '#/components/schemas/PlatformJobResultResponse'\n          type: array\n          title: Data\n      type: object\n      required:\n      - data\n      title: PlatformJobListResultResponse\n    PlatformJobLog:\n      properties:\n        timestamp:\n          type: string\n          format: date-time\n          title: Timestamp\n        job_id:\n          type: string\n          title: Job Id\n        job_step:\n          type: string\n          title: Job Step\n        job_task:\n          type: string\n          title: Job Task\n        message:\n          type: string\n          title: Message\n      type: object\n      required:\n      - timestamp\n      - job_id\n      - job_step\n      - job_task\n      - message\n      title: PlatformJobLog\n    PlatformJobLogPage:\n      properties:\n        object:\n          type: string\n          title: Object\n          description: The type of object being returned.\n          default: list\n        data:\n          items:\n            $ref: '#/components/schemas/PlatformJobLog'\n          type: array\n          title: Data\n        total:\n          type: integer\n          title: Total\n        next_page:\n          type: string\n          title: Next Page\n        prev_page:\n          type: string\n          title: Prev Page\n      type: object\n      required:\n      - data\n      - total\n      - next_page\n      - prev_page\n      title: PlatformJobLogPage\n    PlatformJobResultResponse:\n      properties:\n        result_name:\n          type: string\n          title: Result Name\n        job_id:\n          type: string\n          title: Job Id\n        namespace:\n          type: string\n          title: Namespace\n        project:\n          type: string\n          title: Project\n        created_at:\n          type: string\n          format: date-time\n          title: Created At\n        updated_at:\n          type: string\n          format: date-time\n          title: Updated At\n        artifact_url:\n          type: string\n          title: Artifact Url\n        artifact_storage_type:\n          $ref: '#/components/schemas/FileStorageType'\n      type: object\n      required:\n      - result_name\n      - job_id\n      - namespace\n      - artifact_url\n      - artifact_storage_type\n      title: PlatformJobResultResponse\n    PlatformJobStatus:\n      type: string\n      enum:\n      - created\n      - pending\n      - active\n      - cancelled\n      - cancelling\n      - error\n      - completed\n      - paused\n      - pausing\n      - resuming\n      title: PlatformJobStatus\n      description: 'Enumeration of possible job statuses.\n\n\n        This enum represents the various states a job can be in during its lifecycle,\n\n        from creation to a terminal state.'\n    PlatformJobStatusResponse:\n      properties:\n        job_id:\n          type: string\n          title: Job Id\n        status:\n          $ref: '#/components/schemas/PlatformJobStatus'\n        status_details:\n          additionalProperties: true\n          type: object\n          title: Status Details\n        error_details:\n          type: object\n          additionalProperties: true\n          title: Error Details\n        steps:\n          items:\n            $ref: '#/components/schemas/PlatformJobStepStatusResponse'\n          type: array\n          title: Steps\n      type: object\n      required:\n      - job_id\n      - status\n      - status_details\n      - error_details\n      - steps\n      title: PlatformJobStatusResponse\n    PlatformJobStepStatusResponse:\n      properties:\n        name:\n          type: string\n          title: Name\n        status:\n          $ref: '#/components/schemas/PlatformJobStatus'\n        status_details:\n          additionalProperties: true\n          type: object\n          title: Status Details\n        error_details:\n          type: object\n          additionalProperties: true\n          title: Error Details\n        tasks:\n          items:\n            $ref: '#/components/schemas/PlatformJobTaskStatusResponse'\n          type: array\n          title: Tasks\n      type: object\n      required:\n      - name\n      - status\n      - status_details\n      - error_details\n      - tasks\n      title: PlatformJobStepStatusResponse\n    PlatformJobTaskStatusResponse:\n      properties:\n        id:\n          type: string\n          title: Id\n        status:\n          $ref: '#/components/schemas/PlatformJobStatus'\n        status_details:\n          additionalProperties: true\n          type: object\n          title: Status Details\n        error_details:\n          type: object\n          additionalProperties: true\n          title: Error Details\n        error_stack:\n          type: string\n          title: Error Stack\n      type: object\n      required:\n      - id\n      - status\n      - status_details\n      - error_details\n      - error_stack\n      title: PlatformJobTaskStatusResponse\n    PoissonSamplerParams:\n      properties:\n        mean:\n          type: number\n          title: Mean\n          description: Mean number of events in a fixed interval.\n        sampler_type:\n          type: string\n          const: poisson\n          title: Sampler Type\n          default: poisson\n      additionalProperties: false\n      type: object\n      required:\n      - mean\n      title: PoissonSamplerParams\n      description: \"Parameters for sampling from a Poisson distribution.\\n\\nSamples\\\n        \\ non-negative integer values representing the number of events occurring\\\n        \\ in a fixed\\ninterval of time or space. The Poisson distribution is commonly\\\n        \\ used to model count data\\nlike the number of arrivals, occurrences, or events\\\n        \\ per time period.\\n\\nThe distribution is characterized by a single parameter\\\n        \\ (mean/rate), and both the mean and\\nvariance equal this parameter value.\\n\\\n        \\nAttributes:\\n    mean: Mean number of events in the fixed interval (also\\\n        \\ called rate parameter \\u03BB).\\n        Must be positive. This represents\\\n        \\ both the expected value and the variance of the\\n        distribution.\"\n    PreviewMessage:\n      properties:\n        message:\n          type: string\n          title: Message\n        message_type:\n          $ref: '#/components/schemas/MessageType'\n        extra:\n          type: object\n          additionalProperties:\n            type: string\n          title: Extra\n      additionalProperties: false\n      type: object\n      required:\n      - message\n      - message_type\n      title: PreviewMessage\n    PreviewRequest:\n      properties:\n        config:\n          $ref: '#/components/schemas/DataDesignerConfig'\n        num_records:\n          type: integer\n          title: Num Records\n      type: object\n      required:\n      - config\n      title: PreviewRequest\n    ProcessorConfig:\n      properties:\n        build_stage:\n          allOf:\n          - $ref: '#/components/schemas/BuildStage'\n          description: 'The stage at which the processor will run. Supported stages:\n            post_batch'\n      additionalProperties: false\n      type: object\n      required:\n      - build_stage\n      title: ProcessorConfig\n    RemoteValidatorParams:\n      properties:\n        endpoint_url:\n          type: string\n          title: Endpoint Url\n          description: URL of the remote endpoint\n        output_schema:\n          type: object\n          additionalProperties: true\n          title: Output Schema\n          description: Expected schema for remote validator's output\n        timeout:\n          type: number\n          exclusiveMinimum: 0.0\n          title: Timeout\n          description: The timeout for the HTTP request\n          default: 30.0\n        max_retries:\n          type: integer\n          minimum: 0.0\n          title: Max Retries\n          description: The maximum number of retry attempts\n          default: 3\n        retry_backoff:\n          type: number\n          exclusiveMinimum: 1.0\n          title: Retry Backoff\n          description: The backoff factor for the retry delay\n          default: 2.0\n        max_parallel_requests:\n          type: integer\n          minimum: 1.0\n          title: Max Parallel Requests\n          description: The maximum number of parallel requests to make\n          default: 4\n      additionalProperties: false\n      type: object\n      required:\n      - endpoint_url\n      title: RemoteValidatorParams\n      description: \"Configuration for remote validation. Sends data to a remote endpoint\\\n        \\ for validation.\\n\\nAttributes:\\n    endpoint_url: The URL of the remote\\\n        \\ endpoint.\\n    output_schema: The JSON schema for the remote validator's\\\n        \\ output. If not provided,\\n        the output will not be validated.\\n  \\\n        \\  timeout: The timeout for the HTTP request in seconds. Defaults to 30.0.\\n\\\n        \\    max_retries: The maximum number of retry attempts. Defaults to 3.\\n \\\n        \\   retry_backoff: The backoff factor for the retry delay in seconds. Defaults\\\n        \\ to 2.0.\\n    max_parallel_requests: The maximum number of parallel requests\\\n        \\ to make. Defaults to 4.\"\n    SamplerColumnConfig:\n      properties:\n        name:\n          type: string\n          title: Name\n        drop:\n          type: boolean\n          title: Drop\n          default: false\n        column_type:\n          type: string\n          const: sampler\n          title: Column Type\n          default: sampler\n        sampler_type:\n          $ref: '#/components/schemas/SamplerType'\n        params:\n          oneOf:\n          - $ref: '#/components/schemas/SubcategorySamplerParams'\n          - $ref: '#/components/schemas/CategorySamplerParams'\n          - $ref: '#/components/schemas/DatetimeSamplerParams'\n          - $ref: '#/components/schemas/PersonSamplerParams'\n          - $ref: '#/components/schemas/PersonFromFakerSamplerParams'\n          - $ref: '#/components/schemas/TimeDeltaSamplerParams'\n          - $ref: '#/components/schemas/UUIDSamplerParams'\n          - $ref: '#/components/schemas/BernoulliSamplerParams'\n          - $ref: '#/components/schemas/BernoulliMixtureSamplerParams'\n          - $ref: '#/components/schemas/BinomialSamplerParams'\n          - $ref: '#/components/schemas/GaussianSamplerParams'\n          - $ref: '#/components/schemas/PoissonSamplerParams'\n          - $ref: '#/components/schemas/UniformSamplerParams'\n          - $ref: '#/components/schemas/ScipySamplerParams'\n          title: Params\n          discriminator:\n            propertyName: sampler_type\n            mapping:\n              bernoulli: '#/components/schemas/BernoulliSamplerParams'\n              bernoulli_mixture: '#/components/schemas/BernoulliMixtureSamplerParams'\n              binomial: '#/components/schemas/BinomialSamplerParams'\n              category: '#/components/schemas/CategorySamplerParams'\n              datetime: '#/components/schemas/DatetimeSamplerParams'\n              gaussian: '#/components/schemas/GaussianSamplerParams'\n              person: '#/components/schemas/PersonSamplerParams'\n              person_from_faker: '#/components/schemas/PersonFromFakerSamplerParams'\n              poisson: '#/components/schemas/PoissonSamplerParams'\n              scipy: '#/components/schemas/ScipySamplerParams'\n              subcategory: '#/components/schemas/SubcategorySamplerParams'\n              timedelta: '#/components/schemas/TimeDeltaSamplerParams'\n              uniform: '#/components/schemas/UniformSamplerParams'\n              uuid: '#/components/schemas/UUIDSamplerParams'\n        conditional_params:\n          additionalProperties:\n            oneOf:\n            - $ref: '#/components/schemas/SubcategorySamplerParams'\n            - $ref: '#/components/schemas/CategorySamplerParams'\n            - $ref: '#/components/schemas/DatetimeSamplerParams'\n            - $ref: '#/components/schemas/PersonSamplerParams'\n            - $ref: '#/components/schemas/PersonFromFakerSamplerParams'\n            - $ref: '#/components/schemas/TimeDeltaSamplerParams'\n            - $ref: '#/components/schemas/UUIDSamplerParams'\n            - $ref: '#/components/schemas/BernoulliSamplerParams'\n            - $ref: '#/components/schemas/BernoulliMixtureSamplerParams'\n            - $ref: '#/components/schemas/BinomialSamplerParams'\n            - $ref: '#/components/schemas/GaussianSamplerParams'\n            - $ref: '#/components/schemas/PoissonSamplerParams'\n            - $ref: '#/components/schemas/UniformSamplerParams'\n            - $ref: '#/components/schemas/ScipySamplerParams'\n            discriminator:\n              propertyName: sampler_type\n              mapping:\n                bernoulli: '#/components/schemas/BernoulliSamplerParams'\n                bernoulli_mixture: '#/components/schemas/BernoulliMixtureSamplerParams'\n                binomial: '#/components/schemas/BinomialSamplerParams'\n                category: '#/components/schemas/CategorySamplerParams'\n                datetime: '#/components/schemas/DatetimeSamplerParams'\n                gaussian: '#/components/schemas/GaussianSamplerParams'\n                person: '#/components/schemas/PersonSamplerParams'\n                person_from_faker: '#/components/schemas/PersonFromFakerSamplerParams'\n                poisson: '#/components/schemas/PoissonSamplerParams'\n                scipy: '#/components/schemas/ScipySamplerParams'\n                subcategory: '#/components/schemas/SubcategorySamplerParams'\n                timedelta: '#/components/schemas/TimeDeltaSamplerParams'\n                uniform: '#/components/schemas/UniformSamplerParams'\n                uuid: '#/components/schemas/UUIDSamplerParams'\n          type: object\n          title: Conditional Params\n          default: {}\n        convert_to:\n          type: string\n          title: Convert To\n      additionalProperties: false\n      type: object\n      required:\n      - name\n      - sampler_type\n      - params\n      title: SamplerColumnConfig\n      description: \"Configuration for columns generated using numerical samplers.\\n\\\n        \\nSampler columns provide efficient data generation using numerical samplers\\\n        \\ for\\ncommon data types and distributions. Supported samplers include UUID\\\n        \\ generation,\\ndatetime/timedelta sampling, person generation, category /\\\n        \\ subcategory sampling,\\nand various statistical distributions (uniform, gaussian,\\\n        \\ binomial, poisson, scipy).\\n\\nAttributes:\\n    sampler_type: Type of sampler\\\n        \\ to use. Available types include:\\n        \\\"uuid\\\", \\\"category\\\", \\\"subcategory\\\"\\\n        , \\\"uniform\\\", \\\"gaussian\\\", \\\"bernoulli\\\",\\n        \\\"bernoulli_mixture\\\"\\\n        , \\\"binomial\\\", \\\"poisson\\\", \\\"scipy\\\", \\\"person\\\", \\\"datetime\\\", \\\"timedelta\\\"\\\n        .\\n    params: Parameters specific to the chosen sampler type. Type varies\\\n        \\ based on the `sampler_type`\\n        (e.g., `CategorySamplerParams`, `UniformSamplerParams`,\\\n        \\ `PersonSamplerParams`).\\n    conditional_params: Optional dictionary for\\\n        \\ conditional parameters. The dict keys\\n        are the conditions that must\\\n        \\ be met (e.g., \\\"age > 21\\\") for the conditional parameters\\n        to be\\\n        \\ used. The values of dict are the parameters to use when the condition is\\\n        \\ met.\\n    convert_to: Optional type conversion to apply after sampling.\\\n        \\ Must be one of \\\"float\\\", \\\"int\\\", or \\\"str\\\".\\n        Useful for converting\\\n        \\ numerical samples to strings or other types.\\n    column_type: Discriminator\\\n        \\ field, always \\\"sampler\\\" for this configuration type.\\n\\n!!! tip \\\"Displaying\\\n        \\ available samplers and their parameters\\\"\\n    The config builder has an\\\n        \\ `info` attribute that can be used to display the\\n    available samplers\\\n        \\ and their parameters:\\n    ```python\\n    config_builder.info.display(\\\"\\\n        samplers\\\")\\n    ```\"\n    SamplerType:\n      type: string\n      enum:\n      - bernoulli\n      - bernoulli_mixture\n      - binomial\n      - category\n      - datetime\n      - gaussian\n      - person\n      - person_from_faker\n      - poisson\n      - scipy\n      - subcategory\n      - timedelta\n      - uniform\n      - uuid\n      title: SamplerType\n    SamplingStrategy:\n      type: string\n      enum:\n      - ordered\n      - shuffle\n      title: SamplingStrategy\n    ScalarInequalityConstraint:\n      properties:\n        target_column:\n          type: string\n          title: Target Column\n        rhs:\n          type: number\n          title: Rhs\n        operator:\n          $ref: '#/components/schemas/InequalityOperator'\n      additionalProperties: false\n      type: object\n      required:\n      - target_column\n      - rhs\n      - operator\n      title: ScalarInequalityConstraint\n    ScipySamplerParams:\n      properties:\n        dist_name:\n          type: string\n          title: Dist Name\n          description: Name of a scipy.stats distribution.\n        dist_params:\n          additionalProperties: true\n          type: object\n          title: Dist Params\n          description: Parameters of the scipy.stats distribution given in `dist_name`.\n        decimal_places:\n          type: integer\n          title: Decimal Places\n          description: Number of decimal places to round the sampled values to.\n        sampler_type:\n          type: string\n          const: scipy\n          title: Sampler Type\n          default: scipy\n      additionalProperties: false\n      type: object\n      required:\n      - dist_name\n      - dist_params\n      title: ScipySamplerParams\n      description: \"Parameters for sampling from any scipy.stats continuous or discrete\\\n        \\ distribution.\\n\\nProvides a flexible interface to sample from the wide range\\\n        \\ of probability distributions\\navailable in scipy.stats. This enables advanced\\\n        \\ statistical sampling beyond the built-in\\ndistribution types (Gaussian,\\\n        \\ Uniform, etc.).\\n\\nSee: [scipy.stats documentation](https://docs.scipy.org/doc/scipy/reference/stats.html)\\n\\\n        \\nAttributes:\\n    dist_name: Name of the scipy.stats distribution to sample\\\n        \\ from (e.g., \\\"beta\\\", \\\"gamma\\\",\\n        \\\"lognorm\\\", \\\"expon\\\"). Must\\\n        \\ be a valid distribution name from scipy.stats.\\n    dist_params: Dictionary\\\n        \\ of parameters for the specified distribution. Parameter names\\n        and\\\n        \\ values must match the scipy.stats distribution specification (e.g., {\\\"\\\n        a\\\": 2, \\\"b\\\": 5}\\n        for beta distribution, {\\\"scale\\\": 1.5} for exponential).\\n\\\n        \\    decimal_places: Optional number of decimal places to round sampled values\\\n        \\ to. If None,\\n        values are not rounded.\"\n    Score:\n      properties:\n        name:\n          type: string\n          title: Name\n          description: A clear name for this score.\n        description:\n          type: string\n          title: Description\n          description: An informative and detailed assessment guide for using this\n            score.\n        options:\n          additionalProperties:\n            type: string\n          type: object\n          title: Options\n          description: 'Score options in the format of {score: description}.'\n      additionalProperties: false\n      type: object\n      required:\n      - name\n      - description\n      - options\n      title: Score\n      description: \"Configuration for a \\\"score\\\" in an LLM judge evaluation.\\n\\n\\\n        Defines a single scoring criterion with its possible values and descriptions.\\\n        \\ Multiple\\nScore objects can be combined in an LLMJudgeColumnConfig to create\\\n        \\ multi-dimensional\\nquality assessments.\\n\\nAttributes:\\n    name: A clear,\\\n        \\ concise name for this scoring dimension (e.g., \\\"Relevance\\\", \\\"Fluency\\\"\\\n        ).\\n    description: An informative and detailed assessment guide explaining\\\n        \\ how to evaluate\\n        this dimension. Should provide clear criteria for\\\n        \\ scoring.\\n    options: Dictionary mapping score values to their descriptions.\\\n        \\ Keys can be integers\\n        (e.g., 1-5 scale) or strings (e.g., \\\"Poor\\\"\\\n        , \\\"Good\\\", \\\"Excellent\\\"). Values are\\n        descriptions explaining what\\\n        \\ each score level means.\"\n    SeedConfig:\n      properties:\n        dataset:\n          type: string\n          title: Dataset\n        sampling_strategy:\n          allOf:\n          - $ref: '#/components/schemas/SamplingStrategy'\n          default: ordered\n        selection_strategy:\n          anyOf:\n          - $ref: '#/components/schemas/IndexRange'\n          - $ref: '#/components/schemas/PartitionBlock'\n          title: Selection Strategy\n      additionalProperties: false\n      type: object\n      required:\n      - dataset\n      title: SeedConfig\n      description: \"Configuration for sampling data from a seed dataset.\\n\\nArgs:\\n\\\n        \\    dataset: Path or identifier for the seed dataset.\\n    sampling_strategy:\\\n        \\ Strategy for how to sample rows from the dataset.\\n        - ORDERED: Read\\\n        \\ rows sequentially in their original order.\\n        - SHUFFLE: Randomly\\\n        \\ shuffle rows before sampling. When used with\\n          selection_strategy,\\\n        \\ shuffling occurs within the selected range/partition.\\n    selection_strategy:\\\n        \\ Optional strategy to select a subset of the dataset.\\n        - IndexRange:\\\n        \\ Select a specific range of indices (e.g., rows 100-200).\\n        - PartitionBlock:\\\n        \\ Select a partition by splitting the dataset into N equal parts.\\n      \\\n        \\    Partition indices are zero-based (index=0 is the first partition, index=1\\\n        \\ is\\n          the second, etc.).\\n\\nExamples:\\n    Read rows sequentially\\\n        \\ from start to end:\\n        SeedConfig(dataset=\\\"my_data.parquet\\\", sampling_strategy=SamplingStrategy.ORDERED)\\n\\\n        \\n    Read rows in random order:\\n        SeedConfig(dataset=\\\"my_data.parquet\\\"\\\n        , sampling_strategy=SamplingStrategy.SHUFFLE)\\n\\n    Read specific index range\\\n        \\ (rows 100-199):\\n        SeedConfig(\\n            dataset=\\\"my_data.parquet\\\"\\\n        ,\\n            sampling_strategy=SamplingStrategy.ORDERED,\\n            selection_strategy=IndexRange(start=100,\\\n        \\ end=199)\\n        )\\n\\n    Read random rows from a specific index range\\\n        \\ (shuffles within rows 100-199):\\n        SeedConfig(\\n            dataset=\\\"\\\n        my_data.parquet\\\",\\n            sampling_strategy=SamplingStrategy.SHUFFLE,\\n\\\n        \\            selection_strategy=IndexRange(start=100, end=199)\\n        )\\n\\\n        \\n    Read from partition 2 (3rd partition, zero-based) of 5 partitions (20%\\\n        \\ of dataset):\\n        SeedConfig(\\n            dataset=\\\"my_data.parquet\\\"\\\n        ,\\n            sampling_strategy=SamplingStrategy.ORDERED,\\n            selection_strategy=PartitionBlock(index=2,\\\n        \\ num_partitions=5)\\n        )\\n\\n    Read shuffled rows from partition 0\\\n        \\ of 10 partitions (shuffles within the partition):\\n        SeedConfig(\\n\\\n        \\            dataset=\\\"my_data.parquet\\\",\\n            sampling_strategy=SamplingStrategy.SHUFFLE,\\n\\\n        \\            selection_strategy=PartitionBlock(index=0, num_partitions=10)\\n\\\n        \\        )\"\n    SeedDatasetColumnConfig:\n      properties:\n        name:\n          type: string\n          title: Name\n        drop:\n          type: boolean\n          title: Drop\n          default: false\n        column_type:\n          type: string\n          const: seed-dataset\n          title: Column Type\n          default: seed-dataset\n      additionalProperties: false\n      type: object\n      required:\n      - name\n      title: SeedDatasetColumnConfig\n      description: \"Configuration for columns sourced from seed datasets.\\n\\nThis\\\n        \\ config marks columns that come from seed data. It is typically created\\n\\\n        automatically when calling `with_seed_dataset()` on the builder, rather than\\n\\\n        being instantiated directly by users.\\n\\nAttributes:\\n    column_type: Discriminator\\\n        \\ field, always \\\"seed-dataset\\\" for this configuration type.\"\n    SettingsDefaults:\n      properties:\n        model_configs:\n          items:\n            $ref: '#/components/schemas/ModelConfigOutput'\n          type: array\n          title: Model Configs\n        model_provider:\n          type: string\n          title: Model Provider\n      type: object\n      required:\n      - model_configs\n      - model_provider\n      title: SettingsDefaults\n    SettingsResponse:\n      properties:\n        defaults:\n          $ref: '#/components/schemas/SettingsDefaults'\n        model_providers:\n          items:\n            $ref: '#/components/schemas/DisplayModelProvider'\n          type: array\n          title: Model Providers\n      type: object\n      required:\n      - defaults\n      - model_providers\n      title: SettingsResponse\n    SubcategorySamplerParams:\n      properties:\n        category:\n          type: string\n          title: Category\n          description: Name of parent category to this subcategory.\n        values:\n          additionalProperties:\n            items:\n              anyOf:\n              - type: string\n              - type: integer\n              - type: number\n            type: array\n          type: object\n          title: Values\n          description: Mapping from each value of parent category to a list of subcategory\n            values.\n        sampler_type:\n          type: string\n          const: subcategory\n          title: Sampler Type\n          default: subcategory\n      additionalProperties: false\n      type: object\n      required:\n      - category\n      - values\n      title: SubcategorySamplerParams\n      description: \"Parameters for subcategory sampling conditioned on a parent category\\\n        \\ column.\\n\\nSamples subcategory values based on the value of a parent category\\\n        \\ column. Each parent\\ncategory value maps to its own list of possible subcategory\\\n        \\ values, enabling hierarchical\\nor conditional sampling patterns.\\n\\nAttributes:\\n\\\n        \\    category: Name of the parent category column that this subcategory depends\\\n        \\ on.\\n        The parent column must be generated before this subcategory\\\n        \\ column.\\n    values: Mapping from each parent category value to a list of\\\n        \\ possible subcategory values.\\n        Each key must correspond to a value\\\n        \\ that appears in the parent category column.\"\n    TimeDeltaSamplerParams:\n      properties:\n        dt_min:\n          type: integer\n          minimum: 0.0\n          title: Dt Min\n          description: Minimum possible time-delta for sampling range, inclusive.\n            Must be less than `dt_max`.\n        dt_max:\n          type: integer\n          exclusiveMinimum: 0.0\n          title: Dt Max\n          description: Maximum possible time-delta for sampling range, exclusive.\n            Must be greater than `dt_min`.\n        reference_column_name:\n          type: string\n          title: Reference Column Name\n          description: Name of an existing datetime column to condition time-delta\n            sampling on.\n        unit:\n          type: string\n          enum:\n          - D\n          - h\n          - m\n          - s\n          title: Unit\n          description: Sampling units, e.g. the smallest possible time interval between\n            samples.\n          default: D\n        sampler_type:\n          type: string\n          const: timedelta\n          title: Sampler Type\n          default: timedelta\n      additionalProperties: false\n      type: object\n      required:\n      - dt_min\n      - dt_max\n      - reference_column_name\n      title: TimeDeltaSamplerParams\n      description: \"Parameters for sampling time deltas relative to a reference datetime\\\n        \\ column.\\n\\nSamples time offsets within a specified range and adds them to\\\n        \\ values from a reference\\ndatetime column. This is useful for generating\\\n        \\ related datetime columns like order dates\\nand delivery dates, or event\\\n        \\ start times and end times.\\n\\nNote:\\n    Years and months are not supported\\\n        \\ as timedelta units because they have variable lengths.\\n    See: [pandas\\\n        \\ timedelta documentation](https://pandas.pydata.org/docs/user_guide/timedeltas.html)\\n\\\n        \\nAttributes:\\n    dt_min: Minimum time-delta value (inclusive). Must be non-negative\\\n        \\ and less than `dt_max`.\\n        Specified in units defined by the `unit`\\\n        \\ parameter.\\n    dt_max: Maximum time-delta value (exclusive). Must be positive\\\n        \\ and greater than `dt_min`.\\n        Specified in units defined by the `unit`\\\n        \\ parameter.\\n    reference_column_name: Name of an existing datetime column\\\n        \\ to add the time-delta to.\\n        This column must be generated before\\\n        \\ the timedelta column.\\n    unit: Time unit for the delta values. Options:\\n\\\n        \\        - \\\"D\\\": Days (default)\\n        - \\\"h\\\": Hours\\n        - \\\"m\\\"\\\n        : Minutes\\n        - \\\"s\\\": Seconds\"\n    UUIDSamplerParams:\n      properties:\n        prefix:\n          type: string\n          title: Prefix\n          description: String prepended to the front of the UUID.\n        short_form:\n          type: boolean\n          title: Short Form\n          description: If true, all UUIDs sampled will be truncated at 8 characters.\n          default: false\n        uppercase:\n          type: boolean\n          title: Uppercase\n          description: If true, all letters in the UUID will be capitalized.\n          default: false\n        sampler_type:\n          type: string\n          const: uuid\n          title: Sampler Type\n          default: uuid\n      additionalProperties: false\n      type: object\n      title: UUIDSamplerParams\n      description: \"Parameters for generating UUID (Universally Unique Identifier)\\\n        \\ values.\\n\\nGenerates UUID4 (random) identifiers with optional formatting\\\n        \\ options. UUIDs are useful\\nfor creating unique identifiers for records,\\\n        \\ entities, or transactions.\\n\\nAttributes:\\n    prefix: Optional string to\\\n        \\ prepend to each UUID. Useful for creating namespaced or\\n        typed identifiers\\\n        \\ (e.g., \\\"user-\\\", \\\"order-\\\", \\\"txn-\\\").\\n    short_form: If True, truncates\\\n        \\ UUIDs to 8 characters (first segment only). Default is False\\n        for\\\n        \\ full 32-character UUIDs (excluding hyphens).\\n    uppercase: If True, converts\\\n        \\ all hexadecimal letters to uppercase. Default is False for\\n        lowercase\\\n        \\ UUIDs.\"\n    UniformDistribution:\n      properties:\n        distribution_type:\n          allOf:\n          - $ref: '#/components/schemas/DistributionType'\n          default: uniform\n        params:\n          $ref: '#/components/schemas/UniformDistributionParams'\n      additionalProperties: false\n      type: object\n      required:\n      - params\n      title: UniformDistribution\n    UniformDistributionParams:\n      properties:\n        low:\n          type: number\n          title: Low\n        high:\n          type: number\n          title: High\n      additionalProperties: false\n      type: object\n      required:\n      - low\n      - high\n      title: UniformDistributionParams\n    UniformSamplerParams:\n      properties:\n        low:\n          type: number\n          title: Low\n          description: Lower bound of the uniform distribution, inclusive.\n        high:\n          type: number\n          title: High\n          description: Upper bound of the uniform distribution, inclusive.\n        decimal_places:\n          type: integer\n          title: Decimal Places\n          description: Number of decimal places to round the sampled values to.\n        sampler_type:\n          type: string\n          const: uniform\n          title: Sampler Type\n          default: uniform\n      additionalProperties: false\n      type: object\n      required:\n      - low\n      - high\n      title: UniformSamplerParams\n      description: \"Parameters for sampling from a continuous Uniform distribution.\\n\\\n        \\nSamples continuous values uniformly from a specified range, where every\\\n        \\ value in the range\\nhas equal probability of being sampled. This is useful\\\n        \\ when all values within a range are\\nequally likely, such as random percentages,\\\n        \\ proportions, or unbiased measurements.\\n\\nAttributes:\\n    low: Lower bound\\\n        \\ of the uniform distribution (inclusive). Can be any real number.\\n    high:\\\n        \\ Upper bound of the uniform distribution (inclusive). Must be greater than\\\n        \\ `low`.\\n    decimal_places: Optional number of decimal places to round sampled\\\n        \\ values to. If None,\\n        values are not rounded and may have many decimal\\\n        \\ places.\"\n    ValidationColumnConfig:\n      properties:\n        name:\n          type: string\n          title: Name\n        drop:\n          type: boolean\n          title: Drop\n          default: false\n        column_type:\n          type: string\n          const: validation\n          title: Column Type\n          default: validation\n        target_columns:\n          items:\n            type: string\n          type: array\n          title: Target Columns\n        validator_type:\n          $ref: '#/components/schemas/ValidatorType'\n        validator_params:\n          anyOf:\n          - $ref: '#/components/schemas/CodeValidatorParams'\n          - $ref: '#/components/schemas/LocalCallableValidatorParams'\n          - $ref: '#/components/schemas/RemoteValidatorParams'\n          title: Validator Params\n        batch_size:\n          type: integer\n          minimum: 1.0\n          title: Batch Size\n          description: Number of records to process in each batch\n          default: 10\n      additionalProperties: false\n      type: object\n      required:\n      - name\n      - target_columns\n      - validator_type\n      - validator_params\n      title: ValidationColumnConfig\n      description: \"Configuration for validation columns that validate existing columns.\\n\\\n        \\nValidation columns execute validation logic against specified target columns\\\n        \\ and return\\nstructured results indicating pass/fail status with validation\\\n        \\ details. Supports multiple\\nvalidation strategies: code execution (Python/SQL),\\\n        \\ local callable functions (library only),\\nand remote HTTP endpoints.\\n\\n\\\n        Attributes:\\n    target_columns: List of column names to validate. These columns\\\n        \\ are passed to the\\n        validator for validation. All target columns\\\n        \\ must exist in the dataset\\n        before validation runs.\\n    validator_type:\\\n        \\ The type of validator to use. Options:\\n        - \\\"code\\\": Execute code\\\n        \\ (Python or SQL) for validation. The code receives a\\n          DataFrame\\\n        \\ with target columns and must return a DataFrame with validation results.\\n\\\n        \\        - \\\"local_callable\\\": Call a local Python function with the data.\\\n        \\ Only supported\\n          when running DataDesigner locally.\\n        -\\\n        \\ \\\"remote\\\": Send data to a remote HTTP endpoint for validation. Useful for\\n\\\n        \\    validator_params: Parameters specific to the validator type. Type varies\\\n        \\ by validator:\\n        - CodeValidatorParams: Specifies code language (python\\\n        \\ or SQL dialect like\\n          \\\"sql:postgres\\\", \\\"sql:mysql\\\").\\n     \\\n        \\   - LocalCallableValidatorParams: Provides validation function (Callable[[pd.DataFrame],\\n\\\n        \\          pd.DataFrame]) and optional output schema for validation results.\\n\\\n        \\        - RemoteValidatorParams: Configures endpoint URL, HTTP timeout, retry\\\n        \\ behavior\\n          (max_retries, retry_backoff), and parallel request limits\\\n        \\ (max_parallel_requests).\\n    batch_size: Number of records to process in\\\n        \\ each validation batch. Defaults to 10.\\n        Larger batches are more\\\n        \\ efficient but use more memory. Adjust based on validator\\n        complexity\\\n        \\ and available resources.\\n    column_type: Discriminator field, always \\\"\\\n        validation\\\" for this configuration type.\"\n    ValidationError:\n      properties:\n        loc:\n          items:\n            anyOf:\n            - type: string\n            - type: integer\n          type: array\n          title: Location\n        msg:\n          type: string\n          title: Message\n        type:\n          type: string\n          title: Error Type\n      type: object\n      required:\n      - loc\n      - msg\n      - type\n      title: ValidationError\n    ValidatorType:\n      type: string\n      enum:\n      - code\n      - local_callable\n      - remote\n      title: ValidatorType\ntags:\n- name: Data Designer\n  description: Operations related to synthetic data generation.\n- name: Health Checks\n  description: Operations related to NeMo Microservices platform health.\n"
  },
  {
    "path": "studio/frontend/eslint.config.js",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport js from \"@eslint/js\";\nimport reactHooks from \"eslint-plugin-react-hooks\";\nimport reactRefresh from \"eslint-plugin-react-refresh\";\nimport { defineConfig, globalIgnores } from \"eslint/config\";\nimport globals from \"globals\";\nimport tseslint from \"typescript-eslint\";\n\nexport default defineConfig([\n  globalIgnores([\"dist\", \"**/._*\"]),\n  {\n    files: [\"**/*.{ts,tsx}\"],\n    extends: [\n      js.configs.recommended,\n      tseslint.configs.recommended,\n      reactHooks.configs.flat.recommended,\n      reactRefresh.configs.vite,\n    ],\n    languageOptions: {\n      ecmaVersion: 2020,\n      globals: globals.browser,\n    },\n    rules: {\n      // Allow shadcn ui components to export variants\n      \"react-refresh/only-export-components\": [\n        \"warn\",\n        { allowConstantExport: true },\n      ],\n      // Import restrictions for architecture enforcement\n      \"no-restricted-imports\": [\n        \"error\",\n        {\n          patterns: [\n            // Prevent cross-feature imports\n            {\n              group: [\"@/features/*/*\"],\n              message: \"Import from feature index only: @/features/[name]\",\n            },\n            // Prevent app layer from importing features internals\n            {\n              group: [\"../features/*/**\"],\n              message: \"Use absolute imports: @/features/[name]\",\n            },\n          ],\n        },\n      ],\n    },\n  },\n]);\n"
  },
  {
    "path": "studio/frontend/index.html",
    "content": "<!doctype html>\r\n<!-- SPDX-License-Identifier: AGPL-3.0-only -->\n<!-- Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 -->\n\n<html lang=\"en\">\r\n  <head>\r\n    <meta charset=\"UTF-8\" />\r\n    <link rel=\"icon\" type=\"image/png\" href=\"/favicon.png\" />\r\n    <meta name=\"viewport\" content=\"width=device-width, initial-scale=1.0\" />\r\n    <title>Unsloth Studio</title>\r\n  </head>\r\n  <body>\r\n    <div id=\"root\"></div>\r\n    <script type=\"module\" src=\"/src/main.tsx\"></script>\r\n  </body>\r\n</html>\r\n"
  },
  {
    "path": "studio/frontend/package.json",
    "content": "{\n  \"name\": \"unsloth-theme\",\n  \"private\": true,\n  \"version\": \"0.0.0\",\n  \"type\": \"module\",\n  \"scripts\": {\n    \"dev\": \"vite\",\n    \"build\": \"tsc -b && vite build\",\n    \"lint\": \"eslint .\",\n    \"preview\": \"vite preview\",\n    \"typecheck\": \"tsc -b --pretty false\",\n    \"biome:check\": \"biome check .\",\n    \"biome:fix\": \"biome check . --write\"\n  },\n  \"dependencies\": {\n    \"@assistant-ui/react\": \"^0.12.19\",\n    \"@assistant-ui/react-markdown\": \"^0.12.3\",\n    \"@assistant-ui/react-streamdown\": \"^0.1.2\",\n    \"@base-ui/react\": \"^1.2.0\",\n    \"@dagrejs/dagre\": \"^2.0.4\",\n    \"@dagrejs/graphlib\": \"^3.0.4\",\n    \"@fontsource-variable/figtree\": \"^5.2.10\",\n    \"@fontsource-variable/inter\": \"^5.2.8\",\n    \"@fontsource-variable/space-grotesk\": \"^5.2.10\",\n    \"@hugeicons/core-free-icons\": \"^3.1.1\",\n    \"@hugeicons/react\": \"^1.1.5\",\n    \"@huggingface/hub\": \"^2.9.0\",\n    \"@langchain/core\": \"^1.1.27\",\n    \"@radix-ui/react-checkbox\": \"^1.3.3\",\n    \"@radix-ui/react-label\": \"^2.1.8\",\n    \"@radix-ui/react-select\": \"^2.2.6\",\n    \"@radix-ui/react-separator\": \"^1.1.8\",\n    \"@radix-ui/react-slot\": \"^1.2.4\",\n    \"@streamdown/cjk\": \"1.0.2\",\n    \"@streamdown/code\": \"1.0.2\",\n    \"@streamdown/math\": \"1.0.2\",\n    \"@streamdown/mermaid\": \"1.0.2\",\n    \"@tailwindcss/vite\": \"^4.1.18\",\n    \"@tanstack/react-router\": \"^1.159.10\",\n    \"@tanstack/react-table\": \"^8.21.3\",\n    \"@toolwind/corner-shape\": \"^0.0.8-3\",\n    \"@types/canvas-confetti\": \"^1.9.0\",\n    \"@xyflow/react\": \"^12.10.0\",\n    \"assistant-stream\": \"^0.3.2\",\n    \"canvas-confetti\": \"^1.9.4\",\n    \"class-variance-authority\": \"^0.7.1\",\n    \"clsx\": \"^2.1.1\",\n    \"cmdk\": \"^1.1.1\",\n    \"date-fns\": \"^4.1.0\",\n    \"dexie\": \"^4.3.0\",\n    \"framer-motion\": \"^11.18.2\",\n    \"js-yaml\": \"^4.1.1\",\n    \"katex\": \"^0.16.28\",\n    \"lucide-react\": \"^0.577.0\",\n    \"mammoth\": \"^1.11.0\",\n    \"motion\": \"^12.34.0\",\n    \"next\": \"^16.1.6\",\n    \"next-themes\": \"^0.4.6\",\n    \"radix-ui\": \"^1.4.3\",\n    \"react\": \"^19.2.4\",\n    \"react-day-picker\": \"^9.13.2\",\n    \"react-dom\": \"^19.2.4\",\n    \"react-resizable-panels\": \"^4.6.4\",\n    \"recharts\": \"3.7.0\",\n    \"remark-gfm\": \"^4.0.1\",\n    \"shadcn\": \"^3.8.4\",\n    \"sonner\": \"^2.0.7\",\n    \"streamdown\": \"2.3.0\",\n    \"tailwind-merge\": \"^3.4.0\",\n    \"tailwindcss\": \"^4.1.18\",\n    \"tw-animate-css\": \"^1.4.0\",\n    \"tw-shimmer\": \"^0.4.6\",\n    \"unpdf\": \"^1.4.0\",\n    \"zustand\": \"^5.0.11\"\n  },\n  \"devDependencies\": {\n    \"@biomejs/biome\": \"^1.9.4\",\n    \"@eslint/js\": \"^9.39.1\",\n    \"@types/js-yaml\": \"^4.0.9\",\n    \"@types/node\": \"^24.10.1\",\n    \"@types/react\": \"^19.2.5\",\n    \"@types/react-dom\": \"^19.2.3\",\n    \"@vitejs/plugin-react\": \"^5.1.1\",\n    \"eslint\": \"^9.39.1\",\n    \"eslint-plugin-react-hooks\": \"^7.0.1\",\n    \"eslint-plugin-react-refresh\": \"^0.4.26\",\n    \"globals\": \"^16.5.0\",\n    \"typescript\": \"~5.9.3\",\n    \"typescript-eslint\": \"^8.55.0\",\n    \"vite\": \"^7.3.1\"\n  }\n}\n"
  },
  {
    "path": "studio/frontend/src/app/app.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { RouterProvider } from \"@tanstack/react-router\";\nimport { router } from \"./router\";\n\nexport function App() {\n  return <RouterProvider router={router} />;\n}\n"
  },
  {
    "path": "studio/frontend/src/app/auth-guards.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { redirect } from \"@tanstack/react-router\";\nimport {\n  getPostAuthRoute,\n  hasAuthToken,\n  hasRefreshToken,\n  mustChangePassword,\n  refreshSession,\n} from \"@/features/auth\";\n\nasync function hasActiveSession(): Promise<boolean> {\n  if (hasAuthToken()) return true;\n  if (!hasRefreshToken()) return false;\n  return refreshSession();\n}\n\nasync function checkAuthInitialized(): Promise<boolean> {\n  try {\n    const res = await fetch(\"/api/auth/status\");\n    if (!res.ok) return true; // fallback to login on error\n    const data = (await res.json()) as { initialized: boolean };\n    return data.initialized;\n  } catch {\n    return true; // fallback to login on error\n  }\n}\n\nasync function checkPasswordChangeRequired(): Promise<boolean> {\n  try {\n    const res = await fetch(\"/api/auth/status\");\n    if (!res.ok) return mustChangePassword();\n    const data = (await res.json()) as { requires_password_change: boolean };\n    return data.requires_password_change || mustChangePassword();\n  } catch {\n    return mustChangePassword();\n  }\n}\n\nexport async function requireAuth(): Promise<void> {\n  if (await hasActiveSession()) {\n    if (await checkPasswordChangeRequired()) {\n      throw redirect({ to: \"/change-password\" });\n    }\n    return;\n  }\n  const requiresPasswordChange = await checkPasswordChangeRequired();\n  if (requiresPasswordChange) throw redirect({ to: \"/change-password\" });\n  const initialized = await checkAuthInitialized();\n  throw redirect({ to: initialized ? \"/login\" : \"/change-password\" });\n}\n\nexport async function requireGuest(): Promise<void> {\n  if (!(await hasActiveSession())) return;\n  throw redirect({ to: getPostAuthRoute() });\n}\n\nexport async function requirePasswordChangeFlow(): Promise<void> {\n  const requiresPasswordChange = await checkPasswordChangeRequired();\n\n  if (requiresPasswordChange) return;\n\n  if (await hasActiveSession()) {\n    throw redirect({ to: getPostAuthRoute() });\n  }\n\n  const initialized = await checkAuthInitialized();\n  throw redirect({ to: initialized ? \"/login\" : \"/change-password\" });\n}\n"
  },
  {
    "path": "studio/frontend/src/app/provider.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Toaster } from \"@/components/ui/sonner\";\nimport { ThemeProvider } from \"next-themes\";\nimport type { ReactNode } from \"react\";\n\ninterface AppProviderProps {\n  children: ReactNode;\n}\n\nexport function AppProvider({ children }: AppProviderProps) {\n  return (\n    <ThemeProvider attribute=\"class\" defaultTheme=\"light\">\n      {children}\n      <Toaster position=\"top-right\" visibleToasts={2} expand={true} />\n    </ThemeProvider>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/app/router.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { createRouter } from \"@tanstack/react-router\";\nimport { Route as rootRoute } from \"./routes/__root\";\nimport { Route as dataRecipesRoute } from \"./routes/data-recipes\";\nimport { Route as dataRecipeRoute } from \"./routes/data-recipes.$recipeId\";\nimport { Route as chatRoute } from \"./routes/chat\";\nimport { Route as exportRoute } from \"./routes/export\";\nimport { Route as gridTestRoute } from \"./routes/grid-test\";\nimport { Route as indexRoute } from \"./routes/index\";\nimport { Route as loginRoute } from \"./routes/login\";\nimport { Route as onboardingRoute } from \"./routes/onboarding\";\nimport { Route as changePasswordRoute } from \"./routes/change-password\";\nimport { Route as studioRoute } from \"./routes/studio\";\n\nconst routeTree = rootRoute.addChildren([\n  indexRoute,\n  onboardingRoute,\n  loginRoute,\n  changePasswordRoute,\n  gridTestRoute,\n  studioRoute,\n  chatRoute,\n  exportRoute,\n  dataRecipesRoute,\n  dataRecipeRoute,\n]);\n\nexport const router = createRouter({ routeTree });\n\ndeclare module \"@tanstack/react-router\" {\n  interface Register {\n    router: typeof router;\n  }\n}\n"
  },
  {
    "path": "studio/frontend/src/app/routes/__root.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Navbar } from \"@/components/navbar\";\nimport { usePlatformStore } from \"@/config/env\";\nimport {\n  Outlet,\n  createRootRoute,\n  redirect,\n  useRouterState,\n} from \"@tanstack/react-router\";\nimport { AnimatePresence, motion } from \"motion/react\";\nimport { Suspense } from \"react\";\nimport { AppProvider } from \"../provider\";\n\nconst CHAT_ONLY_ALLOWED = new Set([\"/\", \"/chat\", \"/login\", \"/signup\", \"/change-password\"]);\n\nfunction isChatOnlyAllowed(pathname: string): boolean {\n  if (CHAT_ONLY_ALLOWED.has(pathname)) return true;\n  if (pathname === \"/data-recipes\" || pathname.startsWith(\"/data-recipes/\")) return true;\n  return false;\n}\n\nexport const Route = createRootRoute({\n  beforeLoad: ({ location }) => {\n    const chatOnly = usePlatformStore.getState().isChatOnly();\n    if (chatOnly && !isChatOnlyAllowed(location.pathname)) {\n      throw redirect({ to: \"/chat\" });\n    }\n  },\n  component: RootLayout,\n});\n\nconst HIDDEN_NAVBAR_ROUTES = [\"/onboarding\", \"/login\", \"/change-password\"];\n\nfunction RootLayout() {\n  const pathname = useRouterState({ select: (s) => s.location.pathname });\n  const hideNavbar = HIDDEN_NAVBAR_ROUTES.includes(pathname);\n\n  return (\n    <AppProvider>\n      {!hideNavbar && <Navbar />}\n      <AnimatePresence initial={false}>\n        <motion.div\n          key={pathname}\n          initial={{ opacity: 0 }}\n          animate={{ opacity: 1 }}\n          transition={{ duration: 0.15 }}\n          className=\"flex-1\"\n        >\n          <Suspense fallback={null}>\n            <Outlet />\n          </Suspense>\n        </motion.div>\n      </AnimatePresence>\n    </AppProvider>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/app/routes/change-password.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { createRoute } from \"@tanstack/react-router\";\nimport { lazy } from \"react\";\nimport { requirePasswordChangeFlow } from \"../auth-guards\";\nimport { Route as rootRoute } from \"./__root\";\n\nconst ChangePasswordPage = lazy(() =>\n  import(\"@/features/auth\").then((m) => ({\n    default: m.ChangePasswordPage,\n  })),\n);\n\nexport const Route = createRoute({\n  getParentRoute: () => rootRoute,\n  path: \"/change-password\",\n  beforeLoad: () => requirePasswordChangeFlow(),\n  component: ChangePasswordPage,\n});\n"
  },
  {
    "path": "studio/frontend/src/app/routes/chat.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { createRoute } from \"@tanstack/react-router\";\nimport { lazy } from \"react\";\nimport { requireAuth } from \"../auth-guards\";\nimport { Route as rootRoute } from \"./__root\";\n\nconst ChatPage = lazy(() =>\n  import(\"@/features/chat/chat-page\").then((m) => ({ default: m.ChatPage })),\n);\n\nexport const Route = createRoute({\n  getParentRoute: () => rootRoute,\n  path: \"/chat\",\n  beforeLoad: () => requireAuth(),\n  component: ChatPage,\n});\n"
  },
  {
    "path": "studio/frontend/src/app/routes/data-recipes.$recipeId.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { createRoute } from \"@tanstack/react-router\";\nimport type { ReactElement } from \"react\";\nimport { lazy } from \"react\";\nimport { requireAuth } from \"../auth-guards\";\nimport { Route as rootRoute } from \"./__root\";\n\nconst EditRecipePage = lazy(() =>\n  import(\"@/features/data-recipes\").then((m) => ({\n    default: m.EditRecipePage,\n  })),\n);\n\nexport const Route = createRoute({\n  getParentRoute: () => rootRoute,\n  path: \"/data-recipes/$recipeId\",\n  beforeLoad: () => requireAuth(),\n  component: DataRecipeEditorRoute,\n});\n\nfunction DataRecipeEditorRoute(): ReactElement {\n  const { recipeId } = Route.useParams();\n  return <EditRecipePage recipeId={recipeId} />;\n}\n"
  },
  {
    "path": "studio/frontend/src/app/routes/data-recipes.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { createRoute } from \"@tanstack/react-router\";\nimport { lazy } from \"react\";\nimport { requireAuth } from \"../auth-guards\";\nimport { Route as rootRoute } from \"./__root\";\n\nconst DataRecipesPage = lazy(() =>\n  import(\"@/features/data-recipes\").then((m) => ({\n    default: m.DataRecipesPage,\n  })),\n);\n\nexport const Route = createRoute({\n  getParentRoute: () => rootRoute,\n  path: \"/data-recipes\",\n  beforeLoad: () => requireAuth(),\n  component: DataRecipesPage,\n});\n"
  },
  {
    "path": "studio/frontend/src/app/routes/export.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { createRoute } from \"@tanstack/react-router\";\nimport { lazy } from \"react\";\nimport { requireAuth } from \"../auth-guards\";\nimport { Route as rootRoute } from \"./__root\";\n\nconst ExportPage = lazy(() =>\n  import(\"@/features/export/export-page\").then((m) => ({\n    default: m.ExportPage,\n  })),\n);\n\nexport const Route = createRoute({\n  getParentRoute: () => rootRoute,\n  path: \"/export\",\n  beforeLoad: () => requireAuth(),\n  component: ExportPage,\n});\n"
  },
  {
    "path": "studio/frontend/src/app/routes/grid-test.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { DashboardGrid, DashboardLayout } from \"@/components/layout\";\nimport {\n  Card,\n  CardContent,\n  CardDescription,\n  CardHeader,\n  CardTitle,\n} from \"@/components/ui/card\";\nimport { createRoute } from \"@tanstack/react-router\";\nimport { requireAuth } from \"../auth-guards\";\nimport { Route as rootRoute } from \"./__root\";\n\nexport const Route = createRoute({\n  getParentRoute: () => rootRoute,\n  path: \"/grid-test\",\n  beforeLoad: () => requireAuth(),\n  component: GridTestPage,\n});\n\nfunction GridTestPage() {\n  return (\n    <DashboardLayout>\n      <div className=\"space-y-8\">\n        <div>\n          <h1 className=\"text-2xl font-semibold\">Grid Test - 3 Columns</h1>\n          <p className=\"text-muted-foreground\">\n            max-w-7xl, gap-6, responsive 1→2→3\n          </p>\n        </div>\n\n        <DashboardGrid cols={3}>\n          {[1, 2, 3].map((i) => (\n            <Card key={i}>\n              <CardHeader>\n                <CardTitle>Card {i}</CardTitle>\n                <CardDescription>~400px at 1280px viewport</CardDescription>\n              </CardHeader>\n              <CardContent>\n                <div className=\"h-24 rounded-lg bg-muted\" />\n              </CardContent>\n            </Card>\n          ))}\n        </DashboardGrid>\n\n        <div>\n          <h2 className=\"text-xl font-semibold\">4 Columns</h2>\n          <p className=\"text-muted-foreground\">~296px per card at 1280px</p>\n        </div>\n\n        <DashboardGrid cols={4}>\n          {[1, 2, 3, 4].map((i) => (\n            <Card key={i} size=\"sm\">\n              <CardHeader>\n                <CardTitle>Card {i}</CardTitle>\n                <CardDescription>Smaller cards</CardDescription>\n              </CardHeader>\n              <CardContent>\n                <div className=\"h-16 rounded-lg bg-muted\" />\n              </CardContent>\n            </Card>\n          ))}\n        </DashboardGrid>\n      </div>\n    </DashboardLayout>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/app/routes/index.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { createRoute, redirect } from \"@tanstack/react-router\";\nimport { getPostAuthRoute } from \"@/features/auth\";\nimport { requireAuth } from \"../auth-guards\";\nimport { Route as rootRoute } from \"./__root\";\n\nexport const Route = createRoute({\n  getParentRoute: () => rootRoute,\n  path: \"/\",\n  beforeLoad: async () => {\n    await requireAuth();\n    throw redirect({ to: getPostAuthRoute() });\n  },\n  component: () => null,\n});\n"
  },
  {
    "path": "studio/frontend/src/app/routes/login.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { createRoute } from \"@tanstack/react-router\";\nimport { lazy } from \"react\";\nimport { requireGuest } from \"../auth-guards\";\nimport { Route as rootRoute } from \"./__root\";\n\nconst LoginPage = lazy(() =>\n  import(\"@/features/auth\").then((m) => ({ default: m.LoginPage })),\n);\n\nexport const Route = createRoute({\n  getParentRoute: () => rootRoute,\n  path: \"/login\",\n  beforeLoad: () => requireGuest(),\n  component: LoginPage,\n});\n"
  },
  {
    "path": "studio/frontend/src/app/routes/onboarding.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { createRoute } from \"@tanstack/react-router\";\nimport { lazy } from \"react\";\nimport { requireAuth } from \"../auth-guards\";\nimport { Route as rootRoute } from \"./__root\";\n\nconst WizardLayout = lazy(() =>\n  import(\"@/features/onboarding/components/wizard-layout\").then((m) => ({\n    default: m.WizardLayout,\n  })),\n);\n\nexport const Route = createRoute({\n  getParentRoute: () => rootRoute,\n  path: \"/onboarding\",\n  beforeLoad: () => requireAuth(),\n  component: WizardLayout,\n});\n"
  },
  {
    "path": "studio/frontend/src/app/routes/studio.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { createRoute } from \"@tanstack/react-router\";\nimport { lazy } from \"react\";\nimport { requireAuth } from \"../auth-guards\";\nimport { Route as rootRoute } from \"./__root\";\n\nconst StudioPage = lazy(() =>\n  import(\"@/features/studio/studio-page\").then((m) => ({\n    default: m.StudioPage,\n  })),\n);\n\nexport const Route = createRoute({\n  getParentRoute: () => rootRoute,\n  path: \"/studio\",\n  beforeLoad: () => requireAuth(),\n  component: StudioPage,\n});\n"
  },
  {
    "path": "studio/frontend/src/components/assistant-ui/attachment.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"use client\";\n\n// Avatar removed — caused circular crop on image thumbnails\nimport { TooltipIconButton } from \"@/components/assistant-ui/tooltip-icon-button\";\nimport {\n  Dialog,\n  DialogContent,\n  DialogTitle,\n  DialogTrigger,\n} from \"@/components/ui/dialog\";\nimport {\n  Tooltip,\n  TooltipContent,\n  TooltipTrigger,\n} from \"@/components/ui/tooltip\";\nimport { cn } from \"@/lib/utils\";\nimport {\n  AttachmentPrimitive,\n  ComposerPrimitive,\n  MessagePrimitive,\n  useAui,\n  useAuiState,\n} from \"@assistant-ui/react\";\nimport { FileText, PlusIcon, XIcon } from \"lucide-react\";\nimport {\n  type FC,\n  type PropsWithChildren,\n  useEffect,\n  useState,\n} from \"react\";\nimport { useShallow } from \"zustand/shallow\";\n\nconst useFileSrc = (file: File | undefined): string | undefined => {\n  const [objectUrl, setObjectUrl] = useState<string | undefined>(undefined);\n\n  useEffect(() => {\n    if (!file) {\n      setObjectUrl(undefined);\n      return;\n    }\n    const url = URL.createObjectURL(file);\n    setObjectUrl(url);\n    return () => URL.revokeObjectURL(url);\n  }, [file]);\n\n  return objectUrl;\n};\n\nconst useAttachmentSrc = (): string | undefined => {\n  const { file, src } = useAuiState(\n    useShallow(({ attachment }): { file?: File; src?: string } => {\n      if (attachment.type !== \"image\") {\n        return {};\n      }\n      if (attachment.file) {\n        return { file: attachment.file };\n      }\n      const src = attachment.content?.filter((c) => c.type === \"image\")[0]\n        ?.image;\n      if (!src) {\n        return {};\n      }\n      return { src };\n    }),\n  );\n\n  return useFileSrc(file) ?? src;\n};\n\ntype AttachmentPreviewProps = {\n  src: string;\n};\n\nconst AttachmentPreview: FC<AttachmentPreviewProps> = ({ src }) => {\n  const [isLoaded, setIsLoaded] = useState(false);\n  return (\n    <img\n      src={src}\n      alt=\"Preview\"\n      className={cn(\n        \"block h-auto max-h-[80vh] w-auto max-w-full object-contain\",\n        isLoaded\n          ? \"aui-attachment-preview-image-loaded\"\n          : \"aui-attachment-preview-image-loading invisible\",\n      )}\n      onLoad={() => setIsLoaded(true)}\n    />\n  );\n};\n\nconst AttachmentPreviewDialog: FC<PropsWithChildren> = ({ children }) => {\n  const src = useAttachmentSrc();\n\n  if (!src) {\n    return children;\n  }\n\n  return (\n    <Dialog>\n      <DialogTrigger\n        className=\"aui-attachment-preview-trigger cursor-pointer transition-colors hover:bg-accent/50\"\n        asChild={true}\n      >\n        {children}\n      </DialogTrigger>\n      <DialogContent className=\"aui-attachment-preview-dialog-content p-2 sm:max-w-3xl [&>button]:rounded-full [&>button]:bg-foreground/60 [&>button]:p-1 [&>button]:opacity-100 [&>button]:ring-0! [&_svg]:text-background [&>button]:hover:[&_svg]:text-destructive\">\n        <DialogTitle className=\"aui-sr-only sr-only\">\n          Image Attachment Preview\n        </DialogTitle>\n        <div className=\"aui-attachment-preview relative mx-auto flex max-h-[80dvh] w-full items-center justify-center overflow-hidden bg-background\">\n          <AttachmentPreview src={src} />\n        </div>\n      </DialogContent>\n    </Dialog>\n  );\n};\n\nconst AttachmentThumb: FC = () => {\n  const src = useAttachmentSrc();\n\n  if (src) {\n    return (\n      <img\n        src={src}\n        alt=\"Attachment preview\"\n        className=\"h-full w-full object-cover\"\n      />\n    );\n  }\n\n  return (\n    <div className=\"flex h-full w-full items-center justify-center\">\n      <FileText className=\"size-6 text-muted-foreground\" />\n    </div>\n  );\n};\n\nconst AttachmentUI: FC = () => {\n  const aui = useAui();\n  const isComposer = aui.attachment.source === \"composer\";\n\n  const isImage = useAuiState(({ attachment }) => attachment.type === \"image\");\n  const typeLabel = useAuiState(({ attachment }) => {\n    const type = attachment.type;\n    switch (type) {\n      case \"image\":\n        return \"Image\";\n      case \"document\":\n        return \"Document\";\n      case \"file\":\n        return \"File\";\n      default:\n        throw new Error(`Unknown attachment type: ${type as string}`);\n    }\n  });\n\n  return (\n    <Tooltip>\n      <AttachmentPrimitive.Root\n        className={cn(\n          \"aui-attachment-root relative\",\n          isImage &&\n            \"aui-attachment-root-composer only:[&>#attachment-tile]:size-16\",\n        )}\n      >\n        <AttachmentPreviewDialog>\n          <TooltipTrigger asChild={true}>\n            <button\n              className={cn(\n                \"aui-attachment-tile size-14 cursor-pointer overflow-hidden rounded-[14px] border bg-muted transition-opacity hover:opacity-75\",\n                isComposer &&\n                  \"aui-attachment-tile-composer border-foreground/20\",\n              )}\n              id=\"attachment-tile\"\n              aria-label={`${typeLabel} attachment`}\n              type=\"button\"\n            >\n              <AttachmentThumb />\n            </button>\n          </TooltipTrigger>\n        </AttachmentPreviewDialog>\n        {isComposer && <AttachmentRemove />}\n      </AttachmentPrimitive.Root>\n      <TooltipContent side=\"top\">\n        <AttachmentPrimitive.Name />\n      </TooltipContent>\n    </Tooltip>\n  );\n};\n\nconst AttachmentRemove: FC = () => {\n  return (\n    <AttachmentPrimitive.Remove asChild={true}>\n      <TooltipIconButton\n        tooltip=\"Remove file\"\n        className=\"aui-attachment-tile-remove absolute top-1.5 right-1.5 size-3.5 rounded-full bg-white text-muted-foreground opacity-100 shadow-sm hover:bg-white! [&_svg]:text-black hover:[&_svg]:text-destructive\"\n        side=\"top\"\n      >\n        <XIcon className=\"aui-attachment-remove-icon size-3 dark:stroke-[2.5px]\" />\n      </TooltipIconButton>\n    </AttachmentPrimitive.Remove>\n  );\n};\n\nexport const UserMessageAttachments: FC = () => {\n  return (\n    <div className=\"aui-user-message-attachments-end col-span-full col-start-1 row-start-1 flex w-full flex-row justify-end gap-2\">\n      <MessagePrimitive.Attachments components={{ Attachment: AttachmentUI }} />\n    </div>\n  );\n};\n\nexport const ComposerAttachments: FC = () => {\n  return (\n    <div className=\"aui-composer-attachments mb-2 flex w-full flex-row items-center gap-2 overflow-x-auto px-1.5 pt-0.5 pb-1 empty:hidden\">\n      <ComposerPrimitive.Attachments\n        components={{ Attachment: AttachmentUI }}\n      />\n    </div>\n  );\n};\n\nexport const ComposerAddAttachment: FC = () => {\n  return (\n    <ComposerPrimitive.AddAttachment asChild={true}>\n      <TooltipIconButton\n        tooltip=\"Add Attachment\"\n        side=\"bottom\"\n        variant=\"ghost\"\n        size=\"icon\"\n        className=\"aui-composer-add-attachment size-8.5 rounded-full p-1 font-semibold text-xs hover:bg-muted-foreground/15 dark:border-muted-foreground/15 dark:hover:bg-muted-foreground/30\"\n        aria-label=\"Add Attachment\"\n      >\n        <PlusIcon className=\"aui-attachment-add-icon size-5 stroke-[1.5px]\" />\n      </TooltipIconButton>\n    </ComposerPrimitive.AddAttachment>\n  );\n};\n"
  },
  {
    "path": "studio/frontend/src/components/assistant-ui/audio-player.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"use client\";\n\nimport { Button } from \"@/components/ui/button\";\nimport { DownloadIcon, PauseIcon, PlayIcon } from \"lucide-react\";\nimport { type FC, useRef, useState } from \"react\";\n\ninterface AudioPlayerProps {\n  src: string;\n}\n\nexport const AudioPlayer: FC<AudioPlayerProps> = ({ src }) => {\n  const audioRef = useRef<HTMLAudioElement>(null);\n  const [isPlaying, setIsPlaying] = useState(false);\n  const [progress, setProgress] = useState(0);\n  const [duration, setDuration] = useState(0);\n\n  const togglePlay = () => {\n    const audio = audioRef.current;\n    if (!audio) return;\n    if (isPlaying) {\n      audio.pause();\n    } else {\n      audio.play();\n    }\n    setIsPlaying(!isPlaying);\n  };\n\n  const handleTimeUpdate = () => {\n    const audio = audioRef.current;\n    if (!audio) return;\n    setProgress(audio.currentTime);\n  };\n\n  const handleLoadedMetadata = () => {\n    const audio = audioRef.current;\n    if (!audio) return;\n    setDuration(audio.duration);\n  };\n\n  const handleEnded = () => {\n    setIsPlaying(false);\n    setProgress(0);\n  };\n\n  const handleSeek = (e: React.ChangeEvent<HTMLInputElement>) => {\n    const audio = audioRef.current;\n    if (!audio) return;\n    const time = parseFloat(e.target.value);\n    audio.currentTime = time;\n    setProgress(time);\n  };\n\n  const handleDownload = () => {\n    const link = document.createElement(\"a\");\n    link.href = src;\n    link.download = \"generated-audio.wav\";\n    link.click();\n  };\n\n  const formatTime = (t: number) => {\n    const mins = Math.floor(t / 60);\n    const secs = Math.floor(t % 60);\n    return `${mins}:${secs.toString().padStart(2, \"0\")}`;\n  };\n\n  return (\n    <div className=\"my-2 flex max-w-md items-center gap-3 rounded-xl border bg-muted/50 px-4 py-3\">\n      <audio\n        ref={audioRef}\n        src={src}\n        onTimeUpdate={handleTimeUpdate}\n        onLoadedMetadata={handleLoadedMetadata}\n        onEnded={handleEnded}\n        preload=\"metadata\"\n      />\n      <Button\n        variant=\"ghost\"\n        size=\"icon\"\n        className=\"size-8 shrink-0 rounded-full\"\n        onClick={togglePlay}\n      >\n        {isPlaying ? (\n          <PauseIcon className=\"size-4\" />\n        ) : (\n          <PlayIcon className=\"size-4\" />\n        )}\n      </Button>\n      <div className=\"flex flex-1 flex-col gap-1\">\n        <input\n          type=\"range\"\n          min={0}\n          max={duration || 0}\n          step={0.01}\n          value={progress}\n          onChange={handleSeek}\n          className=\"h-1.5 w-full cursor-pointer accent-primary\"\n        />\n        <div className=\"flex justify-between text-[10px] text-muted-foreground\">\n          <span>{formatTime(progress)}</span>\n          <span>{formatTime(duration)}</span>\n        </div>\n      </div>\n      <Button\n        variant=\"ghost\"\n        size=\"icon\"\n        className=\"size-7 shrink-0 text-muted-foreground\"\n        onClick={handleDownload}\n        title=\"Download audio\"\n      >\n        <DownloadIcon className=\"size-3.5\" />\n      </Button>\n    </div>\n  );\n};\n"
  },
  {
    "path": "studio/frontend/src/components/assistant-ui/badge.tsx",
    "content": "\"use client\";\n\nimport type { ComponentProps } from \"react\";\nimport { Slot } from \"radix-ui\";\nimport { cva, type VariantProps } from \"class-variance-authority\";\nimport { cn } from \"@/lib/utils\";\n\nconst badgeVariants = cva(\n  \"inline-flex items-center justify-center gap-1 rounded-md font-medium text-xs transition-colors [&_svg]:size-3 [&_svg]:shrink-0\",\n  {\n    variants: {\n      variant: {\n        outline:\n          \"border border-input bg-transparent text-muted-foreground hover:bg-accent hover:text-accent-foreground\",\n        secondary:\n          \"bg-secondary text-secondary-foreground hover:bg-secondary/80\",\n        muted:\n          \"bg-muted text-muted-foreground hover:bg-muted/80 hover:text-foreground\",\n        ghost:\n          \"bg-transparent text-muted-foreground hover:bg-accent hover:text-accent-foreground\",\n        info: \"bg-blue-100 text-blue-700 hover:bg-blue-100/80 dark:bg-blue-900/50 dark:text-blue-300\",\n        warning:\n          \"bg-amber-100 text-amber-700 hover:bg-amber-100/80 dark:bg-amber-900/50 dark:text-amber-300\",\n        success:\n          \"bg-emerald-100 text-emerald-700 hover:bg-emerald-100/80 dark:bg-emerald-900/50 dark:text-emerald-300\",\n        destructive:\n          \"bg-red-100 text-red-700 hover:bg-red-100/80 dark:bg-red-900/50 dark:text-red-300\",\n      },\n      size: {\n        sm: \"px-1.5 py-0.5\",\n        default: \"px-2 py-1\",\n        lg: \"px-2.5 py-1.5 text-sm\",\n      },\n    },\n    defaultVariants: {\n      variant: \"outline\",\n      size: \"default\",\n    },\n  },\n);\n\nexport type BadgeProps = ComponentProps<\"span\"> &\n  VariantProps<typeof badgeVariants> & {\n    asChild?: boolean;\n  };\n\nfunction Badge({\n  className,\n  variant,\n  size,\n  asChild = false,\n  ...props\n}: BadgeProps) {\n  const Comp = asChild ? Slot.Root : \"span\";\n\n  return (\n    <Comp\n      data-slot=\"badge\"\n      data-variant={variant}\n      data-size={size}\n      className={cn(badgeVariants({ variant, size }), className)}\n      {...props}\n    />\n  );\n}\n\nexport { Badge, badgeVariants };\n"
  },
  {
    "path": "studio/frontend/src/components/assistant-ui/markdown-text.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"use client\";\n\nimport { copyToClipboard } from \"@/lib/copy-to-clipboard\";\nimport { INTERNAL, useMessagePartText } from \"@assistant-ui/react\";\nimport { Copy02Icon, Tick02Icon } from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport { code } from \"@streamdown/code\";\nimport { math } from \"@streamdown/math\";\nimport { mermaid } from \"@streamdown/mermaid\";\nimport { DownloadIcon, Maximize2Icon, Minimize2Icon } from \"lucide-react\";\nimport { useEffect, useRef, useState } from \"react\";\nimport { Block, type BlockProps, Streamdown } from \"streamdown\";\nimport \"katex/dist/katex.min.css\";\nimport { AudioPlayer } from \"./audio-player\";\n\nconst { withSmoothContextProvider } = INTERNAL;\nconst COPY_RESET_MS = 2000;\nconst MERMAID_SOURCE_RE = /```mermaid\\s*([\\s\\S]*?)```/i;\nconst CODE_FENCE_RE = /^```([^\\r\\n`]*)\\r?\\n([\\s\\S]*?)\\r?\\n?```$/;\nconst ACTION_PANEL_CLASS =\n  \"pointer-events-auto flex shrink-0 items-center gap-2 rounded-md border border-sidebar bg-sidebar/80 px-1.5 py-1 supports-[backdrop-filter]:bg-sidebar/70 supports-[backdrop-filter]:backdrop-blur\";\nconst ACTION_BUTTON_CLASS =\n  \"cursor-pointer p-1 text-muted-foreground transition-all hover:text-foreground disabled:cursor-not-allowed disabled:opacity-50\";\n\ntype CodeFence = {\n  language: string | null;\n  source: string;\n};\n\nfunction getMermaidSource(blockContent: string): string | null {\n  const source = blockContent.match(MERMAID_SOURCE_RE)?.[1]?.trim();\n  return source && source.length > 0 ? source : null;\n}\n\nfunction getCodeFence(blockContent: string): CodeFence | null {\n  const match = blockContent.trimEnd().match(CODE_FENCE_RE);\n  if (!match) {\n    return null;\n  }\n\n  return {\n    language: match[1]?.trim() || null,\n    source: match[2],\n  };\n}\n\nfunction getCodeFilename(language: string | null) {\n  const extByLanguage: Record<string, string> = {\n    bash: \"sh\",\n    javascript: \"js\",\n    js: \"js\",\n    json: \"json\",\n    jsx: \"jsx\",\n    markdown: \"md\",\n    md: \"md\",\n    python: \"py\",\n    py: \"py\",\n    shell: \"sh\",\n    sh: \"sh\",\n    sql: \"sql\",\n    ts: \"ts\",\n    tsx: \"tsx\",\n    typescript: \"ts\",\n    svg: \"svg\",\n    yaml: \"yml\",\n    yml: \"yml\",\n  };\n\n  const normalized = language?.toLowerCase();\n  const fallbackExt = normalized?.replace(/[^a-z0-9]+/g, \"-\");\n  const ext = normalized\n    ? extByLanguage[normalized] || fallbackExt || \"txt\"\n    : \"txt\";\n  return `snippet.${ext}`;\n}\n\nfunction isSvgFence(codeFence: CodeFence): boolean {\n  const lang = codeFence.language?.toLowerCase() ?? \"\";\n  if (lang === \"svg\") return true;\n  if ((lang === \"xml\" || lang === \"html\") && codeFence.source.trimStart().startsWith(\"<svg\")) return true;\n  return false;\n}\n\nfunction isHtmlFence(codeFence: CodeFence): boolean {\n  const lang = codeFence.language?.toLowerCase() ?? \"\";\n  return lang === \"html\" && !codeFence.source.trimStart().startsWith(\"<svg\");\n}\n\nconst UNSAFE_SVG_RE = /<script[\\s>]|on\\w+\\s*=|javascript:|<foreignObject[\\s>]|<iframe[\\s>]|<embed[\\s>]|<object[\\s>]/i;\n\nfunction sanitizeSvg(source: string): string | null {\n  if (UNSAFE_SVG_RE.test(source)) return null;\n  return source;\n}\n\nfunction SvgPreview({ source }: { source: string }) {\n  const dataUri = `data:image/svg+xml;charset=utf-8,${encodeURIComponent(source)}`;\n  return (\n    <div className=\"mt-2 flex justify-center rounded-lg border border-border bg-white p-4 dark:bg-neutral-100\">\n      <img\n        src={dataUri}\n        alt=\"SVG preview\"\n        style={{ maxWidth: \"100%\", maxHeight: 512 }}\n      />\n    </div>\n  );\n}\n\nconst HTML_PREVIEW_DEFAULT_HEIGHT = 400;\nconst HTML_PREVIEW_MAX_HEIGHT = 800;\n\nfunction HtmlPreview({ source }: { source: string }) {\n  const iframeRef = useRef<HTMLIFrameElement>(null);\n  const [height, setHeight] = useState(HTML_PREVIEW_DEFAULT_HEIGHT);\n  const [enlarged, setEnlarged] = useState(false);\n\n  useEffect(() => {\n    const handler = (e: MessageEvent) => {\n      if (e.source !== iframeRef.current?.contentWindow) return;\n      if (typeof e.data?.htmlPreviewHeight === \"number\") {\n        setHeight(Math.min(Math.max(e.data.htmlPreviewHeight, 100), HTML_PREVIEW_MAX_HEIGHT));\n      }\n    };\n    window.addEventListener(\"message\", handler);\n    return () => window.removeEventListener(\"message\", handler);\n  }, []);\n\n  useEffect(() => {\n    if (!enlarged) return;\n    const handler = (e: KeyboardEvent) => {\n      if (e.key === \"Escape\") setEnlarged(false);\n    };\n    window.addEventListener(\"keydown\", handler);\n    return () => window.removeEventListener(\"keydown\", handler);\n  }, [enlarged]);\n\n  const resizeScript = `<script>new ResizeObserver(()=>{\nparent.postMessage({htmlPreviewHeight:document.documentElement.scrollHeight},\"*\");\n}).observe(document.documentElement);</script>`;\n\n  const srcDoc = source + resizeScript;\n\n  if (enlarged) {\n    return (\n      <>\n        <div className=\"mt-2 overflow-hidden rounded-lg border border-border\" style={{ height }}>\n          {/* Placeholder keeps layout stable while overlay is shown */}\n        </div>\n        <div\n          className=\"fixed inset-0 z-50 flex flex-col bg-background/80 backdrop-blur-sm\"\n          onClick={(e) => { if (e.target === e.currentTarget) setEnlarged(false); }}\n        >\n          <div className=\"flex items-center justify-end gap-2 px-4 py-2\">\n            <button\n              type=\"button\"\n              className=\"flex items-center gap-1.5 rounded-md border border-border bg-background px-3 py-1.5 text-sm text-muted-foreground transition-colors hover:bg-muted hover:text-foreground\"\n              onClick={() => setEnlarged(false)}\n              title=\"Exit fullscreen (Esc)\"\n            >\n              <Minimize2Icon className=\"size-4\" />\n              Exit fullscreen\n            </button>\n          </div>\n          <div className=\"mx-4 mb-4 flex-1 overflow-hidden rounded-lg border border-border bg-background\">\n            <iframe\n              ref={iframeRef}\n              srcDoc={srcDoc}\n              sandbox=\"allow-scripts\"\n              style={{ width: \"100%\", height: \"100%\", border: \"none\", display: \"block\" }}\n              title=\"HTML preview\"\n            />\n          </div>\n        </div>\n      </>\n    );\n  }\n\n  return (\n    <div className=\"group/html-preview relative mt-2 overflow-hidden rounded-lg border border-border\">\n      <button\n        type=\"button\"\n        className=\"absolute top-2 right-2 z-10 rounded-md border border-border bg-background/80 p-1.5 text-muted-foreground opacity-0 transition-all hover:bg-muted hover:text-foreground group-hover/html-preview:opacity-100 supports-[backdrop-filter]:backdrop-blur\"\n        onClick={() => setEnlarged(true)}\n        title=\"Enlarge preview\"\n      >\n        <Maximize2Icon className=\"size-4\" />\n      </button>\n      <iframe\n        ref={iframeRef}\n        srcDoc={srcDoc}\n        sandbox=\"allow-scripts\"\n        style={{ width: \"100%\", height, border: \"none\", display: \"block\" }}\n        title=\"HTML preview\"\n      />\n    </div>\n  );\n}\n\nfunction downloadTextFile(filename: string, text: string): void {\n  const blob = new Blob([text], { type: \"text/plain;charset=utf-8\" });\n  const url = URL.createObjectURL(blob);\n  const anchor = document.createElement(\"a\");\n  anchor.href = url;\n  anchor.download = filename;\n  document.body.appendChild(anchor);\n  anchor.click();\n  document.body.removeChild(anchor);\n  window.setTimeout(() => URL.revokeObjectURL(url), 0);\n}\n\nfunction useCopiedState() {\n  const [copied, setCopied] = useState(false);\n  const resetTimeoutRef = useRef<ReturnType<typeof setTimeout> | null>(null);\n\n  useEffect(() => {\n    return () => {\n      if (resetTimeoutRef.current) {\n        clearTimeout(resetTimeoutRef.current);\n      }\n    };\n  }, []);\n\n  const showCopied = () => {\n    setCopied(true);\n    if (resetTimeoutRef.current) {\n      clearTimeout(resetTimeoutRef.current);\n    }\n    resetTimeoutRef.current = setTimeout(() => {\n      setCopied(false);\n      resetTimeoutRef.current = null;\n    }, COPY_RESET_MS);\n  };\n\n  return { copied, showCopied };\n}\n\nfunction MermaidCopyButton({ source }: { source: string }) {\n  const { copied, showCopied } = useCopiedState();\n\n  return (\n    <button\n      type=\"button\"\n      className=\"absolute top-3.5 right-20 z-20 cursor-pointer text-muted-foreground transition-all hover:text-foreground\"\n      title=\"Copy Mermaid source\"\n      onClick={() => {\n        if (!copyToClipboard(source)) {\n          return;\n        }\n        showCopied();\n      }}\n    >\n      <HugeiconsIcon\n        icon={copied ? Tick02Icon : Copy02Icon}\n        className=\"size-5\"\n      />\n    </button>\n  );\n}\n\nfunction CodeBlockActions({\n  disabled,\n  language,\n  source,\n}: {\n  disabled: boolean;\n  language: string | null;\n  source: string;\n}) {\n  const { copied, showCopied } = useCopiedState();\n\n  return (\n    <div className=\"pointer-events-none absolute top-3.5 right-3 z-20 flex items-center justify-end\">\n      <div className={ACTION_PANEL_CLASS}>\n        <button\n          type=\"button\"\n          className={ACTION_BUTTON_CLASS}\n          title=\"Copy code\"\n          disabled={disabled}\n          onClick={() => {\n            if (!copyToClipboard(source)) {\n              return;\n            }\n            showCopied();\n          }}\n        >\n          <HugeiconsIcon\n            icon={copied ? Tick02Icon : Copy02Icon}\n            className=\"size-3.5\"\n          />\n        </button>\n        <button\n          type=\"button\"\n          className={ACTION_BUTTON_CLASS}\n          title=\"Download file\"\n          disabled={disabled}\n          onClick={() => {\n            downloadTextFile(getCodeFilename(language), source);\n          }}\n        >\n          <DownloadIcon className=\"size-3.5\" />\n        </button>\n      </div>\n    </div>\n  );\n}\n\nfunction StreamdownBlock(props: BlockProps) {\n  const hasMermaidFence = props.content.includes(\"```mermaid\");\n  const mermaidSource = getMermaidSource(props.content);\n  const codeFence = getCodeFence(props.content);\n\n  if (props.isIncomplete && hasMermaidFence) {\n    return (\n      <div className=\"my-4 flex h-48 items-center justify-center rounded-xl border border-border bg-muted/30 text-sm text-muted-foreground animate-pulse\">\n        Loading diagram...\n      </div>\n    );\n  }\n\n  if (props.isIncomplete && codeFence && isSvgFence(codeFence)) {\n    return (\n      <div className=\"relative isolate\">\n        <div className=\"my-4 rounded-xl border border-border bg-muted/30 p-4\">\n          <div className=\"mb-2 text-xs font-medium text-muted-foreground\">svg</div>\n          <pre className=\"overflow-x-auto text-xs text-muted-foreground whitespace-pre-wrap break-all\">\n            <code>{codeFence.source}</code>\n          </pre>\n        </div>\n      </div>\n    );\n  }\n\n  if (props.isIncomplete && codeFence && isHtmlFence(codeFence)) {\n    return (\n      <div className=\"my-4 flex h-48 items-center justify-center rounded-xl border border-border bg-muted/30 text-sm text-muted-foreground animate-pulse\">\n        Loading preview...\n      </div>\n    );\n  }\n\n  if (mermaidSource) {\n    return (\n      <div className=\"relative isolate\">\n        <Block {...props} />\n        <MermaidCopyButton source={mermaidSource} />\n      </div>\n    );\n  }\n\n  if (codeFence) {\n    const svgSource = !props.isIncomplete && isSvgFence(codeFence) ? sanitizeSvg(codeFence.source) : null;\n    const htmlSource = !props.isIncomplete && isHtmlFence(codeFence) ? codeFence.source : null;\n    return (\n      <>\n        <div className=\"relative isolate\">\n          <Block {...props} />\n          <CodeBlockActions\n            disabled={props.isIncomplete}\n            language={codeFence.language}\n            source={codeFence.source}\n          />\n        </div>\n        {svgSource && <SvgPreview source={svgSource} />}\n        {htmlSource && <HtmlPreview source={htmlSource} />}\n      </>\n    );\n  }\n\n  return <Block {...props} />;\n}\nconst AUDIO_PLAYER_RE = /<audio-player\\s+src=\"([^\"]+)\"\\s*\\/>/;\n\nconst MarkdownTextImpl = () => {\n  const { text, status } = useMessagePartText();\n\n  const audioMatch = text.match(AUDIO_PLAYER_RE);\n  if (audioMatch) {\n    return <AudioPlayer src={audioMatch[1]} />;\n  }\n\n  return (\n    <div data-status={status.type}>\n      <Streamdown\n        mode=\"streaming\"\n        isAnimating={status.type === \"running\"}\n        plugins={{ code, math, mermaid }}\n        controls={{\n          code: false,\n          mermaid: {\n            fullscreen: true,\n            download: true,\n            copy: false,\n            panZoom: true,\n          },\n        }}\n        shikiTheme={[\"github-light\", \"github-dark\"]}\n        BlockComponent={StreamdownBlock}\n      >\n        {text}\n      </Streamdown>\n    </div>\n  );\n};\n\nexport const MarkdownText = withSmoothContextProvider(MarkdownTextImpl);\n"
  },
  {
    "path": "studio/frontend/src/components/assistant-ui/message-timing.tsx",
    "content": "\"use client\";\n\nimport { useMessageTiming } from \"@assistant-ui/react\";\nimport {\n  Tooltip,\n  TooltipContent,\n  TooltipTrigger,\n} from \"@/components/ui/tooltip\";\nimport { cn } from \"@/lib/utils\";\nimport type { FC } from \"react\";\n\nconst formatTimingMs = (ms: number | undefined): string => {\n  if (ms === undefined) return \"—\";\n  if (ms < 1000) return `${Math.round(ms)}ms`;\n  return `${(ms / 1000).toFixed(2)}s`;\n};\n\n/**\n * Shows streaming stats (TTFT, total time, chunks) as a badge with a\n * hover/focus tooltip. Renders nothing until the stream completes.\n *\n * Place it inside `ActionBarPrimitive.Root` in your `thread.tsx` so it\n * inherits the action bar's autohide behaviour:\n *\n * ```tsx\n * import { MessageTiming } from \"@/components/assistant-ui/message-timing\";\n *\n * <ActionBarPrimitive.Root >\n *   <ActionBarPrimitive.Copy />\n *   <ActionBarPrimitive.Reload />\n *   <MessageTiming />  // <-- add this\n * </ActionBarPrimitive.Root>\n * ```\n *\n * @param side - Side of the tooltip relative to the badge trigger. Defaults to `\"right\"`.\n */\nexport const MessageTiming: FC<{\n  className?: string;\n  side?: \"top\" | \"right\" | \"bottom\" | \"left\";\n}> = ({ className, side = \"right\" }) => {\n  const timing = useMessageTiming();\n  if (timing?.totalStreamTime === undefined) return null;\n\n  return (\n    <Tooltip>\n      <TooltipTrigger asChild>\n        <button\n          type=\"button\"\n          data-slot=\"message-timing-trigger\"\n          aria-label=\"Message timing\"\n          className={cn(\n            \"flex items-center rounded-md p-1 font-mono text-muted-foreground text-xs tabular-nums transition-colors hover:bg-accent hover:text-accent-foreground\",\n            className,\n          )}\n        >\n          {formatTimingMs(timing.totalStreamTime)}\n        </button>\n      </TooltipTrigger>\n      <TooltipContent\n        side={side}\n        sideOffset={8}\n        data-slot=\"message-timing-popover\"\n        className=\"[&_span>svg]:hidden! rounded-lg border bg-popover px-3 py-2 text-popover-foreground shadow-md\"\n      >\n        <div className=\"grid min-w-35 gap-1.5 text-xs\">\n          {timing.firstTokenTime !== undefined && (\n            <div className=\"flex items-center justify-between gap-4\">\n              <span className=\"text-muted-foreground\">First token</span>\n              <span className=\"font-mono tabular-nums\">\n                {formatTimingMs(timing.firstTokenTime)}\n              </span>\n            </div>\n          )}\n          <div className=\"flex items-center justify-between gap-4\">\n            <span className=\"text-muted-foreground\">Total</span>\n            <span className=\"font-mono tabular-nums\">\n              {formatTimingMs(timing.totalStreamTime)}\n            </span>\n          </div>\n          <div className=\"flex items-center justify-between gap-4\">\n            <span className=\"text-muted-foreground\">Chunks</span>\n            <span className=\"font-mono tabular-nums\">{timing.totalChunks}</span>\n          </div>\n        </div>\n      </TooltipContent>\n    </Tooltip>\n  );\n};\n"
  },
  {
    "path": "studio/frontend/src/components/assistant-ui/model-selector/pickers.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport {\n  AlertDialog,\n  AlertDialogAction,\n  AlertDialogCancel,\n  AlertDialogContent,\n  AlertDialogDescription,\n  AlertDialogFooter,\n  AlertDialogHeader,\n  AlertDialogTitle,\n} from \"@/components/ui/alert-dialog\";\nimport { Input } from \"@/components/ui/input\";\nimport { Spinner } from \"@/components/ui/spinner\";\nimport {\n  Tooltip,\n  TooltipContent,\n  TooltipTrigger,\n} from \"@/components/ui/tooltip\";\nimport { deleteCachedModel, listCachedGguf, listCachedModels, listGgufVariants } from \"@/features/chat/api/chat-api\";\nimport type { CachedGgufRepo, CachedModelRepo } from \"@/features/chat/api/chat-api\";\nimport type { GgufVariantDetail } from \"@/features/chat/types/api\";\nimport { usePlatformStore } from \"@/config/env\";\nimport {\n  useDebouncedValue,\n  useGpuInfo,\n  useHfModelSearch,\n  useInfiniteScroll,\n  useRecommendedModelVram,\n} from \"@/hooks\";\nimport { cn, formatCompact } from \"@/lib/utils\";\nimport type { VramFitStatus } from \"@/lib/vram\";\nimport { checkVramFit, estimateLoadingVram } from \"@/lib/vram\";\nimport { Search01Icon } from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport { Trash2Icon } from \"lucide-react\";\nimport { useCallback, useEffect, useMemo, useRef, useState, type ReactNode } from \"react\";\nimport { toast } from \"sonner\";\nimport type {\n  LoraModelOption,\n  ModelOption,\n  ModelSelectorChangeMeta,\n} from \"./types\";\n\nfunction dedupe(values: string[]): string[] {\n  return [...new Set(values.filter(Boolean))];\n}\n\nfunction ListLabel({ children }: { children: ReactNode }) {\n  return (\n    <div className=\"px-2.5 py-1.5 text-[10px] font-semibold uppercase tracking-wider text-muted-foreground\">\n      {children}\n    </div>\n  );\n}\n\n/** Format bytes to a human-readable size string. */\nfunction formatBytes(bytes: number): string {\n  if (bytes === 0) return \"0 B\";\n  const units = [\"B\", \"KB\", \"MB\", \"GB\", \"TB\"];\n  const i = Math.floor(Math.log(bytes) / Math.log(1024));\n  const value = bytes / 1024 ** i;\n  return `${value.toFixed(value < 10 ? 1 : 0)} ${units[i]}`;\n}\n\nfunction ModelRow({\n  label,\n  meta,\n  selected,\n  onClick,\n  vramStatus,\n  vramEst,\n  gpuGb,\n  tooltipText,\n}: {\n  label: string;\n  meta?: string;\n  selected?: boolean;\n  onClick: () => void;\n  vramStatus?: VramFitStatus | null;\n  vramEst?: number;\n  gpuGb?: number;\n  tooltipText?: ReactNode;\n}) {\n  const exceeds = vramStatus === \"exceeds\";\n  const showVramTooltip =\n    vramEst != null && vramEst > 0 && gpuGb != null && gpuGb > 0;\n  const vramTooltipText =\n    showVramTooltip && vramStatus\n      ? exceeds\n        ? `Needs ~${vramEst}GB VRAM (GPU: ${gpuGb}GB)`\n        : vramStatus === \"tight\"\n          ? `~${vramEst}GB VRAM (tight fit on ${gpuGb}GB)`\n          : `~${vramEst}GB VRAM`\n      : null;\n\n  const content = (\n    <button\n      type=\"button\"\n      onClick={onClick}\n      className={cn(\n        \"flex w-full items-center gap-2 rounded-md px-2.5 py-1.5 text-left text-sm transition-colors hover:bg-accent\",\n        selected && \"bg-accent/60\",\n        exceeds && \"opacity-50\",\n      )}\n    >\n      <span\n        className={cn(\n          \"block min-w-0 flex-1 truncate\",\n          exceeds && \"line-through decoration-muted-foreground/50\",\n        )}\n      >\n        {label}\n      </span>\n      <span className=\"ml-auto flex items-center gap-1.5 shrink-0\">\n        {vramStatus === \"exceeds\" && (\n          <span className=\"text-[9px] font-medium text-red-400\">OOM</span>\n        )}\n        {vramStatus === \"tight\" && (\n          <span className=\"text-[9px] font-medium text-amber-400\">TIGHT</span>\n        )}\n        {meta ? (\n          <span className=\"text-[10px] text-muted-foreground\">{meta}</span>\n        ) : null}\n      </span>\n    </button>\n  );\n\n  if (vramTooltipText) {\n    return (\n      <Tooltip>\n        <TooltipTrigger asChild>{content}</TooltipTrigger>\n        <TooltipContent side=\"left\" className=\"max-w-xs break-all\">\n          {label}\n          <span className=\"block text-[10px] mt-1\">{vramTooltipText}</span>\n        </TooltipContent>\n      </Tooltip>\n    );\n  }\n\n  if (tooltipText) {\n    return (\n      <Tooltip>\n        <TooltipTrigger asChild>{content}</TooltipTrigger>\n        <TooltipContent side=\"left\" className=\"max-w-xs break-all\">\n          {tooltipText}\n        </TooltipContent>\n      </Tooltip>\n    );\n  }\n  return content;\n}\n\n// ── GGUF Variant Expander ────────────────────────────────────\n\nfunction GgufVariantExpander({\n  repoId,\n  onSelect,\n  gpuGb,\n  systemRamGb,\n  onDeleteVariant,\n}: {\n  repoId: string;\n  onSelect: (id: string, meta: ModelSelectorChangeMeta) => void;\n  gpuGb?: number;\n  systemRamGb?: number;\n  onDeleteVariant?: (quant: string) => void;\n}) {\n  const [variants, setVariants] = useState<GgufVariantDetail[] | null>(null);\n  const [defaultVariant, setDefaultVariant] = useState<string | null>(null);\n  const [hasVision, setHasVision] = useState(false);\n  const [loading, setLoading] = useState(true);\n  const [error, setError] = useState<string | null>(null);\n\n  useEffect(() => {\n    let canceled = false;\n    setLoading(true);\n    setError(null);\n\n    listGgufVariants(repoId)\n      .then((res) => {\n        if (canceled) return;\n        setVariants(res.variants);\n        setDefaultVariant(res.default_variant);\n        setHasVision(res.has_vision);\n      })\n      .catch((err) => {\n        if (canceled) return;\n        setError(err instanceof Error ? err.message : \"Failed to load variants\");\n      })\n      .finally(() => {\n        if (!canceled) setLoading(false);\n      });\n\n    return () => {\n      canceled = true;\n    };\n  }, [repoId]);\n\n  const handleVariantClick = useCallback(\n    (quant: string, downloaded?: boolean, sizeBytes?: number) => {\n      onSelect(repoId, {\n        source: \"hub\",\n        isLora: false,\n        ggufVariant: quant,\n        isDownloaded: downloaded,\n        expectedBytes: sizeBytes,\n      });\n    },\n    [repoId, onSelect],\n  );\n\n  // GGUF fit classification matching llama-server's _select_gpus logic:\n  //   fits  = model <= 0.7 * total GPU memory\n  //   tight = model > 0.7 * GPU but <= 0.7 * GPU + 0.7 * system RAM (--fit uses CPU offload)\n  //   oom   = model > 0.7 * GPU + 0.7 * system RAM\n  const gpuBudgetGb = (gpuGb ?? 0) * 0.70;\n  const totalBudgetGb = gpuBudgetGb + (systemRamGb ?? 0) * 0.70;\n\n  const getGgufFit = useCallback(\n    (sizeBytes: number): \"fits\" | \"tight\" | \"oom\" => {\n      if (!gpuGb || gpuGb <= 0) return \"fits\";\n      const gb = sizeBytes / (1024 ** 3);\n      if (gb <= 0 || gb <= gpuBudgetGb) return \"fits\";\n      if (gb <= totalBudgetGb) return \"tight\";\n      return \"oom\";\n    },\n    [gpuGb, gpuBudgetGb, totalBudgetGb],\n  );\n\n  // If the backend-recommended variant is OOM, pick the largest fitting\n  // variant instead; if all are OOM, recommend the smallest one.\n  const effectiveRecommended = useMemo(() => {\n    if (!variants || !gpuGb || gpuGb <= 0) return defaultVariant;\n    const defaultV = variants.find((v) => v.quant === defaultVariant);\n    if (defaultV && getGgufFit(defaultV.size_bytes) !== \"oom\") return defaultVariant;\n    // Default is OOM -- pick largest non-OOM variant (best quality that fits)\n    const fitting = variants.filter((v) => getGgufFit(v.size_bytes) !== \"oom\");\n    if (fitting.length > 0) {\n      fitting.sort((a, b) => b.size_bytes - a.size_bytes);\n      return fitting[0].quant;\n    }\n    // All OOM -- recommend smallest (most likely to partially run)\n    const sorted = [...variants].sort((a, b) => a.size_bytes - b.size_bytes);\n    return sorted[0].quant;\n  }, [variants, defaultVariant, gpuGb, getGgufFit]);\n\n  const sortedVariants = useMemo(() => {\n    if (!variants) return variants;\n    // Tier: 0 = downloaded+fits, 1 = downloaded+tight, 2 = fits, 3 = tight, 4 = OOM\n    const tierOf = (v: GgufVariantDetail) => {\n      const f = getGgufFit(v.size_bytes);\n      if (f === \"oom\") return 4;\n      const base = f === \"fits\" ? 0 : 1;\n      return v.downloaded ? base : base + 2;\n    };\n    return [...variants].sort((a, b) => {\n      const aTier = tierOf(a);\n      const bTier = tierOf(b);\n      if (aTier !== bTier) return aTier - bTier;\n\n      // Within the same tier, recommended goes first\n      const aIsRec = a.quant === effectiveRecommended;\n      const bIsRec = b.quant === effectiveRecommended;\n      if (aIsRec !== bIsRec) return aIsRec ? -1 : 1;\n\n      // fits: largest first (best quality that fits in GPU)\n      // tight/OOM: smallest first (closest to fitting, fastest to run)\n      const fitsInGpu = aTier === 0 || aTier === 2;\n      return fitsInGpu ? b.size_bytes - a.size_bytes : a.size_bytes - b.size_bytes;\n    });\n  }, [variants, effectiveRecommended, getGgufFit]);\n\n  if (loading) {\n    return (\n      <div className=\"flex items-center gap-2 px-5 py-2\">\n        <Spinner className=\"size-3 text-muted-foreground\" />\n        <span className=\"text-xs text-muted-foreground\">Loading variants…</span>\n      </div>\n    );\n  }\n\n  if (error) {\n    return (\n      <div className=\"px-5 py-2 text-xs text-destructive\">{error}</div>\n    );\n  }\n\n  if (!sortedVariants || sortedVariants.length === 0) {\n    return (\n      <div className=\"px-5 py-2 text-xs text-muted-foreground\">\n        No GGUF variants found.\n      </div>\n    );\n  }\n\n  return (\n    <div className=\"pl-4 border-l-2 border-accent/50 ml-3 my-1\">\n      <div className=\"px-2 py-1 flex items-center gap-1.5\">\n        <span className=\"text-[10px] font-semibold uppercase tracking-wider text-muted-foreground\">\n          Quantizations\n        </span>\n        {hasVision && (\n          <span className=\"text-[9px] font-medium text-blue-400\">Vision</span>\n        )}\n      </div>\n      {sortedVariants.map((v) => {\n        const fit = getGgufFit(v.size_bytes);\n        const oom = fit === \"oom\";\n        const tight = fit === \"tight\";\n        return (\n          <div key={v.filename} className=\"flex items-center gap-0.5\">\n            <button\n              type=\"button\"\n              onClick={() => handleVariantClick(v.quant, v.downloaded, v.size_bytes)}\n              className={cn(\n                \"flex min-w-0 flex-1 items-center justify-between gap-2 rounded-md px-2.5 py-1 text-left text-sm transition-colors hover:bg-accent\",\n              )}\n            >\n              <span className=\"min-w-0 flex-1 truncate font-mono text-xs\">\n                {v.quant}\n                {v.downloaded ? (\n                  <span className=\"ml-1.5 text-[9px] font-sans font-medium text-green-400\">\n                    downloaded\n                  </span>\n                ) : v.quant === effectiveRecommended ? (\n                  <span className=\"ml-1.5 text-[9px] font-sans font-medium text-primary/70\">\n                    recommended\n                  </span>\n                ) : null}\n              </span>\n              <span className=\"flex items-center gap-1.5 shrink-0\">\n                {oom && (\n                  <span className=\"text-[9px] font-medium text-red-400\">OOM</span>\n                )}\n                {tight && (\n                  <span className=\"text-[9px] font-medium text-amber-400\">TIGHT</span>\n                )}\n                <span className=\"text-[10px] text-muted-foreground\">\n                  {formatBytes(v.size_bytes)}\n                </span>\n              </span>\n            </button>\n            {v.downloaded && onDeleteVariant && (\n              <button\n                type=\"button\"\n                onClick={(e) => { e.stopPropagation(); onDeleteVariant(v.quant); }}\n                className=\"shrink-0 rounded-md p-1 text-muted-foreground/60 transition-colors hover:bg-destructive/10 hover:text-destructive\"\n              >\n                <Trash2Icon className=\"size-3\" />\n              </button>\n            )}\n          </div>\n        );\n      })}\n    </div>\n  );\n}\n\n// ── Detect GGUF repos by naming convention ────────────────────\n\nfunction isGgufRepo(id: string): boolean {\n  return id.toUpperCase().includes(\"-GGUF\");\n}\n\n/** Extract param count label from model name (e.g. \"Qwen3-0.6B\" -> \"0.6B\"). */\nfunction extractParamLabel(id: string): string | undefined {\n  // Match patterns like \"0.6B\", \"1B\", \"4B\", \"3.5B\", \"70B\", \"1.5B\" etc.\n  const name = id.split(\"/\").pop() ?? id;\n  const match = name.match(/(?:^|[-_])(\\d+(?:\\.\\d+)?)[Bb](?:[-_]|$)/);\n  return match ? `${match[1]}B` : undefined;\n}\n\n// Module-level caches so re-mounting the popover shows results instantly\nlet _cachedGgufCache: CachedGgufRepo[] = [];\nlet _cachedModelsCache: CachedModelRepo[] = [];\n\n// ── Hub Model Picker ──────────────────────────────────────────\n\nexport function HubModelPicker({\n  models,\n  value,\n  onSelect,\n}: {\n  models: ModelOption[];\n  value?: string;\n  onSelect: (id: string, meta: ModelSelectorChangeMeta) => void;\n}) {\n  const gpu = useGpuInfo();\n  const [query, setQuery] = useState(\"\");\n  const debouncedQuery = useDebouncedValue(query);\n  const { results, isLoading, isLoadingMore, fetchMore } = useHfModelSearch(\n    debouncedQuery,\n  );\n\n  // Track which GGUF repo is expanded for variant selection\n  const [expandedGguf, setExpandedGguf] = useState<string | null>(null);\n\n  // Delete confirmation dialog state\n  const [deleteTarget, setDeleteTarget] = useState<string | null>(null);\n  const [deleting, setDeleting] = useState(false);\n\n  // Cached (already downloaded) repos -- use module-level cache so\n  // re-mounting the popover does not flash an empty \"Downloaded\" section.\n  const [cachedGguf, setCachedGguf] = useState<CachedGgufRepo[]>(_cachedGgufCache);\n  const [cachedModels, setCachedModels] = useState<CachedModelRepo[]>(_cachedModelsCache);\n  const alreadyCached = _cachedGgufCache.length > 0 || _cachedModelsCache.length > 0;\n  const [cachedReady, setCachedReady] = useState(alreadyCached);\n\n  const refreshCachedLists = useCallback(() => {\n    listCachedGguf().then((v) => { _cachedGgufCache = v; setCachedGguf(v); }).catch(() => {});\n    listCachedModels().then((v) => { _cachedModelsCache = v; setCachedModels(v); }).catch(() => {});\n  }, []);\n\n  useEffect(() => {\n    if (alreadyCached) return;\n    let done = 0;\n    const check = () => { if (++done >= 2) setCachedReady(true); };\n    listCachedGguf().then((v) => { _cachedGgufCache = v; setCachedGguf(v); }).catch(() => {}).finally(check);\n    listCachedModels().then((v) => { _cachedModelsCache = v; setCachedModels(v); }).catch(() => {}).finally(check);\n  }, [alreadyCached]);\n\n  const handleDeleteConfirm = useCallback(async () => {\n    if (!deleteTarget) return;\n    setDeleting(true);\n    try {\n      // deleteTarget is \"repo_id\" or \"repo_id::variant\"\n      const sepIdx = deleteTarget.indexOf(\"::\");\n      const repoId = sepIdx >= 0 ? deleteTarget.slice(0, sepIdx) : deleteTarget;\n      const variant = sepIdx >= 0 ? deleteTarget.slice(sepIdx + 2) : undefined;\n      await deleteCachedModel(repoId, variant);\n      toast.success(`Deleted ${variant ? `${repoId} ${variant}` : repoId}`);\n      refreshCachedLists();\n    } catch (err) {\n      toast.error(err instanceof Error ? err.message : \"Failed to delete model\");\n    } finally {\n      setDeleting(false);\n      setDeleteTarget(null);\n    }\n  }, [deleteTarget, refreshCachedLists]);\n\n  // Deduplicate: don't show downloaded models in the recommended list.\n  // Compare case-insensitively since HF cache lowercases repo IDs.\n  const downloadedSet = useMemo(() => {\n    const s = new Set<string>();\n    for (const c of cachedGguf) s.add(c.repo_id.toLowerCase());\n    for (const c of cachedModels) s.add(c.repo_id.toLowerCase());\n    return s;\n  }, [cachedGguf, cachedModels]);\n\n  const chatOnly = usePlatformStore((s) => s.isChatOnly());\n\n  const recommendedIds = useMemo(() => {\n    const all = dedupe([...models.map((model) => model.id), value ?? \"\"])\n      .filter((id) => !downloadedSet.has(id.toLowerCase()))\n      .filter((id) => !chatOnly || isGgufRepo(id));\n    // Sort: GGUFs first, then hub models\n    const gguf: string[] = [];\n    const hub: string[] = [];\n    for (const id of all) {\n      if (isGgufRepo(id)) gguf.push(id);\n      else hub.push(id);\n    }\n    return [...gguf, ...hub];\n  }, [models, value, downloadedSet, chatOnly]);\n\n  // Infinite scroll paging for the recommended section\n  const [recommendedPage, setRecommendedPage] = useState(1);\n  // Reset page when the underlying list changes\n  useEffect(() => { setRecommendedPage(1); }, [models, chatOnly]);\n\n  const visibleRecommendedIds = useMemo(() => {\n    const hubStartIndex = recommendedIds.findIndex((id) => !isGgufRepo(id));\n    const allGguf = hubStartIndex === -1 ? recommendedIds : recommendedIds.slice(0, hubStartIndex);\n    const allHub = hubStartIndex === -1 ? [] : recommendedIds.slice(hubStartIndex);\n    // Interleave in chunks of 4: [4 gguf, 4 hub, 4 gguf, 4 hub, ...]\n    const result: string[] = [];\n    for (let p = 0; p < recommendedPage; p++) {\n      result.push(...allGguf.slice(p * 4, (p + 1) * 4));\n      result.push(...allHub.slice(p * 4, (p + 1) * 4));\n    }\n    return result;\n  }, [recommendedIds, recommendedPage]);\n\n  const hasMoreRecommended = visibleRecommendedIds.length < recommendedIds.length;\n\n  // Fetch VRAM info for the full pool once (recommendedIds is stable across\n  // page increments) so we don't re-fetch on every scroll.\n  const { paramCountById: recommendedParamCountById } =\n    useRecommendedModelVram(recommendedIds);\n\n  const showHfSection = debouncedQuery.trim().length > 0;\n  const recommendedSet = useMemo(() => new Set(visibleRecommendedIds), [visibleRecommendedIds]);\n\n  const hfIds = useMemo(() => {\n    if (!showHfSection) return [];\n    return results\n      .map((result) => result.id)\n      .filter((id) => !recommendedSet.has(id))\n      .filter((id) => !chatOnly || isGgufRepo(id));\n  }, [recommendedSet, results, showHfSection, chatOnly]);\n\n  const metricsById = useMemo(\n    () =>\n      new Map(\n        results\n          .filter((result) => result.totalParams || result.estimatedSizeBytes)\n          .map((result) => [\n            result.id,\n            result.estimatedSizeBytes\n              ? `~${formatBytes(result.estimatedSizeBytes)}`\n              : formatCompact(result.totalParams!),\n          ]),\n      ),\n    [results],\n  );\n\n  const vramMap = useMemo(() => {\n    const map = new Map<\n      string,\n      { est: number; status: VramFitStatus | null; detail: string | null }\n    >();\n    for (const r of results) {\n      const detail = r.totalParams ? formatCompact(r.totalParams) : null;\n      if (r.totalParams) {\n        const est = estimateLoadingVram(r.totalParams, \"qlora\");\n        const status = gpu.available\n          ? checkVramFit(est, gpu.memoryTotalGb)\n          : null;\n        map.set(r.id, { est, status, detail });\n      } else {\n        map.set(r.id, { est: 0, status: null, detail });\n      }\n    }\n    return map;\n  }, [results, gpu]);\n\n  const recommendedVramMap = useMemo(() => {\n    const map = new Map<\n      string,\n      { est: number; status: VramFitStatus | null; detail: string | null }\n    >();\n    for (const id of visibleRecommendedIds) {\n      const totalParams = recommendedParamCountById.get(id);\n      if (totalParams) {\n        const est = estimateLoadingVram(totalParams, \"qlora\");\n        const status = gpu.available\n          ? checkVramFit(est, gpu.memoryTotalGb)\n          : null;\n        const detail = formatCompact(totalParams);\n        map.set(id, { est, status, detail });\n      }\n    }\n    return map;\n  }, [visibleRecommendedIds, recommendedParamCountById, gpu]);\n\n  const { scrollRef, sentinelRef } = useInfiniteScroll(fetchMore, results.length);\n\n  // Sentinel + IntersectionObserver for recommended infinite scroll.\n  // We disconnect after each fire so the observer doesn't loop while\n  // React re-renders; the effect re-creates it on the next page.\n  // Uses a callback ref for the sentinel so we detect mount/unmount reliably.\n  const [recommendedSentinel, setRecommendedSentinel] = useState<HTMLDivElement | null>(null);\n  const recommendedSentinelRef = useCallback((node: HTMLDivElement | null) => {\n    setRecommendedSentinel(node);\n  }, []);\n  useEffect(() => {\n    if (!recommendedSentinel || !hasMoreRecommended) return;\n    const root = scrollRef.current;\n    if (!root) return;\n    const obs = new IntersectionObserver(\n      ([e]) => {\n        if (e.isIntersecting) {\n          obs.disconnect();\n          setRecommendedPage((p) => p + 1);\n        }\n      },\n      { threshold: 0, root },\n    );\n    // Small delay so the browser finishes layout after the previous page render\n    const timer = setTimeout(() => obs.observe(recommendedSentinel), 100);\n    return () => { clearTimeout(timer); obs.disconnect(); };\n  }, [recommendedSentinel, hasMoreRecommended, recommendedPage, scrollRef]);\n\n  /** Handle clicking a model row — GGUF repos expand, others load directly. */\n  const handleModelClick = useCallback(\n    (id: string) => {\n      if (isGgufRepo(id)) {\n        // Toggle GGUF variant expander\n        setExpandedGguf((prev) => (prev === id ? null : id));\n      } else {\n        onSelect(id, { source: \"hub\", isLora: false });\n      }\n    },\n    [onSelect],\n  );\n\n  return (\n    <div className=\"space-y-2\">\n      <div className=\"relative\">\n        <HugeiconsIcon\n          icon={Search01Icon}\n          className=\"pointer-events-none absolute left-2.5 top-2.5 size-4 text-muted-foreground\"\n        />\n        <Input\n          value={query}\n          onChange={(event) => setQuery(event.target.value)}\n          placeholder=\"Search Hugging Face models\"\n          className=\"h-9 pl-8 pr-8\"\n        />\n        {isLoading && (\n          <Spinner className=\"pointer-events-none absolute right-2.5 top-2.5 size-4 text-muted-foreground\" />\n        )}\n      </div>\n\n      <div ref={scrollRef} className=\"max-h-64 overflow-y-auto\">\n        <div className=\"p-1\">\n          {!cachedReady && !showHfSection ? (\n            <div className=\"flex items-center gap-2 px-5 py-3\">\n              <Spinner className=\"size-3 text-muted-foreground\" />\n              <span className=\"text-xs text-muted-foreground\">Loading models…</span>\n            </div>\n          ) : !showHfSection && (cachedGguf.length > 0 || (!chatOnly && cachedModels.length > 0)) ? (\n            <>\n              <ListLabel>{\"\\uD83E\\uDDA5\"} Downloaded</ListLabel>\n              {cachedGguf.map((c) => (\n                <div key={c.repo_id}>\n                  <ModelRow\n                    label={c.repo_id}\n                    meta={`GGUF · ${formatBytes(c.size_bytes)}`}\n                    selected={value === c.repo_id}\n                    onClick={() => handleModelClick(c.repo_id)}\n                    vramStatus={null}\n                  />\n                  {expandedGguf === c.repo_id && (\n                    <GgufVariantExpander\n                      repoId={c.repo_id}\n                      onSelect={onSelect}\n                      gpuGb={gpu.available ? gpu.memoryTotalGb : undefined}\n                      systemRamGb={gpu.available ? gpu.systemRamAvailableGb : undefined}\n                      onDeleteVariant={(quant) => setDeleteTarget(`${c.repo_id}::${quant}`)}\n                    />\n                  )}\n                </div>\n              ))}\n              {!chatOnly && cachedModels.map((c) => (\n                <div key={c.repo_id} className=\"flex items-center gap-0.5\">\n                  <div className=\"min-w-0 flex-1\">\n                    <ModelRow\n                      label={c.repo_id}\n                      meta={formatBytes(c.size_bytes)}\n                      selected={value === c.repo_id}\n                      onClick={() => onSelect(c.repo_id, { source: \"hub\", isLora: false, isDownloaded: true })}\n                      vramStatus={null}\n                    />\n                  </div>\n                  <button\n                    type=\"button\"\n                    onClick={(e) => { e.stopPropagation(); setDeleteTarget(c.repo_id); }}\n                    className=\"shrink-0 rounded-md p-1.5 text-muted-foreground/60 transition-colors hover:bg-destructive/10 hover:text-destructive\"\n                  >\n                    <Trash2Icon className=\"size-3.5\" />\n                  </button>\n                </div>\n              ))}\n            </>\n          ) : null}\n\n          {!showHfSection && cachedReady ? (\n            <>\n              <ListLabel>{\"\\uD83E\\uDDA5\"} Recommended</ListLabel>\n              {visibleRecommendedIds.length === 0 ? (\n                <div className=\"px-2.5 py-2 text-xs text-muted-foreground\">\n                  No default models.\n                </div>\n              ) : (\n                visibleRecommendedIds.map((id) => {\n                  const vram = recommendedVramMap.get(id);\n                  return (\n                    <div key={id}>\n                      <ModelRow\n                        label={id}\n                        meta={\n                          isGgufRepo(id)\n                            ? \"GGUF\"\n                            : vram?.detail ?? extractParamLabel(id)\n                        }\n                        selected={value === id}\n                        onClick={() => handleModelClick(id)}\n                        vramStatus={isGgufRepo(id) ? null : vram?.status ?? null}\n                        vramEst={isGgufRepo(id) ? undefined : vram?.est}\n                        gpuGb={gpu.available ? gpu.memoryTotalGb : undefined}\n                      />\n                      {expandedGguf === id && (\n                        <GgufVariantExpander repoId={id} onSelect={onSelect} gpuGb={gpu.available ? gpu.memoryTotalGb : undefined} systemRamGb={gpu.available ? gpu.systemRamAvailableGb : undefined} />\n                      )}\n                    </div>\n                  );\n                })\n              )}\n              {hasMoreRecommended && (\n                <>\n                  <div ref={recommendedSentinelRef} className=\"h-px\" />\n                  <div className=\"flex items-center justify-center py-2\">\n                    <Spinner className=\"size-3.5 text-muted-foreground\" />\n                  </div>\n                </>\n              )}\n            </>\n          ) : null}\n\n          {showHfSection ? (\n            <>\n              <ListLabel>Hugging Face</ListLabel>\n              {hfIds.length === 0 && !isLoading ? (\n                <div className=\"px-2.5 py-2 text-xs text-muted-foreground\">\n                  No matching models.\n                </div>\n              ) : (\n                hfIds.map((id) => {\n                  const vram = vramMap.get(id);\n                  return (\n                    <div key={id}>\n                      <ModelRow\n                        label={id}\n                        meta={\n                          isGgufRepo(id)\n                            ? \"GGUF\"\n                            : metricsById.get(id) ?? extractParamLabel(id)\n                        }\n                        selected={value === id}\n                        onClick={() => handleModelClick(id)}\n                        vramStatus={isGgufRepo(id) ? null : vram?.status ?? null}\n                        vramEst={isGgufRepo(id) ? undefined : vram?.est}\n                        gpuGb={gpu.available ? gpu.memoryTotalGb : undefined}\n                      />\n                      {expandedGguf === id && (\n                        <GgufVariantExpander repoId={id} onSelect={onSelect} gpuGb={gpu.available ? gpu.memoryTotalGb : undefined} systemRamGb={gpu.available ? gpu.systemRamAvailableGb : undefined} />\n                      )}\n                    </div>\n                  );\n                })\n              )}\n              <div ref={sentinelRef} className=\"h-px\" />\n              {isLoadingMore ? (\n                <div className=\"flex items-center justify-center py-2\">\n                  <Spinner className=\"size-3.5 text-muted-foreground\" />\n                </div>\n              ) : null}\n            </>\n          ) : null}\n        </div>\n      </div>\n\n      <AlertDialog open={deleteTarget !== null} onOpenChange={(open) => { if (!open && !deleting) setDeleteTarget(null); }}>\n        <AlertDialogContent size=\"sm\">\n          <AlertDialogHeader>\n            <AlertDialogTitle>Delete cached model?</AlertDialogTitle>\n            <AlertDialogDescription>\n              This will remove <span className=\"font-medium text-foreground\">{deleteTarget?.includes(\"::\") ? `${deleteTarget.split(\"::\")[0]} (${deleteTarget.split(\"::\")[1]})` : deleteTarget}</span> from disk. You can re-download it later.\n            </AlertDialogDescription>\n          </AlertDialogHeader>\n          <AlertDialogFooter>\n            <AlertDialogCancel disabled={deleting}>No</AlertDialogCancel>\n            <AlertDialogAction\n              variant=\"destructive\"\n              disabled={deleting}\n              onClick={(e) => { e.preventDefault(); handleDeleteConfirm(); }}\n            >\n              {deleting ? \"Deleting...\" : \"Yes\"}\n            </AlertDialogAction>\n          </AlertDialogFooter>\n        </AlertDialogContent>\n      </AlertDialog>\n    </div>\n  );\n}\n\nexport function LoraModelPicker({\n  loraModels,\n  value,\n  onSelect,\n}: {\n  loraModels: LoraModelOption[];\n  value?: string;\n  onSelect: (id: string, meta: ModelSelectorChangeMeta) => void;\n}) {\n  const [query, setQuery] = useState(\"\");\n\n  const normalized = useMemo(\n    () =>\n      loraModels\n        .map((model) => ({\n          ...model,\n          baseModel: model.baseModel || model.description || \"Unknown base model\",\n        }))\n        .sort((a, b) => {\n          const aTime = a.updatedAt ?? -1;\n          const bTime = b.updatedAt ?? -1;\n          if (aTime !== bTime) return bTime - aTime;\n          const baseCmp = a.baseModel.localeCompare(b.baseModel);\n          if (baseCmp !== 0) return baseCmp;\n          return a.name.localeCompare(b.name);\n        }),\n    [loraModels],\n  );\n\n  const grouped = useMemo(() => {\n    const needle = query.trim().toLowerCase();\n    const out = new Map<string, LoraModelOption[]>();\n\n    for (const model of normalized) {\n      const searchText = `${model.name} ${model.baseModel} ${model.id}`.toLowerCase();\n      if (needle && !searchText.includes(needle)) continue;\n\n      const key = model.baseModel || \"Unknown base model\";\n      const prev = out.get(key) ?? [];\n      prev.push(model);\n      out.set(key, prev);\n    }\n\n    return [...out.entries()].sort((a, b) => {\n      const aLatest = Math.max(...a[1].map((model) => model.updatedAt ?? -1));\n      const bLatest = Math.max(...b[1].map((model) => model.updatedAt ?? -1));\n      if (aLatest !== bLatest) return bLatest - aLatest;\n      return a[0].localeCompare(b[0]);\n    });\n  }, [normalized, query]);\n\n  return (\n    <div className=\"space-y-2\">\n      <div className=\"relative\">\n        <HugeiconsIcon\n          icon={Search01Icon}\n          className=\"pointer-events-none absolute left-2.5 top-2.5 size-4 text-muted-foreground\"\n        />\n        <Input\n          value={query}\n          onChange={(event) => setQuery(event.target.value)}\n          placeholder=\"Search local adapters\"\n          className=\"h-9 pl-8\"\n        />\n      </div>\n\n      <div className=\"max-h-64 overflow-y-auto\">\n        <div className=\"p-1\">\n          {grouped.length === 0 ? (\n            <div className=\"px-2.5 py-2 text-xs text-muted-foreground\">\n              No adapters found.\n            </div>\n          ) : (\n            grouped.map(([baseModel, adapters], index) => (\n              <div key={baseModel}>\n                {index > 0 ? <div className=\"my-1\" /> : null}\n                <ListLabel>{baseModel}</ListLabel>\n                {adapters.map((adapter) => {\n                  const isExported = adapter.source === \"exported\";\n                  const isMerged = adapter.exportType === \"merged\";\n                  const isGguf = adapter.exportType === \"gguf\";\n                  const tag = isGguf\n                    ? \"GGUF\"\n                    : isExported\n                      ? isMerged ? \"Merged\" : \"LoRA\"\n                      : \"LoRA\";\n                  const meta = isExported ? `${tag} · Exported` : tag;\n                  return (\n                    <ModelRow\n                      key={adapter.id}\n                      label={adapter.name}\n                      meta={meta}\n                      selected={value === adapter.id}\n                      onClick={() => onSelect(adapter.id, {\n                        source: isExported ? \"exported\" : \"lora\",\n                        isLora: !isMerged && !isGguf,\n                      })}\n                      tooltipText={\n                        <>\n                          <span className=\"block break-words\">{adapter.name}</span>\n                          <span className=\"block mt-1 text-[10px] text-muted-foreground break-all\">\n                            {adapter.id}\n                          </span>\n                        </>\n                      }\n                    />\n                  );\n                })}\n              </div>\n            ))\n          )}\n        </div>\n      </div>\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/components/assistant-ui/model-selector/types.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { ReactNode } from \"react\";\n\nexport interface ModelOption {\n  id: string;\n  name: string;\n  description?: string;\n  icon?: ReactNode;\n}\n\nexport interface LoraModelOption extends ModelOption {\n  baseModel?: string;\n  updatedAt?: number;\n  source?: \"training\" | \"exported\";\n  exportType?: \"lora\" | \"merged\" | \"gguf\";\n}\n\nexport interface ModelSelectorChangeMeta {\n  source: \"hub\" | \"lora\" | \"exported\";\n  isLora: boolean;\n  ggufVariant?: string;\n  isDownloaded?: boolean;\n  expectedBytes?: number;\n}\n\n"
  },
  {
    "path": "studio/frontend/src/components/assistant-ui/model-selector.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"use client\";\n\nimport {\n  Popover,\n  PopoverContent,\n  PopoverTrigger,\n} from \"@/components/ui/popover\";\nimport { Tabs, TabsContent, TabsList, TabsTrigger } from \"@/components/ui/tabs\";\nimport { usePlatformStore } from \"@/config/env\";\nimport { cn } from \"@/lib/utils\";\nimport {\n  ArrowDown01Icon,\n  Logout01Icon,\n} from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport { useMemo, useState } from \"react\";\nimport type {\n  LoraModelOption,\n  ModelOption,\n  ModelSelectorChangeMeta,\n} from \"./model-selector/types\";\nimport { HubModelPicker, LoraModelPicker } from \"./model-selector/pickers\";\n\nexport type { LoraModelOption, ModelOption, ModelSelectorChangeMeta } from \"./model-selector/types\";\n\ninterface ModelSelectorProps {\n  models: ModelOption[];\n  loraModels?: LoraModelOption[];\n  value?: string;\n  defaultValue?: string;\n  activeGgufVariant?: string | null;\n  onValueChange?: (value: string, meta: ModelSelectorChangeMeta) => void;\n  onEject?: () => void;\n  variant?: \"outline\" | \"ghost\" | \"muted\";\n  size?: \"sm\" | \"default\" | \"lg\";\n  className?: string;\n  contentClassName?: string;\n  open?: boolean;\n  onOpenChange?: (open: boolean) => void;\n  triggerDataTour?: string;\n  contentDataTour?: string;\n}\n\nfunction ModelSelectorTrigger({\n  currentModel,\n  isLoaded,\n  variant = \"outline\",\n  size = \"default\",\n  className,\n  dataTour,\n}: {\n  currentModel?: ModelOption;\n  isLoaded: boolean;\n  variant?: \"outline\" | \"ghost\" | \"muted\";\n  size?: \"sm\" | \"default\" | \"lg\";\n  className?: string;\n  dataTour?: string;\n}) {\n  return (\n    <PopoverTrigger asChild={true}>\n      <button\n        type=\"button\"\n        data-tour={dataTour}\n        className={cn(\n          \"flex items-center gap-2 transition-colors\",\n          variant === \"outline\" &&\n          \"rounded-full border border-border/60 hover:bg-accent\",\n          variant === \"ghost\" && \"rounded-md hover:bg-accent\",\n          variant === \"muted\" && \"rounded-md bg-muted hover:bg-muted/80\",\n          size === \"sm\" && \"h-8 px-3 text-xs\",\n          size === \"default\" && \"h-9 px-3.5 text-sm\",\n          size === \"lg\" && \"h-10 px-4 text-sm\",\n          className,\n        )}\n      >\n        {isLoaded && (\n          <span className=\"size-2 shrink-0 rounded-full bg-emerald-500\" />\n        )}\n        <span className={isLoaded ? \"text-foreground\" : \"text-muted-foreground\"}>\n          {currentModel?.name ?? \"Select model...\"}\n        </span>\n        {currentModel?.description && (\n          <span className=\"text-muted-foreground text-xs\">{currentModel.description}</span>\n        )}\n        <HugeiconsIcon\n          icon={ArrowDown01Icon}\n          className=\"size-3 shrink-0 text-muted-foreground\"\n        />\n      </button>\n    </PopoverTrigger>\n  );\n}\n\nfunction ModelSelectorContent({\n  models,\n  loraModels,\n  value,\n  onSelect,\n  onEject,\n  className,\n  dataTour,\n}: {\n  models: ModelOption[];\n  loraModels: LoraModelOption[];\n  value?: string;\n  onSelect: (id: string, meta: ModelSelectorChangeMeta) => void;\n  onEject?: () => void;\n  className?: string;\n  dataTour?: string;\n}) {\n  const hasSelection = Boolean(value);\n  const chatOnly = usePlatformStore((s) => s.isChatOnly());\n\n  return (\n    <PopoverContent\n      align=\"start\"\n      data-tour={dataTour}\n      className={cn(\n        \"w-[min(440px,calc(100vw-1rem))] max-w-[calc(100vw-1rem)] min-w-0 gap-0 p-2\",\n        className,\n      )}\n    >\n      {chatOnly ? (\n        <HubModelPicker models={models} value={value} onSelect={onSelect} />\n      ) : (\n        <Tabs defaultValue=\"hub\" className=\"w-full\">\n          <TabsList className=\"mb-2 w-full\">\n            <TabsTrigger value=\"hub\">Hub models</TabsTrigger>\n            <TabsTrigger value=\"lora\">Fine-tuned</TabsTrigger>\n          </TabsList>\n\n          <TabsContent value=\"hub\" className=\"m-0\">\n            <HubModelPicker models={models} value={value} onSelect={onSelect} />\n          </TabsContent>\n\n          <TabsContent value=\"lora\" className=\"m-0\">\n            <LoraModelPicker\n              loraModels={loraModels}\n              value={value}\n              onSelect={onSelect}\n            />\n          </TabsContent>\n        </Tabs>\n      )}\n\n      {hasSelection && onEject ? (\n        <div className=\"mt-2 border-t border-border/70 pt-2\">\n          <button\n            type=\"button\"\n            onClick={onEject}\n            className=\"flex w-full items-center justify-center gap-1.5 rounded-md px-2 py-1.5 text-xs text-destructive transition-colors hover:bg-destructive/10\"\n            title=\"Eject model\"\n          >\n            <HugeiconsIcon icon={Logout01Icon} className=\"size-3.5\" />\n            Eject loaded model\n          </button>\n        </div>\n      ) : null}\n    </PopoverContent>\n  );\n}\n\nexport function ModelSelector({\n  models,\n  loraModels = [],\n  value,\n  defaultValue,\n  activeGgufVariant,\n  onValueChange,\n  onEject,\n  variant = \"outline\",\n  size = \"default\",\n  className,\n  contentClassName,\n  open: controlledOpen,\n  onOpenChange,\n  triggerDataTour,\n  contentDataTour,\n}: ModelSelectorProps) {\n  const [uncontrolledOpen, setUncontrolledOpen] = useState(false);\n  const open = controlledOpen ?? uncontrolledOpen;\n  const setOpen = onOpenChange ?? setUncontrolledOpen;\n  const [uncontrolled, setUncontrolled] = useState(defaultValue ?? \"\");\n\n  const selected = value ?? uncontrolled;\n  const isLoaded = selected !== \"\";\n\n  const optionById = useMemo(() => {\n    const all = new Map<string, ModelOption>();\n    for (const model of models) {\n      all.set(model.id, model);\n    }\n    for (const lora of loraModels) {\n      // Strip \"/ suffix\" from display name (e.g. \"foo_123/foo\" → \"foo_123\")\n      const displayName = lora.name.includes(\"/\")\n        ? lora.name.split(\"/\")[0].trim()\n        : lora.name;\n      // Show type tag instead of base model name\n      const isExported = lora.source === \"exported\";\n      const isMerged = lora.exportType === \"merged\";\n      const tag = isExported\n        ? isMerged ? \"Merged · Exported\" : \"LoRA\"\n        : \"LoRA\";\n      all.set(lora.id, {\n        ...lora,\n        name: displayName,\n        description: tag,\n      });\n    }\n    return all;\n  }, [loraModels, models]);\n\n  const currentModel = useMemo(() => {\n    if (!selected) return undefined;\n    const found = optionById.get(selected);\n    if (activeGgufVariant) {\n      const desc = `GGUF · ${activeGgufVariant}`;\n      return found ? { ...found, description: desc } : { id: selected, name: selected, description: desc };\n    }\n    return found ?? { id: selected, name: selected };\n  }, [selected, optionById, activeGgufVariant]);\n\n  function handleSelect(id: string, meta: ModelSelectorChangeMeta) {\n    if (onValueChange) {\n      onValueChange(id, meta);\n    } else {\n      setUncontrolled(id);\n    }\n    setOpen(false);\n  }\n\n  function handleEject() {\n    onEject?.();\n    setOpen(false);\n  }\n\n  return (\n    <Popover open={open} onOpenChange={setOpen}>\n      <ModelSelectorTrigger\n        currentModel={currentModel}\n        isLoaded={isLoaded}\n        variant={variant}\n        size={size}\n        className={className}\n        dataTour={triggerDataTour}\n      />\n      <ModelSelectorContent\n        models={models}\n        loraModels={loraModels}\n        value={selected}\n        onSelect={handleSelect}\n        onEject={onEject ? handleEject : undefined}\n        className={contentClassName}\n        dataTour={contentDataTour}\n      />\n    </Popover>\n  );\n}\n\nModelSelector.Trigger = ModelSelectorTrigger;\nModelSelector.Content = ModelSelectorContent;\n"
  },
  {
    "path": "studio/frontend/src/components/assistant-ui/reasoning.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"use client\";\n\n/* eslint-disable react-refresh/only-export-components */\n\nimport { MarkdownText } from \"@/components/assistant-ui/markdown-text\";\nimport { AnimatedShinyText } from \"@/components/ui/animated-shiny-text\";\nimport {\n  Collapsible,\n  CollapsibleContent,\n  CollapsibleTrigger,\n} from \"@/components/ui/collapsible\";\nimport { cn } from \"@/lib/utils\";\nimport {\n  type ReasoningGroupComponent,\n  type ReasoningMessagePartComponent,\n  useAuiState,\n  useScrollLock,\n} from \"@assistant-ui/react\";\nimport { copyToClipboard } from \"@/lib/copy-to-clipboard\";\nimport { Idea01Icon } from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport { type VariantProps, cva } from \"class-variance-authority\";\nimport { ChevronDownIcon, CopyIcon, CheckIcon } from \"lucide-react\";\nimport {\n  type CSSProperties,\n  type ComponentProps,\n  memo,\n  useCallback,\n  useEffect,\n  useRef,\n  useState,\n} from \"react\";\nconst ANIMATION_DURATION = 200;\n\nexport const reasoningVariants = cva(\"aui-reasoning-root mb-4 w-full\", {\n  variants: {\n    variant: {\n      outline: \"rounded-lg border px-3 py-2\",\n      ghost: \"\",\n      muted: \"rounded-lg bg-muted/50 px-3 py-2\",\n    },\n  },\n  defaultVariants: {\n    variant: \"outline\",\n  },\n});\n\nexport type ReasoningRootProps = Omit<\n  ComponentProps<typeof Collapsible>,\n  \"open\" | \"onOpenChange\"\n> &\n  VariantProps<typeof reasoningVariants> & {\n    open?: boolean;\n    onOpenChange?: (open: boolean) => void;\n    defaultOpen?: boolean;\n  };\n\nfunction ReasoningRoot({\n  className,\n  variant,\n  open: controlledOpen,\n  onOpenChange: controlledOnOpenChange,\n  defaultOpen = false,\n  children,\n  ...props\n}: ReasoningRootProps) {\n  const collapsibleRef = useRef<HTMLDivElement>(null);\n  const [uncontrolledOpen, setUncontrolledOpen] = useState(defaultOpen);\n  const lockScroll = useScrollLock(collapsibleRef, ANIMATION_DURATION);\n\n  const isControlled = controlledOpen !== undefined;\n  const isOpen = isControlled ? controlledOpen : uncontrolledOpen;\n\n  const handleOpenChange = useCallback(\n    (open: boolean) => {\n      if (!open) {\n        lockScroll();\n      }\n      if (!isControlled) {\n        setUncontrolledOpen(open);\n      }\n      controlledOnOpenChange?.(open);\n    },\n    [lockScroll, isControlled, controlledOnOpenChange],\n  );\n\n  return (\n    <Collapsible\n      ref={collapsibleRef}\n      data-slot=\"reasoning-root\"\n      data-variant={variant}\n      open={isOpen}\n      onOpenChange={handleOpenChange}\n      className={cn(\n        \"group/reasoning-root\",\n        reasoningVariants({ variant, className }),\n      )}\n      style={\n        {\n          \"--animation-duration\": `${ANIMATION_DURATION}ms`,\n        } as CSSProperties\n      }\n      {...props}\n    >\n      {children}\n    </Collapsible>\n  );\n}\n\nfunction ReasoningFade({ className, ...props }: ComponentProps<\"div\">) {\n  return (\n    <div\n      data-slot=\"reasoning-fade\"\n      className={cn(\n        \"aui-reasoning-fade pointer-events-none absolute inset-x-0 bottom-0 z-10 h-8\",\n        \"bg-gradient-to-t from-background to-transparent\",\n        className,\n      )}\n      {...props}\n    />\n  );\n}\n\nfunction ReasoningFadeTop({ className, ...props }: ComponentProps<\"div\">) {\n  return (\n    <div\n      data-slot=\"reasoning-fade-top\"\n      className={cn(\n        \"aui-reasoning-fade-top pointer-events-none absolute inset-x-0 top-0 z-10 h-8\",\n        \"bg-gradient-to-b from-background to-transparent\",\n        className,\n      )}\n      {...props}\n    />\n  );\n}\n\nfunction ReasoningTrigger({\n  active,\n  duration,\n  className,\n  ...props\n}: ComponentProps<typeof CollapsibleTrigger> & {\n  active?: boolean;\n  duration?: number;\n}) {\n  return (\n    <CollapsibleTrigger\n      data-slot=\"reasoning-trigger\"\n      className={cn(\n        \"aui-reasoning-trigger group/trigger flex max-w-[75%] items-center gap-2 py-1 text-muted-foreground text-sm transition-colors hover:text-foreground\",\n        className,\n      )}\n      {...props}\n    >\n      <HugeiconsIcon\n        icon={Idea01Icon}\n        className=\"aui-reasoning-trigger-icon size-4 shrink-0\"\n      />\n      <span\n        data-slot=\"reasoning-trigger-label\"\n        className=\"aui-reasoning-trigger-label-wrapper relative inline-block leading-none\"\n      >\n        {active ? (\n          <AnimatedShinyText className=\"text-sm\">Thinking...</AnimatedShinyText>\n        ) : (\n          <span>Thought for {duration ?? 0} seconds</span>\n        )}\n      </span>\n      <ChevronDownIcon\n        data-slot=\"reasoning-trigger-chevron\"\n        className={cn(\n          \"aui-reasoning-trigger-chevron mt-0.5 size-4 shrink-0\",\n          \"transition-transform duration-(--animation-duration) ease-out\",\n          \"group-data-[state=closed]/trigger:-rotate-90\",\n          \"group-data-[state=open]/trigger:rotate-0\",\n        )}\n      />\n    </CollapsibleTrigger>\n  );\n}\n\nfunction ReasoningContent({\n  className,\n  children,\n  streaming,\n  ...props\n}: ComponentProps<typeof CollapsibleContent> & { streaming?: boolean }) {\n  return (\n    <CollapsibleContent\n      data-slot=\"reasoning-content\"\n      className={cn(\n        \"aui-reasoning-content relative overflow-hidden text-muted-foreground text-sm outline-none\",\n        \"group/collapsible-content ease-out\",\n        \"data-[state=closed]:animate-collapsible-up\",\n        \"data-[state=open]:animate-collapsible-down\",\n        \"data-[state=closed]:fill-mode-forwards\",\n        \"data-[state=closed]:pointer-events-none\",\n        \"data-[state=open]:duration-(--animation-duration)\",\n        \"data-[state=closed]:duration-(--animation-duration)\",\n        className,\n      )}\n      {...props}\n    >\n      {streaming && <ReasoningFadeTop />}\n      {children}\n      <ReasoningFade />\n    </CollapsibleContent>\n  );\n}\n\nfunction ReasoningText({\n  className,\n  streaming,\n  children,\n  ...props\n}: ComponentProps<\"div\"> & { streaming?: boolean }) {\n  const scrollRef = useRef<HTMLDivElement>(null);\n\n  useEffect(() => {\n    if (!(streaming && scrollRef.current)) {\n      return;\n    }\n    const el = scrollRef.current;\n    const observer = new MutationObserver(() => {\n      el.scrollTop = el.scrollHeight;\n    });\n    observer.observe(el, {\n      childList: true,\n      subtree: true,\n      characterData: true,\n    });\n    el.scrollTop = el.scrollHeight;\n    return () => observer.disconnect();\n  }, [streaming]);\n\n  return (\n    <div\n      ref={scrollRef}\n      data-slot=\"reasoning-text\"\n      className={cn(\n        \"aui-reasoning-text relative z-0 overflow-y-auto pt-2 pb-2 pl-0 leading-relaxed\",\n        streaming ? \"max-h-32\" : \"max-h-64\",\n        \"transform-gpu transition-[transform,opacity]\",\n        \"group-data-[state=open]/collapsible-content:animate-in\",\n        \"group-data-[state=closed]/collapsible-content:animate-out\",\n        \"group-data-[state=open]/collapsible-content:fade-in-0\",\n        \"group-data-[state=closed]/collapsible-content:fade-out-0\",\n        \"group-data-[state=open]/collapsible-content:slide-in-from-top-4\",\n        \"group-data-[state=closed]/collapsible-content:slide-out-to-top-4\",\n        \"group-data-[state=open]/collapsible-content:duration-(--animation-duration)\",\n        \"group-data-[state=closed]/collapsible-content:duration-(--animation-duration)\",\n        className,\n      )}\n      {...props}\n    >\n      {children}\n    </div>\n  );\n}\n\nconst ReasoningImpl: ReasoningMessagePartComponent = () => <MarkdownText />;\n\nconst COPY_RESET_MS = 2000;\n\nfunction ReasoningCopyButton({ startIndex, endIndex }: { startIndex: number; endIndex: number }) {\n  const [copied, setCopied] = useState(false);\n  const resetRef = useRef<ReturnType<typeof setTimeout> | null>(null);\n\n  const reasoningText = useAuiState(({ message }) => {\n    return message.parts\n      .slice(startIndex, endIndex + 1)\n      .filter((p) => p.type === \"reasoning\")\n      .map((p) => (\"text\" in p ? (p as { text: string }).text : \"\"))\n      .join(\"\\n\");\n  });\n\n  const handleCopy = useCallback(() => {\n    if (copyToClipboard(reasoningText)) {\n      setCopied(true);\n      if (resetRef.current) clearTimeout(resetRef.current);\n      resetRef.current = setTimeout(() => setCopied(false), COPY_RESET_MS);\n    }\n  }, [reasoningText]);\n\n  return (\n    <button\n      type=\"button\"\n      onClick={handleCopy}\n      className=\"inline-flex items-center gap-1 rounded px-1.5 py-0.5 text-xs text-muted-foreground transition-colors hover:text-foreground hover:bg-muted\"\n      aria-label=\"Copy reasoning\"\n    >\n      {copied ? (\n        <CheckIcon className=\"size-3\" />\n      ) : (\n        <CopyIcon className=\"size-3\" />\n      )}\n      {copied ? \"Copied\" : \"Copy\"}\n    </button>\n  );\n}\n\nconst ReasoningGroupImpl: ReasoningGroupComponent = ({\n  children,\n  startIndex,\n  endIndex,\n}) => {\n  const isReasoningStreaming = useAuiState(({ message }) => {\n    if (message.status?.type !== \"running\") {\n      return false;\n    }\n    const lastIndex = message.parts.length - 1;\n    if (lastIndex < 0) {\n      return false;\n    }\n    const lastType = message.parts[lastIndex]?.type;\n    if (lastType !== \"reasoning\") {\n      return false;\n    }\n    return lastIndex >= startIndex && lastIndex <= endIndex;\n  });\n\n  const persistedDuration = useAuiState(({ message }) => {\n    const d = (message.metadata?.custom as Record<string, unknown>)\n      ?.reasoningDuration;\n    return typeof d === \"number\" ? d : 0;\n  });\n\n  const [manualOpen, setManualOpen] = useState(false);\n  const [duration, setDuration] = useState<number>(0);\n  const startTimeRef = useRef<number | null>(null);\n\n  useEffect(() => {\n    if (isReasoningStreaming) {\n      if (startTimeRef.current === null) {\n        startTimeRef.current = Date.now();\n      }\n    } else if (startTimeRef.current !== null) {\n      const elapsed = Math.round((Date.now() - startTimeRef.current) / 1000);\n      setDuration(elapsed);\n      startTimeRef.current = null;\n    }\n  }, [isReasoningStreaming]);\n\n  const isOpen = isReasoningStreaming || manualOpen;\n\n  const variant = isReasoningStreaming\n    ? \"outline\"\n    : manualOpen\n      ? \"outline\"\n      : \"ghost\";\n\n  const handleOpenChange = useCallback(\n    (open: boolean) => {\n      if (!isReasoningStreaming) {\n        setManualOpen(open);\n      }\n    },\n    [isReasoningStreaming],\n  );\n\n  return (\n    <ReasoningRoot\n      open={isOpen}\n      onOpenChange={handleOpenChange}\n      variant={variant}\n    >\n      <div className=\"flex items-center justify-between\">\n        <ReasoningTrigger\n          active={isReasoningStreaming}\n          duration={duration || persistedDuration}\n        />\n        {isOpen && !isReasoningStreaming && (\n          <ReasoningCopyButton startIndex={startIndex} endIndex={endIndex} />\n        )}\n      </div>\n      <ReasoningContent\n        aria-busy={isReasoningStreaming}\n        streaming={isReasoningStreaming}\n      >\n        <ReasoningText streaming={isReasoningStreaming}>\n          {children}\n        </ReasoningText>\n      </ReasoningContent>\n    </ReasoningRoot>\n  );\n};\n\nconst Reasoning = memo(\n  ReasoningImpl,\n) as unknown as ReasoningMessagePartComponent & {\n  Root: typeof ReasoningRoot;\n  Trigger: typeof ReasoningTrigger;\n  Content: typeof ReasoningContent;\n  Text: typeof ReasoningText;\n  Fade: typeof ReasoningFade;\n  FadeTop: typeof ReasoningFadeTop;\n};\n\nReasoning.displayName = \"Reasoning\";\nReasoning.Root = ReasoningRoot;\nReasoning.Trigger = ReasoningTrigger;\nReasoning.Content = ReasoningContent;\nReasoning.Text = ReasoningText;\nReasoning.Fade = ReasoningFade;\nReasoning.FadeTop = ReasoningFadeTop;\n\nconst ReasoningGroup = memo(ReasoningGroupImpl);\nReasoningGroup.displayName = \"ReasoningGroup\";\n\nexport {\n  Reasoning,\n  ReasoningGroup,\n  ReasoningRoot,\n  ReasoningTrigger,\n  ReasoningContent,\n  ReasoningText,\n  ReasoningFade,\n  ReasoningFadeTop,\n};\n"
  },
  {
    "path": "studio/frontend/src/components/assistant-ui/sources.tsx",
    "content": "\"use client\";\n\nimport { memo, useState, type ComponentProps } from \"react\";\nimport type { SourceMessagePartComponent } from \"@assistant-ui/react\";\nimport { cn } from \"@/lib/utils\";\nimport { Badge, badgeVariants, type BadgeProps } from \"./badge\";\n\nconst extractDomain = (url: string): string => {\n  try {\n    return new URL(url).hostname.replace(/^www\\./, \"\");\n  } catch {\n    return url;\n  }\n};\n\nconst getDomainInitial = (url: string): string => {\n  const domain = extractDomain(url);\n  return domain.charAt(0).toUpperCase();\n};\n\nfunction SourceIcon({\n  url,\n  className,\n  ...props\n}: ComponentProps<\"span\"> & { url: string }) {\n  const [hasError, setHasError] = useState(false);\n  const domain = extractDomain(url);\n\n  if (hasError) {\n    return (\n      <span\n        data-slot=\"source-icon-fallback\"\n        className={cn(\n          \"flex size-3 shrink-0 items-center justify-center rounded-sm bg-muted font-medium text-[10px]\",\n          className,\n        )}\n        {...props}\n      >\n        {getDomainInitial(url)}\n      </span>\n    );\n  }\n\n  return (\n    <img\n      data-slot=\"source-icon\"\n      src={`https://www.google.com/s2/favicons?domain=${domain}&sz=32`}\n      alt=\"\"\n      className={cn(\"size-3 shrink-0 rounded-sm\", className)}\n      onError={() => setHasError(true)}\n      {...(props as ComponentProps<\"img\">)}\n    />\n  );\n}\n\nfunction SourceTitle({ className, ...props }: ComponentProps<\"span\">) {\n  return (\n    <span\n      data-slot=\"source-title\"\n      className={cn(\"max-w-37.5 truncate\", className)}\n      {...props}\n    />\n  );\n}\n\nexport type SourceProps = Omit<BadgeProps, \"asChild\"> &\n  ComponentProps<\"a\"> & {\n    asChild?: boolean;\n  };\n\nfunction Source({\n  className,\n  variant,\n  size,\n  asChild = false,\n  target = \"_blank\",\n  rel = \"noopener noreferrer\",\n  ...props\n}: SourceProps) {\n  return (\n    <Badge\n      asChild\n      variant={variant}\n      size={size}\n      className={cn(\n        \"cursor-pointer outline-none focus-visible:border-ring focus-visible:ring-[3px] focus-visible:ring-ring/50\",\n        className,\n      )}\n    >\n      <a\n        data-slot=\"source\"\n        target={target}\n        rel={rel}\n        {...(props as ComponentProps<\"a\">)}\n      />\n    </Badge>\n  );\n}\n\nconst SourcesImpl: SourceMessagePartComponent = ({\n  url,\n  title,\n  sourceType,\n}) => {\n  if (sourceType !== \"url\" || !url) return null;\n\n  const domain = extractDomain(url);\n  const displayTitle = title || domain;\n\n  return (\n    <span className=\"mr-1 mt-1 inline-block first:mt-2\">\n      <Source href={url}>\n        <SourceIcon url={url} />\n        <SourceTitle>{displayTitle}</SourceTitle>\n      </Source>\n    </span>\n  );\n};\n\nconst Sources = memo(SourcesImpl) as unknown as SourceMessagePartComponent & {\n  Root: typeof Source;\n  Icon: typeof SourceIcon;\n  Title: typeof SourceTitle;\n};\n\nSources.displayName = \"Sources\";\nSources.Root = Source;\nSources.Icon = SourceIcon;\nSources.Title = SourceTitle;\n\nexport {\n  Sources,\n  Source,\n  SourceIcon,\n  SourceTitle,\n  badgeVariants as sourceVariants,\n};\n"
  },
  {
    "path": "studio/frontend/src/components/assistant-ui/thread.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport {\n  ComposerAddAttachment,\n  ComposerAttachments,\n  UserMessageAttachments,\n} from \"@/components/assistant-ui/attachment\";\nimport { MessageTiming } from \"@/components/assistant-ui/message-timing\";\nimport { MarkdownText } from \"@/components/assistant-ui/markdown-text\";\nimport { Reasoning, ReasoningGroup } from \"@/components/assistant-ui/reasoning\";\nimport { Sources } from \"@/components/assistant-ui/sources\";\nimport { ToolFallback } from \"@/components/assistant-ui/tool-fallback\";\nimport { ToolGroup } from \"@/components/assistant-ui/tool-group\";\nimport { WebSearchToolUI } from \"@/components/assistant-ui/tool-ui-web-search\";\nimport { PythonToolUI } from \"@/components/assistant-ui/tool-ui-python\";\nimport { TerminalToolUI } from \"@/components/assistant-ui/tool-ui-terminal\";\nimport { TooltipIconButton } from \"@/components/assistant-ui/tooltip-icon-button\";\nimport { Button } from \"@/components/ui/button\";\nimport { sentAudioNames } from \"@/features/chat/api/chat-adapter\";\nimport { AUDIO_ACCEPT, MAX_AUDIO_SIZE, fileToBase64 } from \"@/lib/audio-utils\";\nimport { copyToClipboard } from \"@/lib/copy-to-clipboard\";\nimport { cn } from \"@/lib/utils\";\nimport {\n  ActionBarMorePrimitive,\n  ActionBarPrimitive,\n  AuiIf,\n  BranchPickerPrimitive,\n  ComposerPrimitive,\n  ErrorPrimitive,\n  MessagePrimitive,\n  SuggestionPrimitive,\n  ThreadPrimitive,\n  useAui,\n  useAuiEvent,\n  useAuiState,\n} from \"@assistant-ui/react\";\nimport { motion } from \"framer-motion\";\nimport {\n  ArrowDownIcon,\n  ArrowUpIcon,\n  CheckIcon,\n  ChevronLeftIcon,\n  ChevronRightIcon,\n  CopyIcon,\n  DownloadIcon,\n  GlobeIcon,\n  HeadphonesIcon,\n  LightbulbIcon,\n  LightbulbOffIcon,\n  MicIcon,\n  MoreHorizontalIcon,\n  LoaderIcon,\n  PencilIcon,\n  RefreshCwIcon,\n  SquareIcon,\n  TerminalIcon,\n  XIcon,\n} from \"lucide-react\";\nimport { type FC, useCallback, useEffect, useRef, useState } from \"react\";\nimport { useChatRuntimeStore } from \"@/features/chat/stores/chat-runtime-store\";\n\nexport const Thread: FC<{ hideComposer?: boolean; hideWelcome?: boolean }> = ({\n  hideComposer,\n  hideWelcome,\n}) => {\n  return (\n    <ThreadPrimitive.Root\n      className=\"aui-root aui-thread-root @container flex h-full flex-col \"\n      style={{\n        [\"--thread-max-width\" as string]: \"44rem\",\n      }}\n    >\n      <ThreadPrimitive.Viewport\n        className=\"aui-thread-viewport relative flex flex-1 flex-col overflow-x-auto overflow-y-scroll scroll-smooth px-4 pt-4\"\n      >\n        {!hideWelcome && (\n          <AuiIf condition={({ thread }) => thread.isEmpty}>\n            <ThreadWelcome hideComposer={hideComposer} />\n          </AuiIf>\n        )}\n\n        <ThreadPrimitive.Messages\n          components={{\n            UserMessage,\n            EditComposer,\n            AssistantMessage,\n          }}\n        />\n\n        <ThreadPrimitive.ViewportFooter className=\"aui-thread-viewport-footer sticky bottom-0 mt-auto flex w-full flex-col gap-4 overflow-visible bg-background pb-4 md:pb-4\">\n          <ThreadScrollToBottom />\n          <GeneratingSpinner />\n          <AuiIf condition={({ thread }) => !thread.isEmpty}>\n            {!hideComposer && <ComposerAnimated />}\n          </AuiIf>\n        </ThreadPrimitive.ViewportFooter>\n      </ThreadPrimitive.Viewport>\n    </ThreadPrimitive.Root>\n  );\n};\n\nconst ThreadScrollToBottom: FC = () => {\n  return (\n    <ThreadPrimitive.ScrollToBottom asChild={true}>\n      <TooltipIconButton\n        tooltip=\"Scroll to bottom\"\n        variant=\"outline\"\n        className=\"aui-thread-scroll-to-bottom absolute -top-12 z-10 self-center rounded-full p-4 disabled:invisible dark:bg-background dark:hover:bg-accent\"\n      >\n        <ArrowDownIcon />\n      </TooltipIconButton>\n    </ThreadPrimitive.ScrollToBottom>\n  );\n};\n\nconst SuggestionItem: FC = () => {\n  const aui = useAui();\n  const prompt = useAuiState(({ suggestion }) => suggestion.prompt);\n  const isDisabled = useAuiState(({ thread }) => thread.isDisabled);\n  const isRunning = useAuiState(({ thread }) => thread.isRunning);\n\n  return (\n    <button\n      type=\"button\"\n      onClick={() => {\n        if (!isDisabled && !isRunning) {\n          aui.thread().append(prompt);\n          aui.composer().setText(\"\");\n          return;\n        }\n        aui.composer().setText(prompt);\n      }}\n      className=\"fade-in slide-in-from-bottom-1 animate-in cursor-pointer corner-squircle rounded-xl border bg-background px-4 py-2.5 text-left text-sm text-foreground shadow-sm transition-colors duration-150 hover:bg-accent\"\n    >\n      <SuggestionPrimitive.Title />\n    </button>\n  );\n};\n\nconst ThreadWelcome: FC<{ hideComposer?: boolean }> = ({ hideComposer }) => {\n  return (\n    <div className=\"aui-thread-welcome-root mx-auto my-auto flex w-full max-w-(--thread-max-width) grow flex-col\">\n      <div className=\"aui-thread-welcome-center flex w-full grow flex-col items-center justify-center\">\n        <div className=\"aui-thread-welcome-message flex w-full flex-col justify-center gap-6 px-4\">\n          <div className=\"flex flex-col items-center gap-2 text-center\">\n            <img\n              src=\"/Sloth emojis/sloth pc square.png\"\n              alt=\"Sloth mascot\"\n              className=\"size-20\"\n            />\n            <h1 className=\"aui-thread-welcome-message-inner fade-in slide-in-from-bottom-1 animate-in font-semibold text-2xl duration-200\">\n              Chat with your model\n            </h1>\n            <p className=\"aui-thread-welcome-message-inner fade-in slide-in-from-bottom-1 animate-in text-muted-foreground text-base delay-75 duration-200\">\n              Run GGUFs, safetensors, vision and audio models!\n            </p>\n          </div>\n          <div className=\"grid grid-cols-2 gap-2\">\n            <ThreadPrimitive.Suggestions\n              components={{ Suggestion: SuggestionItem }}\n            />\n          </div>\n          <GeneratingSpinner />\n          {!hideComposer && <ComposerAnimated />}\n        </div>\n      </div>\n    </div>\n  );\n};\n\nconst GeneratingSpinner: FC = () => {\n  const status = useChatRuntimeStore((s) => s.generatingStatus);\n  if (!status) return null;\n  return (\n    <div className=\"mx-auto flex w-full max-w-(--thread-max-width) items-center justify-center py-2\">\n      <div className=\"flex items-center gap-2 text-xs text-muted-foreground\">\n        <LoaderIcon className=\"size-3.5 animate-spin\" />\n        <span>Generating</span>\n      </div>\n    </div>\n  );\n};\n\nconst ComposerAnimated: FC = () => {\n  return (\n    <motion.div\n      layout={true}\n      layoutId=\"composer\"\n      transition={{ type: \"spring\", bounce: 0.15, duration: 0.5 }}\n      className=\"mx-auto w-full max-w-(--thread-max-width)\"\n    >\n      <Composer />\n    </motion.div>\n  );\n};\n\nconst PendingAudioChip: FC = () => {\n  const audioName = useChatRuntimeStore((s) => s.pendingAudioName);\n  const clearPendingAudio = useChatRuntimeStore((s) => s.clearPendingAudio);\n  if (!audioName) return null;\n  return (\n    <div className=\"mb-2 flex w-full flex-row items-center gap-2 px-1.5 pt-0.5 pb-1\">\n      <div className=\"flex items-center gap-2 rounded-lg border border-foreground/20 bg-muted px-3 py-1.5 text-xs\">\n        <HeadphonesIcon className=\"size-3.5 text-muted-foreground\" />\n        <span className=\"max-w-48 truncate\">{audioName}</span>\n        <button\n          type=\"button\"\n          onClick={clearPendingAudio}\n          className=\"flex size-4 items-center justify-center rounded-full hover:bg-destructive hover:text-destructive-foreground\"\n          aria-label=\"Remove audio\"\n        >\n          <XIcon className=\"size-3\" />\n        </button>\n      </div>\n    </div>\n  );\n};\n\nconst Composer: FC = () => {\n  return (\n    <ComposerPrimitive.Root className=\"aui-composer-root relative flex w-full flex-col\">\n      <ComposerPrimitive.AttachmentDropzone className=\"aui-composer-attachment-dropzone shadow-border ring-1 ring-border flex w-full flex-col rounded-2xl bg-background px-1 pt-2 outline-none transition-shadow data-[dragging=true]:ring-ring data-[dragging=true]:bg-accent/50\">\n        <ComposerAttachments />\n        <PendingAudioChip />\n        <ToolStatusDisplay />\n        <ComposerPrimitive.Input\n          placeholder=\"Send a message...\"\n          className=\"aui-composer-input mb-1 max-h-32 min-h-12 w-full resize-none bg-transparent px-4 pt-2 pb-3 text-sm outline-none placeholder:text-muted-foreground focus-visible:ring-0\"\n          rows={1}\n          autoFocus={true}\n          aria-label=\"Message input\"\n        />\n        <ComposerAction />\n      </ComposerPrimitive.AttachmentDropzone>\n    </ComposerPrimitive.Root>\n  );\n};\n\nconst ComposerAudioUpload: FC = () => {\n  const audioInputRef = useRef<HTMLInputElement>(null);\n  const setPendingAudio = useChatRuntimeStore((s) => s.setPendingAudio);\n  const activeModel = useChatRuntimeStore((s) => {\n    const checkpoint = s.params.checkpoint;\n    return s.models.find((m) => m.id === checkpoint);\n  });\n\n  const handleAudioFile = useCallback(\n    async (file: File) => {\n      if (file.size > MAX_AUDIO_SIZE) return;\n      try {\n        const base64 = await fileToBase64(file);\n        setPendingAudio(base64, file.name);\n      } catch {\n        // skip\n      }\n    },\n    [setPendingAudio],\n  );\n\n  if (!activeModel?.hasAudioInput) return null;\n\n  return (\n    <>\n      <input\n        ref={audioInputRef}\n        type=\"file\"\n        accept={AUDIO_ACCEPT}\n        className=\"hidden\"\n        onChange={(e) => {\n          const file = e.target.files?.[0];\n          if (file) handleAudioFile(file);\n          e.target.value = \"\";\n        }}\n      />\n      <TooltipIconButton\n        tooltip=\"Upload audio\"\n        side=\"bottom\"\n        variant=\"ghost\"\n        size=\"icon\"\n        className=\"size-8.5 rounded-full p-1 text-muted-foreground hover:bg-muted-foreground/15\"\n        onClick={() => audioInputRef.current?.click()}\n        aria-label=\"Upload audio\"\n      >\n        <HeadphonesIcon className=\"size-4.5 stroke-[1.5px]\" />\n      </TooltipIconButton>\n    </>\n  );\n};\n\n/** Qwen3/3.5 recommended params differ between thinking on/off. */\nfunction applyQwenThinkingParams(thinkingOn: boolean): void {\n  const store = useChatRuntimeStore.getState();\n  const checkpoint = store.params.checkpoint?.toLowerCase() ?? \"\";\n  if (!checkpoint.includes(\"qwen3\")) return;\n  // Qwen3 & Qwen3.5 share the same recommended settings:\n  // Thinking ON (general): temp=1.0, top_p=0.95, top_k=20\n  // Thinking OFF (general): temp=0.7, top_p=0.8, top_k=20\n  const params = thinkingOn\n    ? { temperature: 0.6, topP: 0.95, topK: 20, minP: 0.0 }\n    : { temperature: 0.7, topP: 0.8, topK: 20, minP: 0.0 };\n  store.setParams({ ...store.params, ...params });\n}\n\nconst ReasoningToggle: FC = () => {\n  const supportsReasoning = useChatRuntimeStore((s) => s.supportsReasoning);\n  const reasoningEnabled = useChatRuntimeStore((s) => s.reasoningEnabled);\n  const setReasoningEnabled = useChatRuntimeStore((s) => s.setReasoningEnabled);\n\n  if (!supportsReasoning) return null;\n\n  return (\n    <button\n      type=\"button\"\n      onClick={() => {\n        const next = !reasoningEnabled;\n        setReasoningEnabled(next);\n        applyQwenThinkingParams(next);\n      }}\n      className={cn(\n        \"flex items-center gap-1.5 rounded-full px-2.5 py-1 text-xs font-medium transition-colors\",\n        reasoningEnabled\n          ? \"bg-primary/10 text-primary hover:bg-primary/20\"\n          : \"bg-muted text-muted-foreground hover:bg-muted-foreground/15\",\n      )}\n      aria-label={reasoningEnabled ? \"Disable thinking\" : \"Enable thinking\"}\n    >\n      {reasoningEnabled ? (\n        <LightbulbIcon className=\"size-3.5\" />\n      ) : (\n        <LightbulbOffIcon className=\"size-3.5\" />\n      )}\n      <span>Think</span>\n    </button>\n  );\n};\n\nconst WebSearchToggle: FC = () => {\n  const supportsTools = useChatRuntimeStore((s) => s.supportsTools);\n  const toolsEnabled = useChatRuntimeStore((s) => s.toolsEnabled);\n  const setToolsEnabled = useChatRuntimeStore((s) => s.setToolsEnabled);\n\n  if (!supportsTools) return null;\n\n  return (\n    <button\n      type=\"button\"\n      onClick={() => setToolsEnabled(!toolsEnabled)}\n      className={cn(\n        \"flex items-center gap-1.5 rounded-full px-2.5 py-1 text-xs font-medium transition-colors\",\n        toolsEnabled\n          ? \"bg-primary/10 text-primary hover:bg-primary/20\"\n          : \"bg-muted text-muted-foreground hover:bg-muted-foreground/15\",\n      )}\n      aria-label={toolsEnabled ? \"Disable web search\" : \"Enable web search\"}\n    >\n      <GlobeIcon className=\"size-3.5\" />\n      <span>Search</span>\n    </button>\n  );\n};\n\nconst CodeToolsToggle: FC = () => {\n  const supportsTools = useChatRuntimeStore((s) => s.supportsTools);\n  const codeToolsEnabled = useChatRuntimeStore((s) => s.codeToolsEnabled);\n  const setCodeToolsEnabled = useChatRuntimeStore(\n    (s) => s.setCodeToolsEnabled,\n  );\n\n  if (!supportsTools) return null;\n\n  return (\n    <button\n      type=\"button\"\n      onClick={() => setCodeToolsEnabled(!codeToolsEnabled)}\n      className={cn(\n        \"flex items-center gap-1.5 rounded-full px-2.5 py-1 text-xs font-medium transition-colors\",\n        codeToolsEnabled\n          ? \"bg-primary/10 text-primary hover:bg-primary/20\"\n          : \"bg-muted text-muted-foreground hover:bg-muted-foreground/15\",\n      )}\n      aria-label={codeToolsEnabled ? \"Disable code execution\" : \"Enable code execution\"}\n    >\n      <TerminalIcon className=\"size-3.5\" />\n      <span>Code</span>\n    </button>\n  );\n};\n\nconst ToolStatusDisplay: FC = () => {\n  const toolStatus = useChatRuntimeStore((s) => s.toolStatus);\n  const [elapsed, setElapsed] = useState(0);\n\n  useEffect(() => {\n    if (!toolStatus) {\n      setElapsed(0);\n      return;\n    }\n    setElapsed(0);\n    const interval = setInterval(() => {\n      setElapsed((prev) => prev + 1);\n    }, 1000);\n    return () => clearInterval(interval);\n  }, [toolStatus]);\n\n  if (!toolStatus) return null;\n  const isRunning = toolStatus.startsWith(\"Running\");\n  const StatusIcon = isRunning ? TerminalIcon : GlobeIcon;\n  return (\n    <div className=\"mb-2 flex w-full flex-row items-center gap-2 px-1.5 pt-0.5 pb-1\">\n      <div className=\"flex animate-pulse items-center gap-2 rounded-full border border-primary/20 bg-primary/5 px-3 py-1.5 text-xs text-primary\">\n        <StatusIcon className=\"size-3.5\" />\n        <span>{toolStatus}</span>\n        <span className=\"tabular-nums opacity-60\">{elapsed}s</span>\n      </div>\n    </div>\n  );\n};\n\nconst ComposerAction: FC = () => {\n  return (\n    <div className=\"aui-composer-action-wrapper relative mx-2 mb-2 flex items-center justify-between\">\n      <div className=\"flex items-center gap-1\">\n        <ComposerAddAttachment />\n        <ComposerAudioUpload />\n        <ReasoningToggle />\n        <WebSearchToggle />\n        <CodeToolsToggle />\n      </div>\n      <div className=\"flex items-center gap-1\">\n        <ComposerPrimitive.If dictation={false}>\n          <ComposerPrimitive.Dictate asChild={true}>\n            <TooltipIconButton\n              tooltip=\"Dictate\"\n              variant=\"ghost\"\n              className=\"size-8 rounded-full text-muted-foreground\"\n            >\n              <MicIcon className=\"size-4\" />\n            </TooltipIconButton>\n          </ComposerPrimitive.Dictate>\n        </ComposerPrimitive.If>\n        <ComposerPrimitive.If dictation={true}>\n          <ComposerPrimitive.StopDictation asChild={true}>\n            <TooltipIconButton\n              tooltip=\"Stop dictation\"\n              variant=\"ghost\"\n              className=\"size-8 rounded-full text-destructive\"\n            >\n              <SquareIcon className=\"size-3 animate-pulse fill-current\" />\n            </TooltipIconButton>\n          </ComposerPrimitive.StopDictation>\n        </ComposerPrimitive.If>\n        <AuiIf condition={({ thread }) => !thread.isRunning}>\n          <ComposerPrimitive.Send asChild={true}>\n            <TooltipIconButton\n              tooltip=\"Send message\"\n              side=\"bottom\"\n              type=\"submit\"\n              variant=\"default\"\n              size=\"icon\"\n              className=\"aui-composer-send size-8 rounded-full\"\n              aria-label=\"Send message\"\n            >\n              <ArrowUpIcon className=\"aui-composer-send-icon size-4\" />\n            </TooltipIconButton>\n          </ComposerPrimitive.Send>\n        </AuiIf>\n        <AuiIf condition={({ thread }) => thread.isRunning}>\n          <ComposerPrimitive.Cancel asChild={true}>\n            <Button\n              type=\"button\"\n              variant=\"default\"\n              size=\"icon\"\n              className=\"aui-composer-cancel size-8 rounded-full\"\n              aria-label=\"Stop generating\"\n            >\n              <SquareIcon className=\"aui-composer-cancel-icon size-3 fill-current\" />\n            </Button>\n          </ComposerPrimitive.Cancel>\n        </AuiIf>\n      </div>\n    </div>\n  );\n};\n\nconst MessageError: FC = () => {\n  return (\n    <MessagePrimitive.Error>\n      <ErrorPrimitive.Root className=\"aui-message-error-root mt-2 rounded-md border border-destructive bg-destructive/10 p-3 text-destructive text-sm dark:bg-destructive/5 dark:text-red-200\">\n        <ErrorPrimitive.Message className=\"aui-message-error-message line-clamp-2\" />\n      </ErrorPrimitive.Root>\n    </MessagePrimitive.Error>\n  );\n};\n\nconst AssistantMessage: FC = () => {\n  return (\n    <MessagePrimitive.Root\n      className=\"aui-assistant-message-root fade-in slide-in-from-bottom-1 relative mx-auto w-full max-w-(--thread-max-width) animate-in py-3 duration-150\"\n      data-role=\"assistant\"\n    >\n      <div className=\"aui-assistant-message-content wrap-break-word px-2 text-foreground leading-relaxed\">\n        <MessagePrimitive.Parts\n          components={{\n            Text: MarkdownText,\n            Reasoning: Reasoning,\n            ReasoningGroup: ReasoningGroup,\n            Source: Sources,\n            ToolGroup: ToolGroup,\n            tools: {\n              by_name: {\n                web_search: WebSearchToolUI,\n                python: PythonToolUI,\n                terminal: TerminalToolUI,\n              },\n              Fallback: ToolFallback,\n            },\n          }}\n        />\n        <MessageError />\n      </div>\n\n      <div className=\"aui-assistant-message-footer mt-1 ml-2 flex\">\n        <BranchPicker />\n        <AssistantActionBar />\n      </div>\n    </MessagePrimitive.Root>\n  );\n};\n\nconst COPY_RESET_MS = 2000;\n\nconst CopyButton: FC = () => {\n  const aui = useAui();\n  const [copied, setCopied] = useState(false);\n  const resetTimeoutRef = useRef<ReturnType<typeof setTimeout> | null>(null);\n\n  const handleCopy = () => {\n    const text = aui.message().getCopyText();\n    if (copyToClipboard(text)) {\n      setCopied(true);\n      if (resetTimeoutRef.current) clearTimeout(resetTimeoutRef.current);\n      resetTimeoutRef.current = setTimeout(() => {\n        setCopied(false);\n        resetTimeoutRef.current = null;\n      }, COPY_RESET_MS);\n    }\n  };\n\n  return (\n    <TooltipIconButton tooltip=\"Copy\" onClick={handleCopy}>\n      {copied ? <CheckIcon /> : <CopyIcon />}\n    </TooltipIconButton>\n  );\n};\n\nconst AssistantActionBar: FC = () => {\n  return (\n    <ActionBarPrimitive.Root\n      hideWhenRunning={true}\n      autohide=\"not-last\"\n      autohideFloat=\"single-branch\"\n      className=\"aui-assistant-action-bar-root col-start-3 row-start-2 -ml-1 flex gap-1 text-muted-foreground data-floating:absolute data-floating:rounded-md data-floating:border data-floating:bg-background data-floating:p-1 data-floating:shadow-sm\"\n    >\n      <CopyButton />\n      <ActionBarPrimitive.Reload asChild={true}>\n        <TooltipIconButton tooltip=\"Refresh\">\n          <RefreshCwIcon />\n        </TooltipIconButton>\n      </ActionBarPrimitive.Reload>\n      <MessageTiming side=\"top\" />\n      <ActionBarMorePrimitive.Root>\n        <ActionBarMorePrimitive.Trigger asChild={true}>\n          <TooltipIconButton\n            tooltip=\"More\"\n            className=\"data-[state=open]:bg-accent\"\n          >\n            <MoreHorizontalIcon />\n          </TooltipIconButton>\n        </ActionBarMorePrimitive.Trigger>\n        <ActionBarMorePrimitive.Content\n          side=\"bottom\"\n          align=\"start\"\n          className=\"aui-action-bar-more-content z-50 min-w-32 overflow-hidden rounded-md border bg-popover p-1 text-popover-foreground shadow-md\"\n        >\n          <ActionBarPrimitive.ExportMarkdown asChild={true}>\n            <ActionBarMorePrimitive.Item className=\"aui-action-bar-more-item flex cursor-pointer select-none items-center gap-2 rounded-sm px-2 py-1.5 text-sm outline-none hover:bg-accent hover:text-accent-foreground focus:bg-accent focus:text-accent-foreground\">\n              <DownloadIcon className=\"size-4\" />\n              Export as Markdown\n            </ActionBarMorePrimitive.Item>\n          </ActionBarPrimitive.ExportMarkdown>\n        </ActionBarMorePrimitive.Content>\n      </ActionBarMorePrimitive.Root>\n    </ActionBarPrimitive.Root>\n  );\n};\n\nconst UserMessageAudio: FC = () => {\n  const audioName = useAuiState(({ message }) => sentAudioNames.get(message.id));\n  if (!audioName) return null;\n  return (\n    <div className=\"col-start-2 flex justify-end\">\n      <div className=\"flex items-center gap-2 rounded-lg border border-foreground/20 bg-muted px-3 py-1.5 text-xs\">\n        <HeadphonesIcon className=\"size-3.5 text-muted-foreground\" />\n        <span className=\"max-w-48 truncate\">{audioName}</span>\n      </div>\n    </div>\n  );\n};\n\nconst UserMessage: FC = () => {\n  return (\n    <MessagePrimitive.Root\n      className=\"aui-user-message-root  fade-in slide-in-from-bottom-1 mx-auto grid w-full max-w-(--thread-max-width) animate-in auto-rows-auto grid-cols-[minmax(72px,1fr)_auto] content-start gap-y-2 px-2 py-3 duration-150 [&:where(>*)]:col-start-2\"\n      data-role=\"user\"\n    >\n      <UserMessageAttachments />\n      <UserMessageAudio />\n\n      <div className=\"aui-user-message-content-wrapper relative col-start-2 min-w-0\">\n        <div className=\"aui-user-message-content wrap-break-word rounded-2xl bg-muted  px-4 py-2.5 text-foreground\">\n          <MessagePrimitive.Parts />\n        </div>\n        <div className=\"aui-user-action-bar-wrapper absolute top-1/2 left-0 -translate-x-full -translate-y-1/2 pr-2\">\n          <UserActionBar />\n        </div>\n      </div>\n\n      <BranchPicker className=\"aui-user-branch-picker col-span-full col-start-1 row-start-3 -mr-1 justify-end\" />\n    </MessagePrimitive.Root>\n  );\n};\n\nconst UserActionBar: FC = () => {\n  return (\n    <ActionBarPrimitive.Root\n      autohide=\"not-last\"\n      className=\"aui-user-action-bar-root flex items-center\"\n    >\n      <CopyButton />\n      <ActionBarPrimitive.Edit asChild={true}>\n        <TooltipIconButton tooltip=\"Edit\" className=\"aui-user-action-edit\">\n          <PencilIcon />\n        </TooltipIconButton>\n      </ActionBarPrimitive.Edit>\n    </ActionBarPrimitive.Root>\n  );\n};\n\nconst EditComposer: FC = () => {\n  const aui = useAui();\n  const resendAfterCancelRef = useRef(false);\n\n  useAuiEvent(\"thread.runEnd\", () => {\n    if (!resendAfterCancelRef.current) {\n      return;\n    }\n    resendAfterCancelRef.current = false;\n    aui.composer().send();\n  });\n\n  return (\n    <MessagePrimitive.Root className=\"aui-edit-composer-wrapper mx-auto flex w-full max-w-(--thread-max-width) flex-col px-2 py-3\">\n      <ComposerPrimitive.Root className=\"aui-edit-composer-root ml-auto flex w-full max-w-[85%] flex-col rounded-2xl bg-muted\">\n        <ComposerPrimitive.Input\n          className=\"aui-edit-composer-input min-h-14 w-full resize-none bg-transparent p-4 text-foreground text-sm outline-none\"\n          autoFocus={true}\n        />\n        <div className=\"aui-edit-composer-footer mx-3 mb-3 flex items-center gap-2 self-end\">\n          <ComposerPrimitive.Cancel asChild={true}>\n            <Button variant=\"ghost\" size=\"sm\">\n              Cancel\n            </Button>\n          </ComposerPrimitive.Cancel>\n          <Button\n            size=\"sm\"\n            onClick={() => {\n              const newText = aui.composer().getState().text;\n              const originalText = aui.message().getCopyText();\n\n              if (newText === originalText) {\n                aui.composer().cancel();\n                return;\n              }\n\n              if (aui.thread().getState().isRunning) {\n                resendAfterCancelRef.current = true;\n                aui.thread().cancelRun();\n                return;\n              }\n              aui.composer().send();\n            }}\n          >\n            Update\n          </Button>\n        </div>\n      </ComposerPrimitive.Root>\n    </MessagePrimitive.Root>\n  );\n};\n\nconst BranchPicker: FC<BranchPickerPrimitive.Root.Props> = ({\n  className,\n  ...rest\n}) => {\n  return (\n    <BranchPickerPrimitive.Root\n      hideWhenSingleBranch={true}\n      className={cn(\n        \"aui-branch-picker-root mr-2 -ml-2 inline-flex items-center text-muted-foreground text-xs\",\n        className,\n      )}\n      {...rest}\n    >\n      <BranchPickerPrimitive.Previous asChild={true}>\n        <TooltipIconButton tooltip=\"Previous\">\n          <ChevronLeftIcon />\n        </TooltipIconButton>\n      </BranchPickerPrimitive.Previous>\n      <span className=\"aui-branch-picker-state font-medium\">\n        <BranchPickerPrimitive.Number /> / <BranchPickerPrimitive.Count />\n      </span>\n      <BranchPickerPrimitive.Next asChild={true}>\n        <TooltipIconButton tooltip=\"Next\">\n          <ChevronRightIcon />\n        </TooltipIconButton>\n      </BranchPickerPrimitive.Next>\n    </BranchPickerPrimitive.Root>\n  );\n};\n"
  },
  {
    "path": "studio/frontend/src/components/assistant-ui/tool-fallback.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"use client\";\n\nimport {\n  Collapsible,\n  CollapsibleContent,\n  CollapsibleTrigger,\n} from \"@/components/ui/collapsible\";\nimport { cn } from \"@/lib/utils\";\nimport {\n  type ToolCallMessagePartComponent,\n  type ToolCallMessagePartStatus,\n  useScrollLock,\n} from \"@assistant-ui/react\";\nimport {\n  AlertCircleIcon,\n  CheckIcon,\n  ChevronDownIcon,\n  LoaderIcon,\n  XCircleIcon,\n} from \"lucide-react\";\nimport {\n  type CSSProperties,\n  type ComponentProps,\n  type ElementType,\n  memo,\n  useCallback,\n  useRef,\n  useState,\n} from \"react\";\n\nconst ANIMATION_DURATION = 200;\n\nexport type ToolFallbackRootProps = Omit<\n  ComponentProps<typeof Collapsible>,\n  \"open\" | \"onOpenChange\"\n> & {\n  open?: boolean;\n  onOpenChange?: (open: boolean) => void;\n  defaultOpen?: boolean;\n};\n\nfunction ToolFallbackRoot({\n  className,\n  open: controlledOpen,\n  onOpenChange: controlledOnOpenChange,\n  defaultOpen = false,\n  children,\n  ...props\n}: ToolFallbackRootProps) {\n  const collapsibleRef = useRef<HTMLDivElement>(null);\n  const [uncontrolledOpen, setUncontrolledOpen] = useState(defaultOpen);\n  const lockScroll = useScrollLock(collapsibleRef, ANIMATION_DURATION);\n\n  const isControlled = controlledOpen !== undefined;\n  const isOpen = isControlled ? controlledOpen : uncontrolledOpen;\n\n  const handleOpenChange = useCallback(\n    (open: boolean) => {\n      if (!open) {\n        lockScroll();\n      }\n      if (!isControlled) {\n        setUncontrolledOpen(open);\n      }\n      controlledOnOpenChange?.(open);\n    },\n    [lockScroll, isControlled, controlledOnOpenChange],\n  );\n\n  return (\n    <Collapsible\n      ref={collapsibleRef}\n      data-slot=\"tool-fallback-root\"\n      open={isOpen}\n      onOpenChange={handleOpenChange}\n      className={cn(\n        \"aui-tool-fallback-root group/tool-fallback-root w-full corner-squircle rounded-lg border py-3\",\n        className,\n      )}\n      style={\n        {\n          \"--animation-duration\": `${ANIMATION_DURATION}ms`,\n        } as CSSProperties\n      }\n      {...props}\n    >\n      {children}\n    </Collapsible>\n  );\n}\n\ntype ToolStatus = ToolCallMessagePartStatus[\"type\"];\n\nconst statusIconMap: Record<ToolStatus, ElementType> = {\n  running: LoaderIcon,\n  complete: CheckIcon,\n  incomplete: XCircleIcon,\n  \"requires-action\": AlertCircleIcon,\n};\n\nfunction ToolFallbackTrigger({\n  toolName,\n  status,\n  icon: ToolIcon,\n  className,\n  ...props\n}: ComponentProps<typeof CollapsibleTrigger> & {\n  toolName: string;\n  status?: ToolCallMessagePartStatus;\n  icon?: ElementType;\n}) {\n  const statusType = status?.type ?? \"complete\";\n  const isRunning = statusType === \"running\";\n  const isCancelled =\n    status?.type === \"incomplete\" && status.reason === \"cancelled\";\n\n  const StatusIcon = statusIconMap[statusType];\n  const label = isCancelled ? \"Cancelled tool\" : \"Used tool\";\n\n  return (\n    <CollapsibleTrigger\n      data-slot=\"tool-fallback-trigger\"\n      className={cn(\n        \"aui-tool-fallback-trigger group/trigger flex w-full items-center gap-2 px-4 text-sm transition-colors\",\n        className,\n      )}\n      {...props}\n    >\n      {isRunning ? (\n        <StatusIcon\n          data-slot=\"tool-fallback-trigger-icon\"\n          className=\"aui-tool-fallback-trigger-icon size-4 shrink-0 animate-spin\"\n        />\n      ) : (\n        ToolIcon ? (\n          <ToolIcon\n            data-slot=\"tool-fallback-trigger-icon\"\n            className={cn(\n              \"aui-tool-fallback-trigger-icon size-4 shrink-0\",\n              isCancelled && \"text-muted-foreground\",\n            )}\n          />\n        ) : (\n          <StatusIcon\n            data-slot=\"tool-fallback-trigger-icon\"\n            className={cn(\n              \"aui-tool-fallback-trigger-icon size-4 shrink-0\",\n              isCancelled && \"text-muted-foreground\",\n            )}\n          />\n        )\n      )}\n      <span\n        data-slot=\"tool-fallback-trigger-label\"\n        className={cn(\n          \"aui-tool-fallback-trigger-label-wrapper relative inline-block grow text-left leading-none\",\n          isCancelled && \"text-muted-foreground line-through\",\n        )}\n      >\n        <span>\n          {label}: <b>{toolName}</b>\n        </span>\n        {isRunning && (\n          <span\n            aria-hidden={true}\n            data-slot=\"tool-fallback-trigger-shimmer\"\n            className=\"aui-tool-fallback-trigger-shimmer shimmer pointer-events-none absolute inset-0 motion-reduce:animate-none\"\n          >\n            {label}: <b>{toolName}</b>\n          </span>\n        )}\n      </span>\n      <ChevronDownIcon\n        data-slot=\"tool-fallback-trigger-chevron\"\n        className={cn(\n          \"aui-tool-fallback-trigger-chevron size-4 shrink-0\",\n          \"transition-transform duration-(--animation-duration) ease-out\",\n          \"group-data-[state=closed]/trigger:-rotate-90\",\n          \"group-data-[state=open]/trigger:rotate-0\",\n        )}\n      />\n    </CollapsibleTrigger>\n  );\n}\n\nfunction ToolFallbackContent({\n  className,\n  children,\n  ...props\n}: ComponentProps<typeof CollapsibleContent>) {\n  return (\n    <CollapsibleContent\n      data-slot=\"tool-fallback-content\"\n      className={cn(\n        \"aui-tool-fallback-content relative overflow-hidden text-sm outline-none\",\n        \"group/collapsible-content ease-out\",\n        \"data-[state=closed]:animate-collapsible-up\",\n        \"data-[state=open]:animate-collapsible-down\",\n        \"data-[state=closed]:fill-mode-forwards\",\n        \"data-[state=closed]:pointer-events-none\",\n        \"data-[state=open]:duration-(--animation-duration)\",\n        \"data-[state=closed]:duration-(--animation-duration)\",\n        className,\n      )}\n      {...props}\n    >\n      <div className=\"mt-3 flex flex-col gap-2 border-t pt-2\">{children}</div>\n    </CollapsibleContent>\n  );\n}\n\nfunction ToolFallbackArgs({\n  argsText,\n  className,\n  ...props\n}: ComponentProps<\"div\"> & {\n  argsText?: string;\n}) {\n  if (!argsText) {\n    return null;\n  }\n\n  return (\n    <div\n      data-slot=\"tool-fallback-args\"\n      className={cn(\"aui-tool-fallback-args px-4\", className)}\n      {...props}\n    >\n      <pre className=\"aui-tool-fallback-args-value whitespace-pre-wrap\">\n        {argsText}\n      </pre>\n    </div>\n  );\n}\n\nfunction ToolFallbackResult({\n  result,\n  className,\n  ...props\n}: ComponentProps<\"div\"> & {\n  result?: unknown;\n}) {\n  if (result === undefined) {\n    return null;\n  }\n\n  return (\n    <div\n      data-slot=\"tool-fallback-result\"\n      className={cn(\n        \"aui-tool-fallback-result border-t border-dashed px-4 pt-2\",\n        className,\n      )}\n      {...props}\n    >\n      <p className=\"aui-tool-fallback-result-header font-semibold\">Result:</p>\n      <pre className=\"aui-tool-fallback-result-content whitespace-pre-wrap\">\n        {typeof result === \"string\" ? result : JSON.stringify(result, null, 2)}\n      </pre>\n    </div>\n  );\n}\n\nfunction ToolFallbackError({\n  status,\n  className,\n  ...props\n}: ComponentProps<\"div\"> & {\n  status?: ToolCallMessagePartStatus;\n}) {\n  if (status?.type !== \"incomplete\") {\n    return null;\n  }\n\n  const error = status.error;\n  const errorText = error\n    ? typeof error === \"string\"\n      ? error\n      : JSON.stringify(error)\n    : null;\n\n  if (!errorText) {\n    return null;\n  }\n\n  const isCancelled = status.reason === \"cancelled\";\n  const headerText = isCancelled ? \"Cancelled reason:\" : \"Error:\";\n\n  return (\n    <div\n      data-slot=\"tool-fallback-error\"\n      className={cn(\"aui-tool-fallback-error px-4\", className)}\n      {...props}\n    >\n      <p className=\"aui-tool-fallback-error-header font-semibold text-muted-foreground\">\n        {headerText}\n      </p>\n      <p className=\"aui-tool-fallback-error-reason text-muted-foreground\">\n        {errorText}\n      </p>\n    </div>\n  );\n}\n\nconst ToolFallbackImpl: ToolCallMessagePartComponent = ({\n  toolName,\n  argsText,\n  result,\n  status,\n}) => {\n  const isCancelled =\n    status?.type === \"incomplete\" && status.reason === \"cancelled\";\n\n  return (\n    <ToolFallbackRoot\n      className={cn(isCancelled && \"border-muted-foreground/30 bg-muted/30\")}\n    >\n      <ToolFallbackTrigger toolName={toolName} status={status} />\n      <ToolFallbackContent>\n        <ToolFallbackError status={status} />\n        <ToolFallbackArgs\n          argsText={argsText}\n          className={cn(isCancelled && \"opacity-60\")}\n        />\n        {!isCancelled && <ToolFallbackResult result={result} />}\n      </ToolFallbackContent>\n    </ToolFallbackRoot>\n  );\n};\n\nconst ToolFallback = memo(\n  ToolFallbackImpl,\n) as unknown as ToolCallMessagePartComponent & {\n  Root: typeof ToolFallbackRoot;\n  Trigger: typeof ToolFallbackTrigger;\n  Content: typeof ToolFallbackContent;\n  Args: typeof ToolFallbackArgs;\n  Result: typeof ToolFallbackResult;\n  Error: typeof ToolFallbackError;\n};\n\nToolFallback.displayName = \"ToolFallback\";\nToolFallback.Root = ToolFallbackRoot;\nToolFallback.Trigger = ToolFallbackTrigger;\nToolFallback.Content = ToolFallbackContent;\nToolFallback.Args = ToolFallbackArgs;\nToolFallback.Result = ToolFallbackResult;\nToolFallback.Error = ToolFallbackError;\n\nexport {\n  ToolFallback,\n  ToolFallbackRoot,\n  ToolFallbackTrigger,\n  ToolFallbackContent,\n  ToolFallbackArgs,\n  ToolFallbackResult,\n  ToolFallbackError,\n};\n"
  },
  {
    "path": "studio/frontend/src/components/assistant-ui/tool-group.tsx",
    "content": "\"use client\";\n\nimport {\n  memo,\n  useCallback,\n  useRef,\n  useState,\n  type FC,\n  type PropsWithChildren,\n} from \"react\";\nimport { ChevronDownIcon, LoaderIcon } from \"lucide-react\";\nimport { cva, type VariantProps } from \"class-variance-authority\";\nimport { useScrollLock } from \"@assistant-ui/react\";\nimport {\n  Collapsible,\n  CollapsibleContent,\n  CollapsibleTrigger,\n} from \"@/components/ui/collapsible\";\nimport { cn } from \"@/lib/utils\";\n\nconst ANIMATION_DURATION = 200;\n\nconst toolGroupVariants = cva(\"aui-tool-group-root group/tool-group w-full\", {\n  variants: {\n    variant: {\n      outline: \"corner-squircle rounded-lg border py-3\",\n      ghost: \"\",\n      muted: \"corner-squircle rounded-lg border border-muted-foreground/30 bg-muted/30 py-3\",\n    },\n  },\n  defaultVariants: { variant: \"outline\" },\n});\n\nexport type ToolGroupRootProps = Omit<\n  React.ComponentProps<typeof Collapsible>,\n  \"open\" | \"onOpenChange\"\n> &\n  VariantProps<typeof toolGroupVariants> & {\n    open?: boolean;\n    onOpenChange?: (open: boolean) => void;\n    defaultOpen?: boolean;\n  };\n\nfunction ToolGroupRoot({\n  className,\n  variant,\n  open: controlledOpen,\n  onOpenChange: controlledOnOpenChange,\n  defaultOpen = false,\n  children,\n  ...props\n}: ToolGroupRootProps) {\n  const collapsibleRef = useRef<HTMLDivElement>(null);\n  const [uncontrolledOpen, setUncontrolledOpen] = useState(defaultOpen);\n  const lockScroll = useScrollLock(collapsibleRef, ANIMATION_DURATION);\n\n  const isControlled = controlledOpen !== undefined;\n  const isOpen = isControlled ? controlledOpen : uncontrolledOpen;\n\n  const handleOpenChange = useCallback(\n    (open: boolean) => {\n      if (!open) {\n        lockScroll();\n      }\n      if (!isControlled) {\n        setUncontrolledOpen(open);\n      }\n      controlledOnOpenChange?.(open);\n    },\n    [lockScroll, isControlled, controlledOnOpenChange],\n  );\n\n  return (\n    <Collapsible\n      ref={collapsibleRef}\n      data-slot=\"tool-group-root\"\n      data-variant={variant ?? \"outline\"}\n      open={isOpen}\n      onOpenChange={handleOpenChange}\n      className={cn(\n        toolGroupVariants({ variant }),\n        \"group/tool-group-root\",\n        className,\n      )}\n      style={\n        {\n          \"--animation-duration\": `${ANIMATION_DURATION}ms`,\n        } as React.CSSProperties\n      }\n      {...props}\n    >\n      {children}\n    </Collapsible>\n  );\n}\n\nfunction ToolGroupTrigger({\n  count,\n  active = false,\n  className,\n  ...props\n}: React.ComponentProps<typeof CollapsibleTrigger> & {\n  count: number;\n  active?: boolean;\n}) {\n  const label = `${count} tool ${count === 1 ? \"call\" : \"calls\"}`;\n\n  return (\n    <CollapsibleTrigger\n      data-slot=\"tool-group-trigger\"\n      className={cn(\n        \"aui-tool-group-trigger group/trigger flex items-center gap-2 text-sm transition-colors\",\n        \"group-data-[variant=outline]/tool-group-root:w-full group-data-[variant=outline]/tool-group-root:px-4\",\n        \"group-data-[variant=muted]/tool-group-root:w-full group-data-[variant=muted]/tool-group-root:px-4\",\n        className,\n      )}\n      {...props}\n    >\n      {active && (\n        <LoaderIcon\n          data-slot=\"tool-group-trigger-loader\"\n          className=\"aui-tool-group-trigger-loader size-4 shrink-0 animate-spin\"\n        />\n      )}\n      <span\n        data-slot=\"tool-group-trigger-label\"\n        className={cn(\n          \"aui-tool-group-trigger-label-wrapper relative inline-block text-left font-medium leading-none\",\n          \"group-data-[variant=outline]/tool-group-root:grow\",\n          \"group-data-[variant=muted]/tool-group-root:grow\",\n        )}\n      >\n        <span>{label}</span>\n        {active && (\n          <span\n            aria-hidden\n            data-slot=\"tool-group-trigger-shimmer\"\n            className=\"aui-tool-group-trigger-shimmer shimmer pointer-events-none absolute inset-0 motion-reduce:animate-none\"\n          >\n            {label}\n          </span>\n        )}\n      </span>\n      <ChevronDownIcon\n        data-slot=\"tool-group-trigger-chevron\"\n        className={cn(\n          \"aui-tool-group-trigger-chevron size-4 shrink-0\",\n          \"transition-transform duration-(--animation-duration) ease-out\",\n          \"group-data-[state=closed]/trigger:-rotate-90\",\n          \"group-data-[state=open]/trigger:rotate-0\",\n        )}\n      />\n    </CollapsibleTrigger>\n  );\n}\n\nfunction ToolGroupContent({\n  className,\n  children,\n  ...props\n}: React.ComponentProps<typeof CollapsibleContent>) {\n  return (\n    <CollapsibleContent\n      data-slot=\"tool-group-content\"\n      className={cn(\n        \"aui-tool-group-content relative overflow-hidden text-sm outline-none\",\n        \"group/collapsible-content ease-out\",\n        \"data-[state=closed]:animate-collapsible-up\",\n        \"data-[state=open]:animate-collapsible-down\",\n        \"data-[state=closed]:fill-mode-forwards\",\n        \"data-[state=closed]:pointer-events-none\",\n        \"data-[state=open]:duration-(--animation-duration)\",\n        \"data-[state=closed]:duration-(--animation-duration)\",\n        className,\n      )}\n      {...props}\n    >\n      <div\n        className={cn(\n          \"mt-2 flex flex-col gap-2\",\n          \"group-data-[variant=outline]/tool-group-root:mt-3 group-data-[variant=outline]/tool-group-root:border-t group-data-[variant=outline]/tool-group-root:px-4 group-data-[variant=outline]/tool-group-root:pt-3\",\n          \"group-data-[variant=muted]/tool-group-root:mt-3 group-data-[variant=muted]/tool-group-root:border-t group-data-[variant=muted]/tool-group-root:px-4 group-data-[variant=muted]/tool-group-root:pt-3\",\n        )}\n      >\n        {children}\n      </div>\n    </CollapsibleContent>\n  );\n}\n\ntype ToolGroupComponent = FC<\n  PropsWithChildren<{ startIndex: number; endIndex: number }>\n> & {\n  Root: typeof ToolGroupRoot;\n  Trigger: typeof ToolGroupTrigger;\n  Content: typeof ToolGroupContent;\n};\n\nconst ToolGroupImpl: FC<\n  PropsWithChildren<{ startIndex: number; endIndex: number }>\n> = ({ children, startIndex, endIndex }) => {\n  const toolCount = endIndex - startIndex + 1;\n\n  // Single tool call — render directly without wrapper\n  if (toolCount <= 1) {\n    return <>{children}</>;\n  }\n\n  return (\n    <ToolGroupRoot>\n      <ToolGroupTrigger count={toolCount} />\n      <ToolGroupContent>{children}</ToolGroupContent>\n    </ToolGroupRoot>\n  );\n};\n\nconst ToolGroup = memo(ToolGroupImpl) as unknown as ToolGroupComponent;\n\nToolGroup.displayName = \"ToolGroup\";\nToolGroup.Root = ToolGroupRoot;\nToolGroup.Trigger = ToolGroupTrigger;\nToolGroup.Content = ToolGroupContent;\n\nexport {\n  ToolGroup,\n  ToolGroupRoot,\n  ToolGroupTrigger,\n  ToolGroupContent,\n  toolGroupVariants,\n};\n"
  },
  {
    "path": "studio/frontend/src/components/assistant-ui/tool-ui-python.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"use client\";\n\nimport { copyToClipboard } from \"@/lib/copy-to-clipboard\";\nimport type { ToolCallMessagePartComponent } from \"@assistant-ui/react\";\nimport { code as codePlugin } from \"@streamdown/code\";\nimport { CheckIcon, CodeIcon, CopyIcon, LoaderIcon } from \"lucide-react\";\nimport { memo, useCallback, useMemo, useRef, useState } from \"react\";\nimport { Streamdown } from \"streamdown\";\nimport {\n  ToolFallbackContent,\n  ToolFallbackRoot,\n  ToolFallbackTrigger,\n} from \"./tool-fallback\";\n\nconst MAX_DISPLAY = 10_000;\nconst COPY_RESET_MS = 2000;\nconst SHIKI_THEME = [\"github-light\", \"github-dark\"] as [\"github-light\", \"github-dark\"];\n\nfunction truncate(text: string): string {\n  return text.length <= MAX_DISPLAY\n    ? text\n    : `${text.slice(0, MAX_DISPLAY)}\\n... (truncated)`;\n}\n\nfunction CopyBtn({ text }: { text: string }) {\n  const [copied, setCopied] = useState(false);\n  const timer = useRef<ReturnType<typeof setTimeout> | null>(null);\n  const copy = useCallback(() => {\n    if (copyToClipboard(text)) {\n      setCopied(true);\n      if (timer.current) {\n        clearTimeout(timer.current);\n      }\n      timer.current = setTimeout(() => setCopied(false), COPY_RESET_MS);\n    }\n  }, [text]);\n\n  return (\n    <button\n      type=\"button\"\n      onClick={copy}\n      className=\"inline-flex items-center gap-1 rounded px-1.5 py-0.5 text-xs text-muted-foreground transition-colors hover:bg-muted hover:text-foreground\"\n      aria-label=\"Copy to clipboard\"\n    >\n      {copied ? (\n        <CheckIcon className=\"size-3\" />\n      ) : (\n        <CopyIcon className=\"size-3\" />\n      )}\n      {copied ? \"Copied\" : \"Copy\"}\n    </button>\n  );\n}\n\n/** Render code with syntax highlighting via Streamdown + shiki. No extra borders — inherits parent container. */\nfunction HighlightedCode({ code: source, language }: { code: string; language: string }) {\n  const markdown = useMemo(\n    () => `\\`\\`\\`${language}\\n${truncate(source)}\\n\\`\\`\\``,\n    [source, language],\n  );\n  return (\n    <div className=\"max-h-48 overflow-auto text-xs [&_pre]:!m-0 [&_pre]:!bg-transparent [&_pre]:!p-0 [&_pre]:!text-xs [&_[data-streamdown=code-block]]:!my-0 [&_[data-streamdown=code-block]]:!p-0 [&_[data-streamdown=code-block]]:!border-0\">\n      <Streamdown\n        mode=\"static\"\n        plugins={{ code: codePlugin }}\n        controls={{ code: false }}\n        shikiTheme={SHIKI_THEME}\n      >\n        {markdown}\n      </Streamdown>\n    </div>\n  );\n}\n\nconst PythonToolUIImpl: ToolCallMessagePartComponent = ({\n  args,\n  result,\n  status,\n}) => {\n  const code = (args as { code?: string })?.code ?? \"\";\n  const firstLine = code.split(\"\\n\")[0]?.slice(0, 60) ?? \"\";\n  const isRunning = status?.type === \"running\";\n  const output =\n    typeof result === \"string\"\n      ? result\n      : result\n        ? JSON.stringify(result, null, 2)\n        : \"\";\n\n  return (\n    <ToolFallbackRoot>\n      <ToolFallbackTrigger\n        toolName={firstLine ? `Python: ${firstLine}` : \"Python\"}\n        status={status}\n        icon={CodeIcon}\n      />\n      <ToolFallbackContent>\n        <div className=\"flex flex-col px-4\">\n          {/* Code + copy */}\n          {code && (\n            <div className=\"flex justify-end\">\n              <CopyBtn text={code} />\n            </div>\n          )}\n          <HighlightedCode code={code} language=\"python\" />\n\n          {/* Output */}\n          {isRunning ? (\n            <div className=\"mt-2 flex items-center gap-2 text-sm text-muted-foreground\">\n              <LoaderIcon className=\"size-3.5 animate-spin\" />\n              <span>Running&hellip;</span>\n            </div>\n          ) : output ? (\n            <div className=\"mt-2 border-t border-dashed pt-2\">\n              <div className=\"flex items-center justify-between\">\n                <span className=\"text-xs font-medium text-muted-foreground\">output</span>\n                <CopyBtn text={output} />\n              </div>\n              <pre className=\"mt-1 max-h-60 overflow-auto whitespace-pre-wrap break-words font-mono text-xs\">\n                {truncate(output)}\n              </pre>\n            </div>\n          ) : null}\n        </div>\n      </ToolFallbackContent>\n    </ToolFallbackRoot>\n  );\n};\n\nexport const PythonToolUI = memo(\n  PythonToolUIImpl,\n) as unknown as ToolCallMessagePartComponent;\nPythonToolUI.displayName = \"PythonToolUI\";\n"
  },
  {
    "path": "studio/frontend/src/components/assistant-ui/tool-ui-terminal.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"use client\";\n\nimport { copyToClipboard } from \"@/lib/copy-to-clipboard\";\nimport type { ToolCallMessagePartComponent } from \"@assistant-ui/react\";\nimport { CheckIcon, CopyIcon, LoaderIcon, TerminalIcon } from \"lucide-react\";\nimport { memo, useCallback, useRef, useState } from \"react\";\nimport {\n  ToolFallbackContent,\n  ToolFallbackRoot,\n  ToolFallbackTrigger,\n} from \"./tool-fallback\";\n\nconst MAX_DISPLAY = 10_000;\nconst COPY_RESET_MS = 2000;\n\nfunction truncate(text: string): string {\n  return text.length <= MAX_DISPLAY\n    ? text\n    : `${text.slice(0, MAX_DISPLAY)}\\n... (truncated)`;\n}\n\nfunction CopyBtn({ text }: { text: string }) {\n  const [copied, setCopied] = useState(false);\n  const timer = useRef<ReturnType<typeof setTimeout> | null>(null);\n  const copy = useCallback(() => {\n    if (copyToClipboard(text)) {\n      setCopied(true);\n      if (timer.current) {\n        clearTimeout(timer.current);\n      }\n      timer.current = setTimeout(() => setCopied(false), COPY_RESET_MS);\n    }\n  }, [text]);\n\n  return (\n    <button\n      type=\"button\"\n      onClick={copy}\n      className=\"inline-flex items-center gap-1 rounded px-1.5 py-0.5 text-xs text-muted-foreground transition-colors hover:bg-muted hover:text-foreground\"\n      aria-label=\"Copy to clipboard\"\n    >\n      {copied ? (\n        <CheckIcon className=\"size-3\" />\n      ) : (\n        <CopyIcon className=\"size-3\" />\n      )}\n      {copied ? \"Copied\" : \"Copy\"}\n    </button>\n  );\n}\n\nconst TerminalToolUIImpl: ToolCallMessagePartComponent = ({\n  args,\n  result,\n  status,\n}) => {\n  const command = (args as { command?: string })?.command ?? \"\";\n  const isRunning = status?.type === \"running\";\n  const output =\n    typeof result === \"string\"\n      ? result\n      : result\n        ? JSON.stringify(result, null, 2)\n        : \"\";\n\n  return (\n    <ToolFallbackRoot>\n      <ToolFallbackTrigger\n        toolName={command ? `$ ${command.slice(0, 60)}` : \"Terminal\"}\n        status={status}\n        icon={TerminalIcon}\n      />\n      <ToolFallbackContent>\n        <div className=\"flex flex-col px-4\">\n          {isRunning ? (\n            <div className=\"flex items-center gap-2 text-sm text-muted-foreground\">\n              <LoaderIcon className=\"size-3.5 animate-spin\" />\n              <span>Running&hellip;</span>\n            </div>\n          ) : output ? (\n            <div>\n              <div className=\"flex items-center justify-between\">\n                <span className=\"text-xs font-medium text-muted-foreground\">output</span>\n                <CopyBtn text={output} />\n              </div>\n              <pre className=\"mt-1 max-h-60 overflow-auto whitespace-pre-wrap break-words font-mono text-xs\">\n                {truncate(output)}\n              </pre>\n            </div>\n          ) : null}\n        </div>\n      </ToolFallbackContent>\n    </ToolFallbackRoot>\n  );\n};\n\nexport const TerminalToolUI = memo(\n  TerminalToolUIImpl,\n) as unknown as ToolCallMessagePartComponent;\nTerminalToolUI.displayName = \"TerminalToolUI\";\n"
  },
  {
    "path": "studio/frontend/src/components/assistant-ui/tool-ui-web-search.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"use client\";\n\nimport { type ToolCallMessagePartComponent, useAuiState } from \"@assistant-ui/react\";\nimport { GlobeIcon, LoaderIcon } from \"lucide-react\";\nimport { memo, useEffect, useState } from \"react\";\nimport { Source, SourceIcon, SourceTitle } from \"./sources\";\nimport {\n  ToolFallbackContent,\n  ToolFallbackRoot,\n  ToolFallbackTrigger,\n} from \"./tool-fallback\";\n\ninterface ParsedSource {\n  title: string;\n  url: string;\n  snippet: string;\n}\n\nconst RE_BLOCK_SEP = /\\n---\\n/;\nconst RE_TITLE = /Title:\\s*(.+)/;\nconst RE_URL = /URL:\\s*(.+)/;\nconst RE_SNIPPET = /Snippet:\\s*(.+)/s;\n\n/** Parse the backend's \"Title: ...\\nURL: ...\\nSnippet: ...\\n---\" format into structured sources. */\nfunction parseSearchResults(raw: string): ParsedSource[] {\n  if (!raw) {\n    return [];\n  }\n  const blocks = raw.split(RE_BLOCK_SEP).filter(Boolean);\n  const sources: ParsedSource[] = [];\n  for (const block of blocks) {\n    const titleMatch = block.match(RE_TITLE);\n    const urlMatch = block.match(RE_URL);\n    const snippetMatch = block.match(RE_SNIPPET);\n    if (titleMatch && urlMatch) {\n      sources.push({\n        title: titleMatch[1].trim(),\n        url: urlMatch[1].trim(),\n        snippet: snippetMatch?.[1]?.trim() ?? \"\",\n      });\n    }\n  }\n  return sources;\n}\n\nconst WebSearchToolUIImpl: ToolCallMessagePartComponent = ({\n  args,\n  result,\n  status,\n}) => {\n  const query = (args as { query?: string })?.query ?? \"\";\n  const isRunning = status?.type === \"running\";\n  const sources = result\n    ? parseSearchResults(\n        typeof result === \"string\" ? result : JSON.stringify(result),\n      )\n    : [];\n\n  // Collapse when LLM starts generating text after the tool call\n  const hasText = useAuiState(({ message }) =>\n    message.content.some((p) => p.type === \"text\" && \"text\" in p && (p as { text: string }).text.length > 0),\n  );\n  const [open, setOpen] = useState(isRunning);\n  useEffect(() => {\n    if (isRunning) {\n      setOpen(true);\n    } else if (hasText) {\n      setOpen(false);\n    }\n  }, [isRunning, hasText]);\n\n  return (\n    <ToolFallbackRoot open={open} onOpenChange={setOpen}>\n      <ToolFallbackTrigger\n        toolName={query ? `Searched \"${query}\"` : \"Web Search\"}\n        status={status}\n        icon={GlobeIcon}\n      />\n      <ToolFallbackContent>\n        {isRunning ? (\n          <div className=\"flex items-center gap-2 px-4 text-sm text-muted-foreground\">\n            <LoaderIcon className=\"size-3.5 animate-spin\" />\n            <span>Searching for &ldquo;{query}&rdquo;&hellip;</span>\n          </div>\n        ) : sources.length > 0 ? (\n          <div className=\"flex flex-col gap-1.5 px-4\">\n            {sources.map((source) => (\n              <Source\n                key={source.url}\n                href={source.url}\n                variant=\"outline\"\n                size=\"default\"\n                className=\"flex w-full max-w-full items-center gap-2 py-1.5\"\n              >\n                <SourceIcon url={source.url} className=\"size-3.5\" />\n                <SourceTitle className=\"max-w-none flex-1 truncate\">\n                  {source.title}\n                </SourceTitle>\n              </Source>\n            ))}\n          </div>\n        ) : result ? (\n          <div className=\"px-4\">\n            <pre className=\"max-h-40 overflow-auto whitespace-pre-wrap break-words rounded bg-muted/50 p-2 text-xs\">\n              {typeof result === \"string\"\n                ? result\n                : JSON.stringify(result, null, 2)}\n            </pre>\n          </div>\n        ) : null}\n      </ToolFallbackContent>\n    </ToolFallbackRoot>\n  );\n};\n\nexport const WebSearchToolUI = memo(\n  WebSearchToolUIImpl,\n) as unknown as ToolCallMessagePartComponent;\nWebSearchToolUI.displayName = \"WebSearchToolUI\";\n"
  },
  {
    "path": "studio/frontend/src/components/assistant-ui/tooltip-icon-button.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"use client\";\n\nimport { Slottable } from \"@radix-ui/react-slot\";\nimport { type ComponentPropsWithRef, forwardRef } from \"react\";\n\nimport { Button } from \"@/components/ui/button\";\nimport {\n  Tooltip,\n  TooltipContent,\n  TooltipTrigger,\n} from \"@/components/ui/tooltip\";\nimport { cn } from \"@/lib/utils\";\n\nexport type TooltipIconButtonProps = ComponentPropsWithRef<typeof Button> & {\n  tooltip: string;\n  side?: \"top\" | \"bottom\" | \"left\" | \"right\";\n};\n\nexport const TooltipIconButton = forwardRef<\n  HTMLButtonElement,\n  TooltipIconButtonProps\n>(({ children, tooltip, side = \"bottom\", className, ...rest }, ref) => {\n  return (\n    <Tooltip>\n      <TooltipTrigger asChild={true}>\n        <Button\n          variant=\"ghost\"\n          size=\"icon\"\n          {...rest}\n          className={cn(\"aui-button-icon size-6 p-1\", className)}\n          ref={ref}\n        >\n          <Slottable>{children}</Slottable>\n          <span className=\"aui-sr-only sr-only\">{tooltip}</span>\n        </Button>\n      </TooltipTrigger>\n      <TooltipContent side={side}>{tooltip}</TooltipContent>\n    </Tooltip>\n  );\n});\n\nTooltipIconButton.displayName = \"TooltipIconButton\";\n"
  },
  {
    "path": "studio/frontend/src/components/example.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { cn } from \"@/lib/utils\";\nimport type { ComponentProps } from \"react\";\n\nfunction ExampleWrapper({ className, ...props }: ComponentProps<\"div\">) {\n  return (\n    <div className=\"bg-background w-full\">\n      <div\n        data-slot=\"example-wrapper\"\n        className={cn(\n          \"mx-auto grid min-h-screen w-full max-w-5xl min-w-0 content-center items-start gap-8 p-4 pt-2 sm:gap-12 sm:p-6 md:grid-cols-2 md:gap-8 lg:p-12 2xl:max-w-6xl\",\n          className,\n        )}\n        {...props}\n      />\n    </div>\n  );\n}\n\nfunction Example({\n  title,\n  children,\n  className,\n  containerClassName,\n  ...props\n}: ComponentProps<\"div\"> & {\n  title?: string;\n  containerClassName?: string;\n}) {\n  return (\n    <div\n      data-slot=\"example\"\n      className={cn(\n        \"mx-auto flex w-full max-w-lg min-w-0 flex-col gap-1 self-stretch lg:max-w-none\",\n        containerClassName,\n      )}\n      {...props}\n    >\n      {title && (\n        <div className=\"text-muted-foreground px-1.5 py-2 text-xs font-medium\">\n          {title}\n        </div>\n      )}\n      <div\n        data-slot=\"example-content\"\n        className={cn(\n          \"bg-background text-foreground flex min-w-0 flex-1 flex-col items-start gap-6 border border-dashed p-4 sm:p-6 *:[div:not([class*='w-'])]:w-full\",\n          className,\n        )}\n      >\n        {children}\n      </div>\n    </div>\n  );\n}\n\nexport { ExampleWrapper, Example };\n"
  },
  {
    "path": "studio/frontend/src/components/layout/dashboard-grid.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type * as React from \"react\";\n\nimport { cn } from \"@/lib/utils\";\n\nconst colsVariants = {\n  3: \"lg:grid-cols-3\",\n  4: \"lg:grid-cols-4\",\n} as const;\n\nfunction DashboardGrid({\n  className,\n  cols = 3,\n  ...props\n}: React.ComponentProps<\"div\"> & { cols?: 3 | 4 }) {\n  return (\n    <div\n      data-slot=\"dashboard-grid\"\n      className={cn(\n        \"grid grid-cols-1 gap-6 md:grid-cols-2\",\n        colsVariants[cols],\n        className,\n      )}\n      {...props}\n    />\n  );\n}\n\nexport { DashboardGrid };\n"
  },
  {
    "path": "studio/frontend/src/components/layout/dashboard-layout.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type * as React from \"react\";\n\nimport { cn } from \"@/lib/utils\";\n\nfunction DashboardLayout({\n  className,\n  children,\n  ...props\n}: React.ComponentProps<\"div\">) {\n  return (\n    <div\n      data-slot=\"dashboard-layout\"\n      className={cn(\n        \"min-h-screen w-full bg-background\",\n        \"flex justify-center\",\n        className,\n      )}\n      {...props}\n    >\n      <div className=\"w-full max-w-7xl px-6 py-8 lg:px-8\">{children}</div>\n    </div>\n  );\n}\n\nexport { DashboardLayout };\n"
  },
  {
    "path": "studio/frontend/src/components/layout/index.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nexport { DashboardLayout } from \"./dashboard-layout\";\nexport { DashboardGrid } from \"./dashboard-grid\";\n"
  },
  {
    "path": "studio/frontend/src/components/markdown/markdown-preview.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { cn } from \"@/lib/utils\";\nimport { code } from \"@streamdown/code\";\nimport { math } from \"@streamdown/math\";\nimport { mermaid } from \"@streamdown/mermaid\";\nimport { memo, type ReactElement } from \"react\";\nimport { Streamdown } from \"streamdown\";\nimport \"katex/dist/katex.min.css\";\n\nconst MARKDOWN_PLUGINS = { code, math, mermaid } as const;\n\ntype MarkdownPreviewProps = {\n  markdown: string;\n  className?: string;\n  plain?: boolean;\n};\n\nfunction MarkdownPreviewImpl({\n  markdown,\n  className,\n  plain = false,\n}: MarkdownPreviewProps): ReactElement {\n  const markdownClassName =\n    \"w-full max-w-none min-w-0 space-y-2 [overflow-wrap:anywhere] [&_*]:max-w-none [&_p]:w-full [&_ul]:w-full [&_ol]:w-full [&_li]:w-full [&_h1]:w-full [&_h2]:w-full [&_h3]:w-full [&_h4]:w-full [&_h5]:w-full [&_h6]:w-full [&_pre]:w-full [&_table]:w-full [&_p]:break-words [&_li]:break-words [&_code]:break-words [&_pre]:whitespace-pre-wrap [&_pre]:break-words\";\n\n  return (\n    <div\n      className={cn(\n        plain\n          ? \"h-full w-full min-w-0 overflow-auto p-2 text-xs leading-relaxed pointer-events-none select-none\"\n          : \"nodrag max-h-56 w-full min-w-0 overflow-auto rounded-md border border-border/60 bg-muted/20 p-2 text-xs leading-relaxed\",\n        className,\n      )}\n    >\n      <Streamdown\n        mode=\"static\"\n        plugins={MARKDOWN_PLUGINS}\n        controls={false}\n        className={markdownClassName}\n      >\n        {markdown.trim() ? markdown : \"_Empty note_\"}\n      </Streamdown>\n    </div>\n  );\n}\n\nexport const MarkdownPreview = memo(MarkdownPreviewImpl);\n"
  },
  {
    "path": "studio/frontend/src/components/markdown/mermaid-error.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { MermaidErrorComponentProps } from \"streamdown\";\n\nfunction hasSlashComment(chart: string): boolean {\n  return /(^|[^:])\\/\\/.*/m.test(chart);\n}\n\nexport function MermaidError({\n  error,\n  chart,\n  retry,\n}: MermaidErrorComponentProps) {\n  return (\n    <div className=\"my-4 rounded-lg border border-red-300 bg-red-50 p-3 text-red-800\">\n      <p className=\"text-sm font-semibold\">Mermaid render failed</p>\n      <p className=\"mt-1 break-words font-mono text-xs\">{error}</p>\n      {hasSlashComment(chart) ? (\n        <p className=\"mt-1 text-xs\">Hint: Mermaid comments use `%%`, not `//`.</p>\n      ) : null}\n      <button\n        type=\"button\"\n        onClick={retry}\n        className=\"mt-2 rounded border border-red-300 px-2 py-1 text-xs hover:bg-red-100\"\n      >\n        Retry\n      </button>\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/components/navbar.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport {\n  HoverCard,\n  HoverCardContent,\n  HoverCardTrigger,\n} from \"@/components/ui/hover-card\";\nimport { AnimatedThemeToggler } from \"@/components/ui/animated-theme-toggler\";\nimport {\n  Sheet,\n  SheetContent,\n  SheetHeader,\n  SheetTitle,\n  SheetTrigger,\n} from \"@/components/ui/sheet\";\nimport { cn } from \"@/lib/utils\";\nimport {\n  ArrowRight01Icon,\n  Book03Icon,\n  BubbleChatIcon,\n  ChefHatIcon,\n  CursorInfo02Icon,\n  PackageIcon,\n  ZapIcon,\n} from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport { useTrainingRuntimeStore } from \"@/features/training\";\nimport { usePlatformStore } from \"@/config/env\";\nimport { Link, useRouterState } from \"@tanstack/react-router\";\nimport { motion } from \"motion/react\";\nimport { useState } from \"react\";\nimport { TOUR_OPEN_EVENT } from \"@/features/tour\";\n\nconst NAV_ITEMS = [\n  { label: \"Studio\", href: \"/studio\", icon: ZapIcon, enabled: true },\n  { label: \"Recipes\", href: \"/data-recipes\", icon: ChefHatIcon, enabled: true },\n  { label: \"Export\", href: \"/export\", icon: PackageIcon, enabled: true },\n  { label: \"Chat\", href: \"/chat\", icon: BubbleChatIcon, enabled: true },\n];\n\nfunction getTourId(pathname: string): \"studio\" | \"chat\" | \"export\" | null {\n  if (pathname === \"/studio\") return \"studio\";\n  if (pathname === \"/chat\") return \"chat\";\n  if (pathname === \"/export\") return \"export\";\n  return null;\n}\n\nexport function Navbar() {\n  const pathname = useRouterState({ select: (s) => s.location.pathname });\n  const isTrainingRunning = useTrainingRuntimeStore((s) => s.isTrainingRunning);\n  const [mobileOpen, setMobileOpen] = useState(false);\n\n  const chatOnly = usePlatformStore((s) => s.isChatOnly());\n\n  const tourId = getTourId(pathname);\n\n  const openTour = () => {\n    if (!tourId) return;\n    window.dispatchEvent(\n      new CustomEvent(TOUR_OPEN_EVENT, { detail: { id: tourId } }),\n    );\n  };\n\n  return (\n    <header className=\"relative top-0 z-40 h-16 w-full\">\n      <div className=\"mx-auto grid h-full max-w-7xl grid-cols-[1fr_auto_1fr] items-center px-4 sm:px-6\">\n        {/* Left: logo */}\n        <Link to={chatOnly ? \"/chat\" : \"/studio\"} className=\"flex items-center gap-1.5 justify-self-start select-none\">\n          <img\n            src=\"/blacklogo.png\"\n            alt=\"Unsloth\"\n            className=\"h-9 w-auto dark:hidden\"\n          />\n          <img\n            src=\"/whitelogo.png\"\n            alt=\"Unsloth\"\n            className=\"hidden h-9 w-auto dark:block\"\n          />\n          <span className=\"relative -top-[1px] inline-flex items-center text-[10px] font-extrabold leading-none tracking-[0.12em] text-primary\">\n            BETA\n          </span>\n        </Link>\n\n        {/* Center: pill nav */}\n        <nav\n          data-tour=\"navbar\"\n          className=\"hidden items-center rounded-full border border-border bg-card p-1 ring-1 ring-foreground/5 md:flex\"\n        >\n          {NAV_ITEMS.map((item) => {\n            const active =\n              pathname === item.href || pathname.startsWith(`${item.href}/`);\n            const disabledByTraining =\n              isTrainingRunning && item.href !== \"/studio\";\n            const disabledByDevice =\n              chatOnly && item.href !== \"/chat\" && item.href !== \"/data-recipes\";\n            if (!item.enabled || disabledByTraining || disabledByDevice) {\n              return (\n                <span\n                  key={item.href}\n                  className=\"relative rounded-full px-3 py-1.5 text-sm font-medium text-muted-foreground/40 cursor-not-allowed\"\n                >\n                  {item.label}\n                </span>\n              );\n            }\n            return (\n              <Link\n                key={item.href}\n                to={item.href}\n                className={cn(\n                  \"relative rounded-full px-3 py-1.5 text-sm font-medium transition-colors\",\n                  active\n                    ? \"text-background\"\n                    : \"text-muted-foreground hover:text-foreground\",\n                )}\n              >\n                {active && (\n                  <motion.span\n                    layoutId=\"nav-pill\"\n                    className=\"absolute inset-0 rounded-full bg-foreground\"\n                    transition={{\n                      type: \"spring\",\n                      stiffness: 500,\n                      damping: 35,\n                      mass: 0.5,\n                    }}\n                  />\n                )}\n                <span className=\"relative z-10 flex items-center\">\n                  <motion.span\n                    initial={false}\n                    animate={{\n                      width: active ? 14 : 0,\n                      marginLeft: active ? -4 : 0,\n                      marginRight: active ? 4 : 0,\n                      opacity: active ? 1 : 0,\n                    }}\n                    transition={{ duration: 0.2, ease: [0.165, 0.84, 0.44, 1] }}\n                    className=\"inline-flex shrink-0 items-center justify-center overflow-hidden\"\n                  >\n                    <HugeiconsIcon\n                      icon={item.icon}\n                      className=\"size-3.5 -mt-px shrink-0\"\n                    />\n                  </motion.span>\n                  {item.label}\n                </span>\n              </Link>\n            );\n          })}\n        </nav>\n\n        {/* Right: docs/tour desktop */}\n        <div className=\"hidden items-center justify-self-end gap-2 md:flex\">\n          <AnimatedThemeToggler\n            className=\"flex h-9 w-9 items-center justify-center rounded-md text-muted-foreground transition-colors hover:bg-accent hover:text-foreground [&_svg]:size-4\"\n            title=\"Toggle theme\"\n            aria-label=\"Toggle theme\"\n          />\n          <HoverCard openDelay={200} closeDelay={100}>\n            <HoverCardTrigger asChild={true}>\n              <a\n                href=\"https://unsloth.ai/docs\"\n                target=\"_blank\"\n                rel=\"noopener noreferrer\"\n                className=\"flex items-center gap-1.5 text-sm font-medium text-emerald-600 hover:text-emerald-700 transition-colors\"\n              >\n                <HugeiconsIcon icon={Book03Icon} className=\"size-4\" />\n                Learn more\n              </a>\n            </HoverCardTrigger>\n            <HoverCardContent align=\"end\" className=\"w-80 p-0\">\n              <a\n                href=\"https://unsloth.ai/docs\"\n                target=\"_blank\"\n                rel=\"noopener noreferrer\"\n                className=\"group/card flex flex-col gap-1 p-4 no-underline\"\n              >\n                <p className=\"text-sm font-semibold font-heading\">\n                  Unsloth Documentation\n                </p>\n                <p className=\"text-xs text-muted-foreground leading-relaxed\">\n                  Guides on fine-tuning LLMs 2x faster with 70% less memory.\n                  Covers LoRA, QLoRA, data formatting, and deployment.\n                </p>\n                <span className=\"mt-1 flex items-center gap-1 text-xs font-medium text-emerald-600 group-hover/card:underline\">\n                  Visit docs\n                  <HugeiconsIcon icon={ArrowRight01Icon} className=\"size-3\" />\n                </span>\n              </a>\n            </HoverCardContent>\n          </HoverCard>\n\n          <button\n            type=\"button\"\n            onClick={tourId ? openTour : undefined}\n            className={cn(\n              \"flex h-9 items-center gap-1.5 rounded-md px-3 text-muted-foreground transition-colors hover:bg-accent hover:text-foreground\",\n              !tourId && \"invisible pointer-events-none\",\n            )}\n            title=\"Tour\"\n            aria-hidden={!tourId}\n            tabIndex={tourId ? 0 : -1}\n          >\n            <HugeiconsIcon icon={CursorInfo02Icon} className=\"size-4\" />\n            <span className=\"text-sm font-medium\">Tour</span>\n          </button>\n        </div>\n\n        {/* Right: mobile */}\n        <div className=\"col-start-3 flex items-center gap-2 justify-self-end md:hidden\">\n          {tourId ? (\n            <button\n              type=\"button\"\n              onClick={openTour}\n              className=\"flex h-9 w-9 items-center justify-center rounded-md text-muted-foreground transition-colors hover:bg-accent hover:text-foreground\"\n              title=\"Tour\"\n            >\n              <HugeiconsIcon icon={CursorInfo02Icon} className=\"size-4\" />\n            </button>\n          ) : null}\n          <Sheet open={mobileOpen} onOpenChange={setMobileOpen}>\n            <SheetTrigger asChild={true}>\n              <button\n                type=\"button\"\n                className=\"rounded-md border border-border px-3 py-1.5 text-sm font-medium text-foreground\"\n                aria-label=\"Open navigation menu\"\n              >\n                Menu\n              </button>\n            </SheetTrigger>\n            <SheetContent side=\"right\" className=\"w-[300px] p-4\">\n              <SheetHeader>\n                <SheetTitle>Navigate</SheetTitle>\n              </SheetHeader>\n              <div className=\"mt-6 flex flex-col gap-2\">\n                {NAV_ITEMS.filter((item) => item.enabled).map((item) => {\n                  const active = pathname === item.href;\n                  const disabledByTraining =\n                    isTrainingRunning && item.href !== \"/studio\";\n                  const disabledByDevice =\n                    chatOnly && item.href !== \"/chat\" && item.href !== \"/data-recipes\";\n                  if (disabledByTraining || disabledByDevice) {\n                    return (\n                      <span\n                        key={item.href}\n                        className=\"flex items-center gap-2 rounded-md border border-border px-3 py-2 text-sm font-medium text-muted-foreground/40 cursor-not-allowed\"\n                      >\n                        <HugeiconsIcon icon={item.icon} className=\"size-4\" />\n                        {item.label}\n                      </span>\n                    );\n                  }\n                  return (\n                    <Link\n                      key={item.href}\n                      to={item.href}\n                      onClick={() => setMobileOpen(false)}\n                      className={cn(\n                        \"flex items-center gap-2 rounded-md border px-3 py-2 text-sm font-medium\",\n                        active\n                          ? \"border-foreground bg-foreground text-background\"\n                          : \"border-border text-foreground hover:bg-accent\",\n                      )}\n                    >\n                      <HugeiconsIcon icon={item.icon} className=\"size-4\" />\n                      {item.label}\n                    </Link>\n                  );\n                })}\n                <a\n                  href=\"https://unsloth.ai/docs\"\n                  target=\"_blank\"\n                  rel=\"noopener noreferrer\"\n                  className=\"mt-2 flex items-center gap-2 rounded-md border border-border px-3 py-2 text-sm font-medium text-foreground hover:bg-accent\"\n                  onClick={() => setMobileOpen(false)}\n                >\n                  <HugeiconsIcon icon={Book03Icon} className=\"size-4\" />\n                  Learn more (Docs)\n                </a>\n                {tourId ? (\n                  <button\n                    type=\"button\"\n                    className=\"flex items-center gap-2 rounded-md border border-border px-3 py-2 text-left text-sm font-medium text-foreground hover:bg-accent\"\n                    onClick={() => {\n                      openTour();\n                      setMobileOpen(false);\n                    }}\n                  >\n                    <HugeiconsIcon icon={CursorInfo02Icon} className=\"size-4\" />\n                    Start tour\n                  </button>\n                ) : null}\n                <div className=\"mt-2 flex items-center justify-between rounded-md border border-border px-3 py-2\">\n                  <span className=\"text-sm font-medium text-foreground\">Theme</span>\n                  <AnimatedThemeToggler\n                    className=\"flex h-8 w-8 items-center justify-center rounded-md text-muted-foreground transition-colors hover:bg-accent hover:text-foreground [&_svg]:size-4\"\n                    title=\"Toggle theme\"\n                    aria-label=\"Toggle theme\"\n                  />\n                </div>\n              </div>\n            </SheetContent>\n          </Sheet>\n        </div>\n      </div>\n    </header>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/components/section-card.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { cn } from \"@/lib/utils\";\nimport type { ReactNode } from \"react\";\n\ninterface SectionCardProps {\n  icon: ReactNode;\n  title: string;\n  description: string;\n  accent?: \"emerald\" | \"indigo\" | \"orange\" | \"blue\";\n  featured?: boolean;\n  className?: string;\n  badge?: string;\n  headerAction?: ReactNode;\n  children: ReactNode;\n}\n\nconst accentStyles = {\n  emerald: {\n    border: \"ring-emerald-500/20\",\n    iconBox:\n      \"ring-emerald-200 bg-emerald-50 text-emerald-600 dark:ring-emerald-800 dark:bg-emerald-950 dark:text-emerald-400\",\n  },\n  indigo: {\n    border: \"ring-indigo-500/20\",\n    iconBox:\n      \"ring-indigo-200 bg-indigo-50 text-indigo-600 dark:ring-indigo-800 dark:bg-indigo-950 dark:text-indigo-400\",\n  },\n  orange: {\n    border: \"ring-orange-500/20\",\n    iconBox:\n      \"ring-orange-200 bg-orange-50 text-orange-600 dark:ring-orange-800 dark:bg-orange-950 dark:text-orange-400\",\n  },\n  blue: {\n    border: \"ring-blue-500/20\",\n    iconBox:\n      \"ring-blue-200 bg-blue-50 text-blue-600 dark:ring-blue-800 dark:bg-blue-950 dark:text-blue-400\",\n  },\n};\n\nexport function SectionCard({\n  icon,\n  title,\n  description,\n  accent = \"emerald\",\n  featured,\n  className,\n  badge,\n  headerAction,\n  children,\n}: SectionCardProps) {\n  const styles = accentStyles[accent];\n\n  return (\n    <div\n      className={cn(\n        \"bg-card corner-squircle rounded-3xl ring-1 ring-foreground/10 flex flex-col gap-5 p-5 relative overflow-clip transition-all duration-300 ease-in-out\",\n        featured && styles.border,\n        className,\n      )}\n    >\n      {featured && (\n        <div className=\"pointer-events-none absolute inset-x-0 top-0 h-24 bg-gradient-to-b from-emerald-500/[0.04] to-transparent\" />\n      )}\n      {/* Header */}\n      <div className=\"flex items-center gap-3\">\n        <div\n          className={cn(\n            \"rounded-xl corner-squircle p-2 ring-1 shrink-0\",\n            styles.iconBox,\n          )}\n        >\n          {icon}\n        </div>\n        <div className=\"min-w-0 flex-1\">\n          <div className=\"flex items-center gap-2 pb-1\">\n            <h3 className=\"text-sm font-semibold\">{title}</h3>\n            {badge && (\n              <span className=\"rounded-full bg-emerald-100 px-2 py-0.5 text-[10px] font-semibold text-emerald-700 dark:bg-emerald-900 dark:text-emerald-300\">\n                {badge}\n              </span>\n            )}\n          </div>\n          <p className=\"text-xs text-muted-foreground\">{description}</p>\n        </div>\n        {headerAction && <div className=\"shrink-0\">{headerAction}</div>}\n      </div>\n      {/* Content */}\n      {children}\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/components/ui/accordion.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"use client\";\r\n\r\nimport { Accordion as AccordionPrimitive } from \"radix-ui\";\r\nimport type * as React from \"react\";\r\n\r\nimport { cn } from \"@/lib/utils\";\r\nimport { ArrowDown01Icon, ArrowUp01Icon } from \"@hugeicons/core-free-icons\";\r\nimport { HugeiconsIcon } from \"@hugeicons/react\";\r\n\r\nfunction Accordion({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<typeof AccordionPrimitive.Root>) {\r\n  return (\r\n    <AccordionPrimitive.Root\r\n      data-slot=\"accordion\"\r\n      className={cn(\r\n        \"overflow-hidden rounded-2xl border flex w-full flex-col\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction AccordionItem({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<typeof AccordionPrimitive.Item>) {\r\n  return (\r\n    <AccordionPrimitive.Item\r\n      data-slot=\"accordion-item\"\r\n      className={cn(\"data-open:bg-muted/50 not-last:border-b\", className)}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction AccordionTrigger({\r\n  className,\r\n  children,\r\n  ...props\r\n}: React.ComponentProps<typeof AccordionPrimitive.Trigger>) {\r\n  return (\r\n    <AccordionPrimitive.Header className=\"flex\">\r\n      <AccordionPrimitive.Trigger\r\n        data-slot=\"accordion-trigger\"\r\n        className={cn(\r\n          \"**:data-[slot=accordion-trigger-icon]:text-muted-foreground gap-6 p-4 text-left text-sm font-medium hover:underline **:data-[slot=accordion-trigger-icon]:ml-auto **:data-[slot=accordion-trigger-icon]:size-4 group/accordion-trigger relative flex flex-1 items-start justify-between border border-transparent transition-all outline-none disabled:pointer-events-none disabled:opacity-50\",\r\n          className,\r\n        )}\r\n        {...props}\r\n      >\r\n        {children}\r\n        <HugeiconsIcon\r\n          icon={ArrowDown01Icon}\r\n          strokeWidth={2}\r\n          data-slot=\"accordion-trigger-icon\"\r\n          className=\"pointer-events-none shrink-0 group-aria-expanded/accordion-trigger:hidden\"\r\n        />\r\n        <HugeiconsIcon\r\n          icon={ArrowUp01Icon}\r\n          strokeWidth={2}\r\n          data-slot=\"accordion-trigger-icon\"\r\n          className=\"pointer-events-none hidden shrink-0 group-aria-expanded/accordion-trigger:inline\"\r\n        />\r\n      </AccordionPrimitive.Trigger>\r\n    </AccordionPrimitive.Header>\r\n  );\r\n}\r\n\r\nfunction AccordionContent({\r\n  className,\r\n  children,\r\n  ...props\r\n}: React.ComponentProps<typeof AccordionPrimitive.Content>) {\r\n  return (\r\n    <AccordionPrimitive.Content\r\n      data-slot=\"accordion-content\"\r\n      className=\"data-open:animate-accordion-down data-closed:animate-accordion-up px-4 text-sm overflow-hidden\"\r\n      {...props}\r\n    >\r\n      <div\r\n        className={cn(\r\n          \"pt-0 pb-4 [&_a]:hover:text-foreground h-(--radix-accordion-content-height) [&_a]:underline [&_a]:underline-offset-3 [&_p:not(:last-child)]:mb-4\",\r\n          className,\r\n        )}\r\n      >\r\n        {children}\r\n      </div>\r\n    </AccordionPrimitive.Content>\r\n  );\r\n}\r\n\r\nexport { Accordion, AccordionItem, AccordionTrigger, AccordionContent };\r\n"
  },
  {
    "path": "studio/frontend/src/components/ui/alert-dialog.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { AlertDialog as AlertDialogPrimitive } from \"radix-ui\";\r\nimport type * as React from \"react\";\r\n\r\nimport { Button } from \"@/components/ui/button\";\r\nimport { cn } from \"@/lib/utils\";\r\n\r\nfunction AlertDialog({\r\n  ...props\r\n}: React.ComponentProps<typeof AlertDialogPrimitive.Root>) {\r\n  return <AlertDialogPrimitive.Root data-slot=\"alert-dialog\" {...props} />;\r\n}\r\n\r\nfunction AlertDialogTrigger({\r\n  ...props\r\n}: React.ComponentProps<typeof AlertDialogPrimitive.Trigger>) {\r\n  return (\r\n    <AlertDialogPrimitive.Trigger data-slot=\"alert-dialog-trigger\" {...props} />\r\n  );\r\n}\r\n\r\nfunction AlertDialogPortal({\r\n  ...props\r\n}: React.ComponentProps<typeof AlertDialogPrimitive.Portal>) {\r\n  return (\r\n    <AlertDialogPrimitive.Portal data-slot=\"alert-dialog-portal\" {...props} />\r\n  );\r\n}\r\n\r\nfunction AlertDialogOverlay({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<typeof AlertDialogPrimitive.Overlay>) {\r\n  return (\r\n    <AlertDialogPrimitive.Overlay\r\n      data-slot=\"alert-dialog-overlay\"\r\n      className={cn(\r\n        \"data-open:animate-in data-closed:animate-out data-closed:fade-out-0 data-open:fade-in-0 bg-black/80 duration-100 supports-backdrop-filter:backdrop-blur-xs fixed inset-0 z-50\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction AlertDialogContent({\n  className,\n  size = \"default\",\n  overlayClassName,\n  ...props\n}: React.ComponentProps<typeof AlertDialogPrimitive.Content> & {\n  size?: \"default\" | \"sm\";\n  overlayClassName?: string;\n}) {\n  return (\n    <AlertDialogPortal>\n      <AlertDialogOverlay className={overlayClassName} />\n      <AlertDialogPrimitive.Content\n        data-slot=\"alert-dialog-content\"\n        data-size={size}\n        className={cn(\r\n          \"data-open:animate-in data-closed:animate-out data-closed:fade-out-0 data-open:fade-in-0 data-closed:zoom-out-95 data-open:zoom-in-95 bg-background ring-foreground/5 gap-6 rounded-4xl p-6 ring-1 duration-100 data-[size=default]:max-w-xs data-[size=sm]:max-w-xs data-[size=default]:sm:max-w-md group/alert-dialog-content fixed top-1/2 left-1/2 z-50 grid w-full -translate-x-1/2 -translate-y-1/2 outline-none\",\r\n          className,\r\n        )}\r\n        {...props}\r\n      />\r\n    </AlertDialogPortal>\r\n  );\r\n}\r\n\r\nfunction AlertDialogHeader({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<\"div\">) {\r\n  return (\r\n    <div\r\n      data-slot=\"alert-dialog-header\"\r\n      className={cn(\r\n        \"grid grid-rows-[auto_1fr] place-items-center gap-1.5 text-center has-data-[slot=alert-dialog-media]:grid-rows-[auto_auto_1fr] has-data-[slot=alert-dialog-media]:gap-x-6 sm:group-data-[size=default]/alert-dialog-content:place-items-start sm:group-data-[size=default]/alert-dialog-content:text-left sm:group-data-[size=default]/alert-dialog-content:has-data-[slot=alert-dialog-media]:grid-rows-[auto_1fr]\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction AlertDialogFooter({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<\"div\">) {\r\n  return (\r\n    <div\r\n      data-slot=\"alert-dialog-footer\"\r\n      className={cn(\r\n        \"flex flex-col-reverse gap-2 group-data-[size=sm]/alert-dialog-content:grid group-data-[size=sm]/alert-dialog-content:grid-cols-2 sm:flex-row sm:justify-end\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction AlertDialogMedia({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<\"div\">) {\r\n  return (\r\n    <div\r\n      data-slot=\"alert-dialog-media\"\r\n      className={cn(\r\n        \"bg-muted mb-2 inline-flex size-16 items-center justify-center rounded-full sm:group-data-[size=default]/alert-dialog-content:row-span-2 *:[svg:not([class*='size-'])]:size-8\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction AlertDialogTitle({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<typeof AlertDialogPrimitive.Title>) {\r\n  return (\r\n    <AlertDialogPrimitive.Title\r\n      data-slot=\"alert-dialog-title\"\r\n      className={cn(\r\n        \"text-lg font-medium sm:group-data-[size=default]/alert-dialog-content:group-has-data-[slot=alert-dialog-media]/alert-dialog-content:col-start-2\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction AlertDialogDescription({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<typeof AlertDialogPrimitive.Description>) {\r\n  return (\r\n    <AlertDialogPrimitive.Description\r\n      data-slot=\"alert-dialog-description\"\r\n      className={cn(\r\n        \"text-muted-foreground *:[a]:hover:text-foreground text-sm text-balance md:text-pretty *:[a]:underline *:[a]:underline-offset-3\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction AlertDialogAction({\r\n  className,\r\n  variant = \"default\",\r\n  size = \"default\",\r\n  ...props\r\n}: React.ComponentProps<typeof AlertDialogPrimitive.Action> &\r\n  Pick<React.ComponentProps<typeof Button>, \"variant\" | \"size\">) {\r\n  return (\r\n    <Button variant={variant} size={size} asChild>\r\n      <AlertDialogPrimitive.Action\r\n        data-slot=\"alert-dialog-action\"\r\n        className={cn(className)}\r\n        {...props}\r\n      />\r\n    </Button>\r\n  );\r\n}\r\n\r\nfunction AlertDialogCancel({\r\n  className,\r\n  variant = \"outline\",\r\n  size = \"default\",\r\n  ...props\r\n}: React.ComponentProps<typeof AlertDialogPrimitive.Cancel> &\r\n  Pick<React.ComponentProps<typeof Button>, \"variant\" | \"size\">) {\r\n  return (\r\n    <Button variant={variant} size={size} asChild>\r\n      <AlertDialogPrimitive.Cancel\r\n        data-slot=\"alert-dialog-cancel\"\r\n        className={cn(className)}\r\n        {...props}\r\n      />\r\n    </Button>\r\n  );\r\n}\r\n\r\nexport {\r\n  AlertDialog,\r\n  AlertDialogAction,\r\n  AlertDialogCancel,\r\n  AlertDialogContent,\r\n  AlertDialogDescription,\r\n  AlertDialogFooter,\r\n  AlertDialogHeader,\r\n  AlertDialogMedia,\r\n  AlertDialogOverlay,\r\n  AlertDialogPortal,\r\n  AlertDialogTitle,\r\n  AlertDialogTrigger,\r\n};\r\n"
  },
  {
    "path": "studio/frontend/src/components/ui/alert.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { type VariantProps, cva } from \"class-variance-authority\";\r\nimport type * as React from \"react\";\r\n\r\nimport { cn } from \"@/lib/utils\";\r\n\r\nconst alertVariants = cva(\r\n  \"grid gap-0.5 rounded-lg border px-4 py-3 text-left text-sm has-data-[slot=alert-action]:relative has-data-[slot=alert-action]:pr-18 has-[>svg]:grid-cols-[auto_1fr] has-[>svg]:gap-x-2.5 *:[svg]:row-span-2 *:[svg]:translate-y-0.5 *:[svg]:text-current *:[svg:not([class*='size-'])]:size-4 w-full relative group/alert\",\r\n  {\r\n    variants: {\r\n      variant: {\r\n        default: \"bg-card text-card-foreground\",\r\n        destructive:\r\n          \"text-destructive bg-card *:data-[slot=alert-description]:text-destructive/90 *:[svg]:text-current\",\r\n      },\r\n    },\r\n    defaultVariants: {\r\n      variant: \"default\",\r\n    },\r\n  },\r\n);\r\n\r\nfunction Alert({\r\n  className,\r\n  variant,\r\n  ...props\r\n}: React.ComponentProps<\"div\"> & VariantProps<typeof alertVariants>) {\r\n  return (\r\n    <div\r\n      data-slot=\"alert\"\r\n      role=\"alert\"\r\n      className={cn(alertVariants({ variant }), className)}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction AlertTitle({ className, ...props }: React.ComponentProps<\"div\">) {\r\n  return (\r\n    <div\r\n      data-slot=\"alert-title\"\r\n      className={cn(\r\n        \"font-medium group-has-[>svg]/alert:col-start-2 [&_a]:hover:text-foreground [&_a]:underline [&_a]:underline-offset-3\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction AlertDescription({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<\"div\">) {\r\n  return (\r\n    <div\r\n      data-slot=\"alert-description\"\r\n      className={cn(\r\n        \"text-muted-foreground text-sm text-balance md:text-pretty [&_p:not(:last-child)]:mb-4 [&_a]:hover:text-foreground [&_a]:underline [&_a]:underline-offset-3\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction AlertAction({ className, ...props }: React.ComponentProps<\"div\">) {\r\n  return (\r\n    <div\r\n      data-slot=\"alert-action\"\r\n      className={cn(\"absolute top-2.5 right-3\", className)}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nexport { Alert, AlertTitle, AlertDescription, AlertAction };\r\n"
  },
  {
    "path": "studio/frontend/src/components/ui/animated-shiny-text.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { ComponentPropsWithoutRef, CSSProperties, FC } from \"react\"\r\n\r\nimport { cn } from \"@/lib/utils\"\r\n\r\nexport interface AnimatedShinyTextProps extends ComponentPropsWithoutRef<\"span\"> {\r\n  shimmerWidth?: number\r\n}\r\n\r\nexport const AnimatedShinyText: FC<AnimatedShinyTextProps> = ({\r\n  children,\r\n  className,\r\n  shimmerWidth = 100,\r\n  ...props\r\n}) => {\r\n  return (\r\n    <span\r\n      style={\r\n        {\r\n          \"--shiny-width\": `${shimmerWidth}px`,\r\n        } as CSSProperties\r\n      }\r\n      className={cn(\r\n        \"mx-auto max-w-md text-neutral-600/70 dark:text-neutral-400/70\",\r\n\r\n        // Shine effect\r\n        \"animate-shiny-text [background-size:var(--shiny-width)_100%] bg-clip-text [background-position:0_0] bg-no-repeat [transition:background-position_1s_cubic-bezier(.6,.6,0,1)_infinite]\",\r\n\r\n        // Shine gradient\r\n        \"bg-gradient-to-r from-transparent via-black/80 via-50% to-transparent dark:via-white/80\",\r\n\r\n        className\r\n      )}\r\n      {...props}\r\n    >\r\n      {children}\r\n    </span>\r\n  )\r\n}\r\n"
  },
  {
    "path": "studio/frontend/src/components/ui/animated-theme-toggler.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { useCallback, useEffect, useRef, useState } from \"react\"\nimport { Moon, Sun } from \"lucide-react\"\nimport { flushSync } from \"react-dom\"\n\nimport { cn } from \"@/lib/utils\"\n\ninterface AnimatedThemeTogglerProps extends React.ComponentPropsWithoutRef<\"button\"> {\n  duration?: number\n}\n\nexport const AnimatedThemeToggler = ({\n  className,\n  duration = 400,\n  ...props\n}: AnimatedThemeTogglerProps) => {\n  const [isDark, setIsDark] = useState(false)\n  const buttonRef = useRef<HTMLButtonElement>(null)\n\n  useEffect(() => {\n    const updateTheme = () => {\n      setIsDark(document.documentElement.classList.contains(\"dark\"))\n    }\n\n    updateTheme()\n\n    const observer = new MutationObserver(updateTheme)\n    observer.observe(document.documentElement, {\n      attributes: true,\n      attributeFilter: [\"class\"],\n    })\n\n    return () => observer.disconnect()\n  }, [])\n\n  const toggleTheme = useCallback(async () => {\n    if (!buttonRef.current) return\n\n    await document.startViewTransition(() => {\n      flushSync(() => {\n        const newTheme = !isDark\n        setIsDark(newTheme)\n        document.documentElement.classList.toggle(\"dark\")\n        localStorage.setItem(\"theme\", newTheme ? \"dark\" : \"light\")\n      })\n    }).ready\n\n    const { top, left, width, height } =\n      buttonRef.current.getBoundingClientRect()\n    const x = left + width / 2\n    const y = top + height / 2\n    const maxRadius = Math.hypot(\n      Math.max(left, window.innerWidth - left),\n      Math.max(top, window.innerHeight - top)\n    )\n\n    document.documentElement.animate(\n      {\n        clipPath: [\n          `circle(0px at ${x}px ${y}px)`,\n          `circle(${maxRadius}px at ${x}px ${y}px)`,\n        ],\n      },\n      {\n        duration,\n        easing: \"ease-in-out\",\n        pseudoElement: \"::view-transition-new(root)\",\n      }\n    )\n  }, [isDark, duration])\n\n  return (\n    <button\n      ref={buttonRef}\n      onClick={toggleTheme}\n      className={cn(className)}\n      {...props}\n    >\n      {isDark ? <Sun /> : <Moon />}\n      <span className=\"sr-only\">Toggle theme</span>\n    </button>\n  )\n}\n"
  },
  {
    "path": "studio/frontend/src/components/ui/aspect-ratio.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { AspectRatio as AspectRatioPrimitive } from \"radix-ui\";\r\n\r\nfunction AspectRatio({\r\n  ...props\r\n}: React.ComponentProps<typeof AspectRatioPrimitive.Root>) {\r\n  return <AspectRatioPrimitive.Root data-slot=\"aspect-ratio\" {...props} />;\r\n}\r\n\r\nexport { AspectRatio };\r\n"
  },
  {
    "path": "studio/frontend/src/components/ui/avatar.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Avatar as AvatarPrimitive } from \"radix-ui\";\r\nimport type * as React from \"react\";\r\n\r\nimport { cn } from \"@/lib/utils\";\r\n\r\nfunction Avatar({\r\n  className,\r\n  size = \"default\",\r\n  ...props\r\n}: React.ComponentProps<typeof AvatarPrimitive.Root> & {\r\n  size?: \"default\" | \"sm\" | \"lg\";\r\n}) {\r\n  return (\r\n    <AvatarPrimitive.Root\r\n      data-slot=\"avatar\"\r\n      data-size={size}\r\n      className={cn(\r\n        \"size-8 rounded-full after:rounded-full data-[size=lg]:size-10 data-[size=sm]:size-6 after:border-border group/avatar relative flex shrink-0 select-none after:absolute after:inset-0 after:border after:mix-blend-darken dark:after:mix-blend-lighten\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction AvatarImage({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<typeof AvatarPrimitive.Image>) {\r\n  return (\r\n    <AvatarPrimitive.Image\r\n      data-slot=\"avatar-image\"\r\n      className={cn(\r\n        \"rounded-full aspect-square size-full object-cover\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction AvatarFallback({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<typeof AvatarPrimitive.Fallback>) {\r\n  return (\r\n    <AvatarPrimitive.Fallback\r\n      data-slot=\"avatar-fallback\"\r\n      className={cn(\r\n        \"bg-muted text-muted-foreground rounded-full flex size-full items-center justify-center text-sm group-data-[size=sm]/avatar:text-xs\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction AvatarBadge({ className, ...props }: React.ComponentProps<\"span\">) {\r\n  return (\r\n    <span\r\n      data-slot=\"avatar-badge\"\r\n      className={cn(\r\n        \"bg-primary text-primary-foreground ring-background absolute right-0 bottom-0 z-10 inline-flex items-center justify-center rounded-full bg-blend-color ring-2 select-none\",\r\n        \"group-data-[size=sm]/avatar:size-2 group-data-[size=sm]/avatar:[&>svg]:hidden\",\r\n        \"group-data-[size=default]/avatar:size-2.5 group-data-[size=default]/avatar:[&>svg]:size-2\",\r\n        \"group-data-[size=lg]/avatar:size-3 group-data-[size=lg]/avatar:[&>svg]:size-2\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction AvatarGroup({ className, ...props }: React.ComponentProps<\"div\">) {\r\n  return (\r\n    <div\r\n      data-slot=\"avatar-group\"\r\n      className={cn(\r\n        \"*:data-[slot=avatar]:ring-background group/avatar-group flex -space-x-2 *:data-[slot=avatar]:ring-2\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction AvatarGroupCount({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<\"div\">) {\r\n  return (\r\n    <div\r\n      data-slot=\"avatar-group-count\"\r\n      className={cn(\r\n        \"bg-muted text-muted-foreground size-8 rounded-full text-sm group-has-data-[size=lg]/avatar-group:size-10 group-has-data-[size=sm]/avatar-group:size-6 [&>svg]:size-4 group-has-data-[size=lg]/avatar-group:[&>svg]:size-5 group-has-data-[size=sm]/avatar-group:[&>svg]:size-3 ring-background relative flex shrink-0 items-center justify-center ring-2\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nexport {\r\n  Avatar,\r\n  AvatarImage,\r\n  AvatarFallback,\r\n  AvatarGroup,\r\n  AvatarGroupCount,\r\n  AvatarBadge,\r\n};\r\n"
  },
  {
    "path": "studio/frontend/src/components/ui/badge.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n/* eslint-disable react-refresh/only-export-components */\r\n\r\nimport { type VariantProps, cva } from \"class-variance-authority\";\r\nimport { Slot } from \"radix-ui\";\r\nimport type * as React from \"react\";\r\n\r\nimport { cn } from \"@/lib/utils\";\r\n\r\nexport const badgeVariants = cva(\r\n  \"h-5 gap-1 rounded-4xl border border-transparent px-2 py-0.5 text-xs font-medium transition-all has-data-[icon=inline-end]:pr-1.5 has-data-[icon=inline-start]:pl-1.5 [&>svg]:size-3! inline-flex items-center justify-center w-fit whitespace-nowrap shrink-0 [&>svg]:pointer-events-none focus-visible:border-ring focus-visible:ring-ring/50 focus-visible:ring-[3px] aria-invalid:ring-destructive/20 dark:aria-invalid:ring-destructive/40 aria-invalid:border-destructive overflow-hidden group/badge\",\r\n  {\r\n    variants: {\r\n      variant: {\r\n        default: \"bg-primary text-primary-foreground [a]:hover:bg-primary/80\",\r\n        secondary:\r\n          \"bg-secondary text-secondary-foreground [a]:hover:bg-secondary/80\",\r\n        destructive:\r\n          \"bg-destructive/10 [a]:hover:bg-destructive/20 focus-visible:ring-destructive/20 dark:focus-visible:ring-destructive/40 text-destructive dark:bg-destructive/20\",\r\n        outline:\r\n          \"border-border text-foreground [a]:hover:bg-muted [a]:hover:text-muted-foreground bg-input/30\",\r\n        ghost:\r\n          \"hover:bg-muted hover:text-muted-foreground dark:hover:bg-muted/50\",\r\n        link: \"text-primary underline-offset-4 hover:underline\",\r\n      },\r\n    },\r\n    defaultVariants: {\r\n      variant: \"default\",\r\n    },\r\n  },\r\n);\r\n\r\nexport function Badge({\r\n  className,\r\n  variant = \"default\",\r\n  asChild = false,\r\n  ...props\r\n}: React.ComponentProps<\"span\"> &\r\n  VariantProps<typeof badgeVariants> & {\r\n    asChild?: boolean;\r\n  }): React.ReactElement {\r\n  const Comp = asChild ? Slot.Root : \"span\";\r\n\r\n  return (\r\n    <Comp\r\n      data-slot=\"badge\"\r\n      data-variant={variant}\r\n      className={cn(badgeVariants({ variant }), className)}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n"
  },
  {
    "path": "studio/frontend/src/components/ui/breadcrumb.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Slot } from \"radix-ui\";\r\nimport type * as React from \"react\";\r\n\r\nimport { cn } from \"@/lib/utils\";\r\nimport {\r\n  ArrowRight01Icon,\r\n  MoreHorizontalCircle01Icon,\r\n} from \"@hugeicons/core-free-icons\";\r\nimport { HugeiconsIcon } from \"@hugeicons/react\";\r\n\r\nfunction Breadcrumb({ className, ...props }: React.ComponentProps<\"nav\">) {\r\n  return (\r\n    <nav\r\n      aria-label=\"breadcrumb\"\r\n      data-slot=\"breadcrumb\"\r\n      className={cn(className)}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction BreadcrumbList({ className, ...props }: React.ComponentProps<\"ol\">) {\r\n  return (\r\n    <ol\r\n      data-slot=\"breadcrumb-list\"\r\n      className={cn(\r\n        \"text-muted-foreground gap-1.5 text-sm sm:gap-2.5 flex flex-wrap items-center break-words\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction BreadcrumbItem({ className, ...props }: React.ComponentProps<\"li\">) {\r\n  return (\r\n    <li\r\n      data-slot=\"breadcrumb-item\"\r\n      className={cn(\"gap-1.5 inline-flex items-center\", className)}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction BreadcrumbLink({\r\n  asChild,\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<\"a\"> & {\r\n  asChild?: boolean;\r\n}) {\r\n  const Comp = asChild ? Slot.Root : \"a\";\r\n\r\n  return (\r\n    <Comp\r\n      data-slot=\"breadcrumb-link\"\r\n      className={cn(\"hover:text-foreground transition-colors\", className)}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction BreadcrumbPage({ className, ...props }: React.ComponentProps<\"span\">) {\r\n  return (\r\n    <span\r\n      data-slot=\"breadcrumb-page\"\r\n      role=\"link\"\r\n      aria-disabled=\"true\"\r\n      aria-current=\"page\"\r\n      className={cn(\"text-foreground font-normal\", className)}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction BreadcrumbSeparator({\r\n  children,\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<\"li\">) {\r\n  return (\r\n    <li\r\n      data-slot=\"breadcrumb-separator\"\r\n      role=\"presentation\"\r\n      aria-hidden=\"true\"\r\n      className={cn(\"[&>svg]:size-3.5\", className)}\r\n      {...props}\r\n    >\r\n      {children ?? <HugeiconsIcon icon={ArrowRight01Icon} strokeWidth={2} />}\r\n    </li>\r\n  );\r\n}\r\n\r\nfunction BreadcrumbEllipsis({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<\"span\">) {\r\n  return (\r\n    <span\r\n      data-slot=\"breadcrumb-ellipsis\"\r\n      role=\"presentation\"\r\n      aria-hidden=\"true\"\r\n      className={cn(\r\n        \"size-5 [&>svg]:size-4 flex items-center justify-center\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    >\r\n      <HugeiconsIcon icon={MoreHorizontalCircle01Icon} strokeWidth={2} />\r\n      <span className=\"sr-only\">More</span>\r\n    </span>\r\n  );\r\n}\r\n\r\nexport {\r\n  Breadcrumb,\r\n  BreadcrumbList,\r\n  BreadcrumbItem,\r\n  BreadcrumbLink,\r\n  BreadcrumbPage,\r\n  BreadcrumbSeparator,\r\n  BreadcrumbEllipsis,\r\n};\r\n"
  },
  {
    "path": "studio/frontend/src/components/ui/button.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n/* eslint-disable react-refresh/only-export-components */\r\n\r\nimport { type VariantProps, cva } from \"class-variance-authority\";\r\nimport { Slot } from \"radix-ui\";\r\nimport type * as React from \"react\";\r\n\r\nimport { cn } from \"@/lib/utils\";\r\n\r\nexport const buttonVariants = cva(\r\n  \"focus-visible:border-ring focus-visible:ring-ring/50 aria-invalid:ring-destructive/20 dark:aria-invalid:ring-destructive/40 aria-invalid:border-destructive dark:aria-invalid:border-destructive/50 rounded-4xl border border-transparent bg-clip-padding text-sm font-medium focus-visible:ring-[3px] aria-invalid:ring-[3px] [&_svg:not([class*='size-'])]:size-4 inline-flex items-center justify-center whitespace-nowrap transition-all disabled:pointer-events-none disabled:opacity-50 [&_svg]:pointer-events-none shrink-0 [&_svg]:shrink-0 outline-none group/button select-none cursor-pointer\",\r\n  {\r\n    variants: {\r\n      variant: {\r\n        default: \"bg-primary text-primary-foreground hover:bg-primary/80\",\r\n        dark: \"bg-foreground text-background hover:bg-foreground/85 dark:bg-foreground dark:text-background\",\r\n        outline:\r\n          \"border-border bg-input/30 hover:bg-input/50 hover:text-foreground aria-expanded:bg-muted aria-expanded:text-foreground\",\r\n        secondary:\r\n          \"bg-secondary text-secondary-foreground hover:bg-secondary/80 aria-expanded:bg-secondary aria-expanded:text-secondary-foreground\",\r\n        ghost:\r\n          \"hover:bg-muted hover:text-foreground dark:hover:bg-muted/50 aria-expanded:bg-muted aria-expanded:text-foreground\",\r\n        destructive:\r\n          \"bg-destructive/10 hover:bg-destructive/20 focus-visible:ring-destructive/20 dark:focus-visible:ring-destructive/40 dark:bg-destructive/20 text-destructive focus-visible:border-destructive/40 dark:hover:bg-destructive/30\",\r\n        link: \"text-primary underline-offset-4 hover:underline\",\r\n      },\r\n      size: {\r\n        default:\r\n          \"h-9 gap-1.5 px-3 has-data-[icon=inline-end]:pr-2.5 has-data-[icon=inline-start]:pl-2.5\",\r\n        xs: \"h-6 gap-1 px-2.5 text-xs has-data-[icon=inline-end]:pr-2 has-data-[icon=inline-start]:pl-2 [&_svg:not([class*='size-'])]:size-3\",\r\n        sm: \"h-8 gap-1 px-3 has-data-[icon=inline-end]:pr-2 has-data-[icon=inline-start]:pl-2\",\r\n        lg: \"h-10 gap-1.5 px-4 has-data-[icon=inline-end]:pr-3 has-data-[icon=inline-start]:pl-3\",\r\n        icon: \"size-9\",\r\n        \"icon-xs\": \"size-6 [&_svg:not([class*='size-'])]:size-3\",\r\n        \"icon-sm\": \"size-8\",\r\n        \"icon-lg\": \"size-10\",\r\n      },\r\n    },\r\n    defaultVariants: {\r\n      variant: \"default\",\r\n      size: \"default\",\r\n    },\r\n  },\r\n);\r\n\r\nexport function Button({\r\n  className,\r\n  variant = \"default\",\r\n  size = \"default\",\r\n  asChild = false,\r\n  ...props\r\n}: React.ComponentProps<\"button\"> &\r\n  VariantProps<typeof buttonVariants> & {\r\n    asChild?: boolean;\r\n  }): React.ReactElement {\r\n  const Comp = asChild ? Slot.Root : \"button\";\r\n\r\n  return (\r\n    <Comp\r\n      data-slot=\"button\"\r\n      data-variant={variant}\r\n      data-size={size}\r\n      className={cn(buttonVariants({ variant, size, className }))}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n"
  },
  {
    "path": "studio/frontend/src/components/ui/calendar.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport * as React from \"react\";\r\nimport {\r\n  type DayButton,\r\n  DayPicker,\r\n  getDefaultClassNames,\r\n} from \"react-day-picker\";\r\n\r\nimport { Button, buttonVariants } from \"@/components/ui/button\";\r\nimport { cn } from \"@/lib/utils\";\r\nimport {\r\n  ArrowDownIcon,\r\n  ArrowLeftIcon,\r\n  ArrowRightIcon,\r\n} from \"@hugeicons/core-free-icons\";\r\nimport { HugeiconsIcon } from \"@hugeicons/react\";\r\n\r\nfunction Calendar({\r\n  className,\r\n  classNames,\r\n  showOutsideDays = true,\r\n  captionLayout = \"label\",\r\n  buttonVariant = \"ghost\",\r\n  formatters,\r\n  components,\r\n  ...props\r\n}: React.ComponentProps<typeof DayPicker> & {\r\n  buttonVariant?: React.ComponentProps<typeof Button>[\"variant\"];\r\n}) {\r\n  const defaultClassNames = getDefaultClassNames();\r\n\r\n  return (\r\n    <DayPicker\r\n      showOutsideDays={showOutsideDays}\r\n      className={cn(\r\n        \"p-3 [--cell-radius:var(--radius-4xl)] [--cell-size:--spacing(8)] bg-background group/calendar [[data-slot=card-content]_&]:bg-transparent [[data-slot=popover-content]_&]:bg-transparent\",\r\n        String.raw`rtl:**:[.rdp-button\\_next>svg]:rotate-180`,\r\n        String.raw`rtl:**:[.rdp-button\\_previous>svg]:rotate-180`,\r\n        className,\r\n      )}\r\n      captionLayout={captionLayout}\r\n      formatters={{\r\n        formatMonthDropdown: (date) =>\r\n          date.toLocaleString(\"default\", { month: \"short\" }),\r\n        ...formatters,\r\n      }}\r\n      classNames={{\r\n        root: cn(\"w-fit\", defaultClassNames.root),\r\n        months: cn(\r\n          \"flex gap-4 flex-col md:flex-row relative\",\r\n          defaultClassNames.months,\r\n        ),\r\n        month: cn(\"flex flex-col w-full gap-4\", defaultClassNames.month),\r\n        nav: cn(\r\n          \"flex items-center gap-1 w-full absolute top-0 inset-x-0 justify-between\",\r\n          defaultClassNames.nav,\r\n        ),\r\n        button_previous: cn(\r\n          buttonVariants({ variant: buttonVariant }),\r\n          \"size-(--cell-size) aria-disabled:opacity-50 p-0 select-none\",\r\n          defaultClassNames.button_previous,\r\n        ),\r\n        button_next: cn(\r\n          buttonVariants({ variant: buttonVariant }),\r\n          \"size-(--cell-size) aria-disabled:opacity-50 p-0 select-none\",\r\n          defaultClassNames.button_next,\r\n        ),\r\n        month_caption: cn(\r\n          \"flex items-center justify-center h-(--cell-size) w-full px-(--cell-size)\",\r\n          defaultClassNames.month_caption,\r\n        ),\r\n        dropdowns: cn(\r\n          \"w-full flex items-center text-sm font-medium justify-center h-(--cell-size) gap-1.5\",\r\n          defaultClassNames.dropdowns,\r\n        ),\r\n        dropdown_root: cn(\r\n          \"relative cn-calendar-dropdown-root rounded-(--cell-radius)\",\r\n          defaultClassNames.dropdown_root,\r\n        ),\r\n        dropdown: cn(\r\n          \"absolute bg-popover inset-0 opacity-0\",\r\n          defaultClassNames.dropdown,\r\n        ),\r\n        caption_label: cn(\r\n          \"select-none font-medium\",\r\n          captionLayout === \"label\"\r\n            ? \"text-sm\"\r\n            : \"cn-calendar-caption-label rounded-(--cell-radius) flex items-center gap-1 text-sm  [&>svg]:text-muted-foreground [&>svg]:size-3.5\",\r\n          defaultClassNames.caption_label,\r\n        ),\r\n        table: \"w-full border-collapse\",\r\n        weekdays: cn(\"flex\", defaultClassNames.weekdays),\r\n        weekday: cn(\r\n          \"text-muted-foreground rounded-(--cell-radius) flex-1 font-normal text-[0.8rem] select-none\",\r\n          defaultClassNames.weekday,\r\n        ),\r\n        week: cn(\"flex w-full mt-2\", defaultClassNames.week),\r\n        week_number_header: cn(\r\n          \"select-none w-(--cell-size)\",\r\n          defaultClassNames.week_number_header,\r\n        ),\r\n        week_number: cn(\r\n          \"text-[0.8rem] select-none text-muted-foreground\",\r\n          defaultClassNames.week_number,\r\n        ),\r\n        day: cn(\r\n          \"relative w-full rounded-(--cell-radius) h-full p-0 text-center [&:last-child[data-selected=true]_button]:rounded-r-(--cell-radius) group/day aspect-square select-none\",\r\n          props.showWeekNumber\r\n            ? \"[&:nth-child(2)[data-selected=true]_button]:rounded-l-(--cell-radius)\"\r\n            : \"[&:first-child[data-selected=true]_button]:rounded-l-(--cell-radius)\",\r\n          defaultClassNames.day,\r\n        ),\r\n        range_start: cn(\r\n          \"rounded-l-(--cell-radius) bg-muted relative after:bg-muted after:absolute after:inset-y-0 after:w-4 after:right-0 -z-0 isolate\",\r\n          defaultClassNames.range_start,\r\n        ),\r\n        range_middle: cn(\"rounded-none\", defaultClassNames.range_middle),\r\n        range_end: cn(\r\n          \"rounded-r-(--cell-radius) bg-muted relative after:bg-muted-200 after:absolute after:inset-y-0 after:w-4 after:left-0 -z-0 isolate\",\r\n          defaultClassNames.range_end,\r\n        ),\r\n        today: cn(\r\n          \"bg-muted text-foreground rounded-(--cell-radius) data-[selected=true]:rounded-none\",\r\n          defaultClassNames.today,\r\n        ),\r\n        outside: cn(\r\n          \"text-muted-foreground aria-selected:text-muted-foreground\",\r\n          defaultClassNames.outside,\r\n        ),\r\n        disabled: cn(\r\n          \"text-muted-foreground opacity-50\",\r\n          defaultClassNames.disabled,\r\n        ),\r\n        hidden: cn(\"invisible\", defaultClassNames.hidden),\r\n        ...classNames,\r\n      }}\r\n      components={{\r\n        Root: ({ className, rootRef, ...props }) => {\r\n          return (\r\n            <div\r\n              data-slot=\"calendar\"\r\n              ref={rootRef}\r\n              className={cn(className)}\r\n              {...props}\r\n            />\r\n          );\r\n        },\r\n        Chevron: ({ className, orientation, ...props }) => {\r\n          if (orientation === \"left\") {\r\n            return (\r\n              <HugeiconsIcon\r\n                icon={ArrowLeftIcon}\r\n                strokeWidth={2}\r\n                className={cn(\"size-4\", className)}\r\n                {...props}\r\n              />\r\n            );\r\n          }\r\n\r\n          if (orientation === \"right\") {\r\n            return (\r\n              <HugeiconsIcon\r\n                icon={ArrowRightIcon}\r\n                strokeWidth={2}\r\n                className={cn(\"size-4\", className)}\r\n                {...props}\r\n              />\r\n            );\r\n          }\r\n\r\n          return (\r\n            <HugeiconsIcon\r\n              icon={ArrowDownIcon}\r\n              strokeWidth={2}\r\n              className={cn(\"size-4\", className)}\r\n              {...props}\r\n            />\r\n          );\r\n        },\r\n        DayButton: CalendarDayButton,\r\n        WeekNumber: ({ children, ...props }) => {\r\n          return (\r\n            <td {...props}>\r\n              <div className=\"flex size-(--cell-size) items-center justify-center text-center\">\r\n                {children}\r\n              </div>\r\n            </td>\r\n          );\r\n        },\r\n        ...components,\r\n      }}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction CalendarDayButton({\r\n  className,\r\n  day,\r\n  modifiers,\r\n  ...props\r\n}: React.ComponentProps<typeof DayButton>) {\r\n  const defaultClassNames = getDefaultClassNames();\r\n\r\n  const ref = React.useRef<HTMLButtonElement>(null);\r\n  React.useEffect(() => {\r\n    if (modifiers.focused) ref.current?.focus();\r\n  }, [modifiers.focused]);\r\n\r\n  return (\r\n    <Button\r\n      ref={ref}\r\n      variant=\"ghost\"\r\n      size=\"icon\"\r\n      data-day={day.date.toLocaleDateString()}\r\n      data-selected-single={\r\n        modifiers.selected &&\r\n        !modifiers.range_start &&\r\n        !modifiers.range_end &&\r\n        !modifiers.range_middle\r\n      }\r\n      data-range-start={modifiers.range_start}\r\n      data-range-end={modifiers.range_end}\r\n      data-range-middle={modifiers.range_middle}\r\n      className={cn(\r\n        \"data-[selected-single=true]:bg-primary data-[selected-single=true]:text-primary-foreground data-[range-middle=true]:bg-muted data-[range-middle=true]:text-foreground data-[range-start=true]:bg-primary data-[range-start=true]:text-primary-foreground data-[range-end=true]:bg-primary data-[range-end=true]:text-primary-foreground group-data-[focused=true]/day:border-ring group-data-[focused=true]/day:ring-ring/50 dark:hover:text-foreground relative isolate z-10 flex aspect-square size-auto w-full min-w-(--cell-size) flex-col gap-1 border-0 leading-none font-normal group-data-[focused=true]/day:relative group-data-[focused=true]/day:z-10 group-data-[focused=true]/day:ring-[3px] data-[range-end=true]:rounded-(--cell-radius) data-[range-end=true]:rounded-r-(--cell-radius) data-[range-middle=true]:rounded-none data-[range-start=true]:rounded-(--cell-radius) data-[range-start=true]:rounded-l-(--cell-radius) [&>span]:text-xs [&>span]:opacity-70\",\r\n        defaultClassNames.day,\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nexport { Calendar, CalendarDayButton };\r\n"
  },
  {
    "path": "studio/frontend/src/components/ui/card.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type * as React from \"react\";\r\n\r\nimport { cn } from \"@/lib/utils\";\r\n\r\nfunction Card({\r\n  className,\r\n  size = \"default\",\r\n  ...props\r\n}: React.ComponentProps<\"div\"> & { size?: \"default\" | \"sm\" }) {\r\n  return (\r\n    <div\r\n      data-slot=\"card\"\r\n      data-size={size}\r\n      className={cn(\r\n        \"ring-foreground/10 bg-card corner-squircle text-card-foreground gap-6 overflow-hidden rounded-4xl py-6 text-sm ring-1 has-[>img:first-child]:pt-0 data-[size=sm]:gap-4 data-[size=sm]:py-4 *:[img:first-child]:rounded-t-xl *:[img:last-child]:rounded-b-xl group/card  flex flex-col\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction CardHeader({ className, ...props }: React.ComponentProps<\"div\">) {\r\n  return (\r\n    <div\r\n      data-slot=\"card-header\"\r\n      className={cn(\r\n        \"gap-2 rounded-t-xl px-6 group-data-[size=sm]/card:px-4 [.border-b]:pb-6 group-data-[size=sm]/card:[.border-b]:pb-4 group/card-header @container/card-header grid auto-rows-min items-start has-data-[slot=card-action]:grid-cols-[1fr_auto] has-data-[slot=card-description]:grid-rows-[auto_auto]\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction CardTitle({ className, ...props }: React.ComponentProps<\"div\">) {\r\n  return (\r\n    <div\r\n      data-slot=\"card-title\"\r\n      className={cn(\"text-base font-medium\", className)}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction CardDescription({ className, ...props }: React.ComponentProps<\"div\">) {\r\n  return (\r\n    <div\r\n      data-slot=\"card-description\"\r\n      className={cn(\"text-muted-foreground text-sm\", className)}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction CardAction({ className, ...props }: React.ComponentProps<\"div\">) {\r\n  return (\r\n    <div\r\n      data-slot=\"card-action\"\r\n      className={cn(\r\n        \"col-start-2 row-span-2 row-start-1 self-start justify-self-end\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction CardContent({ className, ...props }: React.ComponentProps<\"div\">) {\r\n  return (\r\n    <div\r\n      data-slot=\"card-content\"\r\n      className={cn(\"px-6 group-data-[size=sm]/card:px-4\", className)}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction CardFooter({ className, ...props }: React.ComponentProps<\"div\">) {\r\n  return (\r\n    <div\r\n      data-slot=\"card-footer\"\r\n      className={cn(\r\n        \"rounded-b-xl px-6 group-data-[size=sm]/card:px-4 [.border-t]:pt-6 group-data-[size=sm]/card:[.border-t]:pt-4 flex items-center\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nexport {\r\n  Card,\r\n  CardHeader,\r\n  CardFooter,\r\n  CardTitle,\r\n  CardAction,\r\n  CardDescription,\r\n  CardContent,\r\n};\r\n"
  },
  {
    "path": "studio/frontend/src/components/ui/chart.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport * as React from \"react\";\r\nimport * as RechartsPrimitive from \"recharts\";\r\n\r\nimport { cn } from \"@/lib/utils\";\r\n\r\n// Format: { THEME_NAME: CSS_SELECTOR }\r\nconst THEMES = { light: \"\", dark: \".dark\" } as const;\r\n\r\nexport type ChartConfig = {\r\n  [k in string]: {\r\n    label?: React.ReactNode;\r\n    icon?: React.ComponentType;\r\n  } & (\r\n    | { color?: string; theme?: never }\r\n    | { color?: never; theme: Record<keyof typeof THEMES, string> }\r\n  );\r\n};\r\n\r\ntype ChartContextProps = {\r\n  config: ChartConfig;\r\n};\r\n\r\nconst ChartContext = React.createContext<ChartContextProps | null>(null);\r\n\r\nfunction useChart() {\r\n  const context = React.useContext(ChartContext);\r\n\r\n  if (!context) {\r\n    throw new Error(\"useChart must be used within a <ChartContainer />\");\r\n  }\r\n\r\n  return context;\r\n}\r\n\r\nfunction ChartContainer({\r\n  id,\r\n  className,\r\n  children,\r\n  config,\r\n  ...props\r\n}: React.ComponentProps<\"div\"> & {\r\n  config: ChartConfig;\r\n  children: React.ComponentProps<\r\n    typeof RechartsPrimitive.ResponsiveContainer\r\n  >[\"children\"];\r\n}) {\r\n  const uniqueId = React.useId();\r\n  const chartId = `chart-${id || uniqueId.replace(/:/g, \"\")}`;\r\n  const containerRef = React.useRef<HTMLDivElement | null>(null);\r\n  const [containerSize, setContainerSize] = React.useState<{\r\n    width: number;\r\n    height: number;\r\n  } | null>(null);\r\n\r\n  React.useEffect(() => {\r\n    const element = containerRef.current;\r\n    if (!element) return;\r\n\r\n    const updateSizeState = () => {\r\n      const { width, height } = element.getBoundingClientRect();\r\n      const nextSize =\r\n        width > 0 && height > 0\r\n          ? {\r\n              width: Math.round(width),\r\n              height: Math.round(height),\r\n            }\r\n          : null;\r\n\r\n      setContainerSize((currentSize) => {\r\n        if (!nextSize) {\r\n          // Keep the last valid size once mounted to avoid unmount/remount thrash.\r\n          return currentSize;\r\n        }\r\n        if (\r\n          currentSize &&\r\n          currentSize.width === nextSize.width &&\r\n          currentSize.height === nextSize.height\r\n        ) {\r\n          return currentSize;\r\n        }\r\n        return nextSize;\r\n      });\r\n    };\r\n\r\n    updateSizeState();\r\n\r\n    if (typeof ResizeObserver === \"undefined\") {\r\n      const recheckSize = () => {\r\n        if (document.visibilityState === \"visible\") {\r\n          updateSizeState();\r\n        }\r\n      };\r\n\r\n      window.addEventListener(\"resize\", recheckSize);\r\n      window.addEventListener(\"orientationchange\", recheckSize);\r\n      document.addEventListener(\"visibilitychange\", recheckSize);\r\n\r\n      return () => {\r\n        window.removeEventListener(\"resize\", recheckSize);\r\n        window.removeEventListener(\"orientationchange\", recheckSize);\r\n        document.removeEventListener(\"visibilitychange\", recheckSize);\r\n      };\r\n    }\r\n\r\n    const observer = new ResizeObserver(() => {\r\n      updateSizeState();\r\n    });\r\n    observer.observe(element);\r\n\r\n    return () => observer.disconnect();\r\n  }, []);\r\n\r\n  return (\r\n    <ChartContext.Provider value={{ config }}>\r\n      <div\r\n        ref={containerRef}\r\n        data-slot=\"chart\"\r\n        data-chart={chartId}\r\n        className={cn(\r\n          \"[&_.recharts-cartesian-axis-tick_text]:fill-muted-foreground [&_.recharts-cartesian-grid_line[stroke='#ccc']]:stroke-border/50 [&_.recharts-curve.recharts-tooltip-cursor]:stroke-border [&_.recharts-polar-grid_[stroke='#ccc']]:stroke-border [&_.recharts-radial-bar-background-sector]:fill-muted [&_.recharts-rectangle.recharts-tooltip-cursor]:fill-muted [&_.recharts-reference-line_[stroke='#ccc']]:stroke-border flex min-w-0 aspect-video justify-center text-xs [&_.recharts-dot[stroke='#fff']]:stroke-transparent [&_.recharts-layer]:outline-hidden [&_.recharts-sector]:outline-hidden [&_.recharts-sector[stroke='#fff']]:stroke-transparent [&_.recharts-surface]:outline-hidden\",\r\n          className,\r\n        )}\r\n        {...props}\r\n      >\r\n        <ChartStyle id={chartId} config={config} />\r\n        {containerSize ? (\r\n          <RechartsPrimitive.ResponsiveContainer\r\n            width=\"100%\"\r\n            height=\"100%\"\r\n            minWidth={0}\r\n            minHeight={1}\r\n            initialDimension={containerSize}\r\n          >\r\n            {children}\r\n          </RechartsPrimitive.ResponsiveContainer>\r\n        ) : null}\r\n      </div>\r\n    </ChartContext.Provider>\r\n  );\r\n}\r\n\r\nconst ChartStyle = ({ id, config }: { id: string; config: ChartConfig }) => {\r\n  const colorConfig = Object.entries(config).filter(\r\n    ([, config]) => config.theme || config.color,\r\n  );\r\n\r\n  if (!colorConfig.length) {\r\n    return null;\r\n  }\r\n\r\n  return (\r\n    <style\r\n      dangerouslySetInnerHTML={{\r\n        __html: Object.entries(THEMES)\r\n          .map(\r\n            ([theme, prefix]) => `\r\n${prefix} [data-chart=${id}] {\r\n${colorConfig\r\n  .map(([key, itemConfig]) => {\r\n    const color =\r\n      itemConfig.theme?.[theme as keyof typeof itemConfig.theme] ||\r\n      itemConfig.color;\r\n    return color ? `  --color-${key}: ${color};` : null;\r\n  })\r\n  .join(\"\\n\")}\r\n}\r\n`,\r\n          )\r\n          .join(\"\\n\"),\r\n      }}\r\n    />\r\n  );\r\n};\r\n\r\nconst ChartTooltip = RechartsPrimitive.Tooltip;\r\n\r\nfunction ChartTooltipContent({\r\n  active,\r\n  payload,\r\n  className,\r\n  indicator = \"dot\",\r\n  hideLabel = false,\r\n  hideIndicator = false,\r\n  label,\r\n  labelFormatter,\r\n  labelClassName,\r\n  formatter,\r\n  color,\r\n  nameKey,\r\n  labelKey,\r\n}: Partial<RechartsPrimitive.TooltipContentProps<any, any>> &\r\n  React.ComponentProps<\"div\"> & {\r\n    hideLabel?: boolean;\r\n    hideIndicator?: boolean;\r\n    indicator?: \"line\" | \"dot\" | \"dashed\";\r\n    nameKey?: string;\r\n    labelKey?: string;\r\n  }) {\r\n  const { config } = useChart();\r\n\r\n  const tooltipLabel = React.useMemo(() => {\r\n    if (hideLabel || !payload?.length) {\r\n      return null;\r\n    }\r\n\r\n    const [item] = payload;\r\n    const key = `${labelKey || item?.dataKey || item?.name || \"value\"}`;\r\n    const itemConfig = getPayloadConfigFromPayload(config, item, key);\r\n    const value =\r\n      !labelKey && typeof label === \"string\"\r\n        ? config[label as keyof typeof config]?.label || label\r\n        : itemConfig?.label;\r\n\r\n    if (labelFormatter) {\r\n      return (\r\n        <div className={cn(\"font-medium\", labelClassName)}>\r\n          {labelFormatter(value, payload)}\r\n        </div>\r\n      );\r\n    }\r\n\r\n    if (!value) {\r\n      return null;\r\n    }\r\n\r\n    return <div className={cn(\"font-medium\", labelClassName)}>{value}</div>;\r\n  }, [\r\n    label,\r\n    labelFormatter,\r\n    payload,\r\n    hideLabel,\r\n    labelClassName,\r\n    config,\r\n    labelKey,\r\n  ]);\r\n\r\n  if (!active || !payload?.length) {\r\n    return null;\r\n  }\r\n\r\n  const nestLabel = payload.length === 1 && indicator !== \"dot\";\r\n\r\n  return (\r\n    <div\r\n      className={cn(\r\n        \"border-border/50 corner-squircle bg-background gap-1.5 rounded-lg border px-2.5 py-1.5 text-xs shadow-xl grid min-w-[8rem] items-start\",\r\n        className,\r\n      )}\r\n    >\r\n      {!nestLabel ? tooltipLabel : null}\r\n      <div className=\"grid gap-1.5\">\r\n        {payload\n          .filter((item) => item.type !== \"none\")\n          .map((item, index) => {\n            const key = `${nameKey || item.name || item.dataKey || \"value\"}`;\n            const itemConfig = getPayloadConfigFromPayload(config, item, key);\n            const indicatorColor = color || item.payload.fill || item.color;\n            let customContent: React.ReactNode = null;\n            let formattedValue: React.ReactNode =\n              item.value != null && typeof item.value !== \"object\"\n                ? String(item.value)\n                : item.value;\n            let formattedLabel: React.ReactNode = itemConfig?.label || item.name;\n\n            if (formatter && item?.value !== undefined && item.name) {\n              const result = formatter(\n                item.value,\n                item.name,\n                item,\n                index,\n                item.payload,\n              );\n\n              if (Array.isArray(result)) {\n                formattedValue = result[0];\n                formattedLabel = result[1];\n              } else {\n                customContent = result;\n              }\n            }\n\n            return (\n              <div\n                key={item.dataKey}\n                className={cn(\n                  \"[&>svg]:text-muted-foreground flex w-full flex-wrap items-stretch gap-2 [&>svg]:h-2.5 [&>svg]:w-2.5\",\n                  indicator === \"dot\" && \"items-center\",\n                )}\n              >\n                {customContent ?? (\n                  <>\n                    {itemConfig?.icon ? (\n                      <itemConfig.icon />\n                    ) : (\n                      !hideIndicator && (\n                        <div\n                          className={cn(\n                            \"shrink-0 rounded-[2px] border-(--color-border) bg-(--color-bg)\",\n                            {\n                              \"h-2.5 w-2.5\": indicator === \"dot\",\n                              \"w-1\": indicator === \"line\",\n                              \"w-0 border-[1.5px] border-dashed bg-transparent\":\n                                indicator === \"dashed\",\n                              \"my-0.5\": nestLabel && indicator === \"dashed\",\n                            },\n                          )}\n                          style={\n                            {\n                              \"--color-bg\": indicatorColor,\n                              \"--color-border\": indicatorColor,\n                            } as React.CSSProperties\n                          }\n                        />\n                      )\n                    )}\n                    <div\n                      className={cn(\n                        \"flex flex-1 justify-between gap-3 leading-none\",\n                        nestLabel ? \"items-end\" : \"items-center\",\n                      )}\n                    >\n                      <div className=\"grid gap-1.5\">\n                        {nestLabel ? tooltipLabel : null}\n                        <span className=\"text-muted-foreground\">{formattedLabel}</span>\n                      </div>\n                      {formattedValue != null && (\n                        <span className=\"text-foreground font-mono font-medium tabular-nums\">\n                          {formattedValue}\n                        </span>\n                      )}\n                    </div>\n                  </>\n                )}\n              </div>\n            );\n          })}\n      </div>\n    </div>\r\n  );\r\n}\r\n\r\nconst ChartLegend = RechartsPrimitive.Legend;\r\n\r\nfunction ChartLegendContent({\r\n  className,\r\n  hideIcon = false,\r\n  payload,\r\n  verticalAlign = \"bottom\",\r\n  nameKey,\r\n}: React.ComponentProps<\"div\"> &\r\n  Pick<RechartsPrimitive.DefaultLegendContentProps, \"payload\" | \"verticalAlign\"> & {\r\n    hideIcon?: boolean;\r\n    nameKey?: string;\r\n  }) {\r\n  const { config } = useChart();\r\n\r\n  if (!payload?.length) {\r\n    return null;\r\n  }\r\n\r\n  return (\r\n    <div\r\n      className={cn(\r\n        \"flex items-center justify-center gap-4\",\r\n        verticalAlign === \"top\" ? \"pb-3\" : \"pt-3\",\r\n        className,\r\n      )}\r\n    >\r\n      {payload\r\n        .filter((item) => item.type !== \"none\")\r\n        .map((item) => {\r\n          const key = `${nameKey || item.dataKey || \"value\"}`;\r\n          const itemConfig = getPayloadConfigFromPayload(config, item, key);\r\n\r\n          return (\r\n            <div\r\n              key={item.value}\r\n              className={cn(\r\n                \"[&>svg]:text-muted-foreground flex items-center gap-1.5 [&>svg]:h-3 [&>svg]:w-3\",\r\n              )}\r\n            >\r\n              {itemConfig?.icon && !hideIcon ? (\r\n                <itemConfig.icon />\r\n              ) : (\r\n                <div\r\n                  className=\"h-2 w-2 shrink-0 rounded-[2px]\"\r\n                  style={{\r\n                    backgroundColor: item.color,\r\n                  }}\r\n                />\r\n              )}\r\n              {itemConfig?.label}\r\n            </div>\r\n          );\r\n        })}\r\n    </div>\r\n  );\r\n}\r\n\r\nfunction getPayloadConfigFromPayload(\r\n  config: ChartConfig,\r\n  payload: unknown,\r\n  key: string,\r\n) {\r\n  if (typeof payload !== \"object\" || payload === null) {\r\n    return undefined;\r\n  }\r\n\r\n  const payloadPayload =\r\n    \"payload\" in payload &&\r\n    typeof payload.payload === \"object\" &&\r\n    payload.payload !== null\r\n      ? payload.payload\r\n      : undefined;\r\n\r\n  let configLabelKey: string = key;\r\n\r\n  if (\r\n    key in payload &&\r\n    typeof payload[key as keyof typeof payload] === \"string\"\r\n  ) {\r\n    configLabelKey = payload[key as keyof typeof payload] as string;\r\n  } else if (\r\n    payloadPayload &&\r\n    key in payloadPayload &&\r\n    typeof payloadPayload[key as keyof typeof payloadPayload] === \"string\"\r\n  ) {\r\n    configLabelKey = payloadPayload[\r\n      key as keyof typeof payloadPayload\r\n    ] as string;\r\n  }\r\n\r\n  return configLabelKey in config\r\n    ? config[configLabelKey]\r\n    : config[key as keyof typeof config];\r\n}\r\n\r\nexport {\r\n  ChartContainer,\r\n  ChartTooltip,\r\n  ChartTooltipContent,\r\n  ChartLegend,\r\n  ChartLegendContent,\r\n  ChartStyle,\r\n};\r\n"
  },
  {
    "path": "studio/frontend/src/components/ui/checkbox.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Checkbox as CheckboxPrimitive } from \"radix-ui\";\r\nimport type * as React from \"react\";\r\n\r\nimport { cn } from \"@/lib/utils\";\r\nimport { Tick02Icon } from \"@hugeicons/core-free-icons\";\r\nimport { HugeiconsIcon } from \"@hugeicons/react\";\r\n\r\nfunction Checkbox({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<typeof CheckboxPrimitive.Root>) {\r\n  return (\r\n    <CheckboxPrimitive.Root\r\n      data-slot=\"checkbox\"\r\n      className={cn(\r\n        \"border-input dark:bg-input/30 data-checked:bg-primary data-checked:text-primary-foreground dark:data-checked:bg-primary data-checked:border-primary aria-invalid:aria-checked:border-primary aria-invalid:border-destructive dark:aria-invalid:border-destructive/50 focus-visible:border-ring focus-visible:ring-ring/50 aria-invalid:ring-destructive/20 dark:aria-invalid:ring-destructive/40 flex size-4 items-center justify-center rounded-[6px] border transition-shadow group-has-disabled/field:opacity-50 focus-visible:ring-[3px] aria-invalid:ring-[3px] peer relative shrink-0 outline-none after:absolute after:-inset-x-3 after:-inset-y-2 disabled:cursor-not-allowed disabled:opacity-50\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    >\r\n      <CheckboxPrimitive.Indicator\r\n        data-slot=\"checkbox-indicator\"\r\n        className=\"[&>svg]:size-3.5 grid place-content-center text-current transition-none\"\r\n      >\r\n        <HugeiconsIcon icon={Tick02Icon} strokeWidth={2} />\r\n      </CheckboxPrimitive.Indicator>\r\n    </CheckboxPrimitive.Root>\r\n  );\r\n}\r\n\r\nexport { Checkbox };\r\n"
  },
  {
    "path": "studio/frontend/src/components/ui/collapsible.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { cn } from \"@/lib/utils\";\nimport { Collapsible as CollapsiblePrimitive } from \"radix-ui\";\n\nfunction Collapsible({\n  ...props\n}: React.ComponentProps<typeof CollapsiblePrimitive.Root>) {\n  return <CollapsiblePrimitive.Root data-slot=\"collapsible\" {...props} />;\n}\n\nfunction CollapsibleTrigger({\n  ...props\n}: React.ComponentProps<typeof CollapsiblePrimitive.CollapsibleTrigger>) {\n  return (\n    <CollapsiblePrimitive.CollapsibleTrigger\n      data-slot=\"collapsible-trigger\"\n      {...props}\n    />\n  );\n}\n\nfunction CollapsibleContent({\n  className,\n  ...props\n}: React.ComponentProps<typeof CollapsiblePrimitive.CollapsibleContent>) {\n  return (\n    <CollapsiblePrimitive.CollapsibleContent\n      data-slot=\"collapsible-content\"\n      className={cn(\n        \"overflow-hidden data-[state=open]:animate-collapsible-down data-[state=closed]:animate-collapsible-up [--duration:150ms]\",\n        className,\n      )}\n      {...props}\n    />\n  );\n}\n\nexport { Collapsible, CollapsibleTrigger, CollapsibleContent };\n"
  },
  {
    "path": "studio/frontend/src/components/ui/combobox.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"use client\";\r\n\r\n/* eslint-disable react-refresh/only-export-components */\r\n\r\nimport { Combobox as ComboboxPrimitive } from \"@base-ui/react\";\r\nimport * as React from \"react\";\r\nimport { createContext, useContext, useState } from \"react\";\r\n\r\nimport { Button } from \"@/components/ui/button\";\r\nimport { useDialogPortalContainer } from \"@/components/ui/dialog\";\r\nimport {\r\n  InputGroup,\r\n  InputGroupAddon,\r\n  InputGroupButton,\r\n  InputGroupInput,\r\n} from \"@/components/ui/input-group\";\r\nimport { cn } from \"@/lib/utils\";\r\nimport {\r\n  ArrowDown01Icon,\r\n  Cancel01Icon,\r\n  Tick02Icon,\r\n} from \"@hugeicons/core-free-icons\";\r\nimport { HugeiconsIcon } from \"@hugeicons/react\";\r\n\r\nconst ComboboxOpenContext = createContext(false);\r\ntype ComboboxRootProps = ComboboxPrimitive.Root.Props<string, false>;\r\n\r\nfunction Combobox({\r\n  onOpenChange,\r\n  children,\r\n  ...props\r\n}: ComboboxRootProps): React.ReactElement {\r\n  const [isOpen, setIsOpen] = useState(false);\r\n  return (\r\n    <ComboboxOpenContext.Provider value={isOpen}>\r\n      <ComboboxPrimitive.Root\r\n        onOpenChange={(open, eventDetails) => {\r\n          setIsOpen(open);\r\n          onOpenChange?.(open, eventDetails);\r\n        }}\r\n        {...props}\r\n      >\r\n        {children}\r\n      </ComboboxPrimitive.Root>\r\n    </ComboboxOpenContext.Provider>\r\n  );\r\n}\r\n\r\nfunction ComboboxValue({\r\n  ...props\r\n}: ComboboxPrimitive.Value.Props): React.ReactElement {\r\n  return <ComboboxPrimitive.Value data-slot=\"combobox-value\" {...props} />;\r\n}\r\n\r\nfunction ComboboxTrigger({\r\n  className,\r\n  children,\r\n  ...props\r\n}: ComboboxPrimitive.Trigger.Props): React.ReactElement {\r\n  return (\r\n    <ComboboxPrimitive.Trigger\r\n      data-slot=\"combobox-trigger\"\r\n      className={cn(\"[&_svg:not([class*='size-'])]:size-4\", className)}\r\n      {...props}\r\n    >\r\n      {children}\r\n      <HugeiconsIcon\r\n        icon={ArrowDown01Icon}\r\n        strokeWidth={2}\r\n        className=\"text-muted-foreground size-4 pointer-events-none\"\r\n      />\r\n    </ComboboxPrimitive.Trigger>\r\n  );\r\n}\r\n\r\nfunction ComboboxClear({\r\n  className,\r\n  ...props\r\n}: ComboboxPrimitive.Clear.Props): React.ReactElement {\r\n  return (\r\n    <ComboboxPrimitive.Clear\r\n      data-slot=\"combobox-clear\"\r\n      render={<InputGroupButton variant=\"ghost\" size=\"icon-xs\" />}\r\n      className={cn(className)}\r\n      {...props}\r\n    >\r\n      <HugeiconsIcon\r\n        icon={Cancel01Icon}\r\n        strokeWidth={2}\r\n        className=\"pointer-events-none\"\r\n      />\r\n    </ComboboxPrimitive.Clear>\r\n  );\r\n}\r\n\r\nfunction ComboboxInput({\r\n  className,\r\n  children,\r\n  disabled = false,\r\n  showTrigger = true,\r\n  showClear = false,\r\n  ...props\r\n}: ComboboxPrimitive.Input.Props & {\r\n  showTrigger?: boolean;\r\n  showClear?: boolean;\r\n}): React.ReactElement {\r\n  const isOpen = useContext(ComboboxOpenContext);\r\n\r\n  return (\r\n    <InputGroup\r\n      className={cn(\"w-auto\", className)}\r\n      style={{\r\n        borderRadius: isOpen ? \"12px\" : undefined,\r\n        transition: isOpen\r\n          ? \"border-radius 0ms\"\r\n          : \"border-radius 150ms cubic-bezier(0.645, 0.045, 0.355, 1)\",\r\n      }}\r\n    >\r\n      <ComboboxPrimitive.Input\r\n        render={<InputGroupInput disabled={disabled} />}\r\n        {...props}\r\n      />\r\n      <InputGroupAddon align=\"inline-end\">\r\n        {showTrigger && (\r\n          <InputGroupButton\r\n            size=\"icon-xs\"\r\n            variant=\"ghost\"\r\n            asChild\r\n            data-slot=\"input-group-button\"\r\n            className=\"group-has-data-[slot=combobox-clear]/input-group:hidden data-pressed:bg-transparent\"\r\n            disabled={disabled}\r\n          >\r\n            <ComboboxTrigger />\r\n          </InputGroupButton>\r\n        )}\r\n        {showClear && <ComboboxClear disabled={disabled} />}\r\n      </InputGroupAddon>\r\n      {children}\r\n    </InputGroup>\r\n  );\r\n}\r\n\r\nfunction ComboboxContent({\n  className,\n  side = \"bottom\",\n  sideOffset = 6,\n  align = \"start\",\r\n  alignOffset = 0,\r\n  anchor,\r\n  container,\r\n  ...props\r\n}: ComboboxPrimitive.Popup.Props &\r\n  Pick<\r\n    ComboboxPrimitive.Positioner.Props,\r\n    \"side\" | \"align\" | \"sideOffset\" | \"alignOffset\" | \"anchor\"\r\n  > & {\r\n    container?: HTMLElement | null;\r\n  }): React.ReactElement {\r\n  const dialogContainer = useDialogPortalContainer();\r\n  return (\r\n    <ComboboxPrimitive.Portal container={container ?? dialogContainer ?? undefined}>\r\n      <ComboboxPrimitive.Positioner\n        side={side}\n        sideOffset={sideOffset}\n        align={align}\n        alignOffset={alignOffset}\n        anchor={anchor}\n        className=\"isolate z-[120] pointer-events-auto\"\n      >\n        <ComboboxPrimitive.Popup\n          data-slot=\"combobox-content\"\n          data-chips={!!anchor}\n          className={cn(\n            \"bg-popover text-popover-foreground data-open:animate-in data-closed:animate-out data-closed:fade-out-0 data-open:fade-in-0 data-closed:zoom-out-95 data-open:zoom-in-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2 shadow-border ring-1 ring-border *:data-[slot=input-group]:bg-input/30 max-h-72 min-w-36 overflow-hidden rounded-xl corner-squircle duration-100 *:data-[slot=input-group]:m-1 *:data-[slot=input-group]:mb-0 *:data-[slot=input-group]:h-9 *:data-[slot=input-group]:border-none *:data-[slot=input-group]:shadow-none group/combobox-content relative pointer-events-auto max-h-(--available-height) w-(--anchor-width) max-w-(--available-width) min-w-[calc(var(--anchor-width)+--spacing(7))] origin-(--transform-origin) data-[chips=true]:min-w-(--anchor-width)\",\n            className,\n          )}\n          {...props}\n        />\n      </ComboboxPrimitive.Positioner>\r\n    </ComboboxPrimitive.Portal>\r\n  );\r\n}\r\n\r\nfunction ComboboxList({\r\n  className,\r\n  ...props\r\n}: ComboboxPrimitive.List.Props): React.ReactElement {\r\n  return (\r\n    <ComboboxPrimitive.List\r\n      data-slot=\"combobox-list\"\r\n      className={cn(\r\n        \"no-scrollbar max-h-[min(calc(--spacing(72)---spacing(9)),calc(var(--available-height)---spacing(9)))] scroll-py-1 overflow-y-auto p-1 data-empty:p-0 overflow-y-auto overscroll-contain\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction ComboboxItem({\r\n  className,\r\n  children,\r\n  ...props\r\n}: ComboboxPrimitive.Item.Props): React.ReactElement {\r\n  return (\r\n    <ComboboxPrimitive.Item\r\n      data-slot=\"combobox-item\"\r\n      className={cn(\r\n        \"data-highlighted:bg-accent data-highlighted:text-accent-foreground not-data-[variant=destructive]:data-highlighted:**:text-accent-foreground gap-2 rounded-xl corner-squircle py-2 pr-2 pl-3 text-sm [&[aria-selected=true]]:pr-7 [&_svg:not([class*='size-'])]:size-4 relative flex w-full cursor-pointer items-center outline-hidden select-none data-[disabled]:pointer-events-none data-[disabled]:opacity-50 [&_svg]:pointer-events-none [&_svg]:shrink-0\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    >\r\n      {children}\r\n      <ComboboxPrimitive.ItemIndicator\r\n        render={\r\n          <span className=\"pointer-events-none absolute right-2 flex size-4 items-center justify-center\" />\r\n        }\r\n      >\r\n        <HugeiconsIcon\r\n          icon={Tick02Icon}\r\n          strokeWidth={2}\r\n          className=\"pointer-events-none\"\r\n        />\r\n      </ComboboxPrimitive.ItemIndicator>\r\n    </ComboboxPrimitive.Item>\r\n  );\r\n}\r\n\r\nfunction ComboboxGroup({\r\n  className,\r\n  ...props\r\n}: ComboboxPrimitive.Group.Props): React.ReactElement {\r\n  return (\r\n    <ComboboxPrimitive.Group\r\n      data-slot=\"combobox-group\"\r\n      className={cn(className)}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction ComboboxLabel({\r\n  className,\r\n  ...props\r\n}: ComboboxPrimitive.GroupLabel.Props): React.ReactElement {\r\n  return (\r\n    <ComboboxPrimitive.GroupLabel\r\n      data-slot=\"combobox-label\"\r\n      className={cn(\"text-muted-foreground px-3.5 py-2.5 text-xs\", className)}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction ComboboxCollection({\r\n  ...props\r\n}: ComboboxPrimitive.Collection.Props): React.ReactElement {\r\n  return (\r\n    <ComboboxPrimitive.Collection data-slot=\"combobox-collection\" {...props} />\r\n  );\r\n}\r\n\r\nfunction ComboboxEmpty({\r\n  className,\r\n  ...props\r\n}: ComboboxPrimitive.Empty.Props): React.ReactElement {\r\n  return (\r\n    <ComboboxPrimitive.Empty\r\n      data-slot=\"combobox-empty\"\r\n      className={cn(\r\n        \"text-muted-foreground hidden w-full justify-center py-2 text-center text-sm group-data-empty/combobox-content:flex\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction ComboboxSeparator({\r\n  className,\r\n  ...props\r\n}: ComboboxPrimitive.Separator.Props): React.ReactElement {\r\n  return (\r\n    <ComboboxPrimitive.Separator\r\n      data-slot=\"combobox-separator\"\r\n      className={cn(\"bg-border/50 -mx-1 my-1 h-px\", className)}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction ComboboxChips({\r\n  className,\r\n  ...props\r\n}: React.ComponentPropsWithRef<typeof ComboboxPrimitive.Chips> &\r\n  ComboboxPrimitive.Chips.Props): React.ReactElement {\r\n  return (\r\n    <ComboboxPrimitive.Chips\r\n      data-slot=\"combobox-chips\"\r\n      className={cn(\r\n        \"bg-input/30 border-input focus-within:border-ring focus-within:ring-ring/50 has-aria-invalid:ring-destructive/20 dark:has-aria-invalid:ring-destructive/40 has-aria-invalid:border-destructive dark:has-aria-invalid:border-destructive/50 flex min-h-9 flex-wrap items-center gap-1.5 rounded-4xl border bg-clip-padding px-2.5 py-1.5 text-sm transition-colors focus-within:ring-[3px] has-aria-invalid:ring-[3px] has-data-[slot=combobox-chip]:px-1.5\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction ComboboxChip({\r\n  className,\r\n  children,\r\n  showRemove = true,\r\n  ...props\r\n}: ComboboxPrimitive.Chip.Props & {\r\n  showRemove?: boolean;\r\n}): React.ReactElement {\r\n  return (\r\n    <ComboboxPrimitive.Chip\r\n      data-slot=\"combobox-chip\"\r\n      className={cn(\r\n        \"bg-muted-foreground/10 text-foreground flex h-[calc(--spacing(5.5))] w-fit items-center justify-center gap-1 rounded-4xl px-2 text-xs font-medium whitespace-nowrap has-data-[slot=combobox-chip-remove]:pr-0 has-disabled:pointer-events-none has-disabled:cursor-not-allowed has-disabled:opacity-50\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    >\r\n      {children}\r\n      {showRemove && (\r\n        <ComboboxPrimitive.ChipRemove\r\n          render={<Button variant=\"ghost\" size=\"icon-xs\" />}\r\n          className=\"-ml-1 opacity-50 hover:opacity-100\"\r\n          data-slot=\"combobox-chip-remove\"\r\n        >\r\n          <HugeiconsIcon\r\n            icon={Cancel01Icon}\r\n            strokeWidth={2}\r\n            className=\"pointer-events-none\"\r\n          />\r\n        </ComboboxPrimitive.ChipRemove>\r\n      )}\r\n    </ComboboxPrimitive.Chip>\r\n  );\r\n}\r\n\r\nfunction ComboboxChipsInput({\r\n  className,\r\n  ...props\r\n}: ComboboxPrimitive.Input.Props): React.ReactElement {\r\n  return (\r\n    <ComboboxPrimitive.Input\r\n      data-slot=\"combobox-chip-input\"\r\n      className={cn(\"min-w-16 flex-1 outline-none\", className)}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction useComboboxAnchor(): React.MutableRefObject<HTMLDivElement | null> {\r\n  return React.useRef<HTMLDivElement | null>(null);\r\n}\r\n\r\nexport {\r\n  Combobox,\r\n  ComboboxInput,\r\n  ComboboxContent,\r\n  ComboboxList,\r\n  ComboboxItem,\r\n  ComboboxGroup,\r\n  ComboboxLabel,\r\n  ComboboxCollection,\r\n  ComboboxEmpty,\r\n  ComboboxSeparator,\r\n  ComboboxChips,\r\n  ComboboxChip,\r\n  ComboboxChipsInput,\r\n  ComboboxTrigger,\r\n  ComboboxValue,\r\n  useComboboxAnchor,\r\n};\r\n"
  },
  {
    "path": "studio/frontend/src/components/ui/command.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"use client\";\r\n\r\nimport { Command as CommandPrimitive } from \"cmdk\";\r\nimport type * as React from \"react\";\r\n\r\nimport {\r\n  Dialog,\r\n  DialogContent,\r\n  DialogDescription,\r\n  DialogHeader,\r\n  DialogTitle,\r\n} from \"@/components/ui/dialog\";\r\nimport { InputGroup, InputGroupAddon } from \"@/components/ui/input-group\";\r\nimport { cn } from \"@/lib/utils\";\r\nimport { SearchIcon, Tick02Icon } from \"@hugeicons/core-free-icons\";\r\nimport { HugeiconsIcon } from \"@hugeicons/react\";\r\n\r\nfunction Command({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<typeof CommandPrimitive>) {\r\n  return (\r\n    <CommandPrimitive\r\n      data-slot=\"command\"\r\n      className={cn(\r\n        \"bg-popover text-popover-foreground rounded-4xl p-1 flex size-full flex-col overflow-hidden\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction CommandDialog({\r\n  title = \"Command Palette\",\r\n  description = \"Search for a command to run...\",\r\n  children,\r\n  className,\r\n  showCloseButton = false,\r\n  ...props\r\n}: React.ComponentProps<typeof Dialog> & {\r\n  title?: string;\r\n  description?: string;\r\n  className?: string;\r\n  showCloseButton?: boolean;\r\n}) {\r\n  return (\r\n    <Dialog {...props}>\r\n      <DialogHeader className=\"sr-only\">\r\n        <DialogTitle>{title}</DialogTitle>\r\n        <DialogDescription>{description}</DialogDescription>\r\n      </DialogHeader>\r\n      <DialogContent\r\n        className={cn(\r\n          \"rounded-4xl! p-0 top-1/3 translate-y-0 overflow-hidden p-0\",\r\n          className,\r\n        )}\r\n        showCloseButton={showCloseButton}\r\n      >\r\n        {children}\r\n      </DialogContent>\r\n    </Dialog>\r\n  );\r\n}\r\n\r\nfunction CommandInput({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<typeof CommandPrimitive.Input>) {\r\n  return (\r\n    <div data-slot=\"command-input-wrapper\" className=\"p-1 pb-0\">\r\n      <InputGroup className=\"bg-input/30 h-9\">\r\n        <CommandPrimitive.Input\r\n          data-slot=\"command-input\"\r\n          className={cn(\r\n            \"w-full text-sm outline-hidden disabled:cursor-not-allowed disabled:opacity-50\",\r\n            className,\r\n          )}\r\n          {...props}\r\n        />\r\n        <InputGroupAddon>\r\n          <HugeiconsIcon\r\n            icon={SearchIcon}\r\n            strokeWidth={2}\r\n            className=\"size-4 shrink-0 opacity-50\"\r\n          />\r\n        </InputGroupAddon>\r\n      </InputGroup>\r\n    </div>\r\n  );\r\n}\r\n\r\nfunction CommandList({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<typeof CommandPrimitive.List>) {\r\n  return (\r\n    <CommandPrimitive.List\r\n      data-slot=\"command-list\"\r\n      className={cn(\r\n        \"no-scrollbar max-h-72 scroll-py-1 outline-none overflow-x-hidden overflow-y-auto\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction CommandEmpty({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<typeof CommandPrimitive.Empty>) {\r\n  return (\r\n    <CommandPrimitive.Empty\r\n      data-slot=\"command-empty\"\r\n      className={cn(\"py-6 text-center text-sm\", className)}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction CommandGroup({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<typeof CommandPrimitive.Group>) {\r\n  return (\r\n    <CommandPrimitive.Group\r\n      data-slot=\"command-group\"\r\n      className={cn(\r\n        \"text-foreground [&_[cmdk-group-heading]]:text-muted-foreground overflow-hidden p-1 [&_[cmdk-group-heading]]:px-3 [&_[cmdk-group-heading]]:py-2 [&_[cmdk-group-heading]]:text-xs [&_[cmdk-group-heading]]:font-medium\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction CommandSeparator({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<typeof CommandPrimitive.Separator>) {\r\n  return (\r\n    <CommandPrimitive.Separator\r\n      data-slot=\"command-separator\"\r\n      className={cn(\"bg-border/50 my-1 h-px\", className)}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction CommandItem({\r\n  className,\r\n  children,\r\n  ...props\r\n}: React.ComponentProps<typeof CommandPrimitive.Item>) {\r\n  return (\r\n    <CommandPrimitive.Item\r\n      data-slot=\"command-item\"\r\n      className={cn(\r\n        \"data-selected:bg-muted data-selected:text-foreground data-selected:*:[svg]:text-foreground relative flex cursor-default items-center gap-2 rounded-lg px-3 py-2 text-sm outline-hidden select-none [&_svg:not([class*='size-'])]:size-4 [[data-slot=dialog-content]_&]:rounded-2xl group/command-item data-[disabled=true]:pointer-events-none data-[disabled=true]:opacity-50 [&_svg]:pointer-events-none [&_svg]:shrink-0\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    >\r\n      {children}\r\n      <HugeiconsIcon\r\n        icon={Tick02Icon}\r\n        strokeWidth={2}\r\n        className=\"ml-auto opacity-0 group-has-[[data-slot=command-shortcut]]/command-item:hidden group-data-[checked=true]/command-item:opacity-100\"\r\n      />\r\n    </CommandPrimitive.Item>\r\n  );\r\n}\r\n\r\nfunction CommandShortcut({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<\"span\">) {\r\n  return (\r\n    <span\r\n      data-slot=\"command-shortcut\"\r\n      className={cn(\r\n        \"text-muted-foreground group-data-selected/command-item:text-foreground ml-auto text-xs tracking-widest\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nexport {\r\n  Command,\r\n  CommandDialog,\r\n  CommandInput,\r\n  CommandList,\r\n  CommandEmpty,\r\n  CommandGroup,\r\n  CommandItem,\r\n  CommandShortcut,\r\n  CommandSeparator,\r\n};\r\n"
  },
  {
    "path": "studio/frontend/src/components/ui/confetti.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type {\r\n  GlobalOptions as ConfettiGlobalOptions,\r\n  CreateTypes as ConfettiInstance,\r\n  Options as ConfettiOptions,\r\n} from \"canvas-confetti\";\r\nimport confetti from \"canvas-confetti\";\r\nimport type { ReactNode } from \"react\";\r\nimport type React from \"react\";\r\nimport {\r\n  createContext,\r\n  forwardRef,\r\n  useCallback,\r\n  useEffect,\r\n  useImperativeHandle,\r\n  useMemo,\r\n  useRef,\r\n} from \"react\";\r\n\r\nimport { Button } from \"@/components/ui/button\";\r\n\r\ntype Api = {\r\n  fire: (options?: ConfettiOptions) => void;\r\n};\r\n\r\ntype Props = React.ComponentPropsWithRef<\"canvas\"> & {\r\n  options?: ConfettiOptions;\r\n  globalOptions?: ConfettiGlobalOptions;\r\n  manualstart?: boolean;\r\n  children?: ReactNode;\r\n};\r\n\r\nexport type ConfettiRef = Api | null;\r\n\r\nconst ConfettiContext = createContext<Api>({} as Api);\r\n\r\n// Define component first\r\nconst ConfettiComponent = forwardRef<ConfettiRef, Props>((props, ref) => {\r\n  const {\r\n    options,\r\n    globalOptions = { resize: true, useWorker: true },\r\n    manualstart = false,\r\n    children,\r\n    ...rest\r\n  } = props;\r\n  const instanceRef = useRef<ConfettiInstance | null>(null);\r\n\r\n  const canvasRef = useCallback(\r\n    (node: HTMLCanvasElement) => {\r\n      if (node !== null) {\r\n        if (instanceRef.current) return;\r\n        instanceRef.current = confetti.create(node, {\r\n          ...globalOptions,\r\n          resize: true,\r\n        });\r\n      } else {\r\n        if (instanceRef.current) {\r\n          instanceRef.current.reset();\r\n          instanceRef.current = null;\r\n        }\r\n      }\r\n    },\r\n    [globalOptions],\r\n  );\r\n\r\n  const fire = useCallback(\r\n    async (opts = {}) => {\r\n      try {\r\n        await instanceRef.current?.({ ...options, ...opts });\r\n      } catch (error) {\r\n        console.error(\"Confetti error:\", error);\r\n      }\r\n    },\r\n    [options],\r\n  );\r\n\r\n  const api = useMemo(\r\n    () => ({\r\n      fire,\r\n    }),\r\n    [fire],\r\n  );\r\n\r\n  useImperativeHandle(ref, () => api, [api]);\r\n\r\n  useEffect(() => {\r\n    if (!manualstart) {\r\n      (async () => {\r\n        try {\r\n          await fire();\r\n        } catch (error) {\r\n          console.error(\"Confetti effect error:\", error);\r\n        }\r\n      })();\r\n    }\r\n  }, [manualstart, fire]);\r\n\r\n  return (\r\n    <ConfettiContext.Provider value={api}>\r\n      <canvas ref={canvasRef} {...rest} />\r\n      {children}\r\n    </ConfettiContext.Provider>\r\n  );\r\n});\r\n\r\n// Set display name immediately\r\nConfettiComponent.displayName = \"Confetti\";\r\n\r\n// Export as Confetti\r\nexport const Confetti = ConfettiComponent;\r\n\r\ninterface ConfettiButtonProps extends React.ComponentProps<\"button\"> {\r\n  options?: ConfettiOptions &\r\n    ConfettiGlobalOptions & { canvas?: HTMLCanvasElement };\r\n}\r\n\r\nconst ConfettiButtonComponent = ({\r\n  options,\r\n  children,\r\n  ...props\r\n}: ConfettiButtonProps) => {\r\n  const handleClick = async (event: React.MouseEvent<HTMLButtonElement>) => {\r\n    try {\r\n      const rect = event.currentTarget.getBoundingClientRect();\r\n      const x = rect.left + rect.width / 2;\r\n      const y = rect.top + rect.height / 2;\r\n      await confetti({\r\n        ...options,\r\n        origin: {\r\n          x: x / window.innerWidth,\r\n          y: y / window.innerHeight,\r\n        },\r\n      });\r\n    } catch (error) {\r\n      console.error(\"Confetti button error:\", error);\r\n    }\r\n  };\r\n\r\n  return (\r\n    <Button onClick={handleClick} {...props}>\r\n      {children}\r\n    </Button>\r\n  );\r\n};\r\n\r\nConfettiButtonComponent.displayName = \"ConfettiButton\";\r\n\r\nexport const ConfettiButton = ConfettiButtonComponent;\r\n"
  },
  {
    "path": "studio/frontend/src/components/ui/context-menu.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"use client\";\r\n\r\nimport { ContextMenu as ContextMenuPrimitive } from \"radix-ui\";\r\nimport type * as React from \"react\";\r\n\r\nimport { cn } from \"@/lib/utils\";\r\nimport { ArrowRight01Icon, Tick02Icon } from \"@hugeicons/core-free-icons\";\r\nimport { HugeiconsIcon } from \"@hugeicons/react\";\r\n\r\nfunction ContextMenu({\r\n  ...props\r\n}: React.ComponentProps<typeof ContextMenuPrimitive.Root>) {\r\n  return <ContextMenuPrimitive.Root data-slot=\"context-menu\" {...props} />;\r\n}\r\n\r\nfunction ContextMenuTrigger({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<typeof ContextMenuPrimitive.Trigger>) {\r\n  return (\r\n    <ContextMenuPrimitive.Trigger\r\n      data-slot=\"context-menu-trigger\"\r\n      className={cn(\"select-none\", className)}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction ContextMenuGroup({\r\n  ...props\r\n}: React.ComponentProps<typeof ContextMenuPrimitive.Group>) {\r\n  return (\r\n    <ContextMenuPrimitive.Group data-slot=\"context-menu-group\" {...props} />\r\n  );\r\n}\r\n\r\nfunction ContextMenuPortal({\r\n  ...props\r\n}: React.ComponentProps<typeof ContextMenuPrimitive.Portal>) {\r\n  return (\r\n    <ContextMenuPrimitive.Portal data-slot=\"context-menu-portal\" {...props} />\r\n  );\r\n}\r\n\r\nfunction ContextMenuSub({\r\n  ...props\r\n}: React.ComponentProps<typeof ContextMenuPrimitive.Sub>) {\r\n  return <ContextMenuPrimitive.Sub data-slot=\"context-menu-sub\" {...props} />;\r\n}\r\n\r\nfunction ContextMenuRadioGroup({\r\n  ...props\r\n}: React.ComponentProps<typeof ContextMenuPrimitive.RadioGroup>) {\r\n  return (\r\n    <ContextMenuPrimitive.RadioGroup\r\n      data-slot=\"context-menu-radio-group\"\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction ContextMenuContent({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<typeof ContextMenuPrimitive.Content> & {\r\n  side?: \"top\" | \"right\" | \"bottom\" | \"left\";\r\n}) {\r\n  return (\r\n    <ContextMenuPrimitive.Portal>\r\n      <ContextMenuPrimitive.Content\r\n        data-slot=\"context-menu-content\"\r\n        className={cn(\r\n          \"data-open:animate-in data-closed:animate-out data-closed:fade-out-0 data-open:fade-in-0 data-closed:zoom-out-95 data-open:zoom-in-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2 ring-foreground/5 bg-popover text-popover-foreground min-w-48 rounded-2xl p-1 shadow-2xl ring-1 duration-100 z-50 max-h-(--radix-context-menu-content-available-height) origin-(--radix-context-menu-content-transform-origin) overflow-x-hidden overflow-y-auto\",\r\n          className,\r\n        )}\r\n        {...props}\r\n      />\r\n    </ContextMenuPrimitive.Portal>\r\n  );\r\n}\r\n\r\nfunction ContextMenuItem({\r\n  className,\r\n  inset,\r\n  variant = \"default\",\r\n  ...props\r\n}: React.ComponentProps<typeof ContextMenuPrimitive.Item> & {\r\n  inset?: boolean;\r\n  variant?: \"default\" | \"destructive\";\r\n}) {\r\n  return (\r\n    <ContextMenuPrimitive.Item\r\n      data-slot=\"context-menu-item\"\r\n      data-inset={inset}\r\n      data-variant={variant}\r\n      className={cn(\r\n        \"focus:bg-accent focus:text-accent-foreground data-[variant=destructive]:text-destructive data-[variant=destructive]:focus:bg-destructive/10 dark:data-[variant=destructive]:focus:bg-destructive/20 data-[variant=destructive]:focus:text-destructive data-[variant=destructive]:*:[svg]:text-destructive focus:*:[svg]:text-accent-foreground gap-2.5 rounded-xl px-3 py-2 text-sm [&_svg:not([class*='size-'])]:size-4 group/context-menu-item relative flex cursor-default items-center outline-hidden select-none data-[disabled]:pointer-events-none data-[disabled]:opacity-50 data-[inset]:pl-8 [&_svg]:pointer-events-none [&_svg]:shrink-0\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction ContextMenuSubTrigger({\r\n  className,\r\n  inset,\r\n  children,\r\n  ...props\r\n}: React.ComponentProps<typeof ContextMenuPrimitive.SubTrigger> & {\r\n  inset?: boolean;\r\n}) {\r\n  return (\r\n    <ContextMenuPrimitive.SubTrigger\r\n      data-slot=\"context-menu-sub-trigger\"\r\n      data-inset={inset}\r\n      className={cn(\r\n        \"focus:bg-accent focus:text-accent-foreground data-open:bg-accent data-open:text-accent-foreground rounded-xl px-3 py-2 text-sm [&_svg:not([class*='size-'])]:size-4 flex cursor-default items-center outline-hidden select-none data-[inset]:pl-8 [&_svg]:pointer-events-none [&_svg]:shrink-0\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    >\r\n      {children}\r\n      <HugeiconsIcon\r\n        icon={ArrowRight01Icon}\r\n        strokeWidth={2}\r\n        className=\"ml-auto\"\r\n      />\r\n    </ContextMenuPrimitive.SubTrigger>\r\n  );\r\n}\r\n\r\nfunction ContextMenuSubContent({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<typeof ContextMenuPrimitive.SubContent>) {\r\n  return (\r\n    <ContextMenuPrimitive.SubContent\r\n      data-slot=\"context-menu-sub-content\"\r\n      className={cn(\r\n        \"data-open:animate-in data-closed:animate-out data-closed:fade-out-0 data-open:fade-in-0 data-closed:zoom-out-95 data-open:zoom-in-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2 bg-popover text-popover-foreground min-w-32 rounded-md border p-1 shadow-lg duration-100 z-50 origin-(--radix-context-menu-content-transform-origin) overflow-hidden\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction ContextMenuCheckboxItem({\r\n  className,\r\n  children,\r\n  checked,\r\n  ...props\r\n}: React.ComponentProps<typeof ContextMenuPrimitive.CheckboxItem>) {\r\n  return (\r\n    <ContextMenuPrimitive.CheckboxItem\r\n      data-slot=\"context-menu-checkbox-item\"\r\n      className={cn(\r\n        \"focus:bg-accent focus:text-accent-foreground gap-2 rounded-xl py-2 pr-8 pl-3 text-sm [&_svg:not([class*='size-'])]:size-4 relative flex cursor-default items-center outline-hidden select-none data-[disabled]:pointer-events-none data-[disabled]:opacity-50 [&_svg]:pointer-events-none [&_svg]:shrink-0\",\r\n        className,\r\n      )}\r\n      checked={checked}\r\n      {...props}\r\n    >\r\n      <span className=\"absolute right-2 pointer-events-none\">\r\n        <ContextMenuPrimitive.ItemIndicator>\r\n          <HugeiconsIcon icon={Tick02Icon} strokeWidth={2} />\r\n        </ContextMenuPrimitive.ItemIndicator>\r\n      </span>\r\n      {children}\r\n    </ContextMenuPrimitive.CheckboxItem>\r\n  );\r\n}\r\n\r\nfunction ContextMenuRadioItem({\r\n  className,\r\n  children,\r\n  ...props\r\n}: React.ComponentProps<typeof ContextMenuPrimitive.RadioItem>) {\r\n  return (\r\n    <ContextMenuPrimitive.RadioItem\r\n      data-slot=\"context-menu-radio-item\"\r\n      className={cn(\r\n        \"focus:bg-accent focus:text-accent-foreground gap-2 rounded-sm py-1.5 pr-8 pl-2 text-sm [&_svg:not([class*='size-'])]:size-4 relative flex cursor-default items-center outline-hidden select-none data-[disabled]:pointer-events-none data-[disabled]:opacity-50 [&_svg]:pointer-events-none [&_svg]:shrink-0\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    >\r\n      <span className=\"absolute right-2 pointer-events-none\">\r\n        <ContextMenuPrimitive.ItemIndicator>\r\n          <HugeiconsIcon icon={Tick02Icon} strokeWidth={2} />\r\n        </ContextMenuPrimitive.ItemIndicator>\r\n      </span>\r\n      {children}\r\n    </ContextMenuPrimitive.RadioItem>\r\n  );\r\n}\r\n\r\nfunction ContextMenuLabel({\r\n  className,\r\n  inset,\r\n  ...props\r\n}: React.ComponentProps<typeof ContextMenuPrimitive.Label> & {\r\n  inset?: boolean;\r\n}) {\r\n  return (\r\n    <ContextMenuPrimitive.Label\r\n      data-slot=\"context-menu-label\"\r\n      data-inset={inset}\r\n      className={cn(\r\n        \"text-muted-foreground px-3 py-2.5 text-xs data-[inset]:pl-8\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction ContextMenuSeparator({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<typeof ContextMenuPrimitive.Separator>) {\r\n  return (\r\n    <ContextMenuPrimitive.Separator\r\n      data-slot=\"context-menu-separator\"\r\n      className={cn(\"bg-border/50 -mx-1 my-1 h-px\", className)}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction ContextMenuShortcut({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<\"span\">) {\r\n  return (\r\n    <span\r\n      data-slot=\"context-menu-shortcut\"\r\n      className={cn(\r\n        \"text-muted-foreground group-focus/context-menu-item:text-accent-foreground ml-auto text-xs tracking-widest\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nexport {\r\n  ContextMenu,\r\n  ContextMenuTrigger,\r\n  ContextMenuContent,\r\n  ContextMenuItem,\r\n  ContextMenuCheckboxItem,\r\n  ContextMenuRadioItem,\r\n  ContextMenuLabel,\r\n  ContextMenuSeparator,\r\n  ContextMenuShortcut,\r\n  ContextMenuGroup,\r\n  ContextMenuPortal,\r\n  ContextMenuSub,\r\n  ContextMenuSubContent,\r\n  ContextMenuSubTrigger,\r\n  ContextMenuRadioGroup,\r\n};\r\n"
  },
  {
    "path": "studio/frontend/src/components/ui/data-table.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport {\n  type ColumnDef,\n  type SortingState,\n  flexRender,\n  getCoreRowModel,\n  getSortedRowModel,\n  useReactTable,\n} from \"@tanstack/react-table\";\nimport { useState } from \"react\";\n\nimport {\n  Table,\n  TableBody,\n  TableCell,\n  TableHead,\n  TableHeader,\n  TableRow,\n} from \"@/components/ui/table\";\nimport { cn } from \"@/lib/utils\";\n\ninterface DataTableProps<TData, TValue> {\n  columns: ColumnDef<TData, TValue>[];\n  data: TData[];\n  className?: string;\n  onRowClick?: (row: TData, rowIndex: number, rowId: string) => void;\n  getRowClassName?: (\n    row: TData,\n    rowIndex: number,\n    rowId: string,\n  ) => string | undefined;\n}\n\nexport function DataTable<TData, TValue>({\n  columns,\n  data,\n  className,\n  onRowClick,\n  getRowClassName,\n}: DataTableProps<TData, TValue>) {\n  const [sorting, setSorting] = useState<SortingState>([]);\n\n  // eslint-disable-next-line react-hooks/incompatible-library\n  const table = useReactTable({\n    data,\n    columns,\n    getCoreRowModel: getCoreRowModel(),\n    getSortedRowModel: getSortedRowModel(),\n    onSortingChange: setSorting,\n    state: { sorting },\n  });\n\n  return (\n    <div className={cn(\"w-full\", className)}>\n      <Table>\n        <TableHeader className=\"sticky top-0 z-10\">\n          {table.getHeaderGroups().map((headerGroup) => (\n            <TableRow\n              key={headerGroup.id}\n              className=\"bg-muted/60 hover:bg-muted/60 border-b border-border/60\"\n            >\n              {headerGroup.headers.map((header) => (\n                <TableHead\n                  key={header.id}\n                  className=\"border-r border-border/40 last:border-r-0 h-11 px-4 text-xs\"\n                  style={{\n                    width:\n                      header.getSize() !== 150 ? header.getSize() : undefined,\n                  }}\n                >\n                  {header.isPlaceholder\n                    ? null\n                    : flexRender(\n                        header.column.columnDef.header,\n                        header.getContext(),\n                      )}\n                </TableHead>\n              ))}\n            </TableRow>\n          ))}\n        </TableHeader>\n        <TableBody>\n          {table.getRowModel().rows.length ? (\n            table.getRowModel().rows.map((row, idx) => (\n              <TableRow\n                key={row.id}\n                data-state={row.getIsSelected() ? \"selected\" : undefined}\n                className={cn(\n                  \"transition-colors border-b border-border/30\",\n                  idx % 2 === 0\n                    ? \"bg-background\"\n                    : \"bg-muted/20\",\n                  \"hover:bg-primary/[0.03]\",\n                  getRowClassName?.(row.original, idx, row.id),\n                )}\n                onClick={() => onRowClick?.(row.original, idx, row.id)}\n              >\n                {row.getVisibleCells().map((cell) => (\n                  <TableCell\n                    key={cell.id}\n                    className=\"border-r border-border/20 last:border-r-0 text-[13px] py-3 px-4 align-top whitespace-normal\"\n                  >\n                    {flexRender(cell.column.columnDef.cell, cell.getContext())}\n                  </TableCell>\n                ))}\n              </TableRow>\n            ))\n          ) : (\n            <TableRow>\n              <TableCell\n                colSpan={columns.length}\n                className=\"h-32 text-center text-muted-foreground text-sm\"\n              >\n                No results.\n              </TableCell>\n            </TableRow>\n          )}\n        </TableBody>\n      </Table>\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/components/ui/dialog.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"use client\";\n\nimport { Dialog as DialogPrimitive } from \"radix-ui\";\nimport type * as React from \"react\";\nimport { createContext, useContext } from \"react\";\n\nimport { Button } from \"@/components/ui/button\";\nimport { cn } from \"@/lib/utils\";\nimport { Cancel01Icon } from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\n\nconst DialogPortalContainerContext = createContext<HTMLElement | null>(null);\n\nexport function useDialogPortalContainer(): HTMLElement | null {\n  return useContext(DialogPortalContainerContext);\n}\n\nfunction Dialog({\n  ...props\n}: React.ComponentProps<typeof DialogPrimitive.Root>) {\n  return <DialogPrimitive.Root data-slot=\"dialog\" {...props} />;\n}\n\nfunction DialogTrigger({\n  ...props\n}: React.ComponentProps<typeof DialogPrimitive.Trigger>) {\n  return <DialogPrimitive.Trigger data-slot=\"dialog-trigger\" {...props} />;\n}\n\nfunction DialogPortal({\n  ...props\n}: React.ComponentProps<typeof DialogPrimitive.Portal>) {\n  return <DialogPrimitive.Portal data-slot=\"dialog-portal\" {...props} />;\n}\n\nfunction DialogClose({\n  ...props\n}: React.ComponentProps<typeof DialogPrimitive.Close>) {\n  return <DialogPrimitive.Close data-slot=\"dialog-close\" {...props} />;\n}\n\nfunction DialogOverlay({\n  className,\n  position = \"fixed\",\n  ...props\n}: React.ComponentProps<typeof DialogPrimitive.Overlay> & {\n  position?: \"fixed\" | \"absolute\";\n}) {\n  return (\n    <DialogPrimitive.Overlay\n      data-slot=\"dialog-overlay\"\n      className={cn(\n        \"data-open:animate-in data-closed:animate-out data-closed:fade-out-0 data-open:fade-in-0 bg-black/80 duration-100  inset-0 isolate z-50\",\n        position === \"fixed\" ? \"fixed\" : \"absolute\",\n        className,\n      )}\n      {...props}\n    />\n  );\n}\n\nfunction DialogContent({\n  className,\n  children,\n  showCloseButton = true,\n  container,\n  position = \"fixed\",\n  overlayClassName,\n  overlayPosition,\n  ...props\n}: React.ComponentProps<typeof DialogPrimitive.Content> & {\n  showCloseButton?: boolean;\n  container?: HTMLElement | null;\n  position?: \"fixed\" | \"absolute\";\n  overlayClassName?: string;\n  overlayPosition?: \"fixed\" | \"absolute\";\n}) {\n  const resolvedContainer = container ?? null;\n  return (\n    <DialogPortalContainerContext.Provider value={resolvedContainer}>\n      <DialogPortal container={resolvedContainer ?? undefined}>\n        <DialogOverlay\n          className={overlayClassName}\n          position={overlayPosition ?? position}\n        />\n        <DialogPrimitive.Content\n          data-slot=\"dialog-content\"\n          className={cn(\n            \"bg-background data-open:animate-in data-closed:animate-out data-closed:fade-out-0 data-open:fade-in-0 data-closed:zoom-out-95 data-open:zoom-in-95 ring-foreground/5 grid max-w-[calc(100%-2rem)] gap-6 rounded-4xl p-6 text-sm ring-1 duration-100 sm:max-w-md top-1/2 left-1/2 z-50 w-full -translate-x-1/2 -translate-y-1/2\",\n            position === \"fixed\" ? \"fixed\" : \"absolute\",\n            className,\n          )}\n          {...props}\n        >\n          {children}\n          {showCloseButton && (\n            <DialogPrimitive.Close data-slot=\"dialog-close\" asChild>\n              <Button\n                variant=\"ghost\"\n                className=\"absolute top-4 right-4\"\n                size=\"icon-sm\"\n              >\n                <HugeiconsIcon icon={Cancel01Icon} strokeWidth={2} />\n                <span className=\"sr-only\">Close</span>\n              </Button>\n            </DialogPrimitive.Close>\n          )}\n        </DialogPrimitive.Content>\n      </DialogPortal>\n    </DialogPortalContainerContext.Provider>\n  );\n}\n\nfunction DialogHeader({ className, ...props }: React.ComponentProps<\"div\">) {\n  return (\n    <div\n      data-slot=\"dialog-header\"\n      className={cn(\"gap-2 flex flex-col\", className)}\n      {...props}\n    />\n  );\n}\n\nfunction DialogFooter({\n  className,\n  showCloseButton = false,\n  children,\n  ...props\n}: React.ComponentProps<\"div\"> & {\n  showCloseButton?: boolean;\n}) {\n  return (\n    <div\n      data-slot=\"dialog-footer\"\n      className={cn(\n        \"flex flex-col-reverse gap-2 sm:flex-row sm:justify-end\",\n        className,\n      )}\n      {...props}\n    >\n      {children}\n      {showCloseButton && (\n        <DialogPrimitive.Close asChild>\n          <Button variant=\"outline\">Close</Button>\n        </DialogPrimitive.Close>\n      )}\n    </div>\n  );\n}\n\nfunction DialogTitle({\n  className,\n  ...props\n}: React.ComponentProps<typeof DialogPrimitive.Title>) {\n  return (\n    <DialogPrimitive.Title\n      data-slot=\"dialog-title\"\n      className={cn(\"text-base leading-none font-medium\", className)}\n      {...props}\n    />\n  );\n}\n\nfunction DialogDescription({\n  className,\n  ...props\n}: React.ComponentProps<typeof DialogPrimitive.Description>) {\n  return (\n    <DialogPrimitive.Description\n      data-slot=\"dialog-description\"\n      className={cn(\n        \"text-muted-foreground *:[a]:hover:text-foreground text-sm *:[a]:underline *:[a]:underline-offset-3\",\n        className,\n      )}\n      {...props}\n    />\n  );\n}\n\nexport {\n  Dialog,\n  DialogClose,\n  DialogContent,\n  DialogDescription,\n  DialogFooter,\n  DialogHeader,\n  DialogOverlay,\n  DialogPortal,\n  DialogPortalContainerContext,\n  DialogTitle,\n  DialogTrigger,\n};\n"
  },
  {
    "path": "studio/frontend/src/components/ui/dropdown-menu.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { DropdownMenu as DropdownMenuPrimitive } from \"radix-ui\";\r\nimport type * as React from \"react\";\r\n\r\nimport { cn } from \"@/lib/utils\";\r\nimport { ArrowRight01Icon, Tick02Icon } from \"@hugeicons/core-free-icons\";\r\nimport { HugeiconsIcon } from \"@hugeicons/react\";\r\n\r\nfunction DropdownMenu({\r\n  ...props\r\n}: React.ComponentProps<typeof DropdownMenuPrimitive.Root>) {\r\n  return <DropdownMenuPrimitive.Root data-slot=\"dropdown-menu\" {...props} />;\r\n}\r\n\r\nfunction DropdownMenuPortal({\r\n  ...props\r\n}: React.ComponentProps<typeof DropdownMenuPrimitive.Portal>) {\r\n  return (\r\n    <DropdownMenuPrimitive.Portal data-slot=\"dropdown-menu-portal\" {...props} />\r\n  );\r\n}\r\n\r\nfunction DropdownMenuTrigger({\r\n  ...props\r\n}: React.ComponentProps<typeof DropdownMenuPrimitive.Trigger>) {\r\n  return (\r\n    <DropdownMenuPrimitive.Trigger\r\n      data-slot=\"dropdown-menu-trigger\"\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction DropdownMenuContent({\r\n  className,\r\n  align = \"start\",\r\n  sideOffset = 4,\r\n  ...props\r\n}: React.ComponentProps<typeof DropdownMenuPrimitive.Content>) {\r\n  return (\r\n    <DropdownMenuPrimitive.Portal>\r\n      <DropdownMenuPrimitive.Content\r\n        data-slot=\"dropdown-menu-content\"\r\n        sideOffset={sideOffset}\r\n        align={align}\r\n        className={cn(\r\n          \"data-open:animate-in data-closed:animate-out data-closed:fade-out-0 data-open:fade-in-0 data-closed:zoom-out-95 data-open:zoom-in-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2 shadow-border ring-1 ring-border bg-popover text-popover-foreground min-w-48 rounded-lg p-1 duration-100 z-50 max-h-(--radix-dropdown-menu-content-available-height) w-(--radix-dropdown-menu-trigger-width) origin-(--radix-dropdown-menu-content-transform-origin) overflow-x-hidden overflow-y-auto data-[state=closed]:overflow-hidden\",\r\n          className,\r\n        )}\r\n        {...props}\r\n      />\r\n    </DropdownMenuPrimitive.Portal>\r\n  );\r\n}\r\n\r\nfunction DropdownMenuGroup({\r\n  ...props\r\n}: React.ComponentProps<typeof DropdownMenuPrimitive.Group>) {\r\n  return (\r\n    <DropdownMenuPrimitive.Group data-slot=\"dropdown-menu-group\" {...props} />\r\n  );\r\n}\r\n\r\nfunction DropdownMenuItem({\r\n  className,\r\n  inset,\r\n  variant = \"default\",\r\n  ...props\r\n}: React.ComponentProps<typeof DropdownMenuPrimitive.Item> & {\r\n  inset?: boolean;\r\n  variant?: \"default\" | \"destructive\";\r\n}) {\r\n  return (\r\n    <DropdownMenuPrimitive.Item\r\n      data-slot=\"dropdown-menu-item\"\r\n      data-inset={inset}\r\n      data-variant={variant}\r\n      className={cn(\r\n        \"focus:bg-accent focus:text-accent-foreground data-[variant=destructive]:text-destructive data-[variant=destructive]:focus:bg-destructive/10 dark:data-[variant=destructive]:focus:bg-destructive/20 data-[variant=destructive]:focus:text-destructive data-[variant=destructive]:*:[svg]:text-destructive not-data-[variant=destructive]:focus:**:text-accent-foreground gap-2.5 rounded-lg px-3 py-2 text-sm [&_svg:not([class*='size-'])]:size-4 group/dropdown-menu-item relative flex cursor-default items-center outline-hidden select-none data-[disabled]:pointer-events-none data-[disabled]:opacity-50 data-[inset]:pl-8 [&_svg]:pointer-events-none [&_svg]:shrink-0\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction DropdownMenuCheckboxItem({\r\n  className,\r\n  children,\r\n  checked,\r\n  ...props\r\n}: React.ComponentProps<typeof DropdownMenuPrimitive.CheckboxItem>) {\r\n  return (\r\n    <DropdownMenuPrimitive.CheckboxItem\r\n      data-slot=\"dropdown-menu-checkbox-item\"\r\n      className={cn(\r\n        \"focus:bg-accent focus:text-accent-foreground focus:**:text-accent-foreground gap-2.5 rounded-lg py-2 pr-8 pl-3 text-sm [&_svg:not([class*='size-'])]:size-4 relative flex cursor-default items-center outline-hidden select-none data-[disabled]:pointer-events-none data-[disabled]:opacity-50 [&_svg]:pointer-events-none [&_svg]:shrink-0\",\r\n        className,\r\n      )}\r\n      checked={checked}\r\n      {...props}\r\n    >\r\n      <span\r\n        className=\"pointer-events-none absolute right-2 flex items-center justify-center pointer-events-none\"\r\n        data-slot=\"dropdown-menu-checkbox-item-indicator\"\r\n      >\r\n        <DropdownMenuPrimitive.ItemIndicator>\r\n          <HugeiconsIcon icon={Tick02Icon} strokeWidth={2} />\r\n        </DropdownMenuPrimitive.ItemIndicator>\r\n      </span>\r\n      {children}\r\n    </DropdownMenuPrimitive.CheckboxItem>\r\n  );\r\n}\r\n\r\nfunction DropdownMenuRadioGroup({\r\n  ...props\r\n}: React.ComponentProps<typeof DropdownMenuPrimitive.RadioGroup>) {\r\n  return (\r\n    <DropdownMenuPrimitive.RadioGroup\r\n      data-slot=\"dropdown-menu-radio-group\"\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction DropdownMenuRadioItem({\r\n  className,\r\n  children,\r\n  ...props\r\n}: React.ComponentProps<typeof DropdownMenuPrimitive.RadioItem>) {\r\n  return (\r\n    <DropdownMenuPrimitive.RadioItem\r\n      data-slot=\"dropdown-menu-radio-item\"\r\n      className={cn(\r\n        \"focus:bg-accent focus:text-accent-foreground focus:**:text-accent-foreground gap-2.5 rounded-lg py-2 pr-8 pl-3 text-sm [&_svg:not([class*='size-'])]:size-4 relative flex cursor-default items-center outline-hidden select-none data-[disabled]:pointer-events-none data-[disabled]:opacity-50 [&_svg]:pointer-events-none [&_svg]:shrink-0\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    >\r\n      <span\r\n        className=\"pointer-events-none absolute right-2 flex items-center justify-center pointer-events-none\"\r\n        data-slot=\"dropdown-menu-radio-item-indicator\"\r\n      >\r\n        <DropdownMenuPrimitive.ItemIndicator>\r\n          <HugeiconsIcon icon={Tick02Icon} strokeWidth={2} />\r\n        </DropdownMenuPrimitive.ItemIndicator>\r\n      </span>\r\n      {children}\r\n    </DropdownMenuPrimitive.RadioItem>\r\n  );\r\n}\r\n\r\nfunction DropdownMenuLabel({\r\n  className,\r\n  inset,\r\n  ...props\r\n}: React.ComponentProps<typeof DropdownMenuPrimitive.Label> & {\r\n  inset?: boolean;\r\n}) {\r\n  return (\r\n    <DropdownMenuPrimitive.Label\r\n      data-slot=\"dropdown-menu-label\"\r\n      data-inset={inset}\r\n      className={cn(\r\n        \"text-muted-foreground px-3 py-2.5 text-xs data-[inset]:pl-8\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction DropdownMenuSeparator({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<typeof DropdownMenuPrimitive.Separator>) {\r\n  return (\r\n    <DropdownMenuPrimitive.Separator\r\n      data-slot=\"dropdown-menu-separator\"\r\n      className={cn(\"bg-border/50 -mx-1 my-1 h-px\", className)}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction DropdownMenuShortcut({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<\"span\">) {\r\n  return (\r\n    <span\r\n      data-slot=\"dropdown-menu-shortcut\"\r\n      className={cn(\r\n        \"text-muted-foreground group-focus/dropdown-menu-item:text-accent-foreground ml-auto text-xs tracking-widest\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction DropdownMenuSub({\r\n  ...props\r\n}: React.ComponentProps<typeof DropdownMenuPrimitive.Sub>) {\r\n  return <DropdownMenuPrimitive.Sub data-slot=\"dropdown-menu-sub\" {...props} />;\r\n}\r\n\r\nfunction DropdownMenuSubTrigger({\r\n  className,\r\n  inset,\r\n  children,\r\n  ...props\r\n}: React.ComponentProps<typeof DropdownMenuPrimitive.SubTrigger> & {\r\n  inset?: boolean;\r\n}) {\r\n  return (\r\n    <DropdownMenuPrimitive.SubTrigger\r\n      data-slot=\"dropdown-menu-sub-trigger\"\r\n      data-inset={inset}\r\n      className={cn(\r\n        \"focus:bg-accent focus:text-accent-foreground data-open:bg-accent data-open:text-accent-foreground not-data-[variant=destructive]:focus:**:text-accent-foreground gap-2 rounded-lg px-3 py-2 text-sm [&_svg:not([class*='size-'])]:size-4 flex cursor-default items-center outline-hidden select-none data-[inset]:pl-8 [&_svg]:pointer-events-none [&_svg]:shrink-0\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    >\r\n      {children}\r\n      <HugeiconsIcon\r\n        icon={ArrowRight01Icon}\r\n        strokeWidth={2}\r\n        className=\"ml-auto\"\r\n      />\r\n    </DropdownMenuPrimitive.SubTrigger>\r\n  );\r\n}\r\n\r\nfunction DropdownMenuSubContent({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<typeof DropdownMenuPrimitive.SubContent>) {\r\n  return (\r\n    <DropdownMenuPrimitive.SubContent\r\n      data-slot=\"dropdown-menu-sub-content\"\r\n      className={cn(\r\n        \"data-open:animate-in data-closed:animate-out data-closed:fade-out-0 data-open:fade-in-0 data-closed:zoom-out-95 data-open:zoom-in-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2 shadow-border ring-1 ring-border bg-popover text-popover-foreground min-w-36 rounded-lg p-1 duration-100 z-50 origin-(--radix-dropdown-menu-content-transform-origin) overflow-hidden\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nexport {\r\n  DropdownMenu,\r\n  DropdownMenuPortal,\r\n  DropdownMenuTrigger,\r\n  DropdownMenuContent,\r\n  DropdownMenuGroup,\r\n  DropdownMenuLabel,\r\n  DropdownMenuItem,\r\n  DropdownMenuCheckboxItem,\r\n  DropdownMenuRadioGroup,\r\n  DropdownMenuRadioItem,\r\n  DropdownMenuSeparator,\r\n  DropdownMenuShortcut,\r\n  DropdownMenuSub,\r\n  DropdownMenuSubTrigger,\r\n  DropdownMenuSubContent,\r\n};\r\n"
  },
  {
    "path": "studio/frontend/src/components/ui/empty.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { cva, type VariantProps } from \"class-variance-authority\"\n\nimport { cn } from \"@/lib/utils\"\n\nfunction Empty({ className, ...props }: React.ComponentProps<\"div\">) {\n  return (\n    <div\n      data-slot=\"empty\"\n      className={cn(\n        \"gap-4 rounded-lg border-dashed p-12 flex w-full min-w-0 flex-1 flex-col items-center justify-center text-center text-balance\",\n        className\n      )}\n      {...props}\n    />\n  )\n}\n\nfunction EmptyHeader({ className, ...props }: React.ComponentProps<\"div\">) {\n  return (\n    <div\n      data-slot=\"empty-header\"\n      className={cn(\n        \"gap-2 flex max-w-sm flex-col items-center\",\n        className\n      )}\n      {...props}\n    />\n  )\n}\n\nconst emptyMediaVariants = cva(\n  \"mb-2 flex shrink-0 items-center justify-center [&_svg]:pointer-events-none [&_svg]:shrink-0\",\n  {\n    variants: {\n      variant: {\n        default: \"bg-transparent\",\n        icon: \"bg-muted text-foreground flex size-10 shrink-0 items-center justify-center rounded-lg [&_svg:not([class*='size-'])]:size-6\",\n      },\n    },\n    defaultVariants: {\n      variant: \"default\",\n    },\n  }\n)\n\nfunction EmptyMedia({\n  className,\n  variant = \"default\",\n  ...props\n}: React.ComponentProps<\"div\"> & VariantProps<typeof emptyMediaVariants>) {\n  return (\n    <div\n      data-slot=\"empty-icon\"\n      data-variant={variant}\n      className={cn(emptyMediaVariants({ variant, className }))}\n      {...props}\n    />\n  )\n}\n\nfunction EmptyTitle({ className, ...props }: React.ComponentProps<\"div\">) {\n  return (\n    <div\n      data-slot=\"empty-title\"\n      className={cn(\"text-lg font-medium tracking-tight\", className)}\n      {...props}\n    />\n  )\n}\n\nfunction EmptyDescription({ className, ...props }: React.ComponentProps<\"p\">) {\n  return (\n    <div\n      data-slot=\"empty-description\"\n      className={cn(\n        \"text-sm/relaxed text-muted-foreground [&>a:hover]:text-primary [&>a]:underline [&>a]:underline-offset-4\",\n        className\n      )}\n      {...props}\n    />\n  )\n}\n\nfunction EmptyContent({ className, ...props }: React.ComponentProps<\"div\">) {\n  return (\n    <div\n      data-slot=\"empty-content\"\n      className={cn(\n        \"gap-4 text-sm flex w-full max-w-sm min-w-0 flex-col items-center text-balance\",\n        className\n      )}\n      {...props}\n    />\n  )\n}\n\nexport {\n  Empty,\n  EmptyHeader,\n  EmptyTitle,\n  EmptyDescription,\n  EmptyContent,\n  EmptyMedia,\n}\n"
  },
  {
    "path": "studio/frontend/src/components/ui/field.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { type VariantProps, cva } from \"class-variance-authority\";\r\nimport { useMemo } from \"react\";\r\n\r\nimport { Label } from \"@/components/ui/label\";\r\nimport { Separator } from \"@/components/ui/separator\";\r\nimport { cn } from \"@/lib/utils\";\r\n\r\nfunction FieldSet({ className, ...props }: React.ComponentProps<\"fieldset\">) {\r\n  return (\r\n    <fieldset\r\n      data-slot=\"field-set\"\r\n      className={cn(\r\n        \"gap-6 has-[>[data-slot=checkbox-group]]:gap-3 has-[>[data-slot=radio-group]]:gap-3 flex flex-col\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction FieldLegend({\r\n  className,\r\n  variant = \"legend\",\r\n  ...props\r\n}: React.ComponentProps<\"legend\"> & { variant?: \"legend\" | \"label\" }) {\r\n  return (\r\n    <legend\r\n      data-slot=\"field-legend\"\r\n      data-variant={variant}\r\n      className={cn(\r\n        \"mb-3 font-medium data-[variant=label]:text-sm data-[variant=legend]:text-base\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction FieldGroup({ className, ...props }: React.ComponentProps<\"div\">) {\r\n  return (\r\n    <div\r\n      data-slot=\"field-group\"\r\n      className={cn(\r\n        \"gap-7 data-[slot=checkbox-group]:gap-3 [&>[data-slot=field-group]]:gap-4 group/field-group @container/field-group flex w-full flex-col\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nconst fieldVariants = cva(\r\n  \"data-[invalid=true]:text-destructive gap-3 group/field flex w-full\",\r\n  {\r\n    variants: {\r\n      orientation: {\r\n        vertical: \"flex-col [&>*]:w-full [&>.sr-only]:w-auto\",\r\n        horizontal:\r\n          \"flex-row items-center [&>[data-slot=field-label]]:flex-auto has-[>[data-slot=field-content]]:items-start has-[>[data-slot=field-content]]:[&>[role=checkbox],[role=radio]]:mt-px\",\r\n        responsive:\r\n          \"flex-col [&>*]:w-full [&>.sr-only]:w-auto @md/field-group:flex-row @md/field-group:items-center @md/field-group:[&>*]:w-auto @md/field-group:[&>[data-slot=field-label]]:flex-auto @md/field-group:has-[>[data-slot=field-content]]:items-start @md/field-group:has-[>[data-slot=field-content]]:[&>[role=checkbox],[role=radio]]:mt-px\",\r\n      },\r\n    },\r\n    defaultVariants: {\r\n      orientation: \"vertical\",\r\n    },\r\n  },\r\n);\r\n\r\nfunction Field({\r\n  className,\r\n  orientation = \"vertical\",\r\n  ...props\r\n}: React.ComponentProps<\"div\"> & VariantProps<typeof fieldVariants>) {\r\n  return (\r\n    <div\r\n      role=\"group\"\r\n      data-slot=\"field\"\r\n      data-orientation={orientation}\r\n      className={cn(fieldVariants({ orientation }), className)}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction FieldContent({ className, ...props }: React.ComponentProps<\"div\">) {\r\n  return (\r\n    <div\r\n      data-slot=\"field-content\"\r\n      className={cn(\r\n        \"gap-1 group/field-content flex flex-1 flex-col leading-snug\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction FieldLabel({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<typeof Label>) {\r\n  return (\r\n    <Label\r\n      data-slot=\"field-label\"\r\n      className={cn(\r\n        \"has-data-checked:bg-primary/5 has-data-checked:border-primary/50 dark:has-data-checked:bg-primary/10 gap-2 group-data-[disabled=true]/field:opacity-50 has-[>[data-slot=field]]:rounded-xl has-[>[data-slot=field]]:border [&>*]:data-[slot=field]:p-4 group/field-label peer/field-label flex w-fit leading-snug\",\r\n        \"has-[>[data-slot=field]]:w-full has-[>[data-slot=field]]:flex-col\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction FieldTitle({ className, ...props }: React.ComponentProps<\"div\">) {\r\n  return (\r\n    <div\r\n      data-slot=\"field-label\"\r\n      className={cn(\r\n        \"gap-2 text-sm font-medium group-data-[disabled=true]/field:opacity-50 flex w-fit items-center leading-snug\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction FieldDescription({ className, ...props }: React.ComponentProps<\"p\">) {\r\n  return (\r\n    <p\r\n      data-slot=\"field-description\"\r\n      className={cn(\r\n        \"text-muted-foreground text-left text-sm [[data-variant=legend]+&]:-mt-1.5 leading-normal font-normal group-has-[[data-orientation=horizontal]]/field:text-balance\",\r\n        \"last:mt-0 nth-last-2:-mt-1\",\r\n        \"[&>a:hover]:text-primary [&>a]:underline [&>a]:underline-offset-4\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction FieldSeparator({\r\n  children,\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<\"div\"> & {\r\n  children?: React.ReactNode;\r\n}) {\r\n  return (\r\n    <div\r\n      data-slot=\"field-separator\"\r\n      data-content={!!children}\r\n      className={cn(\r\n        \"-my-2 h-5 text-sm group-data-[variant=outline]/field-group:-mb-2 relative\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    >\r\n      <Separator className=\"absolute inset-0 top-1/2\" />\r\n      {children && (\r\n        <span\r\n          className=\"text-muted-foreground px-2 bg-background relative mx-auto block w-fit\"\r\n          data-slot=\"field-separator-content\"\r\n        >\r\n          {children}\r\n        </span>\r\n      )}\r\n    </div>\r\n  );\r\n}\r\n\r\nfunction FieldError({\r\n  className,\r\n  children,\r\n  errors,\r\n  ...props\r\n}: React.ComponentProps<\"div\"> & {\r\n  errors?: Array<{ message?: string } | undefined>;\r\n}) {\r\n  const content = useMemo(() => {\r\n    if (children) {\r\n      return children;\r\n    }\r\n\r\n    if (!errors?.length) {\r\n      return null;\r\n    }\r\n\r\n    const uniqueErrors = [\r\n      ...new Map(errors.map((error) => [error?.message, error])).values(),\r\n    ];\r\n\r\n    if (uniqueErrors?.length === 1) {\r\n      return uniqueErrors[0]?.message;\r\n    }\r\n\r\n    return (\r\n      <ul className=\"ml-4 flex list-disc flex-col gap-1\">\r\n        {uniqueErrors.map(\r\n          (error, index) =>\r\n            error?.message && <li key={index}>{error.message}</li>,\r\n        )}\r\n      </ul>\r\n    );\r\n  }, [children, errors]);\r\n\r\n  if (!content) {\r\n    return null;\r\n  }\r\n\r\n  return (\r\n    <div\r\n      role=\"alert\"\r\n      data-slot=\"field-error\"\r\n      className={cn(\"text-destructive text-sm font-normal\", className)}\r\n      {...props}\r\n    >\r\n      {content}\r\n    </div>\r\n  );\r\n}\r\n\r\nexport {\r\n  Field,\r\n  FieldLabel,\r\n  FieldDescription,\r\n  FieldError,\r\n  FieldGroup,\r\n  FieldLegend,\r\n  FieldSeparator,\r\n  FieldSet,\r\n  FieldContent,\r\n  FieldTitle,\r\n};\r\n"
  },
  {
    "path": "studio/frontend/src/components/ui/hover-card.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { HoverCard as HoverCardPrimitive } from \"radix-ui\";\r\nimport type * as React from \"react\";\r\n\r\nimport { cn } from \"@/lib/utils\";\r\n\r\nfunction HoverCard({\r\n  ...props\r\n}: React.ComponentProps<typeof HoverCardPrimitive.Root>) {\r\n  return <HoverCardPrimitive.Root data-slot=\"hover-card\" {...props} />;\r\n}\r\n\r\nfunction HoverCardTrigger({\r\n  ...props\r\n}: React.ComponentProps<typeof HoverCardPrimitive.Trigger>) {\r\n  return (\r\n    <HoverCardPrimitive.Trigger data-slot=\"hover-card-trigger\" {...props} />\r\n  );\r\n}\r\n\r\nfunction HoverCardContent({\r\n  className,\r\n  align = \"center\",\r\n  sideOffset = 4,\r\n  ...props\r\n}: React.ComponentProps<typeof HoverCardPrimitive.Content>) {\r\n  return (\r\n    <HoverCardPrimitive.Portal data-slot=\"hover-card-portal\">\r\n      <HoverCardPrimitive.Content\r\n        data-slot=\"hover-card-content\"\r\n        align={align}\r\n        sideOffset={sideOffset}\r\n        className={cn(\r\n          \"data-open:animate-in data-closed:animate-out data-closed:fade-out-0 data-open:fade-in-0 data-closed:zoom-out-95 data-open:zoom-in-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2 ring-foreground/5 bg-popover text-popover-foreground w-72 rounded-2xl p-4 text-sm shadow-2xl ring-1 duration-100 z-50 origin-(--radix-hover-card-content-transform-origin) outline-hidden\",\r\n          className,\r\n        )}\r\n        {...props}\r\n      />\r\n    </HoverCardPrimitive.Portal>\r\n  );\r\n}\r\n\r\nexport { HoverCard, HoverCardTrigger, HoverCardContent };\r\n"
  },
  {
    "path": "studio/frontend/src/components/ui/input-group.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { type VariantProps, cva } from \"class-variance-authority\";\r\nimport type * as React from \"react\";\r\n\r\nimport { Button } from \"@/components/ui/button\";\r\nimport { Input } from \"@/components/ui/input\";\r\nimport { Textarea } from \"@/components/ui/textarea\";\r\nimport { cn } from \"@/lib/utils\";\r\n\r\nfunction InputGroup({ className, ...props }: React.ComponentProps<\"div\">) {\r\n  return (\r\n    <div\r\n      data-slot=\"input-group\"\r\n      role=\"group\"\r\n      className={cn(\r\n        \"border-input bg-input/30 has-[[data-slot=input-group-control]:focus-visible]:border-ring has-[[data-slot=input-group-control]:focus-visible]:ring-ring/50 has-[[data-slot][aria-invalid=true]]:ring-destructive/20 has-[[data-slot][aria-invalid=true]]:border-destructive dark:has-[[data-slot][aria-invalid=true]]:ring-destructive/40 h-9 rounded-4xl border transition-colors has-data-[align=block-end]:rounded-2xl has-data-[align=block-start]:rounded-2xl has-[[data-slot=input-group-control]:focus-visible]:ring-[3px] has-[[data-slot][aria-invalid=true]]:ring-[3px] has-[textarea]:rounded-xl has-[>[data-align=block-end]]:h-auto has-[>[data-align=block-end]]:flex-col has-[>[data-align=block-start]]:h-auto has-[>[data-align=block-start]]:flex-col has-[>[data-align=block-end]]:[&>input]:pt-3 has-[>[data-align=block-start]]:[&>input]:pb-3 has-[>[data-align=inline-end]]:[&>input]:pr-1.5 has-[>[data-align=inline-start]]:[&>input]:pl-1.5 [[data-slot=combobox-content]_&]:focus-within:border-inherit [[data-slot=combobox-content]_&]:focus-within:ring-0 group/input-group relative flex w-full min-w-0 items-center outline-none has-[>textarea]:h-auto\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nconst inputGroupAddonVariants = cva(\r\n  \"text-muted-foreground **:data-[slot=kbd]:bg-muted-foreground/10 h-auto gap-2 py-2 text-sm font-medium group-data-[disabled=true]/input-group:opacity-50 **:data-[slot=kbd]:rounded-4xl **:data-[slot=kbd]:px-1.5 [&>svg:not([class*='size-'])]:size-4 flex cursor-text items-center justify-center select-none\",\r\n  {\r\n    variants: {\r\n      align: {\r\n        \"inline-start\":\r\n          \"pl-3 has-[>button]:ml-[-0.25rem] has-[>kbd]:ml-[-0.15rem] order-first\",\r\n        \"inline-end\":\r\n          \"pr-3 has-[>button]:mr-[-0.25rem] has-[>kbd]:mr-[-0.15rem] order-last\",\r\n        \"block-start\":\r\n          \"px-3 pt-3 group-has-[>input]/input-group:pt-3 [.border-b]:pb-3 order-first w-full justify-start\",\r\n        \"block-end\":\r\n          \"px-3 pb-3 group-has-[>input]/input-group:pb-3 [.border-t]:pt-3 order-last w-full justify-start\",\r\n      },\r\n    },\r\n    defaultVariants: {\r\n      align: \"inline-start\",\r\n    },\r\n  },\r\n);\r\n\r\nfunction InputGroupAddon({\r\n  className,\r\n  align = \"inline-start\",\r\n  ...props\r\n}: React.ComponentProps<\"div\"> & VariantProps<typeof inputGroupAddonVariants>) {\r\n  return (\r\n    <div\r\n      role=\"group\"\r\n      data-slot=\"input-group-addon\"\r\n      data-align={align}\r\n      className={cn(inputGroupAddonVariants({ align }), className)}\r\n      onClick={(e) => {\r\n        if ((e.target as HTMLElement).closest(\"button\")) {\r\n          return;\r\n        }\r\n        e.currentTarget.parentElement?.querySelector(\"input\")?.focus();\r\n      }}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nconst inputGroupButtonVariants = cva(\r\n  \"gap-2 rounded-4xl text-sm shadow-none flex items-center\",\r\n  {\r\n    variants: {\r\n      size: {\r\n        xs: \"h-6 gap-1 px-1.5 [&>svg:not([class*='size-'])]:size-3.5\",\r\n        sm: \"\",\r\n        \"icon-xs\": \"size-6 p-0 has-[>svg]:p-0\",\r\n        \"icon-sm\": \"size-8 p-0 has-[>svg]:p-0\",\r\n      },\r\n    },\r\n    defaultVariants: {\r\n      size: \"xs\",\r\n    },\r\n  },\r\n);\r\n\r\nfunction InputGroupButton({\r\n  className,\r\n  type = \"button\",\r\n  variant = \"ghost\",\r\n  size = \"xs\",\r\n  ...props\r\n}: Omit<React.ComponentProps<typeof Button>, \"size\"> &\r\n  VariantProps<typeof inputGroupButtonVariants>) {\r\n  return (\r\n    <Button\r\n      type={type}\r\n      data-size={size}\r\n      variant={variant}\r\n      className={cn(inputGroupButtonVariants({ size }), className)}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction InputGroupText({ className, ...props }: React.ComponentProps<\"span\">) {\r\n  return (\r\n    <span\r\n      className={cn(\r\n        \"text-muted-foreground gap-2 text-sm [&_svg:not([class*='size-'])]:size-4 flex items-center [&_svg]:pointer-events-none\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction InputGroupInput({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<\"input\">) {\r\n  return (\r\n    <Input\r\n      data-slot=\"input-group-control\"\r\n      className={cn(\r\n        \"rounded-none border-0 bg-transparent shadow-none ring-0 focus-visible:ring-0 aria-invalid:ring-0 dark:bg-transparent flex-1\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction InputGroupTextarea({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<\"textarea\">) {\r\n  return (\r\n    <Textarea\r\n      data-slot=\"input-group-control\"\r\n      className={cn(\r\n        \"rounded-none border-0 bg-transparent py-2 shadow-none ring-0 focus-visible:ring-0 aria-invalid:ring-0 dark:bg-transparent flex-1 resize-none\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nexport {\r\n  InputGroup,\r\n  InputGroupAddon,\r\n  InputGroupButton,\r\n  InputGroupText,\r\n  InputGroupInput,\r\n  InputGroupTextarea,\r\n};\r\n"
  },
  {
    "path": "studio/frontend/src/components/ui/input.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type * as React from \"react\";\r\n\r\nimport { cn } from \"@/lib/utils\";\r\n\r\nfunction Input({ className, type, ...props }: React.ComponentProps<\"input\">) {\r\n  return (\r\n    <input\r\n      type={type}\r\n      data-slot=\"input\"\r\n      className={cn(\r\n        \"bg-input/30 border-input focus-visible:border-ring focus-visible:ring-ring/50 aria-invalid:ring-destructive/20 dark:aria-invalid:ring-destructive/40 aria-invalid:border-destructive dark:aria-invalid:border-destructive/50 h-9 rounded-4xl border px-3 py-1 text-base transition-colors file:h-7 file:text-sm file:font-medium focus-visible:ring-[3px] aria-invalid:ring-[3px] md:text-sm file:text-foreground placeholder:text-muted-foreground w-full min-w-0 outline-none file:inline-flex file:border-0 file:bg-transparent disabled:pointer-events-none disabled:cursor-not-allowed disabled:opacity-50\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nexport { Input };\r\n"
  },
  {
    "path": "studio/frontend/src/components/ui/label.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"use client\";\r\n\r\nimport { Label as LabelPrimitive } from \"radix-ui\";\r\nimport type * as React from \"react\";\r\n\r\nimport { cn } from \"@/lib/utils\";\r\n\r\nfunction Label({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<typeof LabelPrimitive.Root>) {\r\n  return (\r\n    <LabelPrimitive.Root\r\n      data-slot=\"label\"\r\n      className={cn(\r\n        \"gap-2 text-sm leading-none font-medium group-data-[disabled=true]:opacity-50 peer-disabled:opacity-50 flex items-center select-none group-data-[disabled=true]:pointer-events-none peer-disabled:cursor-not-allowed\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nexport { Label };\r\n"
  },
  {
    "path": "studio/frontend/src/components/ui/light-rays.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { motion } from \"motion/react\";\r\nimport { type CSSProperties, useState } from \"react\";\r\n\r\nimport { cn } from \"@/lib/utils\";\r\n\r\ninterface LightRaysProps extends React.HTMLAttributes<HTMLDivElement> {\r\n  ref?: React.Ref<HTMLDivElement>;\r\n  count?: number;\r\n  color?: string;\r\n  blur?: number;\r\n  speed?: number;\r\n  length?: string;\r\n}\r\n\r\ntype LightRay = {\r\n  id: string;\r\n  left: number;\r\n  rotate: number;\r\n  width: number;\r\n  swing: number;\r\n  delay: number;\r\n  duration: number;\r\n  intensity: number;\r\n};\r\n\r\nconst createRays = (count: number, cycle: number): LightRay[] => {\r\n  if (count <= 0) return [];\r\n\r\n  return Array.from({ length: count }, (_, index) => {\r\n    const left = 8 + Math.random() * 84;\r\n    const rotate = -28 + Math.random() * 56;\r\n    const width = 160 + Math.random() * 160;\r\n    const swing = 0.8 + Math.random() * 1.8;\r\n    const delay = Math.random() * cycle;\r\n    const duration = cycle * (0.75 + Math.random() * 0.5);\r\n    const intensity = 0.6 + Math.random() * 0.5;\r\n\r\n    return {\r\n      id: `${index}-${Math.round(left * 10)}`,\r\n      left,\r\n      rotate,\r\n      width,\r\n      swing,\r\n      delay,\r\n      duration,\r\n      intensity,\r\n    };\r\n  });\r\n};\r\n\r\nconst Ray = ({\r\n  left,\r\n  rotate,\r\n  width,\r\n  swing,\r\n  delay,\r\n  duration,\r\n  intensity,\r\n}: LightRay) => {\r\n  return (\r\n    <motion.div\r\n      className=\"pointer-events-none absolute -top-[12%] left-[var(--ray-left)] h-[var(--light-rays-length)] w-[var(--ray-width)] origin-top -translate-x-1/2 rounded-full bg-gradient-to-b from-[color-mix(in_srgb,var(--light-rays-color)_70%,transparent)] to-transparent opacity-0 mix-blend-screen blur-[var(--light-rays-blur)]\"\r\n      style={\r\n        {\r\n          \"--ray-left\": `${left}%`,\r\n          \"--ray-width\": `${width}px`,\r\n        } as CSSProperties\r\n      }\r\n      initial={{ rotate: rotate }}\r\n      animate={{\r\n        opacity: [0, intensity, 0],\r\n        rotate: [rotate - swing, rotate + swing, rotate - swing],\r\n      }}\r\n      transition={{\r\n        duration: duration,\r\n        repeat: Number.POSITIVE_INFINITY,\r\n        ease: \"easeInOut\",\r\n        delay: delay,\r\n        repeatDelay: duration * 0.1,\r\n      }}\r\n    />\r\n  );\r\n};\r\n\r\nexport function LightRays({\r\n  className,\r\n  style,\r\n  count = 7,\r\n  color = \"rgba(160, 210, 255, 0.2)\",\r\n  blur = 36,\r\n  speed = 14,\r\n  length = \"70vh\",\r\n  ref,\r\n  ...props\r\n}: LightRaysProps) {\r\n  const cycleDuration = Math.max(speed, 0.1);\r\n  const [rays] = useState(() => createRays(count, cycleDuration));\r\n\r\n  return (\r\n    <div\r\n      ref={ref}\r\n      className={cn(\r\n        \"pointer-events-none absolute inset-0 isolate overflow-hidden rounded-[inherit]\",\r\n        className,\r\n      )}\r\n      style={\r\n        {\r\n          \"--light-rays-color\": color,\r\n          \"--light-rays-blur\": `${blur}px`,\r\n          \"--light-rays-length\": length,\r\n          ...style,\r\n        } as CSSProperties\r\n      }\r\n      {...props}\r\n    >\r\n      <div className=\"absolute inset-0 overflow-hidden\">\r\n        <div\r\n          aria-hidden\r\n          className=\"absolute inset-0 opacity-60\"\r\n          style={\r\n            {\r\n              background:\r\n                \"radial-gradient(circle at 20% 15%, color-mix(in srgb, var(--light-rays-color) 45%, transparent), transparent 70%)\",\r\n            } as CSSProperties\r\n          }\r\n        />\r\n        <div\r\n          aria-hidden\r\n          className=\"absolute inset-0 opacity-60\"\r\n          style={\r\n            {\r\n              background:\r\n                \"radial-gradient(circle at 80% 10%, color-mix(in srgb, var(--light-rays-color) 35%, transparent), transparent 75%)\",\r\n            } as CSSProperties\r\n          }\r\n        />\r\n        {rays.map((ray) => (\r\n          <Ray key={ray.id} {...ray} />\r\n        ))}\r\n      </div>\r\n    </div>\r\n  );\r\n}\r\n"
  },
  {
    "path": "studio/frontend/src/components/ui/menubar.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"use client\";\r\n\r\nimport { Menubar as MenubarPrimitive } from \"radix-ui\";\r\nimport type * as React from \"react\";\r\n\r\nimport { cn } from \"@/lib/utils\";\r\nimport { ArrowRight01Icon, Tick02Icon } from \"@hugeicons/core-free-icons\";\r\nimport { HugeiconsIcon } from \"@hugeicons/react\";\r\n\r\nfunction Menubar({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<typeof MenubarPrimitive.Root>) {\r\n  return (\r\n    <MenubarPrimitive.Root\r\n      data-slot=\"menubar\"\r\n      className={cn(\r\n        \"bg-background h-9 rounded-2xl border p-1 flex items-center\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction MenubarMenu({\r\n  ...props\r\n}: React.ComponentProps<typeof MenubarPrimitive.Menu>) {\r\n  return <MenubarPrimitive.Menu data-slot=\"menubar-menu\" {...props} />;\r\n}\r\n\r\nfunction MenubarGroup({\r\n  ...props\r\n}: React.ComponentProps<typeof MenubarPrimitive.Group>) {\r\n  return <MenubarPrimitive.Group data-slot=\"menubar-group\" {...props} />;\r\n}\r\n\r\nfunction MenubarPortal({\r\n  ...props\r\n}: React.ComponentProps<typeof MenubarPrimitive.Portal>) {\r\n  return <MenubarPrimitive.Portal data-slot=\"menubar-portal\" {...props} />;\r\n}\r\n\r\nfunction MenubarRadioGroup({\r\n  ...props\r\n}: React.ComponentProps<typeof MenubarPrimitive.RadioGroup>) {\r\n  return (\r\n    <MenubarPrimitive.RadioGroup data-slot=\"menubar-radio-group\" {...props} />\r\n  );\r\n}\r\n\r\nfunction MenubarTrigger({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<typeof MenubarPrimitive.Trigger>) {\r\n  return (\r\n    <MenubarPrimitive.Trigger\r\n      data-slot=\"menubar-trigger\"\r\n      className={cn(\r\n        \"hover:bg-muted aria-expanded:bg-muted rounded-xl px-2.5 py-1 text-sm font-medium flex items-center outline-hidden select-none\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction MenubarContent({\r\n  className,\r\n  align = \"start\",\r\n  alignOffset = -4,\r\n  sideOffset = 8,\r\n  ...props\r\n}: React.ComponentProps<typeof MenubarPrimitive.Content>) {\r\n  return (\r\n    <MenubarPortal>\r\n      <MenubarPrimitive.Content\r\n        data-slot=\"menubar-content\"\r\n        align={align}\r\n        alignOffset={alignOffset}\r\n        sideOffset={sideOffset}\r\n        className={cn(\r\n          \"bg-popover text-popover-foreground data-open:animate-in data-open:fade-in-0 data-open:zoom-in-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2 ring-foreground/5 min-w-48 rounded-2xl p-1 shadow-2xl ring-1 duration-100 z-50 origin-(--radix-menubar-content-transform-origin) overflow-hidden\",\r\n          className,\r\n        )}\r\n        {...props}\r\n      />\r\n    </MenubarPortal>\r\n  );\r\n}\r\n\r\nfunction MenubarItem({\r\n  className,\r\n  inset,\r\n  variant = \"default\",\r\n  ...props\r\n}: React.ComponentProps<typeof MenubarPrimitive.Item> & {\r\n  inset?: boolean;\r\n  variant?: \"default\" | \"destructive\";\r\n}) {\r\n  return (\r\n    <MenubarPrimitive.Item\r\n      data-slot=\"menubar-item\"\r\n      data-inset={inset}\r\n      data-variant={variant}\r\n      className={cn(\r\n        \"focus:bg-accent focus:text-accent-foreground data-[variant=destructive]:text-destructive data-[variant=destructive]:focus:bg-destructive/10 dark:data-[variant=destructive]:focus:bg-destructive/20 data-[variant=destructive]:focus:text-destructive data-[variant=destructive]:*:[svg]:!text-destructive not-data-[variant=destructive]:focus:**:text-accent-foreground gap-2.5 rounded-xl px-3 py-2 text-sm data-[disabled]:opacity-50 data-[inset]:pl-8 [&_svg:not([class*='size-'])]:size-4 group/menubar-item relative flex cursor-default items-center outline-hidden select-none data-[disabled]:pointer-events-none [&_svg]:pointer-events-none [&_svg]:shrink-0\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction MenubarCheckboxItem({\r\n  className,\r\n  children,\r\n  checked,\r\n  ...props\r\n}: React.ComponentProps<typeof MenubarPrimitive.CheckboxItem>) {\r\n  return (\r\n    <MenubarPrimitive.CheckboxItem\r\n      data-slot=\"menubar-checkbox-item\"\r\n      className={cn(\r\n        \"focus:bg-accent focus:text-accent-foreground focus:**:text-accent-foreground gap-2.5 rounded-xl py-2 pr-3 pl-8 text-sm data-disabled:opacity-50 relative flex cursor-default items-center outline-hidden select-none data-disabled:pointer-events-none [&_svg]:pointer-events-none [&_svg]:shrink-0\",\r\n        className,\r\n      )}\r\n      checked={checked}\r\n      {...props}\r\n    >\r\n      <span className=\"left-2 size-4 [&_svg:not([class*='size-'])]:size-4 pointer-events-none absolute flex items-center justify-center\">\r\n        <MenubarPrimitive.ItemIndicator>\r\n          <HugeiconsIcon icon={Tick02Icon} strokeWidth={2} />\r\n        </MenubarPrimitive.ItemIndicator>\r\n      </span>\r\n      {children}\r\n    </MenubarPrimitive.CheckboxItem>\r\n  );\r\n}\r\n\r\nfunction MenubarRadioItem({\r\n  className,\r\n  children,\r\n  ...props\r\n}: React.ComponentProps<typeof MenubarPrimitive.RadioItem>) {\r\n  return (\r\n    <MenubarPrimitive.RadioItem\r\n      data-slot=\"menubar-radio-item\"\r\n      className={cn(\r\n        \"focus:bg-accent focus:text-accent-foreground focus:**:text-accent-foreground gap-2.5 rounded-xl py-2 pr-3 pl-8 text-sm data-disabled:opacity-50 [&_svg:not([class*='size-'])]:size-4 relative flex cursor-default items-center outline-hidden select-none data-disabled:pointer-events-none [&_svg]:pointer-events-none [&_svg]:shrink-0\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    >\r\n      <span className=\"left-2 size-4 [&_svg:not([class*='size-'])]:size-4 pointer-events-none absolute flex items-center justify-center\">\r\n        <MenubarPrimitive.ItemIndicator>\r\n          <HugeiconsIcon icon={Tick02Icon} strokeWidth={2} />\r\n        </MenubarPrimitive.ItemIndicator>\r\n      </span>\r\n      {children}\r\n    </MenubarPrimitive.RadioItem>\r\n  );\r\n}\r\n\r\nfunction MenubarLabel({\r\n  className,\r\n  inset,\r\n  ...props\r\n}: React.ComponentProps<typeof MenubarPrimitive.Label> & {\r\n  inset?: boolean;\r\n}) {\r\n  return (\r\n    <MenubarPrimitive.Label\r\n      data-slot=\"menubar-label\"\r\n      data-inset={inset}\r\n      className={cn(\r\n        \"text-muted-foreground px-3.5 py-2.5 text-xs data-[inset]:pl-8\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction MenubarSeparator({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<typeof MenubarPrimitive.Separator>) {\r\n  return (\r\n    <MenubarPrimitive.Separator\r\n      data-slot=\"menubar-separator\"\r\n      className={cn(\"bg-border/50 -mx-1 my-1 h-px\", className)}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction MenubarShortcut({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<\"span\">) {\r\n  return (\r\n    <span\r\n      data-slot=\"menubar-shortcut\"\r\n      className={cn(\r\n        \"text-muted-foreground group-focus/menubar-item:text-accent-foreground text-xs tracking-widest ml-auto\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction MenubarSub({\r\n  ...props\r\n}: React.ComponentProps<typeof MenubarPrimitive.Sub>) {\r\n  return <MenubarPrimitive.Sub data-slot=\"menubar-sub\" {...props} />;\r\n}\r\n\r\nfunction MenubarSubTrigger({\r\n  className,\r\n  inset,\r\n  children,\r\n  ...props\r\n}: React.ComponentProps<typeof MenubarPrimitive.SubTrigger> & {\r\n  inset?: boolean;\r\n}) {\r\n  return (\r\n    <MenubarPrimitive.SubTrigger\r\n      data-slot=\"menubar-sub-trigger\"\r\n      data-inset={inset}\r\n      className={cn(\r\n        \"focus:bg-accent focus:text-accent-foreground data-open:bg-accent data-open:text-accent-foreground gap-2 rounded-xl px-3 py-2 text-sm data-[inset]:pl-8 [&_svg:not([class*='size-'])]:size-4 flex cursor-default items-center outline-none select-none\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    >\r\n      {children}\r\n      <HugeiconsIcon\r\n        icon={ArrowRight01Icon}\r\n        strokeWidth={2}\r\n        className=\"ml-auto size-4\"\r\n      />\r\n    </MenubarPrimitive.SubTrigger>\r\n  );\r\n}\r\n\r\nfunction MenubarSubContent({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<typeof MenubarPrimitive.SubContent>) {\r\n  return (\r\n    <MenubarPrimitive.SubContent\r\n      data-slot=\"menubar-sub-content\"\r\n      className={cn(\r\n        \"bg-popover text-popover-foreground data-open:animate-in data-closed:animate-out data-closed:fade-out-0 data-open:fade-in-0 data-closed:zoom-out-95 data-open:zoom-in-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2 ring-foreground/5 min-w-32 rounded-2xl p-1 shadow-2xl ring-1 duration-100 z-50 origin-(--radix-menubar-content-transform-origin) overflow-hidden\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nexport {\r\n  Menubar,\r\n  MenubarPortal,\r\n  MenubarMenu,\r\n  MenubarTrigger,\r\n  MenubarContent,\r\n  MenubarGroup,\r\n  MenubarSeparator,\r\n  MenubarLabel,\r\n  MenubarItem,\r\n  MenubarShortcut,\r\n  MenubarCheckboxItem,\r\n  MenubarRadioGroup,\r\n  MenubarRadioItem,\r\n  MenubarSub,\r\n  MenubarSubTrigger,\r\n  MenubarSubContent,\r\n};\r\n"
  },
  {
    "path": "studio/frontend/src/components/ui/navigation-menu.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n/* eslint-disable react-refresh/only-export-components */\r\n\r\nimport { cva } from \"class-variance-authority\";\r\nimport { NavigationMenu as NavigationMenuPrimitive } from \"radix-ui\";\r\nimport type * as React from \"react\";\r\n\r\nimport { cn } from \"@/lib/utils\";\r\nimport { ArrowDown01Icon } from \"@hugeicons/core-free-icons\";\r\nimport { HugeiconsIcon } from \"@hugeicons/react\";\r\n\r\nexport function NavigationMenu({\r\n  className,\r\n  children,\r\n  viewport = true,\r\n  ...props\r\n}: React.ComponentProps<typeof NavigationMenuPrimitive.Root> & {\r\n  viewport?: boolean;\r\n}): React.ReactElement {\r\n  return (\r\n    <NavigationMenuPrimitive.Root\r\n      data-slot=\"navigation-menu\"\r\n      data-viewport={viewport}\r\n      className={cn(\r\n        \"max-w-max group/navigation-menu relative flex max-w-max flex-1 items-center justify-center\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    >\r\n      {children}\r\n      {viewport && <NavigationMenuViewport />}\r\n    </NavigationMenuPrimitive.Root>\r\n  );\r\n}\r\n\r\nexport function NavigationMenuList({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<\r\n  typeof NavigationMenuPrimitive.List\r\n>): React.ReactElement {\r\n  return (\r\n    <NavigationMenuPrimitive.List\r\n      data-slot=\"navigation-menu-list\"\r\n      className={cn(\r\n        \"gap-0 group flex flex-1 list-none items-center justify-center\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nexport function NavigationMenuItem({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<\r\n  typeof NavigationMenuPrimitive.Item\r\n>): React.ReactElement {\r\n  return (\r\n    <NavigationMenuPrimitive.Item\r\n      data-slot=\"navigation-menu-item\"\r\n      className={cn(\"relative\", className)}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nexport const navigationMenuTriggerStyle = cva(\r\n  \"bg-background hover:bg-muted focus:bg-muted data-open:hover:bg-muted data-open:focus:bg-muted data-open:bg-muted/50 focus-visible:ring-ring/50 data-popup-open:bg-muted/50 data-popup-open:hover:bg-muted rounded-2xl px-4.5 py-2.5 text-sm font-medium transition-all focus-visible:ring-[3px] focus-visible:outline-1 disabled:opacity-50 group/navigation-menu-trigger inline-flex h-9 w-max items-center justify-center disabled:pointer-events-none outline-none\",\r\n);\r\n\r\nexport function NavigationMenuTrigger({\r\n  className,\r\n  children,\r\n  ...props\r\n}: React.ComponentProps<\r\n  typeof NavigationMenuPrimitive.Trigger\r\n>): React.ReactElement {\r\n  return (\r\n    <NavigationMenuPrimitive.Trigger\r\n      data-slot=\"navigation-menu-trigger\"\r\n      className={cn(navigationMenuTriggerStyle(), \"group\", className)}\r\n      {...props}\r\n    >\r\n      {children}{\" \"}\r\n      <HugeiconsIcon\r\n        icon={ArrowDown01Icon}\r\n        strokeWidth={2}\r\n        className=\"relative top-[1px] ml-1 size-3 transition duration-300 group-data-open/navigation-menu-trigger:rotate-180 group-data-popup-open/navigation-menu-trigger:rotate-180\"\r\n        aria-hidden=\"true\"\r\n      />\r\n    </NavigationMenuPrimitive.Trigger>\r\n  );\r\n}\r\n\r\nexport function NavigationMenuContent({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<\r\n  typeof NavigationMenuPrimitive.Content\r\n>): React.ReactElement {\r\n  return (\r\n    <NavigationMenuPrimitive.Content\r\n      data-slot=\"navigation-menu-content\"\r\n      className={cn(\r\n        \"data-[motion^=from-]:animate-in data-[motion^=to-]:animate-out data-[motion^=from-]:fade-in data-[motion^=to-]:fade-out data-[motion=from-end]:slide-in-from-right-52 data-[motion=from-start]:slide-in-from-left-52 data-[motion=to-end]:slide-out-to-right-52 data-[motion=to-start]:slide-out-to-left-52 group-data-[viewport=false]/navigation-menu:bg-popover group-data-[viewport=false]/navigation-menu:text-popover-foreground group-data-[viewport=false]/navigation-menu:data-open:animate-in group-data-[viewport=false]/navigation-menu:data-closed:animate-out group-data-[viewport=false]/navigation-menu:data-closed:zoom-out-95 group-data-[viewport=false]/navigation-menu:data-open:zoom-in-95 group-data-[viewport=false]/navigation-menu:data-open:fade-in-0 group-data-[viewport=false]/navigation-menu:data-closed:fade-out-0 group-data-[viewport=false]/navigation-menu:ring-foreground/5 p-2.5 pr-3 ease-[cubic-bezier(0.22,1,0.36,1)] group-data-[viewport=false]/navigation-menu:rounded-2xl group-data-[viewport=false]/navigation-menu:shadow-2xl group-data-[viewport=false]/navigation-menu:ring-1 group-data-[viewport=false]/navigation-menu:duration-300 top-0 left-0 w-full group-data-[viewport=false]/navigation-menu:top-full group-data-[viewport=false]/navigation-menu:mt-1.5 group-data-[viewport=false]/navigation-menu:overflow-hidden **:data-[slot=navigation-menu-link]:focus:ring-0 **:data-[slot=navigation-menu-link]:focus:outline-none md:absolute md:w-auto\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nexport function NavigationMenuViewport({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<\r\n  typeof NavigationMenuPrimitive.Viewport\r\n>): React.ReactElement {\r\n  return (\r\n    <div\r\n      className={cn(\r\n        \"absolute top-full left-0 isolate z-50 flex justify-center\",\r\n      )}\r\n    >\r\n      <NavigationMenuPrimitive.Viewport\r\n        data-slot=\"navigation-menu-viewport\"\r\n        className={cn(\r\n          \"bg-popover text-popover-foreground data-open:animate-in data-closed:animate-out data-closed:zoom-out-95 data-open:zoom-in-90 ring-foreground/5 rounded-2xl shadow-2xl ring-1 duration-100 origin-top-center relative mt-1.5 h-[var(--radix-navigation-menu-viewport-height)] w-full overflow-hidden md:w-[var(--radix-navigation-menu-viewport-width)]\",\r\n          className,\r\n        )}\r\n        {...props}\r\n      />\r\n    </div>\r\n  );\r\n}\r\n\r\nexport function NavigationMenuLink({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<\r\n  typeof NavigationMenuPrimitive.Link\r\n>): React.ReactElement {\r\n  return (\r\n    <NavigationMenuPrimitive.Link\r\n      data-slot=\"navigation-menu-link\"\r\n      className={cn(\r\n        \"data-[active=true]:focus:bg-muted data-[active=true]:hover:bg-muted data-[active=true]:bg-muted/50 focus-visible:ring-ring/50 hover:bg-muted focus:bg-muted flex items-center gap-1.5 rounded-xl p-3 text-sm transition-all outline-none focus-visible:ring-[3px] focus-visible:outline-1 [&_svg:not([class*='size-'])]:size-4\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nexport function NavigationMenuIndicator({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<\r\n  typeof NavigationMenuPrimitive.Indicator\r\n>): React.ReactElement {\r\n  return (\r\n    <NavigationMenuPrimitive.Indicator\r\n      data-slot=\"navigation-menu-indicator\"\r\n      className={cn(\r\n        \"data-[state=visible]:animate-in data-[state=hidden]:animate-out data-[state=hidden]:fade-out data-[state=visible]:fade-in top-full z-[1] flex h-1.5 items-end justify-center overflow-hidden\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    >\r\n      <div className=\"bg-border rounded-tl-sm shadow-md relative top-[60%] h-2 w-2 rotate-45\" />\r\n    </NavigationMenuPrimitive.Indicator>\r\n  );\r\n}\r\n"
  },
  {
    "path": "studio/frontend/src/components/ui/pagination.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type * as React from \"react\";\r\n\r\nimport { Button } from \"@/components/ui/button\";\r\nimport { cn } from \"@/lib/utils\";\r\nimport {\r\n  ArrowLeft01Icon,\r\n  ArrowRight01Icon,\r\n  MoreHorizontalCircle01Icon,\r\n} from \"@hugeicons/core-free-icons\";\r\nimport { HugeiconsIcon } from \"@hugeicons/react\";\r\n\r\nfunction Pagination({ className, ...props }: React.ComponentProps<\"nav\">) {\r\n  return (\r\n    <nav\r\n      aria-label=\"pagination\"\r\n      data-slot=\"pagination\"\r\n      className={cn(\"mx-auto flex w-full justify-center\", className)}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction PaginationContent({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<\"ul\">) {\r\n  return (\r\n    <ul\r\n      data-slot=\"pagination-content\"\r\n      className={cn(\"gap-1 flex items-center\", className)}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction PaginationItem({ ...props }: React.ComponentProps<\"li\">) {\r\n  return <li data-slot=\"pagination-item\" {...props} />;\r\n}\r\n\r\ntype PaginationLinkProps = {\r\n  isActive?: boolean;\r\n} & Pick<React.ComponentProps<typeof Button>, \"size\"> &\r\n  React.ComponentProps<\"a\">;\r\n\r\nfunction PaginationLink({\r\n  className,\r\n  isActive,\r\n  size = \"icon\",\r\n  ...props\r\n}: PaginationLinkProps) {\r\n  return (\r\n    <Button\r\n      asChild\r\n      variant={isActive ? \"outline\" : \"ghost\"}\r\n      size={size}\r\n      className={cn(className)}\r\n    >\r\n      <a\r\n        aria-current={isActive ? \"page\" : undefined}\r\n        data-slot=\"pagination-link\"\r\n        data-active={isActive}\r\n        {...props}\r\n      />\r\n    </Button>\r\n  );\r\n}\r\n\r\nfunction PaginationPrevious({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<typeof PaginationLink>) {\r\n  return (\r\n    <PaginationLink\r\n      aria-label=\"Go to previous page\"\r\n      size=\"default\"\r\n      className={cn(\"pl-2!\", className)}\r\n      {...props}\r\n    >\r\n      <HugeiconsIcon\r\n        icon={ArrowLeft01Icon}\r\n        strokeWidth={2}\r\n        data-icon=\"inline-start\"\r\n      />\r\n      <span className=\"hidden sm:block\">Previous</span>\r\n    </PaginationLink>\r\n  );\r\n}\r\n\r\nfunction PaginationNext({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<typeof PaginationLink>) {\r\n  return (\r\n    <PaginationLink\r\n      aria-label=\"Go to next page\"\r\n      size=\"default\"\r\n      className={cn(\"pr-2!\", className)}\r\n      {...props}\r\n    >\r\n      <span className=\"hidden sm:block\">Next</span>\r\n      <HugeiconsIcon\r\n        icon={ArrowRight01Icon}\r\n        strokeWidth={2}\r\n        data-icon=\"inline-end\"\r\n      />\r\n    </PaginationLink>\r\n  );\r\n}\r\n\r\nfunction PaginationEllipsis({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<\"span\">) {\r\n  return (\r\n    <span\r\n      aria-hidden\r\n      data-slot=\"pagination-ellipsis\"\r\n      className={cn(\r\n        \"size-9 items-center justify-center [&_svg:not([class*='size-'])]:size-4 flex items-center justify-center\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    >\r\n      <HugeiconsIcon icon={MoreHorizontalCircle01Icon} strokeWidth={2} />\r\n      <span className=\"sr-only\">More pages</span>\r\n    </span>\r\n  );\r\n}\r\n\r\nexport {\r\n  Pagination,\r\n  PaginationContent,\r\n  PaginationEllipsis,\r\n  PaginationItem,\r\n  PaginationLink,\r\n  PaginationNext,\r\n  PaginationPrevious,\r\n};\r\n"
  },
  {
    "path": "studio/frontend/src/components/ui/popover.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"use client\";\r\n\r\nimport { Popover as PopoverPrimitive } from \"radix-ui\";\r\nimport type * as React from \"react\";\r\n\r\nimport { cn } from \"@/lib/utils\";\r\n\r\nfunction Popover({\r\n  ...props\r\n}: React.ComponentProps<typeof PopoverPrimitive.Root>) {\r\n  return <PopoverPrimitive.Root data-slot=\"popover\" {...props} />;\r\n}\r\n\r\nfunction PopoverTrigger({\r\n  ...props\r\n}: React.ComponentProps<typeof PopoverPrimitive.Trigger>) {\r\n  return <PopoverPrimitive.Trigger data-slot=\"popover-trigger\" {...props} />;\r\n}\r\n\r\nfunction PopoverContent({\r\n  className,\r\n  align = \"center\",\r\n  sideOffset = 4,\r\n  ...props\r\n}: React.ComponentProps<typeof PopoverPrimitive.Content>) {\r\n  return (\r\n    <PopoverPrimitive.Portal>\r\n      <PopoverPrimitive.Content\r\n        data-slot=\"popover-content\"\r\n        align={align}\r\n        sideOffset={sideOffset}\r\n        className={cn(\r\n          \"bg-popover text-popover-foreground data-open:animate-in data-closed:animate-out data-closed:fade-out-0 data-open:fade-in-0 data-closed:zoom-out-95 data-open:zoom-in-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2 shadow-border ring-1 ring-border flex flex-col gap-4 rounded-lg p-4 text-sm duration-100 z-50 w-72 origin-(--radix-popover-content-transform-origin) outline-hidden\",\r\n          className,\r\n        )}\r\n        {...props}\r\n      />\r\n    </PopoverPrimitive.Portal>\r\n  );\r\n}\r\n\r\nfunction PopoverAnchor({\r\n  ...props\r\n}: React.ComponentProps<typeof PopoverPrimitive.Anchor>) {\r\n  return <PopoverPrimitive.Anchor data-slot=\"popover-anchor\" {...props} />;\r\n}\r\n\r\nfunction PopoverHeader({ className, ...props }: React.ComponentProps<\"div\">) {\r\n  return (\r\n    <div\r\n      data-slot=\"popover-header\"\r\n      className={cn(\"flex flex-col gap-1 text-sm\", className)}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction PopoverTitle({ className, ...props }: React.ComponentProps<\"h2\">) {\r\n  return (\r\n    <div\r\n      data-slot=\"popover-title\"\r\n      className={cn(\"text-base font-medium\", className)}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction PopoverDescription({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<\"p\">) {\r\n  return (\r\n    <p\r\n      data-slot=\"popover-description\"\r\n      className={cn(\"text-muted-foreground\", className)}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nexport {\r\n  Popover,\r\n  PopoverAnchor,\r\n  PopoverContent,\r\n  PopoverDescription,\r\n  PopoverHeader,\r\n  PopoverTitle,\r\n  PopoverTrigger,\r\n};\r\n"
  },
  {
    "path": "studio/frontend/src/components/ui/progress.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"use client\";\r\n\r\nimport { Progress as ProgressPrimitive } from \"radix-ui\";\r\nimport type * as React from \"react\";\r\n\r\nimport { cn } from \"@/lib/utils\";\r\n\r\nfunction Progress({\r\n  className,\r\n  value,\r\n  ...props\r\n}: React.ComponentProps<typeof ProgressPrimitive.Root>) {\r\n  return (\r\n    <ProgressPrimitive.Root\r\n      data-slot=\"progress\"\r\n      className={cn(\r\n        \"bg-foreground/[0.06] h-3 rounded-4xl relative flex w-full items-center overflow-x-hidden\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    >\r\n      <ProgressPrimitive.Indicator\r\n        data-slot=\"progress-indicator\"\r\n        className=\"bg-primary size-full flex-1 transition-all\"\r\n        style={{ transform: `translateX(-${100 - (value || 0)}%)` }}\r\n      />\r\n    </ProgressPrimitive.Root>\r\n  );\r\n}\r\n\r\nexport { Progress };\r\n"
  },
  {
    "path": "studio/frontend/src/components/ui/radio-group.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { RadioGroup as RadioGroupPrimitive } from \"radix-ui\";\r\nimport type * as React from \"react\";\r\n\r\nimport { cn } from \"@/lib/utils\";\r\nimport { CircleIcon } from \"@hugeicons/core-free-icons\";\r\nimport { HugeiconsIcon } from \"@hugeicons/react\";\r\n\r\nfunction RadioGroup({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<typeof RadioGroupPrimitive.Root>) {\r\n  return (\r\n    <RadioGroupPrimitive.Root\r\n      data-slot=\"radio-group\"\r\n      className={cn(\"grid gap-3 w-full\", className)}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction RadioGroupItem({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<typeof RadioGroupPrimitive.Item>) {\r\n  return (\r\n    <RadioGroupPrimitive.Item\r\n      data-slot=\"radio-group-item\"\r\n      className={cn(\r\n        \"border-input text-primary dark:bg-input/30 focus-visible:border-ring focus-visible:ring-ring/50 aria-invalid:ring-destructive/20 dark:aria-invalid:ring-destructive/40 aria-invalid:border-destructive dark:aria-invalid:border-destructive/50 data-checked:bg-primary data-checked:border-primary flex size-4 rounded-full transition-none focus-visible:ring-[3px] aria-invalid:ring-[3px] group/radio-group-item peer relative aspect-square shrink-0 border outline-none after:absolute after:-inset-x-3 after:-inset-y-2 disabled:cursor-not-allowed disabled:opacity-50\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    >\r\n      <RadioGroupPrimitive.Indicator\r\n        data-slot=\"radio-group-indicator\"\r\n        className=\"group-aria-invalid/radio-group-item:text-destructive flex size-4 items-center justify-center text-white\"\r\n      >\r\n        <HugeiconsIcon\r\n          icon={CircleIcon}\r\n          strokeWidth={2}\r\n          className=\"absolute top-1/2 left-1/2 size-2 -translate-x-1/2 -translate-y-1/2 fill-current\"\r\n        />\r\n      </RadioGroupPrimitive.Indicator>\r\n    </RadioGroupPrimitive.Item>\r\n  );\r\n}\r\n\r\nexport { RadioGroup, RadioGroupItem };\r\n"
  },
  {
    "path": "studio/frontend/src/components/ui/resizable.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type * as React from \"react\";\r\nimport * as ResizablePrimitive from \"react-resizable-panels\";\r\n\r\nimport { cn } from \"@/lib/utils\";\r\n\r\nfunction ResizablePanelGroup({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<typeof ResizablePrimitive.Group>): React.ReactElement {\r\n  return (\r\n    <ResizablePrimitive.Group\r\n      data-slot=\"resizable-panel-group\"\r\n      className={cn(\r\n        \"flex h-full w-full data-[panel-group-direction=vertical]:flex-col\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction ResizablePanel({\r\n  ...props\r\n}: React.ComponentProps<typeof ResizablePrimitive.Panel>): React.ReactElement {\r\n  return <ResizablePrimitive.Panel data-slot=\"resizable-panel\" {...props} />;\r\n}\r\n\r\nfunction ResizableHandle({\r\n  withHandle,\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<typeof ResizablePrimitive.Separator> & {\r\n  withHandle?: boolean;\r\n}): React.ReactElement {\r\n  return (\r\n    <ResizablePrimitive.Separator\r\n      data-slot=\"resizable-handle\"\r\n      className={cn(\r\n        \"bg-border focus-visible:ring-ring relative flex w-px items-center justify-center after:absolute after:inset-y-0 after:left-1/2 after:w-1 after:-translate-x-1/2 focus-visible:ring-1 focus-visible:ring-offset-1 focus-visible:outline-hidden data-[panel-group-direction=vertical]:h-px data-[panel-group-direction=vertical]:w-full data-[panel-group-direction=vertical]:after:left-0 data-[panel-group-direction=vertical]:after:h-1 data-[panel-group-direction=vertical]:after:w-full data-[panel-group-direction=vertical]:after:translate-x-0 data-[panel-group-direction=vertical]:after:-translate-y-1/2 [&[data-panel-group-direction=vertical]>div]:rotate-90\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    >\r\n      {withHandle && (\r\n        <div className=\"bg-border h-6 w-1 rounded-lg z-10 flex shrink-0\" />\r\n      )}\r\n    </ResizablePrimitive.Separator>\r\n  );\r\n}\r\n\r\nexport { ResizablePanelGroup, ResizablePanel, ResizableHandle };\r\n"
  },
  {
    "path": "studio/frontend/src/components/ui/scroll-area.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"use client\";\r\n\r\nimport { ScrollArea as ScrollAreaPrimitive } from \"radix-ui\";\r\nimport type * as React from \"react\";\r\n\r\nimport { cn } from \"@/lib/utils\";\r\n\r\nfunction ScrollArea({\r\n  className,\r\n  children,\r\n  ...props\r\n}: React.ComponentProps<typeof ScrollAreaPrimitive.Root>) {\r\n  return (\r\n    <ScrollAreaPrimitive.Root\r\n      data-slot=\"scroll-area\"\r\n      className={cn(\"relative\", className)}\r\n      {...props}\r\n    >\r\n      <ScrollAreaPrimitive.Viewport\r\n        data-slot=\"scroll-area-viewport\"\r\n        className=\"focus-visible:ring-ring/50 size-full rounded-[inherit] transition-[color,box-shadow] outline-none focus-visible:ring-[3px] focus-visible:outline-1\"\r\n      >\r\n        {children}\r\n      </ScrollAreaPrimitive.Viewport>\r\n      <ScrollBar />\r\n      <ScrollAreaPrimitive.Corner />\r\n    </ScrollAreaPrimitive.Root>\r\n  );\r\n}\r\n\r\nfunction ScrollBar({\r\n  className,\r\n  orientation = \"vertical\",\r\n  ...props\r\n}: React.ComponentProps<typeof ScrollAreaPrimitive.ScrollAreaScrollbar>) {\r\n  return (\r\n    <ScrollAreaPrimitive.ScrollAreaScrollbar\r\n      data-slot=\"scroll-area-scrollbar\"\r\n      data-orientation={orientation}\r\n      orientation={orientation}\r\n      className={cn(\r\n        \"data-horizontal:h-2.5 data-horizontal:flex-col data-horizontal:border-t data-horizontal:border-t-transparent data-vertical:h-full data-vertical:w-2.5 data-vertical:border-l data-vertical:border-l-transparent flex touch-none p-px transition-colors select-none\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    >\r\n      <ScrollAreaPrimitive.ScrollAreaThumb\r\n        data-slot=\"scroll-area-thumb\"\r\n        className=\"rounded-full bg-border relative flex-1\"\r\n      />\r\n    </ScrollAreaPrimitive.ScrollAreaScrollbar>\r\n  );\r\n}\r\n\r\nexport { ScrollArea, ScrollBar };\r\n"
  },
  {
    "path": "studio/frontend/src/components/ui/select.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"use client\";\r\n\r\nimport { Select as SelectPrimitive } from \"radix-ui\";\r\nimport type * as React from \"react\";\r\nimport { createContext, useContext, useState } from \"react\";\r\n\r\nimport { cn } from \"@/lib/utils\";\r\nimport { useDialogPortalContainer } from \"@/components/ui/dialog\";\r\nimport {\r\n  ArrowDown01Icon,\r\n  ArrowUp01Icon,\r\n  Tick02Icon,\r\n  UnfoldMoreIcon,\r\n} from \"@hugeicons/core-free-icons\";\r\nimport { HugeiconsIcon } from \"@hugeicons/react\";\r\n\r\nconst SelectOpenContext = createContext(false);\r\n\r\nfunction Select({\r\n  onOpenChange,\r\n  ...props\r\n}: React.ComponentProps<typeof SelectPrimitive.Root>) {\r\n  const [isOpen, setIsOpen] = useState(false);\r\n  return (\r\n    <SelectOpenContext.Provider value={isOpen}>\r\n      <SelectPrimitive.Root\r\n        data-slot=\"select\"\r\n        onOpenChange={(open) => {\r\n          setIsOpen(open);\r\n          onOpenChange?.(open);\r\n        }}\r\n        {...props}\r\n      />\r\n    </SelectOpenContext.Provider>\r\n  );\r\n}\r\n\r\nfunction SelectGroup({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<typeof SelectPrimitive.Group>) {\r\n  return (\r\n    <SelectPrimitive.Group\r\n      data-slot=\"select-group\"\r\n      className={cn(\"scroll-my-1 p-1\", className)}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction SelectValue({\r\n  ...props\r\n}: React.ComponentProps<typeof SelectPrimitive.Value>) {\r\n  return <SelectPrimitive.Value data-slot=\"select-value\" {...props} />;\r\n}\r\n\r\nfunction SelectTrigger({\r\n  className,\r\n  size = \"default\",\r\n  children,\r\n  ...props\r\n}: React.ComponentProps<typeof SelectPrimitive.Trigger> & {\r\n  size?: \"sm\" | \"default\";\r\n}) {\r\n  const isOpen = useContext(SelectOpenContext);\r\n\r\n  return (\r\n    <SelectPrimitive.Trigger\r\n      data-slot=\"select-trigger\"\r\n      data-size={size}\r\n      style={{\r\n        borderRadius: isOpen ? \"12px\" : undefined,\r\n        transition: isOpen\r\n          ? \"border-radius 0ms\"\r\n          : \"border-radius 150ms cubic-bezier(0.645, 0.045, 0.355, 1)\",\r\n      }}\r\n      className={cn(\r\n        \"border-input data-[placeholder]:text-muted-foreground bg-input/30 dark:hover:bg-input/50 focus-visible:border-ring focus-visible:ring-ring/50 aria-invalid:ring-destructive/20 dark:aria-invalid:ring-destructive/40 aria-invalid:border-destructive dark:aria-invalid:border-destructive/50 gap-1.5 rounded-4xl border px-3 py-2 text-sm transition-colors focus-visible:ring-[3px] aria-invalid:ring-[3px] data-[size=default]:h-9 data-[size=sm]:h-8 *:data-[slot=select-value]:flex *:data-[slot=select-value]:gap-1.5 [&_svg:not([class*='size-'])]:size-4 flex w-fit items-center justify-between whitespace-nowrap outline-none disabled:cursor-not-allowed disabled:opacity-50 *:data-[slot=select-value]:line-clamp-1 *:data-[slot=select-value]:flex *:data-[slot=select-value]:items-center [&_svg]:pointer-events-none [&_svg]:shrink-0 cursor-pointer\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    >\r\n      {children}\r\n      <SelectPrimitive.Icon asChild>\r\n        <HugeiconsIcon\r\n          icon={UnfoldMoreIcon}\r\n          strokeWidth={2}\r\n          className=\"text-muted-foreground size-4 pointer-events-none\"\r\n        />\r\n      </SelectPrimitive.Icon>\r\n    </SelectPrimitive.Trigger>\r\n  );\r\n}\r\n\r\nfunction SelectContent({\r\n  className,\r\n  children,\r\n  position = \"item-aligned\",\r\n  align = \"center\",\r\n  container,\r\n  ...props\r\n}: React.ComponentProps<typeof SelectPrimitive.Content> & {\r\n  container?: HTMLElement | null;\r\n}) {\r\n  const dialogContainer = useDialogPortalContainer();\r\n  return (\r\n    <SelectPrimitive.Portal container={container ?? dialogContainer ?? undefined}>\r\n      <SelectPrimitive.Content\r\n        data-slot=\"select-content\"\r\n        data-align-trigger={position === \"item-aligned\"}\r\n        className={cn(\r\n          \"bg-popover text-popover-foreground data-open:animate-in data-closed:animate-out data-closed:fade-out-0 data-open:fade-in-0 data-closed:zoom-out-95 data-open:zoom-in-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2 shadow-border ring-1 ring-border min-w-36 rounded-xl p-1 corner-squircle duration-100 relative z-50 max-h-(--radix-select-content-available-height) origin-(--radix-select-content-transform-origin) overflow-x-hidden overflow-y-auto \",\r\n          position === \"popper\" &&\r\n            \"data-[side=bottom]:translate-y-1 data-[side=left]:-translate-x-1 data-[side=right]:translate-x-1 data-[side=top]:-translate-y-1\",\r\n          className,\r\n        )}\r\n        position={position}\r\n        align={align}\r\n        {...props}\r\n      >\r\n        <SelectScrollUpButton />\r\n        <SelectPrimitive.Viewport\r\n          data-position={position}\r\n          className={cn(\r\n            \"data-[position=popper]:h-[var(--radix-select-trigger-height)] data-[position=popper]:w-full data-[position=popper]:min-w-[var(--radix-select-trigger-width)]\",\r\n            position === \"popper\" && \"\",\r\n          )}\r\n        >\r\n          {children}\r\n        </SelectPrimitive.Viewport>\r\n        <SelectScrollDownButton />\r\n      </SelectPrimitive.Content>\r\n    </SelectPrimitive.Portal>\r\n  );\r\n}\r\n\r\nfunction SelectLabel({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<typeof SelectPrimitive.Label>) {\r\n  return (\r\n    <SelectPrimitive.Label\r\n      data-slot=\"select-label\"\r\n      className={cn(\"text-muted-foreground px-3 py-2.5 text-xs\", className)}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction SelectItem({\r\n  className,\r\n  children,\r\n  ...props\r\n}: React.ComponentProps<typeof SelectPrimitive.Item>) {\r\n  return (\r\n    <SelectPrimitive.Item\r\n      data-slot=\"select-item\"\r\n      className={cn(\r\n        \"focus:bg-accent focus:text-accent-foreground not-data-[variant=destructive]:focus:**:text-accent-foreground gap-2.5 rounded-xl corner-squircle py-2 pr-8 pl-3 text-sm [&_svg:not([class*='size-'])]:size-4 *:[span]:last:flex *:[span]:last:items-center *:[span]:last:gap-2 relative flex w-full cursor-pointer items-center outline-hidden select-none data-[disabled]:pointer-events-none data-[disabled]:opacity-50 [&_svg]:pointer-events-none [&_svg]:shrink-0\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    >\r\n      <span className=\"pointer-events-none absolute right-2 flex size-4 items-center justify-center\">\r\n        <SelectPrimitive.ItemIndicator>\r\n          <HugeiconsIcon\r\n            icon={Tick02Icon}\r\n            strokeWidth={2}\r\n            className=\"pointer-events-none\"\r\n          />\r\n        </SelectPrimitive.ItemIndicator>\r\n      </span>\r\n      <SelectPrimitive.ItemText>{children}</SelectPrimitive.ItemText>\r\n    </SelectPrimitive.Item>\r\n  );\r\n}\r\n\r\nfunction SelectSeparator({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<typeof SelectPrimitive.Separator>) {\r\n  return (\r\n    <SelectPrimitive.Separator\r\n      data-slot=\"select-separator\"\r\n      className={cn(\r\n        \"bg-border/50 -mx-1 my-1 h-px pointer-events-none\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction SelectScrollUpButton({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<typeof SelectPrimitive.ScrollUpButton>) {\r\n  return (\r\n    <SelectPrimitive.ScrollUpButton\r\n      data-slot=\"select-scroll-up-button\"\r\n      className={cn(\r\n        \"bg-popover z-10 flex cursor-default items-center justify-center py-1 [&_svg:not([class*='size-'])]:size-4\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    >\r\n      <HugeiconsIcon icon={ArrowUp01Icon} strokeWidth={2} />\r\n    </SelectPrimitive.ScrollUpButton>\r\n  );\r\n}\r\n\r\nfunction SelectScrollDownButton({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<typeof SelectPrimitive.ScrollDownButton>) {\r\n  return (\r\n    <SelectPrimitive.ScrollDownButton\r\n      data-slot=\"select-scroll-down-button\"\r\n      className={cn(\r\n        \"bg-popover z-10 flex cursor-default items-center justify-center py-1 [&_svg:not([class*='size-'])]:size-4\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    >\r\n      <HugeiconsIcon icon={ArrowDown01Icon} strokeWidth={2} />\r\n    </SelectPrimitive.ScrollDownButton>\r\n  );\r\n}\r\n\r\nexport {\r\n  Select,\r\n  SelectContent,\r\n  SelectGroup,\r\n  SelectItem,\r\n  SelectLabel,\r\n  SelectScrollDownButton,\r\n  SelectScrollUpButton,\r\n  SelectSeparator,\r\n  SelectTrigger,\r\n  SelectValue,\r\n};\r\n"
  },
  {
    "path": "studio/frontend/src/components/ui/separator.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Separator as SeparatorPrimitive } from \"radix-ui\";\r\nimport type * as React from \"react\";\r\n\r\nimport { cn } from \"@/lib/utils\";\r\n\r\nfunction Separator({\r\n  className,\r\n  orientation = \"horizontal\",\r\n  decorative = true,\r\n  ...props\r\n}: React.ComponentProps<typeof SeparatorPrimitive.Root>) {\r\n  return (\r\n    <SeparatorPrimitive.Root\r\n      data-slot=\"separator\"\r\n      decorative={decorative}\r\n      orientation={orientation}\r\n      className={cn(\r\n        \"bg-border shrink-0 data-[orientation=horizontal]:h-px data-[orientation=horizontal]:w-full data-[orientation=vertical]:w-px data-[orientation=vertical]:self-stretch\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nexport { Separator };\r\n"
  },
  {
    "path": "studio/frontend/src/components/ui/sheet.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Dialog as SheetPrimitive } from \"radix-ui\";\nimport type * as React from \"react\";\nimport { useState } from \"react\";\n\nimport { Button } from \"@/components/ui/button\";\nimport { DialogPortalContainerContext } from \"@/components/ui/dialog\";\nimport { cn } from \"@/lib/utils\";\nimport { Cancel01Icon } from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\n\nfunction Sheet({ ...props }: React.ComponentProps<typeof SheetPrimitive.Root>) {\n  return <SheetPrimitive.Root data-slot=\"sheet\" {...props} />;\n}\n\nfunction SheetTrigger({\n  ...props\n}: React.ComponentProps<typeof SheetPrimitive.Trigger>) {\n  return <SheetPrimitive.Trigger data-slot=\"sheet-trigger\" {...props} />;\n}\n\nfunction SheetClose({\n  ...props\n}: React.ComponentProps<typeof SheetPrimitive.Close>) {\n  return <SheetPrimitive.Close data-slot=\"sheet-close\" {...props} />;\n}\n\nfunction SheetPortal({\n  ...props\n}: React.ComponentProps<typeof SheetPrimitive.Portal>) {\n  return <SheetPrimitive.Portal data-slot=\"sheet-portal\" {...props} />;\n}\n\nfunction SheetOverlay({\n  className,\n  position = \"fixed\",\n  ...props\n}: React.ComponentProps<typeof SheetPrimitive.Overlay> & {\n  position?: \"fixed\" | \"absolute\";\n}) {\n  return (\n    <SheetPrimitive.Overlay\n      data-slot=\"sheet-overlay\"\n      className={cn(\n        \"data-open:animate-in data-closed:animate-out data-closed:fade-out-0 data-open:fade-in-0 bg-black/80 duration-100 data-ending-style:opacity-0 data-starting-style:opacity-0 supports-backdrop-filter:backdrop-blur-xs inset-0 z-50\",\n        position === \"fixed\" ? \"fixed\" : \"absolute\",\n        className,\n      )}\n      {...props}\n    />\n  );\n}\n\nfunction SheetContent({\n  className,\n  children,\n  side = \"right\",\n  showCloseButton = true,\n  container,\n  position = \"fixed\",\n  overlayClassName,\n  overlayPosition,\n  ...props\n}: React.ComponentProps<typeof SheetPrimitive.Content> & {\n  side?: \"top\" | \"right\" | \"bottom\" | \"left\";\n  showCloseButton?: boolean;\n  container?: HTMLElement | null;\n  position?: \"fixed\" | \"absolute\";\n  overlayClassName?: string;\n  overlayPosition?: \"fixed\" | \"absolute\";\n}) {\n  const [contentEl, setContentEl] = useState<HTMLDivElement | null>(null);\n  return (\n    <SheetPortal container={container ?? undefined}>\n      <SheetOverlay\n        className={overlayClassName}\n        position={overlayPosition ?? position}\n      />\n      <SheetPrimitive.Content\n        ref={setContentEl}\n        data-slot=\"sheet-content\"\n        data-side={side}\n        className={cn(\n          \"bg-background data-open:animate-in data-closed:animate-out data-[side=right]:data-closed:slide-out-to-right-10 data-[side=right]:data-open:slide-in-from-right-10 data-[side=left]:data-closed:slide-out-to-left-10 data-[side=left]:data-open:slide-in-from-left-10 data-[side=top]:data-closed:slide-out-to-top-10 data-[side=top]:data-open:slide-in-from-top-10 data-closed:fade-out-0 data-open:fade-in-0 data-[side=bottom]:data-closed:slide-out-to-bottom-10 data-[side=bottom]:data-open:slide-in-from-bottom-10 z-50 flex flex-col bg-clip-padding text-sm shadow-lg transition duration-200 ease-in-out data-[side=bottom]:inset-x-0 data-[side=bottom]:bottom-0 data-[side=bottom]:h-auto data-[side=bottom]:border-t data-[side=left]:inset-y-0 data-[side=left]:left-0 data-[side=left]:h-full data-[side=left]:w-3/4 data-[side=left]:border-r data-[side=right]:inset-y-0 data-[side=right]:right-0 data-[side=right]:h-full data-[side=right]:w-3/4 data-[side=right]:border-l data-[side=top]:inset-x-0 data-[side=top]:top-0 data-[side=top]:h-auto data-[side=top]:border-b data-[side=left]:sm:max-w-sm data-[side=right]:sm:max-w-sm\",\n          position === \"fixed\" ? \"fixed\" : \"absolute\",\n          className,\n        )}\n        {...props}\n      >\n        <DialogPortalContainerContext.Provider value={contentEl}>\n          {children}\n        </DialogPortalContainerContext.Provider>\n        {showCloseButton && (\n          <SheetPrimitive.Close data-slot=\"sheet-close\" asChild>\n            <Button\n              variant=\"ghost\"\n              className=\"absolute top-4 right-4\"\n              size=\"icon-sm\"\n            >\n              <HugeiconsIcon icon={Cancel01Icon} strokeWidth={2} />\n              <span className=\"sr-only\">Close</span>\n            </Button>\n          </SheetPrimitive.Close>\n        )}\n      </SheetPrimitive.Content>\n    </SheetPortal>\n  );\n}\n\nfunction SheetHeader({ className, ...props }: React.ComponentProps<\"div\">) {\n  return (\n    <div\n      data-slot=\"sheet-header\"\n      className={cn(\"gap-1.5 p-6 flex flex-col\", className)}\n      {...props}\n    />\n  );\n}\n\nfunction SheetFooter({ className, ...props }: React.ComponentProps<\"div\">) {\n  return (\n    <div\n      data-slot=\"sheet-footer\"\n      className={cn(\"gap-2 p-6 mt-auto flex flex-col\", className)}\n      {...props}\n    />\n  );\n}\n\nfunction SheetTitle({\n  className,\n  ...props\n}: React.ComponentProps<typeof SheetPrimitive.Title>) {\n  return (\n    <SheetPrimitive.Title\n      data-slot=\"sheet-title\"\n      className={cn(\"text-foreground text-base font-medium\", className)}\n      {...props}\n    />\n  );\n}\n\nfunction SheetDescription({\n  className,\n  ...props\n}: React.ComponentProps<typeof SheetPrimitive.Description>) {\n  return (\n    <SheetPrimitive.Description\n      data-slot=\"sheet-description\"\n      className={cn(\"text-muted-foreground text-sm\", className)}\n      {...props}\n    />\n  );\n}\n\nexport {\n  Sheet,\n  SheetTrigger,\n  SheetClose,\n  SheetContent,\n  SheetHeader,\n  SheetFooter,\n  SheetTitle,\n  SheetDescription,\n};\n"
  },
  {
    "path": "studio/frontend/src/components/ui/shine-border.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport * as React from \"react\"\n\nimport { cn } from \"@/lib/utils\"\n\ninterface ShineBorderProps extends React.HTMLAttributes<HTMLDivElement> {\n  /**\n   * Width of the border in pixels\n   * @default 1\n   */\n  borderWidth?: number\n  /**\n   * Duration of the animation in seconds\n   * @default 14\n   */\n  duration?: number\n  /**\n   * Color of the border, can be a single color or an array of colors\n   * @default \"#000000\"\n   */\n  shineColor?: string | string[]\n}\n\n/**\n * Shine Border\n *\n * An animated background border effect component with configurable properties.\n */\nexport function ShineBorder({\n  borderWidth = 1,\n  duration = 14,\n  shineColor = \"#000000\",\n  className,\n  style,\n  ...props\n}: ShineBorderProps) {\n  return (\n    <div\n      style={\n        {\n          \"--border-width\": `${borderWidth}px`,\n          \"--duration\": `${duration}s`,\n          backgroundImage: `radial-gradient(transparent,transparent, ${\n            Array.isArray(shineColor) ? shineColor.join(\",\") : shineColor\n          },transparent,transparent)`,\n          backgroundSize: \"300% 300%\",\n          mask: `linear-gradient(#fff 0 0) content-box, linear-gradient(#fff 0 0)`,\n          WebkitMask: `linear-gradient(#fff 0 0) content-box, linear-gradient(#fff 0 0)`,\n          WebkitMaskComposite: \"xor\",\n          maskComposite: \"exclude\",\n          padding: \"var(--border-width)\",\n          ...style,\n        } as React.CSSProperties\n      }\n      className={cn(\n        \"motion-safe:animate-shine pointer-events-none absolute inset-0 size-full rounded-[inherit] will-change-[background-position]\",\n        className\n      )}\n      {...props}\n    />\n  )\n}\n"
  },
  {
    "path": "studio/frontend/src/components/ui/sidebar.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"use client\"\r\n\r\nimport * as React from \"react\"\r\nimport { cva, type VariantProps } from \"class-variance-authority\"\r\nimport { Slot } from \"radix-ui\"\r\n\r\nimport { cn } from \"@/lib/utils\"\r\nimport { Button } from \"@/components/ui/button\"\r\nimport { Input } from \"@/components/ui/input\"\r\nimport { Separator } from \"@/components/ui/separator\"\r\nimport {\r\n  Sheet,\r\n  SheetContent,\r\n  SheetDescription,\r\n  SheetHeader,\r\n  SheetTitle,\r\n} from \"@/components/ui/sheet\"\r\nimport { Skeleton } from \"@/components/ui/skeleton\"\r\nimport {\r\n  Tooltip,\r\n  TooltipContent,\r\n  TooltipTrigger,\r\n} from \"@/components/ui/tooltip\"\r\nimport { useIsMobile } from \"@/hooks/use-mobile\"\r\nimport { HugeiconsIcon } from \"@hugeicons/react\"\r\nimport { SidebarLeftIcon } from \"@hugeicons/core-free-icons\"\r\n\r\nconst SIDEBAR_COOKIE_NAME = \"sidebar_state\"\r\nconst SIDEBAR_COOKIE_MAX_AGE = 60 * 60 * 24 * 7\r\nconst SIDEBAR_WIDTH = \"16rem\"\r\nconst SIDEBAR_WIDTH_MOBILE = \"18rem\"\r\nconst SIDEBAR_WIDTH_ICON = \"3rem\"\r\nconst SIDEBAR_KEYBOARD_SHORTCUT = \"b\"\r\n\r\ntype SidebarContextProps = {\r\n  state: \"expanded\" | \"collapsed\"\r\n  open: boolean\r\n  setOpen: (open: boolean) => void\r\n  openMobile: boolean\r\n  setOpenMobile: (open: boolean) => void\r\n  isMobile: boolean\r\n  toggleSidebar: () => void\r\n}\r\n\r\nconst SidebarContext = React.createContext<SidebarContextProps | null>(null)\r\n\r\nfunction useSidebar() {\r\n  const context = React.useContext(SidebarContext)\r\n  if (!context) {\r\n    throw new Error(\"useSidebar must be used within a SidebarProvider.\")\r\n  }\r\n\r\n  return context\r\n}\r\n\r\nfunction SidebarProvider({\r\n  defaultOpen = true,\r\n  open: openProp,\r\n  onOpenChange: setOpenProp,\r\n  className,\r\n  style,\r\n  children,\r\n  ...props\r\n}: React.ComponentProps<\"div\"> & {\r\n  defaultOpen?: boolean\r\n  open?: boolean\r\n  onOpenChange?: (open: boolean) => void\r\n}) {\r\n  const isMobile = useIsMobile()\r\n  const [openMobile, setOpenMobile] = React.useState(false)\r\n\r\n  // This is the internal state of the sidebar.\r\n  // We use openProp and setOpenProp for control from outside the component.\r\n  const [_open, _setOpen] = React.useState(defaultOpen)\r\n  const open = openProp ?? _open\r\n  const setOpen = React.useCallback(\r\n    (value: boolean | ((value: boolean) => boolean)) => {\r\n      const openState = typeof value === \"function\" ? value(open) : value\r\n      if (setOpenProp) {\r\n        setOpenProp(openState)\r\n      } else {\r\n        _setOpen(openState)\r\n      }\r\n\r\n      // This sets the cookie to keep the sidebar state.\r\n      document.cookie = `${SIDEBAR_COOKIE_NAME}=${openState}; path=/; max-age=${SIDEBAR_COOKIE_MAX_AGE}`\r\n    },\r\n    [setOpenProp, open]\r\n  )\r\n\r\n  // Helper to toggle the sidebar.\r\n  const toggleSidebar = React.useCallback(() => {\r\n    return isMobile ? setOpenMobile((open) => !open) : setOpen((open) => !open)\r\n  }, [isMobile, setOpen, setOpenMobile])\r\n\r\n  // Adds a keyboard shortcut to toggle the sidebar.\r\n  React.useEffect(() => {\r\n    const handleKeyDown = (event: KeyboardEvent) => {\r\n      if (\r\n        event.key === SIDEBAR_KEYBOARD_SHORTCUT &&\r\n        (event.metaKey || event.ctrlKey)\r\n      ) {\r\n        event.preventDefault()\r\n        toggleSidebar()\r\n      }\r\n    }\r\n\r\n    window.addEventListener(\"keydown\", handleKeyDown)\r\n    return () => window.removeEventListener(\"keydown\", handleKeyDown)\r\n  }, [toggleSidebar])\r\n\r\n  // We add a state so that we can do data-state=\"expanded\" or \"collapsed\".\r\n  // This makes it easier to style the sidebar with Tailwind classes.\r\n  const state = open ? \"expanded\" : \"collapsed\"\r\n\r\n  const contextValue = React.useMemo<SidebarContextProps>(\r\n    () => ({\r\n      state,\r\n      open,\r\n      setOpen,\r\n      isMobile,\r\n      openMobile,\r\n      setOpenMobile,\r\n      toggleSidebar,\r\n    }),\r\n    [state, open, setOpen, isMobile, openMobile, setOpenMobile, toggleSidebar]\r\n  )\r\n\r\n  return (\r\n    <SidebarContext.Provider value={contextValue}>\r\n      <div\r\n        data-slot=\"sidebar-wrapper\"\r\n        style={\r\n          {\r\n            \"--sidebar-width\": SIDEBAR_WIDTH,\r\n            \"--sidebar-width-icon\": SIDEBAR_WIDTH_ICON,\r\n            ...style,\r\n          } as React.CSSProperties\r\n        }\r\n        className={cn(\r\n          \"group/sidebar-wrapper has-data-[variant=inset]:bg-sidebar flex min-h-svh w-full\",\r\n          className\r\n        )}\r\n        {...props}\r\n      >\r\n        {children}\r\n      </div>\r\n    </SidebarContext.Provider>\r\n  )\r\n}\r\n\r\nfunction Sidebar({\r\n  side = \"left\",\r\n  variant = \"sidebar\",\r\n  collapsible = \"offcanvas\",\r\n  className,\r\n  children,\r\n  dir,\r\n  ...props\r\n}: React.ComponentProps<\"div\"> & {\r\n  side?: \"left\" | \"right\"\r\n  variant?: \"sidebar\" | \"floating\" | \"inset\"\r\n  collapsible?: \"offcanvas\" | \"icon\" | \"none\"\r\n}) {\r\n  const { isMobile, state, openMobile, setOpenMobile } = useSidebar()\r\n\r\n  if (collapsible === \"none\") {\r\n    return (\r\n      <div\r\n        data-slot=\"sidebar\"\r\n        className={cn(\r\n          \"bg-sidebar text-sidebar-foreground flex h-full w-(--sidebar-width) flex-col\",\r\n          className\r\n        )}\r\n        {...props}\r\n      >\r\n        {children}\r\n      </div>\r\n    )\r\n  }\r\n\r\n  if (isMobile) {\r\n    return (\r\n      <Sheet open={openMobile} onOpenChange={setOpenMobile} {...props}>\r\n        <SheetContent\r\n          dir={dir}\r\n          data-sidebar=\"sidebar\"\r\n          data-slot=\"sidebar\"\r\n          data-mobile=\"true\"\r\n          className=\"bg-sidebar text-sidebar-foreground w-(--sidebar-width) p-0 [&>button]:hidden\"\r\n          style={\r\n            {\r\n              \"--sidebar-width\": SIDEBAR_WIDTH_MOBILE,\r\n            } as React.CSSProperties\r\n          }\r\n          side={side}\r\n        >\r\n          <SheetHeader className=\"sr-only\">\r\n            <SheetTitle>Sidebar</SheetTitle>\r\n            <SheetDescription>Displays the mobile sidebar.</SheetDescription>\r\n          </SheetHeader>\r\n          <div className=\"flex h-full w-full flex-col\">{children}</div>\r\n        </SheetContent>\r\n      </Sheet>\r\n    )\r\n  }\r\n\r\n  return (\r\n    <div\r\n      className=\"group peer text-sidebar-foreground hidden md:block\"\r\n      data-state={state}\r\n      data-collapsible={state === \"collapsed\" ? collapsible : \"\"}\r\n      data-variant={variant}\r\n      data-side={side}\r\n      data-slot=\"sidebar\"\r\n    >\r\n      {/* This is what handles the sidebar gap on desktop */}\r\n      <div\r\n        data-slot=\"sidebar-gap\"\r\n        className={cn(\r\n          \"transition-[width] duration-200 ease-linear relative w-(--sidebar-width) bg-transparent\",\r\n          \"group-data-[collapsible=offcanvas]:w-0\",\r\n          \"group-data-[side=right]:rotate-180\",\r\n          variant === \"floating\" || variant === \"inset\"\r\n            ? \"group-data-[collapsible=icon]:w-[calc(var(--sidebar-width-icon)+(--spacing(4)))]\"\r\n            : \"group-data-[collapsible=icon]:w-(--sidebar-width-icon)\"\r\n        )}\r\n      />\r\n      <div\r\n        data-slot=\"sidebar-container\"\r\n        data-side={side}\r\n        className={cn(\r\n          \"fixed inset-y-0 z-10 hidden h-svh w-(--sidebar-width) transition-[left,right,width] duration-200 ease-linear data-[side=left]:left-0 data-[side=left]:group-data-[collapsible=offcanvas]:left-[calc(var(--sidebar-width)*-1)] data-[side=right]:right-0 data-[side=right]:group-data-[collapsible=offcanvas]:right-[calc(var(--sidebar-width)*-1)] md:flex\",\r\n          // Adjust the padding for floating and inset variants.\r\n          variant === \"floating\" || variant === \"inset\"\r\n            ? \"p-2 group-data-[collapsible=icon]:w-[calc(var(--sidebar-width-icon)+(--spacing(4))+2px)]\"\r\n            : \"group-data-[collapsible=icon]:w-(--sidebar-width-icon) group-data-[side=left]:border-r group-data-[side=right]:border-l\",\r\n          className\r\n        )}\r\n        {...props}\r\n      >\r\n        <div\r\n          data-sidebar=\"sidebar\"\r\n          data-slot=\"sidebar-inner\"\r\n          className=\"bg-sidebar group-data-[variant=floating]:ring-sidebar-border group-data-[variant=floating]:rounded-lg group-data-[variant=floating]:shadow-sm group-data-[variant=floating]:ring-1 flex size-full flex-col\"\r\n        >\r\n          {children}\r\n        </div>\r\n      </div>\r\n    </div>\r\n  )\r\n}\r\n\r\nfunction SidebarTrigger({\r\n  className,\r\n  onClick,\r\n  ...props\r\n}: React.ComponentProps<typeof Button>) {\r\n  const { toggleSidebar } = useSidebar()\r\n\r\n  return (\r\n    <Button\r\n      data-sidebar=\"trigger\"\r\n      data-slot=\"sidebar-trigger\"\r\n      variant=\"ghost\"\r\n      size=\"icon-sm\"\r\n      className={cn(className)}\r\n      onClick={(event) => {\r\n        onClick?.(event)\r\n        toggleSidebar()\r\n      }}\r\n      {...props}\r\n    >\r\n      <HugeiconsIcon icon={SidebarLeftIcon} strokeWidth={2} />\r\n      <span className=\"sr-only\">Toggle Sidebar</span>\r\n    </Button>\r\n  )\r\n}\r\n\r\nfunction SidebarRail({ className, ...props }: React.ComponentProps<\"button\">) {\r\n  const { toggleSidebar } = useSidebar()\r\n\r\n  return (\r\n    <button\r\n      data-sidebar=\"rail\"\r\n      data-slot=\"sidebar-rail\"\r\n      aria-label=\"Toggle Sidebar\"\r\n      tabIndex={-1}\r\n      onClick={toggleSidebar}\r\n      title=\"Toggle Sidebar\"\r\n      className={cn(\r\n        \"hover:after:bg-sidebar-border absolute inset-y-0 z-20 hidden w-4 transition-all ease-linear group-data-[side=left]:-right-4 group-data-[side=right]:left-0 after:absolute after:inset-y-0 after:start-1/2 after:w-[2px] sm:flex ltr:-translate-x-1/2 rtl:-translate-x-1/2\",\r\n        \"in-data-[side=left]:cursor-w-resize in-data-[side=right]:cursor-e-resize\",\r\n        \"[[data-side=left][data-state=collapsed]_&]:cursor-e-resize [[data-side=right][data-state=collapsed]_&]:cursor-w-resize\",\r\n        \"hover:group-data-[collapsible=offcanvas]:bg-sidebar group-data-[collapsible=offcanvas]:translate-x-0 group-data-[collapsible=offcanvas]:after:left-full\",\r\n        \"[[data-side=left][data-collapsible=offcanvas]_&]:-right-2\",\r\n        \"[[data-side=right][data-collapsible=offcanvas]_&]:-left-2\",\r\n        className\r\n      )}\r\n      {...props}\r\n    />\r\n  )\r\n}\r\n\r\nfunction SidebarInset({ className, ...props }: React.ComponentProps<\"main\">) {\r\n  return (\r\n    <main\r\n      data-slot=\"sidebar-inset\"\r\n      className={cn(\r\n        \"bg-background md:peer-data-[variant=inset]:m-2 md:peer-data-[variant=inset]:ml-0 md:peer-data-[variant=inset]:rounded-xl md:peer-data-[variant=inset]:shadow-sm md:peer-data-[variant=inset]:peer-data-[state=collapsed]:ml-2 relative flex w-full flex-1 flex-col\",\r\n        className\r\n      )}\r\n      {...props}\r\n    />\r\n  )\r\n}\r\n\r\nfunction SidebarInput({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<typeof Input>) {\r\n  return (\r\n    <Input\r\n      data-slot=\"sidebar-input\"\r\n      data-sidebar=\"input\"\r\n      className={cn(\"bg-background h-8 w-full shadow-none\", className)}\r\n      {...props}\r\n    />\r\n  )\r\n}\r\n\r\nfunction SidebarHeader({ className, ...props }: React.ComponentProps<\"div\">) {\r\n  return (\r\n    <div\r\n      data-slot=\"sidebar-header\"\r\n      data-sidebar=\"header\"\r\n      className={cn(\"gap-2 p-2 flex flex-col\", className)}\r\n      {...props}\r\n    />\r\n  )\r\n}\r\n\r\nfunction SidebarFooter({ className, ...props }: React.ComponentProps<\"div\">) {\r\n  return (\r\n    <div\r\n      data-slot=\"sidebar-footer\"\r\n      data-sidebar=\"footer\"\r\n      className={cn(\"gap-2 p-2 flex flex-col\", className)}\r\n      {...props}\r\n    />\r\n  )\r\n}\r\n\r\nfunction SidebarSeparator({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<typeof Separator>) {\r\n  return (\r\n    <Separator\r\n      data-slot=\"sidebar-separator\"\r\n      data-sidebar=\"separator\"\r\n      className={cn(\"bg-sidebar-border mx-2 w-auto\", className)}\r\n      {...props}\r\n    />\r\n  )\r\n}\r\n\r\nfunction SidebarContent({ className, ...props }: React.ComponentProps<\"div\">) {\r\n  return (\r\n    <div\r\n      data-slot=\"sidebar-content\"\r\n      data-sidebar=\"content\"\r\n      className={cn(\r\n        \"no-scrollbar gap-2 flex min-h-0 flex-1 flex-col overflow-auto group-data-[collapsible=icon]:overflow-hidden\",\r\n        className\r\n      )}\r\n      {...props}\r\n    />\r\n  )\r\n}\r\n\r\nfunction SidebarGroup({ className, ...props }: React.ComponentProps<\"div\">) {\r\n  return (\r\n    <div\r\n      data-slot=\"sidebar-group\"\r\n      data-sidebar=\"group\"\r\n      className={cn(\r\n        \"p-2 relative flex w-full min-w-0 flex-col\",\r\n        className\r\n      )}\r\n      {...props}\r\n    />\r\n  )\r\n}\r\n\r\nfunction SidebarGroupLabel({\r\n  className,\r\n  asChild = false,\r\n  ...props\r\n}: React.ComponentProps<\"div\"> & { asChild?: boolean }) {\r\n  const Comp = asChild ? Slot.Root : \"div\"\r\n\r\n  return (\r\n    <Comp\r\n      data-slot=\"sidebar-group-label\"\r\n      data-sidebar=\"group-label\"\r\n      className={cn(\r\n        \"text-sidebar-foreground/70 ring-sidebar-ring h-8 rounded-md px-2 text-xs font-medium transition-[margin,opacity] duration-200 ease-linear group-data-[collapsible=icon]:-mt-8 group-data-[collapsible=icon]:opacity-0 focus-visible:ring-2 [&>svg]:size-4 flex shrink-0 items-center outline-hidden [&>svg]:shrink-0\",\r\n        className\r\n      )}\r\n      {...props}\r\n    />\r\n  )\r\n}\r\n\r\nfunction SidebarGroupAction({\r\n  className,\r\n  asChild = false,\r\n  ...props\r\n}: React.ComponentProps<\"button\"> & { asChild?: boolean }) {\r\n  const Comp = asChild ? Slot.Root : \"button\"\r\n\r\n  return (\r\n    <Comp\r\n      data-slot=\"sidebar-group-action\"\r\n      data-sidebar=\"group-action\"\r\n      className={cn(\r\n        \"text-sidebar-foreground ring-sidebar-ring hover:bg-sidebar-accent hover:text-sidebar-accent-foreground absolute top-3.5 right-3 w-5 rounded-md p-0 focus-visible:ring-2 [&>svg]:size-4 flex aspect-square items-center justify-center outline-hidden transition-transform group-data-[collapsible=icon]:hidden after:absolute after:-inset-2 md:after:hidden [&>svg]:shrink-0\",\r\n        className\r\n      )}\r\n      {...props}\r\n    />\r\n  )\r\n}\r\n\r\nfunction SidebarGroupContent({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<\"div\">) {\r\n  return (\r\n    <div\r\n      data-slot=\"sidebar-group-content\"\r\n      data-sidebar=\"group-content\"\r\n      className={cn(\"text-sm w-full\", className)}\r\n      {...props}\r\n    />\r\n  )\r\n}\r\n\r\nfunction SidebarMenu({ className, ...props }: React.ComponentProps<\"ul\">) {\r\n  return (\r\n    <ul\r\n      data-slot=\"sidebar-menu\"\r\n      data-sidebar=\"menu\"\r\n      className={cn(\"gap-1 flex w-full min-w-0 flex-col\", className)}\r\n      {...props}\r\n    />\r\n  )\r\n}\r\n\r\nfunction SidebarMenuItem({ className, ...props }: React.ComponentProps<\"li\">) {\r\n  return (\r\n    <li\r\n      data-slot=\"sidebar-menu-item\"\r\n      data-sidebar=\"menu-item\"\r\n      className={cn(\"group/menu-item relative\", className)}\r\n      {...props}\r\n    />\r\n  )\r\n}\r\n\r\nconst sidebarMenuButtonVariants = cva(\r\n  \"ring-sidebar-ring hover:bg-sidebar-accent hover:text-sidebar-accent-foreground active:bg-sidebar-accent active:text-sidebar-accent-foreground data-active:bg-sidebar-accent data-active:text-sidebar-accent-foreground data-open:hover:bg-sidebar-accent data-open:hover:text-sidebar-accent-foreground gap-2 rounded-lg corner-squircle p-2 text-left text-sm transition-[width,height,padding] group-has-data-[sidebar=menu-action]/menu-item:pr-8 group-data-[collapsible=icon]:size-8! group-data-[collapsible=icon]:p-2! focus-visible:ring-2 data-active:font-medium peer/menu-button flex w-full items-center overflow-hidden outline-hidden group/menu-button disabled:pointer-events-none disabled:opacity-50 aria-disabled:pointer-events-none aria-disabled:opacity-50 [&>span:last-child]:truncate [&_svg]:size-4 [&_svg]:shrink-0\",\r\n  {\r\n    variants: {\r\n      variant: {\r\n        default: \"hover:bg-sidebar-accent hover:text-sidebar-accent-foreground\",\r\n        outline: \"bg-background hover:bg-sidebar-accent hover:text-sidebar-accent-foreground shadow-[0_0_0_1px_hsl(var(--sidebar-border))] hover:shadow-[0_0_0_1px_hsl(var(--sidebar-accent))]\",\r\n      },\r\n      size: {\r\n        default: \"h-9 text-sm\",\r\n        sm: \"h-8 text-xs\",\r\n        lg: \"h-12 text-sm group-data-[collapsible=icon]:p-0!\",\r\n      },\r\n    },\r\n    defaultVariants: {\r\n      variant: \"default\",\r\n      size: \"default\",\r\n    },\r\n  }\r\n)\r\n\r\nfunction SidebarMenuButton({\r\n  asChild = false,\r\n  isActive = false,\r\n  variant = \"default\",\r\n  size = \"default\",\r\n  tooltip,\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<\"button\"> & {\r\n  asChild?: boolean\r\n  isActive?: boolean\r\n  tooltip?: string | React.ComponentProps<typeof TooltipContent>\r\n} & VariantProps<typeof sidebarMenuButtonVariants>) {\r\n  const Comp = asChild ? Slot.Root : \"button\"\r\n  const { isMobile, state } = useSidebar()\r\n\r\n  const button = (\r\n    <Comp\r\n      data-slot=\"sidebar-menu-button\"\r\n      data-sidebar=\"menu-button\"\r\n      data-size={size}\r\n      data-active={isActive}\r\n      className={cn(sidebarMenuButtonVariants({ variant, size }), className)}\r\n      {...props}\r\n    />\r\n  )\r\n\r\n  if (!tooltip) {\r\n    return button\r\n  }\r\n\r\n  if (typeof tooltip === \"string\") {\r\n    tooltip = {\r\n      children: tooltip,\r\n    }\r\n  }\r\n\r\n  return (\r\n    <Tooltip>\r\n      <TooltipTrigger asChild>{button}</TooltipTrigger>\r\n      <TooltipContent\r\n        side=\"right\"\r\n        align=\"center\"\r\n        hidden={state !== \"collapsed\" || isMobile}\r\n        {...tooltip}\r\n      />\r\n    </Tooltip>\r\n  )\r\n}\r\n\r\nfunction SidebarMenuAction({\r\n  className,\r\n  asChild = false,\r\n  showOnHover = false,\r\n  ...props\r\n}: React.ComponentProps<\"button\"> & {\r\n  asChild?: boolean\r\n  showOnHover?: boolean\r\n}) {\r\n  const Comp = asChild ? Slot.Root : \"button\"\r\n\r\n  return (\r\n    <Comp\r\n      data-slot=\"sidebar-menu-action\"\r\n      data-sidebar=\"menu-action\"\r\n      className={cn(\r\n        \"text-sidebar-foreground ring-sidebar-ring hover:bg-sidebar-accent hover:text-sidebar-accent-foreground peer-hover/menu-button:text-sidebar-accent-foreground absolute top-1.5 right-1 aspect-square w-5 rounded-md p-0 peer-data-[size=default]/menu-button:top-2 peer-data-[size=lg]/menu-button:top-2.5 peer-data-[size=sm]/menu-button:top-1 focus-visible:ring-2 [&>svg]:size-4 flex items-center justify-center outline-hidden transition-transform group-data-[collapsible=icon]:hidden after:absolute after:-inset-2 md:after:hidden [&>svg]:shrink-0\",\r\n        showOnHover &&\r\n          \"peer-data-active/menu-button:text-sidebar-accent-foreground group-focus-within/menu-item:opacity-100 group-hover/menu-item:opacity-100 data-open:opacity-100 md:opacity-0\",\r\n        className\r\n      )}\r\n      {...props}\r\n    />\r\n  )\r\n}\r\n\r\nfunction SidebarMenuBadge({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<\"div\">) {\r\n  return (\r\n    <div\r\n      data-slot=\"sidebar-menu-badge\"\r\n      data-sidebar=\"menu-badge\"\r\n      className={cn(\r\n        \"text-sidebar-foreground peer-hover/menu-button:text-sidebar-accent-foreground peer-data-active/menu-button:text-sidebar-accent-foreground pointer-events-none absolute right-1 flex h-5 min-w-5 rounded-md px-1 text-xs font-medium peer-data-[size=default]/menu-button:top-1.5 peer-data-[size=lg]/menu-button:top-2.5 peer-data-[size=sm]/menu-button:top-1 flex items-center justify-center tabular-nums select-none group-data-[collapsible=icon]:hidden\",\r\n        className\r\n      )}\r\n      {...props}\r\n    />\r\n  )\r\n}\r\n\r\nfunction SidebarMenuSkeleton({\r\n  className,\r\n  showIcon = false,\r\n  ...props\r\n}: React.ComponentProps<\"div\"> & {\r\n  showIcon?: boolean\r\n}) {\r\n  // Random width between 50 to 90%.\r\n  const [width] = React.useState(() => {\r\n    return `${Math.floor(Math.random() * 40) + 50}%`\r\n  })\r\n\r\n  return (\r\n    <div\r\n      data-slot=\"sidebar-menu-skeleton\"\r\n      data-sidebar=\"menu-skeleton\"\r\n      className={cn(\"h-8 gap-2 rounded-md px-2 flex items-center\", className)}\r\n      {...props}\r\n    >\r\n      {showIcon && (\r\n        <Skeleton\r\n          className=\"size-4 rounded-md\"\r\n          data-sidebar=\"menu-skeleton-icon\"\r\n        />\r\n      )}\r\n      <Skeleton\r\n        className=\"h-4 max-w-(--skeleton-width) flex-1\"\r\n        data-sidebar=\"menu-skeleton-text\"\r\n        style={\r\n          {\r\n            \"--skeleton-width\": width,\r\n          } as React.CSSProperties\r\n        }\r\n      />\r\n    </div>\r\n  )\r\n}\r\n\r\nfunction SidebarMenuSub({ className, ...props }: React.ComponentProps<\"ul\">) {\r\n  return (\r\n    <ul\r\n      data-slot=\"sidebar-menu-sub\"\r\n      data-sidebar=\"menu-sub\"\r\n      className={cn(\"border-sidebar-border mx-3.5 translate-x-px gap-1 border-l px-2.5 py-0.5 group-data-[collapsible=icon]:hidden flex min-w-0 flex-col\", className)}\r\n      {...props}\r\n    />\r\n  )\r\n}\r\n\r\nfunction SidebarMenuSubItem({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<\"li\">) {\r\n  return (\r\n    <li\r\n      data-slot=\"sidebar-menu-sub-item\"\r\n      data-sidebar=\"menu-sub-item\"\r\n      className={cn(\"group/menu-sub-item relative\", className)}\r\n      {...props}\r\n    />\r\n  )\r\n}\r\n\r\nfunction SidebarMenuSubButton({\r\n  asChild = false,\r\n  size = \"md\",\r\n  isActive = false,\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<\"a\"> & {\r\n  asChild?: boolean\r\n  size?: \"sm\" | \"md\"\r\n  isActive?: boolean\r\n}) {\r\n  const Comp = asChild ? Slot.Root : \"a\"\r\n\r\n  return (\r\n    <Comp\r\n      data-slot=\"sidebar-menu-sub-button\"\r\n      data-sidebar=\"menu-sub-button\"\r\n      data-size={size}\r\n      data-active={isActive}\r\n      className={cn(\r\n        \"text-sidebar-foreground ring-sidebar-ring hover:bg-sidebar-accent hover:text-sidebar-accent-foreground active:bg-sidebar-accent active:text-sidebar-accent-foreground [&>svg]:text-sidebar-accent-foreground data-active:bg-sidebar-accent data-active:text-sidebar-accent-foreground h-7 gap-2 rounded-md px-2 focus-visible:ring-2 data-[size=md]:text-sm data-[size=sm]:text-xs [&>svg]:size-4 flex min-w-0 -translate-x-px items-center overflow-hidden outline-hidden group-data-[collapsible=icon]:hidden disabled:pointer-events-none disabled:opacity-50 aria-disabled:pointer-events-none aria-disabled:opacity-50 [&>span:last-child]:truncate [&>svg]:shrink-0\",\r\n        className\r\n      )}\r\n      {...props}\r\n    />\r\n  )\r\n}\r\n\r\nexport {\r\n  Sidebar,\r\n  SidebarContent,\r\n  SidebarFooter,\r\n  SidebarGroup,\r\n  SidebarGroupAction,\r\n  SidebarGroupContent,\r\n  SidebarGroupLabel,\r\n  SidebarHeader,\r\n  SidebarInput,\r\n  SidebarInset,\r\n  SidebarMenu,\r\n  SidebarMenuAction,\r\n  SidebarMenuBadge,\r\n  SidebarMenuButton,\r\n  SidebarMenuItem,\r\n  SidebarMenuSkeleton,\r\n  SidebarMenuSub,\r\n  SidebarMenuSubButton,\r\n  SidebarMenuSubItem,\r\n  SidebarProvider,\r\n  SidebarRail,\r\n  SidebarSeparator,\r\n  SidebarTrigger,\r\n  useSidebar,\r\n}\r\n"
  },
  {
    "path": "studio/frontend/src/components/ui/skeleton.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { cn } from \"@/lib/utils\";\r\n\r\nfunction Skeleton({ className, ...props }: React.ComponentProps<\"div\">) {\r\n  return (\r\n    <div\r\n      data-slot=\"skeleton\"\r\n      className={cn(\"bg-muted rounded-xl animate-pulse\", className)}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nexport { Skeleton };\r\n"
  },
  {
    "path": "studio/frontend/src/components/ui/slider.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Slider as SliderPrimitive } from \"radix-ui\";\nimport * as React from \"react\";\n\nimport { cn } from \"@/lib/utils\";\n\nconst THUMB_SIZE_PX = 16;\n\nfunction getThumbInBoundsOffset(width: number, percent: number) {\n  const halfWidth = width / 2;\n  const halfPercent = 50;\n\n  if (percent <= 0) return halfWidth;\n  if (percent >= 100) return -halfWidth;\n\n  return halfWidth - (percent / halfPercent) * halfWidth;\n}\n\nfunction Slider({\n  className,\n  defaultValue,\n  value,\n  min = 0,\n  max = 100,\n  orientation = \"horizontal\",\n  onValueChange,\n  ...props\n}: React.ComponentProps<typeof SliderPrimitive.Root>) {\n  const isControlled = Array.isArray(value);\n  const [uncontrolledValues, setUncontrolledValues] =\n    React.useState<number[]>(() =>\n      Array.isArray(defaultValue) ? defaultValue : [min, max],\n    );\n\n  const values = isControlled ? value : uncontrolledValues;\n  const handleValueChange = React.useCallback(\n    (nextValues: number[]) => {\n      if (!isControlled) {\n        setUncontrolledValues(nextValues);\n      }\n      onValueChange?.(nextValues);\n    },\n    [isControlled, onValueChange],\n  );\n  const isSingleThumbHorizontal =\n    values.length === 1 && orientation === \"horizontal\";\n  const fillPercent = isSingleThumbHorizontal\n    ? Math.min(\n        100,\n        Math.max(\n          0,\n          max === min ? 0 : (((values[0] ?? min) - min) / (max - min)) * 100,\n        ),\n      )\n    : null;\n  const fillWidth =\n    fillPercent === null\n      ? undefined\n      : fillPercent <= 0\n        ? \"0%\"\n        : `calc(${fillPercent}% + ${getThumbInBoundsOffset(THUMB_SIZE_PX, fillPercent)}px)`;\n\n  return (\n    <SliderPrimitive.Root\n      data-slot=\"slider\"\n      defaultValue={defaultValue}\n      value={value}\n      min={min}\n      max={max}\n      orientation={orientation}\n      onValueChange={handleValueChange}\n      className={cn(\n        \"data-vertical:min-h-40 relative flex w-full touch-none items-center select-none data-disabled:opacity-50 data-vertical:h-full data-vertical:w-auto data-vertical:flex-col\",\n        className,\n      )}\n      {...props}\n    >\n      <SliderPrimitive.Track\n        data-slot=\"slider-track\"\n        className=\"bg-muted rounded-4xl data-horizontal:h-3 data-horizontal:w-full data-vertical:h-full data-vertical:w-3 bg-muted relative grow overflow-hidden data-horizontal:w-full data-vertical:h-full cursor-pointer\"\n      >\n        <SliderPrimitive.Range\n          data-slot=\"slider-range\"\n          className={cn(\n            \"bg-primary absolute select-none data-horizontal:h-full data-vertical:w-full\",\n            isSingleThumbHorizontal && \"opacity-0\",\n          )}\n        />\n        {isSingleThumbHorizontal && (\n          <div\n            aria-hidden={true}\n            className={cn(\n              \"absolute inset-y-0 left-0 bg-primary pointer-events-none\",\n              fillPercent === 100 ? \"rounded-4xl\" : \"rounded-l-4xl\",\n            )}\n            style={{ width: fillWidth }}\n          />\n        )}\n      </SliderPrimitive.Track>\n      {Array.from({ length: values.length }, (_, index) => (\n        <SliderPrimitive.Thumb\n          data-slot=\"slider-thumb\"\n          key={index}\n          className=\"border-primary ring-ring/50 relative z-10 size-4 rounded-4xl border bg-white shadow-sm block shrink-0 select-none cursor-pointer disabled:pointer-events-none disabled:opacity-50 transition-transform duration-100 ease-out hover:scale-110 hover:ring-4 active:scale-95 focus-visible:ring-4 focus-visible:outline-hidden\"\n        />\n      ))}\n    </SliderPrimitive.Root>\n  );\n}\n\nexport { Slider };\n"
  },
  {
    "path": "studio/frontend/src/components/ui/sonner.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\r\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\r\n\r\nimport {\r\n  Alert02Icon,\r\n  CheckmarkCircle02Icon,\r\n  InformationCircleIcon,\r\n  Loading03Icon,\r\n  MultiplicationSignCircleIcon,\r\n} from \"@hugeicons/core-free-icons\";\r\nimport { HugeiconsIcon } from \"@hugeicons/react\";\r\nimport { useTheme } from \"next-themes\";\r\nimport { Toaster as Sonner, type ToasterProps } from \"sonner\";\r\n\r\nconst Toaster = ({ ...props }: ToasterProps) => {\r\n  const { theme = \"system\" } = useTheme();\r\n\r\n  return (\r\n    <Sonner\r\n      theme={theme as ToasterProps[\"theme\"]}\r\n      className=\"toaster group\"\r\n      duration={5000}\r\n      icons={{\r\n        success: (\r\n          <HugeiconsIcon\r\n            icon={CheckmarkCircle02Icon}\r\n            strokeWidth={2}\r\n            className=\"size-4\"\r\n          />\r\n        ),\r\n        info: (\r\n          <HugeiconsIcon\r\n            icon={InformationCircleIcon}\r\n            strokeWidth={2}\r\n            className=\"size-4\"\r\n          />\r\n        ),\r\n        warning: (\r\n          <HugeiconsIcon\r\n            icon={Alert02Icon}\r\n            strokeWidth={2}\r\n            className=\"size-4\"\r\n          />\r\n        ),\r\n        error: (\r\n          <HugeiconsIcon\r\n            icon={MultiplicationSignCircleIcon}\r\n            strokeWidth={2}\r\n            className=\"size-4\"\r\n          />\r\n        ),\r\n        loading: (\r\n          <HugeiconsIcon\r\n            icon={Loading03Icon}\r\n            strokeWidth={2}\r\n            className=\"size-4 animate-spin\"\r\n          />\r\n        ),\r\n      }}\r\n      style={\r\n        {\r\n          \"--normal-bg\": \"var(--popover)\",\r\n          \"--normal-text\": \"var(--popover-foreground)\",\r\n          \"--normal-border\": \"var(--border)\",\r\n          \"--border-radius\": \"var(--radius)\",\r\n        } as React.CSSProperties\r\n      }\r\n      toastOptions={{\r\n        classNames: {\r\n          toast: \"cn-toast\",\r\n          description: \"!text-muted-foreground\",\r\n          closeButton: \"!top-3 !right-3 !translate-y-0\",\r\n        },\r\n      }}\r\n      {...props}\r\n    />\r\n  );\r\n};\r\n\r\nexport { Toaster };\r\n"
  },
  {
    "path": "studio/frontend/src/components/ui/sparkles-text.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { motion } from \"motion/react\";\r\nimport {\r\n  type CSSProperties,\r\n  type ReactElement,\r\n  type ReactNode,\r\n  useEffect,\r\n  useState,\r\n} from \"react\";\r\n\r\nimport { cn } from \"@/lib/utils\";\r\n\r\ntype SparkleItem = {\r\n  id: string;\r\n  x: string;\r\n  y: string;\r\n  color: string;\r\n  delay: number;\r\n  scale: number;\r\n  lifespan: number;\r\n};\r\n\r\nfunction Sparkle({ id, x, y, color, delay, scale }: SparkleItem): ReactElement {\r\n  return (\r\n    <motion.svg\r\n      key={id}\r\n      className=\"pointer-events-none absolute z-20\"\r\n      initial={{ opacity: 0, left: x, top: y }}\r\n      animate={{\r\n        opacity: [0, 1, 0],\r\n        scale: [0, scale, 0],\r\n        rotate: [75, 120, 150],\r\n      }}\r\n      transition={{ duration: 0.8, repeat: Number.POSITIVE_INFINITY, delay }}\r\n      width=\"21\"\r\n      height=\"21\"\r\n      viewBox=\"0 0 21 21\"\r\n    >\r\n      <title>Sparkle</title>\r\n      <path\r\n        d=\"M9.82531 0.843845C10.0553 0.215178 10.9446 0.215178 11.1746 0.843845L11.8618 2.72026C12.4006 4.19229 12.3916 6.39157 13.5 7.5C14.6084 8.60843 16.8077 8.59935 18.2797 9.13822L20.1561 9.82534C20.7858 10.0553 20.7858 10.9447 20.1561 11.1747L18.2797 11.8618C16.8077 12.4007 14.6084 12.3916 13.5 13.5C12.3916 14.6084 12.4006 16.8077 11.8618 18.2798L11.1746 20.1562C10.9446 20.7858 10.0553 20.7858 9.82531 20.1562L9.13819 18.2798C8.59932 16.8077 8.60843 14.6084 7.5 13.5C6.39157 12.3916 4.19225 12.4007 2.72023 11.8618L0.843814 11.1747C0.215148 10.9447 0.215148 10.0553 0.843814 9.82534L2.72023 9.13822C4.19225 8.59935 6.39157 8.60843 7.5 7.5C8.60843 6.39157 8.59932 4.19229 9.13819 2.72026L9.82531 0.843845Z\"\r\n        fill={color}\r\n      />\r\n    </motion.svg>\r\n  );\r\n}\r\n\r\ninterface SparklesTextProps {\r\n  /**\r\n   * @default <div />\r\n   * @type ReactElement\r\n   * @description\r\n   * The component to be rendered as the text\r\n   * */\r\n  as?: ReactElement;\r\n\r\n  /**\r\n   * @default \"\"\r\n   * @type string\r\n   * @description\r\n   * The className of the text\r\n   */\r\n  className?: string;\r\n\r\n  /**\r\n   * @required\r\n   * @type ReactNode\r\n   * @description\r\n   * The content to be displayed\r\n   * */\r\n  children: ReactNode;\r\n\r\n  /**\r\n   * @default 10\r\n   * @type number\r\n   * @description\r\n   * The count of sparkles\r\n   * */\r\n  sparklesCount?: number;\r\n\r\n  /**\r\n   * @default \"{first: '#9E7AFF', second: '#FE8BBB'}\"\r\n   * @type string\r\n   * @description\r\n   * The colors of the sparkles\r\n   * */\r\n  colors?: {\r\n    first: string;\r\n    second: string;\r\n  };\r\n}\r\n\r\nexport function SparklesText({\r\n  children,\r\n  colors = { first: \"#9E7AFF\", second: \"#FE8BBB\" },\r\n  className,\r\n  sparklesCount = 10,\r\n  ...props\r\n}: SparklesTextProps): ReactElement {\r\n  const [sparkles, setSparkles] = useState<SparkleItem[]>([]);\r\n\r\n  useEffect(() => {\r\n    const generateStar = (): SparkleItem => {\r\n      const starX = `${Math.random() * 100}%`;\r\n      const starY = `${Math.random() * 100}%`;\r\n      const color = Math.random() > 0.5 ? colors.first : colors.second;\r\n      const delay = Math.random() * 2;\r\n      const scale = Math.random() * 1 + 0.3;\r\n      const lifespan = Math.random() * 10 + 5;\r\n      const id = `${starX}-${starY}-${Date.now()}`;\r\n      return { id, x: starX, y: starY, color, delay, scale, lifespan };\r\n    };\r\n\r\n    const initializeStars = () => {\r\n      const newSparkles = Array.from({ length: sparklesCount }, generateStar);\r\n      setSparkles(newSparkles);\r\n    };\r\n\r\n    const updateStars = () => {\r\n      setSparkles((currentSparkles) =>\r\n        currentSparkles.map((star) => {\r\n          if (star.lifespan <= 0) {\r\n            return generateStar();\r\n          }\r\n          return { ...star, lifespan: star.lifespan - 0.1 };\r\n        }),\r\n      );\r\n    };\r\n\r\n    initializeStars();\r\n    const interval = setInterval(updateStars, 100);\r\n\r\n    return () => clearInterval(interval);\r\n  }, [colors.first, colors.second, sparklesCount]);\r\n\r\n  return (\r\n    <div\r\n      className={cn(\"text-6xl font-bold\", className)}\r\n      {...props}\r\n      style={\r\n        {\r\n          \"--sparkles-first-color\": `${colors.first}`,\r\n          \"--sparkles-second-color\": `${colors.second}`,\r\n        } as CSSProperties\r\n      }\r\n    >\r\n      <span className=\"relative inline-block\">\r\n        {sparkles.map((sparkle) => (\r\n          <Sparkle key={sparkle.id} {...sparkle} />\r\n        ))}\r\n        <strong>{children}</strong>\r\n      </span>\r\n    </div>\r\n  );\r\n}\r\n"
  },
  {
    "path": "studio/frontend/src/components/ui/spinner.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { cn } from \"@/lib/utils\"\nimport { HugeiconsIcon } from \"@hugeicons/react\"\nimport { Loading03Icon } from \"@hugeicons/core-free-icons\"\n\nfunction Spinner({ className }: { className?: string }) {\n  return (\n    <HugeiconsIcon icon={Loading03Icon} strokeWidth={2} role=\"status\" aria-label=\"Loading\" className={cn(\"size-4 animate-spin\", className)} />\n  )\n}\n\nexport { Spinner }\n"
  },
  {
    "path": "studio/frontend/src/components/ui/switch.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Switch as SwitchPrimitive } from \"radix-ui\";\r\nimport type * as React from \"react\";\r\n\r\nimport { cn } from \"@/lib/utils\";\r\n\r\nfunction Switch({\r\n  className,\r\n  size = \"default\",\r\n  ...props\r\n}: React.ComponentProps<typeof SwitchPrimitive.Root> & {\r\n  size?: \"sm\" | \"default\";\r\n}) {\r\n  return (\r\n    <SwitchPrimitive.Root\r\n      data-slot=\"switch\"\r\n      data-size={size}\r\n      className={cn(\r\n        \"data-checked:bg-primary data-unchecked:bg-input focus-visible:border-ring focus-visible:ring-ring/50 aria-invalid:ring-destructive/20 dark:aria-invalid:ring-destructive/40 aria-invalid:border-destructive dark:aria-invalid:border-destructive/50 dark:data-unchecked:bg-input/80 shrink-0 rounded-full border border-transparent focus-visible:ring-[3px] aria-invalid:ring-[3px] data-[size=default]:h-[18.4px] data-[size=default]:w-[32px] data-[size=sm]:h-[14px] data-[size=sm]:w-[24px] peer group/switch relative inline-flex items-center transition-all outline-none after:absolute after:-inset-x-3 after:-inset-y-2 data-disabled:cursor-not-allowed data-disabled:opacity-50\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    >\r\n      <SwitchPrimitive.Thumb\r\n        data-slot=\"switch-thumb\"\r\n        className=\"bg-background dark:data-unchecked:bg-foreground dark:data-checked:bg-primary-foreground rounded-full group-data-[size=default]/switch:size-4 group-data-[size=sm]/switch:size-3 group-data-[size=default]/switch:data-checked:translate-x-[calc(100%-2px)] group-data-[size=sm]/switch:data-checked:translate-x-[calc(100%-2px)] group-data-[size=default]/switch:data-unchecked:translate-x-0 group-data-[size=sm]/switch:data-unchecked:translate-x-0 pointer-events-none block ring-0 transition-transform\"\r\n      />\r\n    </SwitchPrimitive.Root>\r\n  );\r\n}\r\n\r\nexport { Switch };\r\n"
  },
  {
    "path": "studio/frontend/src/components/ui/table.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type * as React from \"react\";\r\n\r\nimport { cn } from \"@/lib/utils\";\r\n\r\nfunction Table({ className, ...props }: React.ComponentProps<\"table\">) {\r\n  return (\r\n    <div\r\n      data-slot=\"table-container\"\r\n      className=\"relative w-full overflow-x-auto\"\r\n    >\r\n      <table\r\n        data-slot=\"table\"\r\n        className={cn(\"w-full caption-bottom text-sm\", className)}\r\n        {...props}\r\n      />\r\n    </div>\r\n  );\r\n}\r\n\r\nfunction TableHeader({ className, ...props }: React.ComponentProps<\"thead\">) {\r\n  return (\r\n    <thead\r\n      data-slot=\"table-header\"\r\n      className={cn(\"[&_tr]:border-b\", className)}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction TableBody({ className, ...props }: React.ComponentProps<\"tbody\">) {\r\n  return (\r\n    <tbody\r\n      data-slot=\"table-body\"\r\n      className={cn(\"[&_tr:last-child]:border-0\", className)}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction TableFooter({ className, ...props }: React.ComponentProps<\"tfoot\">) {\r\n  return (\r\n    <tfoot\r\n      data-slot=\"table-footer\"\r\n      className={cn(\r\n        \"bg-muted/50 border-t font-medium [&>tr]:last:border-b-0\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction TableRow({ className, ...props }: React.ComponentProps<\"tr\">) {\r\n  return (\r\n    <tr\r\n      data-slot=\"table-row\"\r\n      className={cn(\r\n        \"hover:bg-muted/50 data-[state=selected]:bg-muted border-b transition-colors\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction TableHead({ className, ...props }: React.ComponentProps<\"th\">) {\r\n  return (\r\n    <th\r\n      data-slot=\"table-head\"\r\n      className={cn(\r\n        \"text-foreground h-12 px-3 text-left align-middle font-medium whitespace-nowrap [&:has([role=checkbox])]:pr-0\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction TableCell({ className, ...props }: React.ComponentProps<\"td\">) {\r\n  return (\r\n    <td\r\n      data-slot=\"table-cell\"\r\n      className={cn(\r\n        \"p-3 align-middle whitespace-nowrap [&:has([role=checkbox])]:pr-0\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nfunction TableCaption({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<\"caption\">) {\r\n  return (\r\n    <caption\r\n      data-slot=\"table-caption\"\r\n      className={cn(\"text-muted-foreground mt-4 text-sm\", className)}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nexport {\r\n  Table,\r\n  TableHeader,\r\n  TableBody,\r\n  TableFooter,\r\n  TableHead,\r\n  TableRow,\r\n  TableCell,\r\n  TableCaption,\r\n};\r\n"
  },
  {
    "path": "studio/frontend/src/components/ui/tabs.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"use client\";\r\n\r\n/* eslint-disable react-refresh/only-export-components */\r\n\r\nimport { type VariantProps, cva } from \"class-variance-authority\";\r\nimport { motion } from \"motion/react\";\r\nimport { Tabs as TabsPrimitive } from \"radix-ui\";\r\nimport * as React from \"react\";\r\n\r\nimport { cn } from \"@/lib/utils\";\r\n\r\nconst TabsContext = React.createContext<{ value?: string; id: string }>({\r\n  id: \"\",\r\n});\r\n\r\nexport function Tabs({\r\n  className,\r\n  orientation = \"horizontal\",\r\n  value,\r\n  defaultValue,\r\n  onValueChange,\r\n  ...props\r\n}: React.ComponentProps<typeof TabsPrimitive.Root>): React.ReactElement {\r\n  const [internal, setInternal] = React.useState(defaultValue ?? \"\");\r\n  const current = value ?? internal;\r\n  const id = React.useId();\r\n\r\n  return (\r\n    <TabsContext.Provider value={{ value: current, id }}>\r\n      <TabsPrimitive.Root\r\n        data-slot=\"tabs\"\r\n        data-orientation={orientation}\r\n        value={current}\r\n        onValueChange={(v) => {\r\n          setInternal(v);\r\n          onValueChange?.(v);\r\n        }}\r\n        className={cn(\r\n          \"gap-2 group/tabs flex data-[orientation=horizontal]:flex-col\",\r\n          className,\r\n        )}\r\n        {...props}\r\n      />\r\n    </TabsContext.Provider>\r\n  );\r\n}\r\n\r\nexport const tabsListVariants = cva(\r\n  \"rounded-4xl p-[3px]  group-data-horizontal/tabs:h-9 group-data-vertical/tabs:rounded-2xl data-[variant=line]:rounded-none group/tabs-list text-muted-foreground inline-flex w-fit items-center justify-center group-data-[orientation=vertical]/tabs:h-fit group-data-[orientation=vertical]/tabs:flex-col\",\r\n  {\r\n    variants: {\r\n      variant: {\r\n        default: \"bg-muted\",\r\n        line: \"gap-1 bg-transparent\",\r\n      },\r\n    },\r\n    defaultVariants: {\r\n      variant: \"default\",\r\n    },\r\n  },\r\n);\r\n\r\nexport function TabsList({\r\n  className,\r\n  variant = \"default\",\r\n  ...props\r\n}: React.ComponentProps<typeof TabsPrimitive.List> &\r\n  VariantProps<typeof tabsListVariants>): React.ReactElement {\r\n  return (\r\n    <TabsPrimitive.List\r\n      data-slot=\"tabs-list\"\r\n      data-variant={variant}\r\n      className={cn(tabsListVariants({ variant }), className)}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nexport function TabsTrigger({\r\n  className,\r\n  value,\r\n  children,\r\n  ...props\r\n}: React.ComponentProps<typeof TabsPrimitive.Trigger>): React.ReactElement {\r\n  const ctx = React.useContext(TabsContext);\r\n  const isActive = ctx.value === value;\r\n\r\n  return (\r\n    <TabsPrimitive.Trigger\r\n      data-slot=\"tabs-trigger\"\r\n      value={value}\r\n      className={cn(\r\n        \"gap-1.5 rounded-xl corner-squircle border border-transparent px-2 py-1 text-sm font-medium group-data-vertical/tabs:px-2.5 group-data-vertical/tabs:py-1.5 [&_svg:not([class*='size-'])]:size-4 focus-visible:border-ring focus-visible:ring-ring/50 focus-visible:outline-ring text-foreground/60 hover:text-foreground dark:text-muted-foreground dark:hover:text-foreground relative inline-flex h-[calc(100%-1px)] flex-1 items-center justify-center whitespace-nowrap transition-colors group-data-[orientation=vertical]/tabs:w-full group-data-[orientation=vertical]/tabs:justify-start focus-visible:ring-[3px] focus-visible:outline-1 disabled:pointer-events-none disabled:opacity-50 [&_svg]:pointer-events-none [&_svg]:shrink-0\",\r\n        \"group-data-[variant=line]/tabs-list:bg-transparent group-data-[variant=line]/tabs-list:data-active:bg-transparent dark:group-data-[variant=line]/tabs-list:data-active:border-transparent dark:group-data-[variant=line]/tabs-list:data-active:bg-transparent\",\r\n        \"data-active:text-foreground dark:data-active:text-foreground\",\r\n        \"after:bg-foreground after:absolute after:opacity-0 after:transition-opacity group-data-[orientation=horizontal]/tabs:after:inset-x-0 group-data-[orientation=horizontal]/tabs:after:bottom-[-5px] group-data-[orientation=horizontal]/tabs:after:h-0.5 group-data-[orientation=vertical]/tabs:after:inset-y-0 group-data-[orientation=vertical]/tabs:after:-right-1 group-data-[orientation=vertical]/tabs:after:w-0.5 group-data-[variant=line]/tabs-list:data-active:after:opacity-100\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    >\r\n      {isActive && (\r\n        <motion.span\r\n          layoutId={`tab-bg-${ctx.id}`}\r\n          className=\"absolute inset-0 rounded-xl bg-background dark:bg-input/30 dark:border dark:border-input\"\r\n          transition={{\r\n            type: \"spring\",\r\n            stiffness: 500,\r\n            damping: 35,\r\n            mass: 0.5,\r\n          }}\r\n        />\r\n      )}\r\n      <span className=\"relative z-10\">{children}</span>\r\n    </TabsPrimitive.Trigger>\r\n  );\r\n}\r\n\r\nexport function TabsContent({\r\n  className,\r\n  ...props\r\n}: React.ComponentProps<typeof TabsPrimitive.Content>): React.ReactElement {\r\n  return (\r\n    <TabsPrimitive.Content\r\n      data-slot=\"tabs-content\"\r\n      className={cn(\"text-sm flex-1 outline-none\", className)}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n"
  },
  {
    "path": "studio/frontend/src/components/ui/terminal.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { cn } from \"@/lib/utils\"\nimport {\n  Children,\n  cloneElement,\n  isValidElement,\n  useEffect,\n  useRef,\n  useState,\n} from \"react\"\nimport type { ElementType, ReactElement, ReactNode } from \"react\"\n\ntype TerminalProps = {\n  children: ReactNode\n  className?: string\n  sequence?: boolean\n  startOnView?: boolean\n}\n\ntype InternalLineProps = {\n  __isActive?: boolean\n  __onDone?: () => void\n  __sequence?: boolean\n}\n\nfunction useStartOnView(enabled: boolean): {\n  ref: React.RefObject<HTMLDivElement | null>\n  started: boolean\n} {\n  const ref = useRef<HTMLDivElement | null>(null)\n  const [isInView, setIsInView] = useState(false)\n  const started = !enabled || isInView\n\n  useEffect(() => {\n    if (!enabled) {\n      return\n    }\n\n    const node = ref.current\n    if (!node) {\n      return\n    }\n\n    const observer = new IntersectionObserver(\n      ([entry]) => {\n        if (entry?.isIntersecting) {\n          setIsInView(true)\n          observer.disconnect()\n        }\n      },\n      { threshold: 0.2 }\n    )\n\n    observer.observe(node)\n    return () => observer.disconnect()\n  }, [enabled])\n\n  return { ref, started }\n}\n\nexport function Terminal({\n  children,\n  className,\n  sequence = true,\n  startOnView = true,\n}: TerminalProps): ReactElement {\n  const { ref, started } = useStartOnView(startOnView)\n  const childElements = Children.toArray(children).filter(isValidElement)\n  const [activeIndex, setActiveIndex] = useState(0)\n  const visibleIndex = sequence\n    ? started\n      ? activeIndex\n      : -1\n    : Number.MAX_SAFE_INTEGER\n\n  function handleLineDone(index: number): void {\n    if (!sequence) {\n      return\n    }\n\n    setActiveIndex((prev) => {\n      if (prev !== index) {\n        return prev\n      }\n      return Math.min(index + 1, childElements.length)\n    })\n  }\n\n  return (\n    <div\n      ref={ref}\n      className={cn(\n        \"w-full rounded-2xl border border-border bg-card px-6 py-5 font-mono text-sm text-foreground shadow-2xl\",\n        className\n      )}\n    >\n      {childElements.map((child, index) =>\n        cloneElement(child, {\n          __sequence: sequence,\n          __isActive: !sequence || visibleIndex >= index,\n          __onDone: () => handleLineDone(index),\n          key: child.key ?? index,\n        } as InternalLineProps)\n      )}\n    </div>\n  )\n}\n\ntype AnimatedSpanProps = InternalLineProps & {\n  children: ReactNode\n  className?: string\n  delay?: number\n  startOnView?: boolean\n}\n\nexport function AnimatedSpan({\n  children,\n  className,\n  delay = 0,\n  startOnView = false,\n  __isActive,\n  __sequence,\n  __onDone,\n}: AnimatedSpanProps): ReactElement {\n  const { ref, started } = useStartOnView(startOnView)\n  const [visible, setVisible] = useState(false)\n  const doneRef = useRef(false)\n  const onDoneRef = useRef(__onDone)\n  const shouldStart = __sequence ? __isActive : started\n\n  useEffect(() => {\n    onDoneRef.current = __onDone\n  }, [__onDone])\n\n  useEffect(() => {\n    if (!shouldStart || doneRef.current) {\n      return\n    }\n\n    const timeout = window.setTimeout(() => {\n      setVisible(true)\n      doneRef.current = true\n      onDoneRef.current?.()\n    }, delay)\n\n    return () => window.clearTimeout(timeout)\n  }, [delay, shouldStart])\n\n  return (\n    <div\n      ref={ref}\n      className={cn(\n        \"min-h-5 transition-opacity duration-300\",\n        visible ? \"opacity-100\" : \"opacity-0\",\n        className\n      )}\n    >\n      {children}\n    </div>\n  )\n}\n\ntype TypingAnimationProps = InternalLineProps & {\n  children: string\n  className?: string\n  duration?: number\n  delay?: number\n  as?: ElementType\n  startOnView?: boolean\n}\n\nexport function TypingAnimation({\n  children,\n  className,\n  duration = 60,\n  delay = 0,\n  as: Component = \"span\",\n  startOnView = true,\n  __isActive,\n  __sequence,\n  __onDone,\n}: TypingAnimationProps): ReactElement {\n  const { ref, started } = useStartOnView(startOnView)\n  const [typed, setTyped] = useState(\"\")\n  const doneRef = useRef(false)\n  const onDoneRef = useRef(__onDone)\n  const shouldStart = __sequence ? __isActive : started\n\n  useEffect(() => {\n    onDoneRef.current = __onDone\n  }, [__onDone])\n\n  useEffect(() => {\n    if (!shouldStart || doneRef.current) {\n      return\n    }\n\n    let index = 0\n    let intervalId: number | null = null\n    const startTimer = window.setTimeout(() => {\n      intervalId = window.setInterval(() => {\n        index += 1\n        setTyped(children.slice(0, index))\n\n        if (index >= children.length) {\n          if (intervalId) {\n            window.clearInterval(intervalId)\n          }\n          doneRef.current = true\n          onDoneRef.current?.()\n        }\n      }, duration)\n    }, delay)\n\n    return () => {\n      window.clearTimeout(startTimer)\n      if (intervalId) {\n        window.clearInterval(intervalId)\n      }\n    }\n  }, [children, delay, duration, shouldStart])\n\n  return (\n    <div ref={ref} className=\"min-h-5\">\n      <Component className={cn(\"whitespace-pre-wrap\", className)}>{typed}</Component>\n    </div>\n  )\n}\n"
  },
  {
    "path": "studio/frontend/src/components/ui/textarea.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type * as React from \"react\";\r\n\r\nimport { cn } from \"@/lib/utils\";\r\n\r\nfunction Textarea({ className, ...props }: React.ComponentProps<\"textarea\">) {\r\n  return (\r\n    <textarea\r\n      data-slot=\"textarea\"\r\n      className={cn(\r\n        \"border-input bg-input/30 focus-visible:border-ring focus-visible:ring-ring/50 aria-invalid:ring-destructive/20 dark:aria-invalid:ring-destructive/40 aria-invalid:border-destructive dark:aria-invalid:border-destructive/50 resize-none rounded-xl border px-3 py-3 text-base transition-colors focus-visible:ring-[3px] aria-invalid:ring-[3px] md:text-sm placeholder:text-muted-foreground flex field-sizing-content min-h-16 w-full outline-none disabled:cursor-not-allowed disabled:opacity-50\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n\r\nexport { Textarea };\r\n"
  },
  {
    "path": "studio/frontend/src/components/ui/toggle-group.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"use client\";\r\n\r\nimport type { VariantProps } from \"class-variance-authority\";\r\nimport { ToggleGroup as ToggleGroupPrimitive } from \"radix-ui\";\r\nimport * as React from \"react\";\r\n\r\nimport { toggleVariants } from \"@/components/ui/toggle\";\r\nimport { cn } from \"@/lib/utils\";\r\n\r\nconst ToggleGroupContext = React.createContext<\r\n  VariantProps<typeof toggleVariants> & {\r\n    spacing?: number;\r\n    orientation?: \"horizontal\" | \"vertical\";\r\n  }\r\n>({\r\n  size: \"default\",\r\n  variant: \"default\",\r\n  spacing: 0,\r\n  orientation: \"horizontal\",\r\n});\r\n\r\nfunction ToggleGroup({\r\n  className,\r\n  variant,\r\n  size,\r\n  spacing = 0,\r\n  orientation = \"horizontal\",\r\n  children,\r\n  ...props\r\n}: React.ComponentProps<typeof ToggleGroupPrimitive.Root> &\r\n  VariantProps<typeof toggleVariants> & {\r\n    spacing?: number;\r\n    orientation?: \"horizontal\" | \"vertical\";\r\n  }) {\r\n  return (\r\n    <ToggleGroupPrimitive.Root\r\n      data-slot=\"toggle-group\"\r\n      data-variant={variant}\r\n      data-size={size}\r\n      data-spacing={spacing}\r\n      data-orientation={orientation}\r\n      style={{ \"--gap\": spacing } as React.CSSProperties}\r\n      className={cn(\r\n        \"data-[spacing=0]:data-[variant=outline]:rounded-4xl group/toggle-group flex w-fit flex-row items-center gap-[--spacing(var(--gap))] data-[orientation=vertical]:flex-col data-[orientation=vertical]:items-stretch\",\r\n        className,\r\n      )}\r\n      {...props}\r\n    >\r\n      <ToggleGroupContext.Provider\r\n        value={{ variant, size, spacing, orientation }}\r\n      >\r\n        {children}\r\n      </ToggleGroupContext.Provider>\r\n    </ToggleGroupPrimitive.Root>\r\n  );\r\n}\r\n\r\nfunction ToggleGroupItem({\r\n  className,\r\n  children,\r\n  variant = \"default\",\r\n  size = \"default\",\r\n  ...props\r\n}: React.ComponentProps<typeof ToggleGroupPrimitive.Item> &\r\n  VariantProps<typeof toggleVariants>) {\r\n  const context = React.useContext(ToggleGroupContext);\r\n\r\n  return (\r\n    <ToggleGroupPrimitive.Item\r\n      data-slot=\"toggle-group-item\"\r\n      data-variant={context.variant || variant}\r\n      data-size={context.size || size}\r\n      data-spacing={context.spacing}\r\n      className={cn(\r\n        \"data-[state=on]:bg-muted group-data-[spacing=0]/toggle-group:rounded-none group-data-[spacing=0]/toggle-group:px-3 group-data-[spacing=0]/toggle-group:shadow-none group-data-horizontal/toggle-group:data-[spacing=0]:first:rounded-l-4xl group-data-vertical/toggle-group:data-[spacing=0]:first:rounded-t-xl group-data-horizontal/toggle-group:data-[spacing=0]:last:rounded-r-4xl group-data-vertical/toggle-group:data-[spacing=0]:last:rounded-b-xl shrink-0 focus:z-10 focus-visible:z-10 group-data-horizontal/toggle-group:data-[spacing=0]:data-[variant=outline]:border-l-0 group-data-vertical/toggle-group:data-[spacing=0]:data-[variant=outline]:border-t-0 group-data-horizontal/toggle-group:data-[spacing=0]:data-[variant=outline]:first:border-l group-data-vertical/toggle-group:data-[spacing=0]:data-[variant=outline]:first:border-t\",\r\n        toggleVariants({\r\n          variant: context.variant || variant,\r\n          size: context.size || size,\r\n        }),\r\n        className,\r\n      )}\r\n      {...props}\r\n    >\r\n      {children}\r\n    </ToggleGroupPrimitive.Item>\r\n  );\r\n}\r\n\r\nexport { ToggleGroup, ToggleGroupItem };\r\n"
  },
  {
    "path": "studio/frontend/src/components/ui/toggle.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"use client\";\r\n\r\n/* eslint-disable react-refresh/only-export-components */\r\n\r\nimport { type VariantProps, cva } from \"class-variance-authority\";\r\nimport { Toggle as TogglePrimitive } from \"radix-ui\";\r\nimport type * as React from \"react\";\r\n\r\nimport { cn } from \"@/lib/utils\";\r\n\r\nexport const toggleVariants = cva(\r\n  \"hover:text-foreground aria-pressed:bg-muted focus-visible:border-ring focus-visible:ring-ring/50 aria-invalid:ring-destructive/20 dark:aria-invalid:ring-destructive/40 aria-invalid:border-destructive gap-1 rounded-4xl text-sm font-medium transition-colors [&_svg:not([class*='size-'])]:size-4 group/toggle hover:bg-muted inline-flex items-center justify-center whitespace-nowrap outline-none focus-visible:ring-[3px] disabled:pointer-events-none disabled:opacity-50 [&_svg]:pointer-events-none [&_svg]:shrink-0\",\r\n  {\r\n    variants: {\r\n      variant: {\r\n        default: \"bg-transparent\",\r\n        outline: \"border-input hover:bg-muted border bg-transparent\",\r\n      },\r\n      size: {\r\n        default: \"h-9 min-w-9 rounded-[min(var(--radius-2xl),12px)] px-2.5\",\r\n        sm: \"h-8 min-w-8 px-3\",\r\n        lg: \"h-10 min-w-10 px-2.5\",\r\n      },\r\n    },\r\n    defaultVariants: {\r\n      variant: \"default\",\r\n      size: \"default\",\r\n    },\r\n  },\r\n);\r\n\r\nexport function Toggle({\r\n  className,\r\n  variant = \"default\",\r\n  size = \"default\",\r\n  ...props\r\n}: React.ComponentProps<typeof TogglePrimitive.Root> &\r\n  VariantProps<typeof toggleVariants>): React.ReactElement {\r\n  return (\r\n    <TogglePrimitive.Root\r\n      data-slot=\"toggle\"\r\n      className={cn(toggleVariants({ variant, size, className }))}\r\n      {...props}\r\n    />\r\n  );\r\n}\r\n"
  },
  {
    "path": "studio/frontend/src/components/ui/tooltip.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Tooltip as TooltipPrimitive } from \"radix-ui\";\nimport { createContext, useCallback, useContext, useState } from \"react\";\nimport type * as React from \"react\";\n\nimport { cn } from \"@/lib/utils\";\n\ntype ToggleFn = () => void;\nconst TooltipToggleCtx = createContext<ToggleFn | null>(null);\n\nfunction TooltipProvider({\n  delayDuration = 400,\n  ...props\n}: React.ComponentProps<typeof TooltipPrimitive.Provider>) {\n  return (\n    <TooltipPrimitive.Provider\n      data-slot=\"tooltip-provider\"\n      delayDuration={delayDuration}\n      {...props}\n    />\n  );\n}\n\nfunction Tooltip({\n  open: controlledOpen,\n  onOpenChange: controlledOnOpenChange,\n  ...props\n}: React.ComponentProps<typeof TooltipPrimitive.Root>) {\n  const isControlled = controlledOpen !== undefined;\n  const [clickOpen, setClickOpen] = useState(false);\n\n  const onOpenChange = useCallback(\n    (nextOpen: boolean) => {\n      if (!nextOpen) setClickOpen(false);\n      controlledOnOpenChange?.(nextOpen);\n    },\n    [controlledOnOpenChange],\n  );\n\n  const toggle = useCallback(() => {\n    setClickOpen((prev) => !prev);\n  }, []);\n\n  return (\n    <TooltipProvider>\n      <TooltipToggleCtx.Provider value={toggle}>\n        <TooltipPrimitive.Root\n          data-slot=\"tooltip\"\n          open={isControlled ? controlledOpen : clickOpen || undefined}\n          onOpenChange={onOpenChange}\n          {...props}\n        />\n      </TooltipToggleCtx.Provider>\n    </TooltipProvider>\n  );\n}\n\nfunction TooltipTrigger({\n  onClick,\n  ...props\n}: React.ComponentProps<typeof TooltipPrimitive.Trigger>) {\n  const toggle = useContext(TooltipToggleCtx);\n\n  const handleClick = useCallback(\n    (e: React.MouseEvent<HTMLButtonElement>) => {\n      e.preventDefault();\n      toggle?.();\n      onClick?.(e);\n    },\n    [toggle, onClick],\n  );\n\n  return (\n    <TooltipPrimitive.Trigger\n      data-slot=\"tooltip-trigger\"\n      onClick={handleClick}\n      {...props}\n    />\n  );\n}\n\nfunction TooltipContent({\n  className,\n  sideOffset = 0,\n  children,\n  ...props\n}: React.ComponentProps<typeof TooltipPrimitive.Content>) {\n  return (\n    <TooltipPrimitive.Portal>\n      <TooltipPrimitive.Content\n        data-slot=\"tooltip-content\"\n        sideOffset={sideOffset}\n        className={cn(\n          \"data-open:animate-in data-open:fade-in-0 data-open:zoom-in-95 data-[state=delayed-open]:animate-in data-[state=delayed-open]:fade-in-0 data-[state=delayed-open]:zoom-in-95 data-closed:animate-out data-closed:fade-out-0 data-closed:zoom-out-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2 rounded-2xl corner-squircle px-3 py-1.5 text-xs **:data-[slot=kbd]:rounded-4xl bg-foreground text-background border border-foreground/40 shadow-lg z-[999999] w-fit max-w-xs origin-(--radix-tooltip-content-transform-origin)\",\n          className,\n        )}\n        {...props}\n      >\n        {children}\n        <TooltipPrimitive.Arrow className=\"size-2.5 translate-y-[calc(-50%_-_2px)] rotate-45 rounded-[2px] data-[side=left]:translate-x-[-1.5px] data-[side=right]:translate-x-[1.5px] bg-foreground fill-foreground z-[999999] translate-y-[calc(-50%_-_2px)]\" />\n      </TooltipPrimitive.Content>\n    </TooltipPrimitive.Portal>\n  );\n}\n\nexport { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger };\n"
  },
  {
    "path": "studio/frontend/src/config/env.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { create } from \"zustand\";\n\nexport const env = {\n  MODE: import.meta.env.MODE,\n  DEV: import.meta.env.DEV,\n  PROD: import.meta.env.PROD,\n  BASE_URL: import.meta.env.BASE_URL,\n} as const;\n\n// ── Platform / device type ──────────────────────────────────\n\nexport type DeviceType = \"mac\" | \"windows\" | \"linux\" | string;\n\ninterface PlatformState {\n  deviceType: DeviceType;\n  chatOnly: boolean;\n  fetched: boolean;\n  isChatOnly: () => boolean;\n}\n\nexport const usePlatformStore = create<PlatformState>()((_, get) => ({\n  deviceType: \"linux\",\n  chatOnly: false,\n  fetched: false,\n  isChatOnly: () => get().chatOnly,\n}));\n\nexport async function fetchDeviceType(): Promise<DeviceType> {\n  const { fetched } = usePlatformStore.getState();\n  if (fetched) return usePlatformStore.getState().deviceType;\n\n  try {\n    const res = await fetch(\"/api/health\");\n    if (res.ok) {\n      const data = (await res.json()) as { device_type?: string; chat_only?: boolean };\n      const deviceType = data.device_type ?? \"linux\";\n      const chatOnly = data.chat_only ?? deviceType === \"mac\";\n      usePlatformStore.setState({ deviceType, chatOnly, fetched: true });\n      return deviceType;\n    }\n  } catch (err) {\n    console.warn(\"[platform] Failed to fetch device type, will retry\", err);\n  }\n\n  return usePlatformStore.getState().deviceType;\n}\n"
  },
  {
    "path": "studio/frontend/src/config/training.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { ModelType, StepConfig } from \"@/types/training\";\nimport type { PipelineType } from \"@huggingface/hub\";\n\nexport const STEPS: StepConfig[] = [\n  {\n    number: 1,\n    title: \"Model Type\",\n    subtitle: \"Select type\",\n    description: \"Choose the type of model you want to fine-tune\",\n  },\n  {\n    number: 2,\n    title: \"Model\",\n    subtitle: \"Select model\",\n    description: \"Choose a base model and training method\",\n  },\n  {\n    number: 3,\n    title: \"Dataset\",\n    subtitle: \"Add dataset\",\n    description: \"Select or upload a training dataset\",\n  },\n  {\n    number: 4,\n    title: \"Parameters\",\n    subtitle: \"Configure\",\n    description: \"Fine-tune your training hyperparameters\",\n  },\n  {\n    number: 5,\n    title: \"Summary\",\n    subtitle: \"Review\",\n    description: \"Review your configuration before starting\",\n  },\n];\n\nexport const MODEL_TYPES: ReadonlyArray<{\n  value: ModelType;\n  label: string;\n  description: string;\n}> = [\n  {\n    value: \"text\",\n    label: \"Text\",\n    description: \"Language models\",\n  },\n    {\n      value: \"vision\",\n      label: \"Vision\",\n      description: \"Image understanding models\",\n    },\n    {\n      value: \"audio\",\n      label: \"Audio\",\n      description: \"Audio and speech models\",\n    },\n    {\n      value: \"embeddings\",\n      label: \"Embeddings\",\n      description: \"Text embedding models\",\n    },\n  ];\n\nexport const CONTEXT_LENGTHS = [512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144];\n\nexport const TARGET_MODULES = [\n  \"q_proj\",\n  \"k_proj\",\n  \"v_proj\",\n  \"o_proj\",\n  \"gate_proj\",\n  \"up_proj\",\n  \"down_proj\",\n];\n\nexport const OPTIMIZER_OPTIONS: ReadonlyArray<{ value: string; label: string }> = [\n  { value: \"adamw_8bit\", label: \"AdamW 8-bit\" },\n  { value: \"paged_adamw_8bit\", label: \"Paged AdamW 8-bit\" },\n  { value: \"adamw_bnb_8bit\", label: \"AdamW BNB 8-bit\" },\n  { value: \"paged_adamw_32bit\", label: \"Paged AdamW 32-bit\" },\n  { value: \"adamw_torch\", label: \"AdamW (PyTorch)\" },\n  { value: \"adamw_torch_fused\", label: \"AdamW (PyTorch Fused)\" },\n];\n\nexport const LR_SCHEDULER_OPTIONS: ReadonlyArray<{ value: string; label: string }> = [\n  { value: \"linear\", label: \"Linear\" },\n  { value: \"cosine\", label: \"Cosine\" },\n];\n\nexport const DEFAULT_HYPERPARAMS = {\n  epochs: 3,\n  contextLength: 2048,\n  learningRate: 2e-4,\n  optimizerType: \"adamw_8bit\",\n  lrSchedulerType: \"linear\",\n  loraRank: 16,\n  loraAlpha: 32,\n  loraDropout: 0.05,\n  loraVariant: \"lora\" as const,\n  batchSize: 4,\n  gradientAccumulation: 8,\n  weightDecay: 0.01,\n  warmupSteps: 5,\n  maxSteps: 60,\n  saveSteps: 0,\n  evalSteps: 0.00,\n  packing: false,\n  trainOnCompletions: false,\n  gradientCheckpointing: \"unsloth\" as const,\n  randomSeed: 3407,\n  enableWandb: false,\n  wandbToken: \"\",\n  wandbProject: \"llm-finetuning\",\n  enableTensorboard: false,\n  tensorboardDir: \"runs\",\n  logFrequency: 10,\n  trustRemoteCode: false,\n  finetuneVisionLayers: true,\n  finetuneLanguageLayers: true,\n  finetuneAttentionModules: true,\n  finetuneMLPModules: true,\n  targetModules: TARGET_MODULES,\n};\n\nexport const MODEL_TYPE_TO_HF_TASK: Record<ModelType, PipelineType> = {\n  text: \"text-generation\",\n  vision: \"image-text-to-text\",\n  audio: \"text-to-speech\",\n  embeddings: \"feature-extraction\",\n};\n\n\nexport const PRIORITY_TRAINING_MODELS: readonly string[] = [\n  \"unsloth/Qwen3.5-2B\",\n  \"unsloth/Qwen3.5-9B\",\n  \"unsloth/gpt-oss-20b\",\n  \"unsloth/NVIDIA-Nemotron-3-Nano-4B\",\n  \"unsloth/Qwen3-0.6B\",\n  \"unsloth/gemma-3-4b-it\",\n  \"unsloth/embeddinggemma-300m\",\n  \"unsloth/orpheus-3b-0.1-ft\",\n  \"unsloth/Llama-3.1-8B-Instruct\",\n  \"unsloth/Llama-3.2-3B-Instruct\",\n];\n\n/** Pin priority models to the top of a list of model IDs, preserving their defined order. */\nexport function applyPriorityOrdering(ids: string[]): string[] {\n  const idSet = new Set(ids);\n  const pinned = PRIORITY_TRAINING_MODELS.filter((id) => idSet.has(id));\n  const pinnedSet = new Set(pinned);\n  const rest = ids.filter((id) => !pinnedSet.has(id));\n  return [...pinned, ...rest];\n}\n"
  },
  {
    "path": "studio/frontend/src/features/auth/api.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport {\n  clearAuthTokens,\n  getAuthToken,\n  getRefreshToken,\n  mustChangePassword,\n  storeAuthTokens,\n} from \"./session\";\n\ntype RefreshResponse = {\n  access_token: string;\n  refresh_token: string;\n  must_change_password: boolean;\n};\n\nlet isRedirecting = false;\n\nasync function isPasswordChangeRequiredResponse(response: Response): Promise<boolean> {\n  if (response.status !== 403) return false;\n\n  try {\n    const payload = (await response.clone().json()) as { detail?: string };\n    return payload.detail === \"Password change required\";\n  } catch {\n    return false;\n  }\n}\n\nasync function redirectToAuth(): Promise<void> {\n  if (isRedirecting) return;\n  isRedirecting = true;\n\n  let target = \"/login\";\n  try {\n    const res = await fetch(\"/api/auth/status\");\n    if (res.ok) {\n      const data = (await res.json()) as { requires_password_change: boolean };\n      if (data.requires_password_change || mustChangePassword()) target = \"/change-password\";\n    }\n  } catch {\n    // Fall through to /login on error\n  }\n\n  window.location.href = target;\n}\n\nexport async function refreshSession(): Promise<boolean> {\n  const refreshToken = getRefreshToken();\n  if (!refreshToken) return false;\n\n  try {\n    const response = await fetch(\"/api/auth/refresh\", {\n      method: \"POST\",\n      headers: { \"Content-Type\": \"application/json\" },\n      body: JSON.stringify({ refresh_token: refreshToken }),\n    });\n\n    if (!response.ok) {\n      clearAuthTokens();\n      return false;\n    }\n\n    const payload = (await response.json()) as RefreshResponse;\n    storeAuthTokens(\n      payload.access_token,\n      payload.refresh_token,\n      payload.must_change_password,\n    );\n    return true;\n  } catch {\n    return false;\n  }\n}\n\nexport async function authFetch(\n  input: RequestInfo | URL,\n  init?: RequestInit,\n): Promise<Response> {\n  const headers = new Headers(init?.headers);\n  const accessToken = getAuthToken();\n  if (accessToken) {\n    headers.set(\"Authorization\", `Bearer ${accessToken}`);\n  }\n\n  let response: Response;\n  try {\n    response = await fetch(input, { ...init, headers });\n  } catch (err) {\n    if (err instanceof TypeError) {\n      throw new Error(\"Studio isn't running -- please relaunch it.\");\n    }\n    throw err;\n  }\n  if (await isPasswordChangeRequiredResponse(response)) {\n    void redirectToAuth();\n    return response;\n  }\n  if (response.status !== 401) return response;\n\n  const refreshed = await refreshSession();\n  if (!refreshed) {\n    clearAuthTokens();\n    void redirectToAuth();\n    return response;\n  }\n\n  if (mustChangePassword()) {\n    void redirectToAuth();\n    return response;\n  }\n\n  const retryHeaders = new Headers(init?.headers);\n  const newToken = getAuthToken();\n  if (newToken) {\n    retryHeaders.set(\"Authorization\", `Bearer ${newToken}`);\n  } else {\n    clearAuthTokens();\n  }\n\n  return fetch(input, { ...init, headers: retryHeaders });\n}\n\nexport function logout(): void {\n  clearAuthTokens();\n}\n"
  },
  {
    "path": "studio/frontend/src/features/auth/change-password-page.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { LightRays } from \"@/components/ui/light-rays\";\nimport { Card } from \"@/components/ui/card\";\nimport { AuthForm } from \"./components/auth-form\";\n\nexport function ChangePasswordPage() {\n  return (\n    <div className=\"relative flex min-h-screen items-center justify-center overflow-hidden bg-background px-4 py-8 sm:px-6 sm:py-10 md:px-10\">\n      <LightRays\n        count={6}\n        color=\"rgba(34, 197, 94, 0.25)\"\n        blur={34}\n        speed={15}\n        length=\"70vh\"\n        style={{ opacity: 0.4 }}\n      />\n      <Card className=\"relative z-10 w-full max-w-sm px-5 py-6 shadow-border ring-1 ring-border sm:px-6 sm:py-8\">\n        <AuthForm mode=\"change-password\" />\n      </Card>\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/auth/components/auth-form.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Button } from \"@/components/ui/button\";\nimport { Input } from \"@/components/ui/input\";\nimport { Label } from \"@/components/ui/label\";\nimport { Link, useNavigate } from \"@tanstack/react-router\";\nimport { Eye, EyeOff } from \"lucide-react\";\nimport { useEffect, useState } from \"react\";\nimport type { ReactElement } from \"react\";\nimport type { SyntheticEvent } from \"react\";\nimport { refreshSession } from \"../api\";\n\n// Bootstrap credentials injected into index.html by the backend\n// (only present while default admin must_change_password is true)\ndeclare global {\n  interface Window {\n    __UNSLOTH_BOOTSTRAP__?: { username: string; password: string };\n  }\n}\n\nimport {\n  clearAuthTokens,\n  getAuthToken,\n  getPostAuthRoute,\n  hasAuthToken,\n  hasRefreshToken,\n  mustChangePassword,\n  resetOnboardingDone,\n  setMustChangePassword,\n  storeAuthTokens,\n} from \"../session\";\n\ntype AuthMode = \"login\" | \"change-password\";\n\ntype AuthStatusResponse = {\n  initialized: boolean;\n  requires_password_change: boolean;\n};\n\ntype TokenResponse = {\n  access_token: string;\n  refresh_token: string;\n  must_change_password: boolean;\n};\n\nasync function loginWithPassword(\n  username: string,\n  password: string,\n): Promise<TokenResponse> {\n  const response = await fetch(\"/api/auth/login\", {\n    method: \"POST\",\n    headers: {\n      \"Content-Type\": \"application/json\",\n    },\n    body: JSON.stringify({\n      username: username.trim(),\n      password,\n    }),\n  });\n\n  if (!response.ok) {\n    const errorPayload = (await response.json().catch(() => null)) as { detail?: string } | null;\n    throw new Error(errorPayload?.detail ?? \"Login failed.\");\n  }\n\n  return (await response.json()) as TokenResponse;\n}\n\ntype AuthFormProps = {\n  mode: AuthMode;\n};\n\nconst HIDDEN_LOGIN_USERNAME = \"unsloth\";\n\nexport function AuthForm({ mode }: AuthFormProps): ReactElement | null {\n  const navigate = useNavigate();\n  const isLoginMode = mode === \"login\";\n  const [showPassword, setShowPassword] = useState(false);\n  const username = HIDDEN_LOGIN_USERNAME;\n  const [password, setPassword] = useState(\"\");\n  const [newPassword, setNewPassword] = useState(\"\");\n  const [confirmPassword, setConfirmPassword] = useState(\"\");\n  const [loading, setLoading] = useState(false);\n  const [statusLoading, setStatusLoading] = useState(true);\n  const [initialized, setInitialized] = useState<boolean | null>(null);\n  const [requiresPasswordChange, setRequiresPasswordChange] = useState(false);\n  const [error, setError] = useState<string | null>(null);\n\n  useEffect(() => {\n    let canceled = false;\n\n    async function initializeAuthForm(): Promise<void> {\n      // Always check the server first — localStorage flags can be stale\n      // (e.g. tokens from a previous install attempt).  The server's\n      // /api/auth/status is the source of truth for requires_password_change.\n      try {\n        const response = await fetch(\"/api/auth/status\");\n        if (!response.ok) throw new Error(\"Failed to load auth status.\");\n        const result = (await response.json()) as AuthStatusResponse;\n        if (!canceled) {\n          setInitialized(result.initialized);\n          setRequiresPasswordChange(result.requires_password_change);\n\n          // Redirect between login ↔ change-password based on server state\n          if (mode === \"login\" && result.requires_password_change) {\n            navigate({ to: \"/change-password\" });\n            return;\n          }\n          if (mode === \"change-password\" && !result.requires_password_change && !mustChangePassword()) {\n            navigate({ to: \"/login\" });\n            return;\n          }\n\n          // On login page, if user already has a valid session and no\n          // password change is required, skip straight to the app.\n          if (isLoginMode && !result.requires_password_change) {\n            if (hasRefreshToken()) {\n              const refreshed = await refreshSession();\n              if (refreshed) {\n                if (!canceled) setStatusLoading(false);\n                navigate({ to: getPostAuthRoute() });\n                return;\n              }\n            }\n            if (hasAuthToken()) {\n              if (!canceled) setStatusLoading(false);\n              navigate({ to: getPostAuthRoute() });\n              return;\n            }\n          }\n        }\n      } catch (err: unknown) {\n        if (!canceled) {\n          setError(err instanceof Error ? err.message : \"Failed to load.\");\n        }\n      } finally {\n        if (!canceled) setStatusLoading(false);\n      }\n    }\n\n    void initializeAuthForm();\n\n    return () => {\n      canceled = true;\n    };\n  }, [navigate]);\n\n  // Seed password from bootstrap credentials injected into HTML\n  useEffect(() => {\n    const bootstrap = window.__UNSLOTH_BOOTSTRAP__;\n    if (bootstrap) {\n      if (!isLoginMode && !password) {\n        setPassword(bootstrap.password);\n      }\n    }\n  }, []);\n\n  const blockedByState =\n    initialized === false ||\n    (mode === \"login\" && requiresPasswordChange) ||\n    (mode === \"change-password\" && !requiresPasswordChange && !mustChangePassword());\n\n  let helperText: string | null = null;\n  if (initialized === false) {\n    helperText = \"Auth is still bootstrapping the default admin account.\";\n  } else if (isLoginMode && requiresPasswordChange) {\n    helperText = \"Sign in once with the seeded credentials to change the password.\";\n  } else if (!isLoginMode && !requiresPasswordChange && !mustChangePassword()) {\n    helperText = \"Password already updated. Use the login screen.\";\n  }\n  const title = isLoginMode ? \"Welcome back\" : \"Setup your account\";\n  const subtitle = isLoginMode  \n    ? \"Sign in with your password.\"\n    : \"Choose a new password\";\n  const submitLabel = isLoginMode ? \"Login\" : \"Change password\";\n  const showSwitchLink = !isLoginMode;\n  const switchText = \"Password already setup? \";\n  const switchLinkTo = \"/login\";\n  const switchLinkText = \"Back to login\";\n  const currentPassword = password || window.__UNSLOTH_BOOTSTRAP__?.password || \"\";\n  const invalidChangePasswordForm =\n    !isLoginMode &&\n    (newPassword.length < 8 || newPassword !== confirmPassword || currentPassword === newPassword);\n  const showPasswordMismatchWarning =\n    !isLoginMode &&\n    newPassword.length > 0 &&\n    confirmPassword.length > 0 &&\n    newPassword !== confirmPassword;\n\n  async function handleSubmit(event: SyntheticEvent<HTMLFormElement>) {\n    event.preventDefault();\n    setError(null);\n\n    if (!isLoginMode) {\n      if (!currentPassword) {\n        setError(\"Unable to initialize setup. Reload the page and try again.\");\n        return;\n      }\n      if (newPassword.length < 8) {\n        setError(\"New password must be at least 8 characters.\");\n        return;\n      }\n      if (newPassword !== confirmPassword) {\n        setError(\"Passwords do not match.\");\n        return;\n      }\n      if (currentPassword === newPassword) {\n        setError(\"New password must be different from your current password.\");\n        return;\n      }\n    }\n\n    setLoading(true);\n    try {\n      let token: TokenResponse;\n\n      if (isLoginMode) {\n        token = await loginWithPassword(username, password);\n      } else {\n        let accessToken = getAuthToken();\n\n        if (hasRefreshToken()) {\n          const refreshed = await refreshSession();\n          accessToken = getAuthToken();\n          if (!refreshed) {\n            clearAuthTokens();\n            accessToken = null;\n          }\n        }\n\n        if (!accessToken) {\n          const bootstrapToken = await loginWithPassword(username, currentPassword);\n          storeAuthTokens(\n            bootstrapToken.access_token,\n            bootstrapToken.refresh_token,\n            bootstrapToken.must_change_password,\n          );\n          setMustChangePassword(bootstrapToken.must_change_password);\n          accessToken = bootstrapToken.access_token;\n        }\n\n        const response = await fetch(\"/api/auth/change-password\", {\n          method: \"POST\",\n          headers: {\n            \"Content-Type\": \"application/json\",\n            Authorization: `Bearer ${accessToken}`,\n          },\n          body: JSON.stringify({\n            current_password: currentPassword,\n            new_password: newPassword,\n          }),\n        });\n\n        if (!response.ok) {\n          let message = \"Password update failed.\";\n          const errorPayload = (await response\n            .json()\n            .catch(() => null)) as { detail?: string } | null;\n          if (errorPayload?.detail) message = errorPayload.detail;\n          throw new Error(message);\n        }\n\n        token = (await response.json()) as TokenResponse;\n      }\n\n      if (!isLoginMode) {\n        resetOnboardingDone();\n        setRequiresPasswordChange(false);\n        setMustChangePassword(false);\n      } else {\n        setMustChangePassword(token.must_change_password);\n      }\n      storeAuthTokens(\n        token.access_token,\n        token.refresh_token,\n        token.must_change_password,\n      );\n      navigate({ to: getPostAuthRoute() });\n    } catch (err: unknown) {\n      setError(err instanceof Error ? err.message : \"Auth failed.\");\n    } finally {\n      setLoading(false);\n    }\n  }\n\n  if (statusLoading && initialized === null && error === null) return null;\n\n  return (\n    <div className=\"w-full max-w-sm space-y-6\">\n      <div className=\"space-y-1.5 text-center\">\n        <img\n          src=\"/Sloth emojis/large sloth wave.png\"\n          alt=\"Unsloth waving mascot\"\n          className=\"mx-auto mb-2 h-20 w-20 object-contain\"\n        />\n        <h2 className=\"text-2xl font-semibold text-foreground\">{title}</h2>\n        <p className=\"text-muted-foreground\">{subtitle}</p>\n      </div>\n      <form className=\"space-y-5\" onSubmit={handleSubmit}>\n        {isLoginMode && (\n          <div className=\"space-y-2\">\n            <Label htmlFor=\"password\">Password</Label>\n            <div className=\"relative\">\n              <Input\n                id=\"password\"\n                type={showPassword ? \"text\" : \"password\"}\n                className=\"pr-10\"\n                autoComplete=\"current-password\"\n                value={password}\n                onChange={(event) => setPassword(event.target.value)}\n                minLength={8}\n                required\n              />\n              <Button\n                type=\"button\"\n                variant=\"ghost\"\n                size=\"icon\"\n                className=\"absolute right-0 top-0 h-full px-3 text-muted-foreground hover:bg-transparent\"\n                onClick={() => setShowPassword((prev) => !prev)}\n              >\n                {showPassword ? (\n                  <EyeOff className=\"h-4 w-4\" />\n                ) : (\n                  <Eye className=\"h-4 w-4\" />\n                )}\n              </Button>\n            </div>\n          </div>\n        )}\n\n        {!isLoginMode && (\n          <>\n            <div className=\"space-y-2\">\n              <Label htmlFor=\"new-password\">New password</Label>\n              <div className=\"relative\">\n                <Input\n                  id=\"new-password\"\n                  type={showPassword ? \"text\" : \"password\"}\n                  className=\"pr-10\"\n                  autoComplete=\"new-password\"\n                  value={newPassword}\n                  onChange={(event) => setNewPassword(event.target.value)}\n                  minLength={8}\n                  required\n                />\n                <Button\n                  type=\"button\"\n                  variant=\"ghost\"\n                  size=\"icon\"\n                  className=\"absolute right-0 top-0 h-full px-3 text-muted-foreground hover:bg-transparent\"\n                  onClick={() => setShowPassword((prev) => !prev)}\n                >\n                  {showPassword ? (\n                    <EyeOff className=\"h-4 w-4\" />\n                  ) : (\n                    <Eye className=\"h-4 w-4\" />\n                  )}\n                </Button>\n              </div>\n            </div>\n            <div className=\"space-y-2\">\n              <Label htmlFor=\"confirm-password\">Confirm password</Label>\n              <Input\n                id=\"confirm-password\"\n                type=\"password\"\n                autoComplete=\"new-password\"\n                value={confirmPassword}\n                onChange={(event) => setConfirmPassword(event.target.value)}\n                minLength={8}\n                required\n              />\n            </div>\n            <p\n              className={`min-h-4 text-xs ${\n                showPasswordMismatchWarning ? \"text-destructive\" : \"text-muted-foreground\"\n              }`}\n              aria-live=\"polite\"\n            >\n              {showPasswordMismatchWarning\n                ? \"Please ensure passwords match.\"\n                : \"Must be at least 8 characters.\"}\n            </p>\n          </>\n        )}\n\n        {helperText && (\n          <p className=\"text-center text-sm text-amber-600\">{helperText}</p>\n        )}\n        {error && <p className=\"text-center text-sm text-destructive\">{error}</p>}\n\n        <Button\n          type=\"submit\"\n          className=\"w-full\"\n          disabled={\n            loading ||\n            statusLoading ||\n            blockedByState ||\n            (isLoginMode && password.length < 8) ||\n            invalidChangePasswordForm\n          }\n        >\n          {loading ? \"Please wait...\" : submitLabel}\n        </Button>\n      </form>\n\n      {showSwitchLink && (\n        <p className=\"text-center text-sm text-muted-foreground\">\n          {switchText}\n          <Link to={switchLinkTo} className=\"text-primary hover:underline\">\n            {switchLinkText}\n          </Link>\n        </p>\n      )}\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/auth/index.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nexport { LoginPage } from \"./login-page\";\nexport { ChangePasswordPage } from \"./change-password-page\";\nexport { authFetch, refreshSession } from \"./api\";\nexport {\n  getPostAuthRoute,\n  hasAuthToken,\n  hasRefreshToken,\n  isOnboardingDone,\n  markOnboardingDone,\n  mustChangePassword,\n  resetOnboardingDone,\n  setMustChangePassword,\n} from \"./session\";\n"
  },
  {
    "path": "studio/frontend/src/features/auth/login-page.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { LightRays } from \"@/components/ui/light-rays\";\nimport { Card } from \"@/components/ui/card\";\nimport { AuthForm } from \"./components/auth-form\";\n\nexport function LoginPage() {\n  return (\n    <div className=\"relative flex min-h-screen items-center justify-center overflow-hidden bg-background px-4 py-8 sm:px-6 sm:py-10 md:px-10\">\n      <LightRays\n        count={6}\n        color=\"rgba(34, 197, 94, 0.25)\"\n        blur={34}\n        speed={15}\n        length=\"70vh\"\n        style={{ opacity: 0.4 }}\n      />\n      <Card className=\"relative z-10 w-full max-w-sm px-5 py-6 shadow-border ring-1 ring-border sm:px-6 sm:py-8\">\n        <AuthForm mode=\"login\" />\n      </Card>\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/auth/session.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { usePlatformStore } from \"@/config/env\";\n\nexport const AUTH_TOKEN_KEY = \"unsloth_auth_token\";\nexport const AUTH_REFRESH_TOKEN_KEY = \"unsloth_auth_refresh_token\";\nexport const ONBOARDING_DONE_KEY = \"unsloth_onboarding_done\";\nexport const AUTH_MUST_CHANGE_PASSWORD_KEY = \"unsloth_auth_must_change_password\";\n\ntype PostAuthRoute = \"/onboarding\" | \"/studio\" | \"/change-password\" | \"/chat\";\n\nfunction canUseStorage(): boolean {\n  return typeof window !== \"undefined\";\n}\n\nexport function hasAuthToken(): boolean {\n  if (!canUseStorage()) return false;\n  return Boolean(localStorage.getItem(AUTH_TOKEN_KEY));\n}\n\nexport function hasRefreshToken(): boolean {\n  if (!canUseStorage()) return false;\n  return Boolean(localStorage.getItem(AUTH_REFRESH_TOKEN_KEY));\n}\n\nexport function getAuthToken(): string | null {\n  if (!canUseStorage()) return null;\n  return localStorage.getItem(AUTH_TOKEN_KEY);\n}\n\nexport function getRefreshToken(): string | null {\n  if (!canUseStorage()) return null;\n  return localStorage.getItem(AUTH_REFRESH_TOKEN_KEY);\n}\n\nexport function storeAuthTokens(\n  accessToken: string,\n  refreshToken: string,\n  mustChangePassword = false,\n): void {\n  if (!canUseStorage()) return;\n  localStorage.setItem(AUTH_TOKEN_KEY, accessToken);\n  localStorage.setItem(AUTH_REFRESH_TOKEN_KEY, refreshToken);\n  localStorage.setItem(AUTH_MUST_CHANGE_PASSWORD_KEY, String(mustChangePassword));\n}\n\nexport function clearAuthTokens(): void {\n  if (!canUseStorage()) return;\n  localStorage.removeItem(AUTH_TOKEN_KEY);\n  localStorage.removeItem(AUTH_REFRESH_TOKEN_KEY);\n  localStorage.removeItem(AUTH_MUST_CHANGE_PASSWORD_KEY);\n}\n\nexport function mustChangePassword(): boolean {\n  if (!canUseStorage()) return false;\n  return localStorage.getItem(AUTH_MUST_CHANGE_PASSWORD_KEY) === \"true\";\n}\n\nexport function setMustChangePassword(required: boolean): void {\n  if (!canUseStorage()) return;\n  localStorage.setItem(AUTH_MUST_CHANGE_PASSWORD_KEY, String(required));\n}\n\nexport function isOnboardingDone(): boolean {\n  if (!canUseStorage()) return false;\n  return localStorage.getItem(ONBOARDING_DONE_KEY) === \"true\";\n}\n\nexport function markOnboardingDone(): void {\n  if (!canUseStorage()) return;\n  localStorage.setItem(ONBOARDING_DONE_KEY, \"true\");\n}\n\nexport function resetOnboardingDone(): void {\n  if (!canUseStorage()) return;\n  localStorage.removeItem(ONBOARDING_DONE_KEY);\n}\n\nexport function getPostAuthRoute(): PostAuthRoute {\n  if (mustChangePassword()) return \"/change-password\";\n  if (usePlatformStore.getState().isChatOnly()) return \"/chat\";\n  return isOnboardingDone() ? \"/studio\" : \"/onboarding\";\n}\n"
  },
  {
    "path": "studio/frontend/src/features/chat/api/chat-adapter.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { ChatModelAdapter } from \"@assistant-ui/react\";\nimport type { MessageTiming, ToolCallMessagePart } from \"@assistant-ui/core\";\nimport { toast } from \"sonner\";\nimport {\n  generateAudio,\n  listCachedGguf,\n  listCachedModels,\n  listGgufVariants,\n  loadModel,\n  streamChatCompletions,\n} from \"./chat-api\";\nimport { db } from \"../db\";\nimport { useChatRuntimeStore } from \"../stores/chat-runtime-store\";\nimport type { ChatModelSummary } from \"../types/runtime\";\nimport {\n  hasClosedThinkTag,\n  parseAssistantContent,\n} from \"../utils/parse-assistant-content\";\n\ntype RunMessages = Parameters<ChatModelAdapter[\"run\"]>[0][\"messages\"];\ntype RunMessage = RunMessages[number];\n\n/** Tracks which user messages were sent with an audio file (messageId → filename). */\nexport const sentAudioNames = new Map<string, string>();\n\n/** Parse \"Title: ...\\nURL: ...\\nSnippet: ...\" blocks into source content parts. */\nfunction parseSourcesFromResult(raw: string): { type: \"source\"; sourceType: \"url\"; id: string; url: string; title: string }[] {\n  if (!raw) return [];\n  const blocks = raw.split(/\\n---\\n/).filter(Boolean);\n  const sources: { type: \"source\"; sourceType: \"url\"; id: string; url: string; title: string }[] = [];\n  for (const block of blocks) {\n    const titleMatch = block.match(/Title:\\s*(.+)/);\n    const urlMatch = block.match(/URL:\\s*(.+)/);\n    if (titleMatch && urlMatch) {\n      const url = urlMatch[1].trim();\n      sources.push({\n        type: \"source\" as const,\n        sourceType: \"url\" as const,\n        id: url,\n        url,\n        title: titleMatch[1].trim(),\n      });\n    }\n  }\n  return sources;\n}\n\nfunction estimateTokenCount(text: string): number | undefined {\n  const trimmed = text.trim();\n  if (!trimmed) {\n    return undefined;\n  }\n  return Math.max(1, Math.round(trimmed.length / 4));\n}\n\nfunction buildTiming(\n  streamStartTime: number,\n  totalChunks: number,\n  firstTokenTime?: number,\n  totalStreamTime?: number,\n  tokenCount?: number,\n  toolCallCount = 0,\n): MessageTiming {\n  return {\n    streamStartTime,\n    firstTokenTime,\n    totalStreamTime,\n    tokenCount,\n    tokensPerSecond:\n      typeof totalStreamTime === \"number\" &&\n      totalStreamTime > 0 &&\n      typeof tokenCount === \"number\"\n        ? tokenCount / (totalStreamTime / 1000)\n        : undefined,\n    totalChunks,\n    toolCallCount,\n  };\n}\n\nfunction collectTextParts(message: RunMessage): string[] {\n  const textParts = message.content\n    .filter((part) => part.type === \"text\")\n    .map((part) => part.text);\n\n  if (\"attachments\" in message && (message.attachments?.length ?? 0) > 0) {\n    for (const attachment of message.attachments ?? []) {\n      for (const part of attachment.content ?? []) {\n        if (part.type === \"text\") {\n          textParts.push(part.text);\n        }\n      }\n    }\n  }\n\n  return textParts;\n}\n\nfunction toOpenAIMessage(message: RunMessage): {\n  role: \"system\" | \"user\" | \"assistant\";\n  content: string;\n} | null {\n  if (\n    message.role !== \"system\" &&\n    message.role !== \"user\" &&\n    message.role !== \"assistant\"\n  ) {\n    return null;\n  }\n\n  let content = collectTextParts(message).join(\"\\n\");\n  // Strip inline audio base64 from prior assistant messages to avoid\n  // inflating token counts (e.g. audio-player responses with embedded WAV).\n  if (message.role === \"assistant\") {\n    content = content.replace(\n      /data:audio\\/[a-z0-9.+-]+;base64,[A-Za-z0-9+/=]+/g,\n      \"[audio]\",\n    );\n  }\n\n  return { role: message.role, content };\n}\n\nfunction extractImageBase64(input: string): string | undefined {\n  if (!input) {\n    return undefined;\n  }\n  if (input.startsWith(\"data:\")) {\n    const commaIndex = input.indexOf(\",\");\n    return commaIndex >= 0 ? input.slice(commaIndex + 1) : undefined;\n  }\n  return input;\n}\n\nfunction findLatestUserImageBase64(messages: RunMessages): string | undefined {\n  for (let i = messages.length - 1; i >= 0; i -= 1) {\n    const message = messages[i];\n    if (!message || message.role !== \"user\") {\n      continue;\n    }\n\n    // Image in message.content (e.g. compare view appends content with image parts)\n    for (const part of message.content ?? []) {\n      if (part.type === \"image\" && \"image\" in part) {\n        const encoded = extractImageBase64(part.image);\n        if (encoded) return encoded;\n      }\n    }\n\n    // Image in message.attachments (e.g. chat composer)\n    if (\"attachments\" in message && (message.attachments?.length ?? 0) > 0) {\n      for (const attachment of message.attachments ?? []) {\n        for (const part of attachment.content ?? []) {\n          if (part.type !== \"image\") {\n            continue;\n          }\n          const encoded = extractImageBase64(part.image);\n          if (encoded) {\n            return encoded;\n          }\n        }\n      }\n    }\n  }\n\n  return undefined;\n}\n\nfunction findLatestUserAudioBase64(messages: RunMessages): string | undefined {\n  // Check message content parts (from compare view's CompareMessagePart with type: \"audio\")\n  for (let i = messages.length - 1; i >= 0; i -= 1) {\n    const message = messages[i];\n    if (!message || message.role !== \"user\") continue;\n\n    for (const part of message.content ?? []) {\n      if (part.type === \"audio\" && \"audio\" in part) {\n        const audioPart = (part as unknown as { type: \"audio\"; audio: string | { data: string; format: string } }).audio;\n        const raw = typeof audioPart === \"string\" ? audioPart : audioPart?.data;\n        if (raw) return raw.startsWith(\"data:\") ? raw.split(\",\")[1] : raw;\n      }\n    }\n  }\n\n  // Check the runtime store (from main composer's audio upload)\n  const pendingAudio = useChatRuntimeStore.getState().pendingAudioBase64;\n  return pendingAudio ?? undefined;\n}\n\nasync function resolveUseAdapter(\n  threadId: string | undefined,\n): Promise<boolean | undefined> {\n  if (!threadId) {\n    return undefined;\n  }\n  try {\n    const thread = await db.threads.get(threadId);\n    if (!thread?.pairId) {\n      return undefined;\n    }\n    // model1/model2 threads don't use the adapter toggle — each side\n    // loads its own model via /api/inference/load before generation.\n    if (thread.modelType === \"model1\" || thread.modelType === \"model2\") {\n      return undefined;\n    }\n    return thread.modelType === \"lora\";\n  } catch {\n    return undefined;\n  }\n}\n\n/** Wait for an in-progress model load to finish (polls store every 500ms). */\nfunction waitForModelReady(abortSignal?: AbortSignal): Promise<void> {\n  return new Promise((resolve, reject) => {\n    const check = () => {\n      if (abortSignal?.aborted) { reject(new Error(\"Aborted\")); return; }\n      if (!useChatRuntimeStore.getState().modelLoading) { resolve(); return; }\n      setTimeout(check, 500);\n    };\n    check();\n  });\n}\n\n/**\n * Auto-load the smallest downloaded model when the user tries to chat\n * without selecting one. Prefers GGUF (picks smallest cached variant),\n * falls back to smallest cached safetensors model.\n */\nasync function autoLoadSmallestModel(): Promise<boolean> {\n  const toastId = toast(\"Loading a model…\", {\n    description: \"Auto-selecting the smallest downloaded model.\",\n    duration: 5000,\n    closeButton: true,\n  });\n  try {\n    const [ggufRepos, modelRepos] = await Promise.all([\n      listCachedGguf().catch(() => []),\n      listCachedModels().catch(() => []),\n    ]);\n\n    // Try GGUF first: pick the repo with the smallest total size,\n    // then pick its smallest downloaded variant.\n    if (ggufRepos.length > 0) {\n      const sorted = [...ggufRepos].sort((a, b) => a.size_bytes - b.size_bytes);\n      for (const repo of sorted) {\n        try {\n          const variants = await listGgufVariants(repo.repo_id);\n          const downloaded = variants.variants\n            .filter((v) => v.downloaded)\n            .sort((a, b) => a.size_bytes - b.size_bytes);\n          if (downloaded.length > 0) {\n            const variant = downloaded[0];\n            const loadResp = await loadModel({\n              model_path: repo.repo_id,\n              hf_token: null,\n              max_seq_length: 4096,\n              load_in_4bit: true,\n              is_lora: false,\n              gguf_variant: variant.quant,\n              trust_remote_code: false,\n            });\n            useChatRuntimeStore.getState().setCheckpoint(repo.repo_id, variant.quant);\n            const store = useChatRuntimeStore.getState();\n            store.setParams({ ...store.params, maxTokens: loadResp.context_length ?? 131072 });\n            // Add model to store so the selector shows the name\n            const autoModel: ChatModelSummary = {\n              id: repo.repo_id,\n              name: loadResp.display_name ?? repo.repo_id,\n              isVision: loadResp.is_vision ?? false,\n              isLora: loadResp.is_lora ?? false,\n              isGguf: loadResp.is_gguf ?? false,\n              isAudio: loadResp.is_audio ?? false,\n              audioType: loadResp.audio_type ?? null,\n              hasAudioInput: loadResp.has_audio_input ?? false,\n            };\n            const existingModels = store.models;\n            if (!existingModels.some((m) => m.id === repo.repo_id)) {\n              store.setModels([...existingModels, autoModel]);\n            }\n            useChatRuntimeStore.setState({\n              ggufContextLength: loadResp.context_length ?? 131072,\n              supportsReasoning: loadResp.supports_reasoning ?? false,\n              reasoningEnabled: loadResp.supports_reasoning ?? false,\n              supportsTools: loadResp.supports_tools ?? false,\n              toolsEnabled: false,\n              codeToolsEnabled: false,\n              defaultChatTemplate: loadResp.chat_template ?? null,\n              chatTemplateOverride: null,\n            });\n            toast.success(`Loaded ${repo.repo_id} (${variant.quant})`, { id: toastId });\n            return true;\n          }\n        } catch {\n          continue;\n        }\n      }\n    }\n\n    // Fall back to safetensors models\n    if (modelRepos.length > 0) {\n      const sorted = [...modelRepos].sort((a, b) => a.size_bytes - b.size_bytes);\n      for (const repo of sorted) {\n        try {\n          const sfLoadResp = await loadModel({\n            model_path: repo.repo_id,\n            hf_token: null,\n            max_seq_length: 4096,\n            load_in_4bit: true,\n            is_lora: false,\n            gguf_variant: null,\n            trust_remote_code: false,\n          });\n          useChatRuntimeStore.getState().setCheckpoint(repo.repo_id);\n          const store = useChatRuntimeStore.getState();\n          store.setParams({ ...store.params, maxTokens: 4096 });\n          const sfModel: ChatModelSummary = {\n            id: repo.repo_id,\n            name: sfLoadResp.display_name ?? repo.repo_id,\n            isVision: sfLoadResp.is_vision ?? false,\n            isLora: sfLoadResp.is_lora ?? false,\n            isGguf: sfLoadResp.is_gguf ?? false,\n          };\n          if (!store.models.some((m) => m.id === repo.repo_id)) {\n            store.setModels([...store.models, sfModel]);\n          }\n          toast.success(`Loaded ${repo.repo_id}`, { id: toastId });\n          return true;\n        } catch {\n          continue;\n        }\n      }\n    }\n\n    // No cached models found — try downloading a small default GGUF\n    toast(\"Downloading a small model…\", {\n      id: toastId,\n      description: \"No downloaded models found. Fetching Qwen3.5-4B (UD-Q4_K_XL).\",\n      duration: 30000,\n    });\n    try {\n      const loadResp = await loadModel({\n        model_path: \"unsloth/Qwen3.5-4B-GGUF\",\n        hf_token: null,\n        max_seq_length: 4096,\n        load_in_4bit: true,\n        is_lora: false,\n        gguf_variant: \"UD-Q4_K_XL\",\n        trust_remote_code: false,\n      });\n      useChatRuntimeStore.getState().setCheckpoint(\"unsloth/Qwen3.5-4B-GGUF\", \"UD-Q4_K_XL\");\n      const store = useChatRuntimeStore.getState();\n      store.setParams({ ...store.params, maxTokens: loadResp.context_length ?? 131072 });\n      const defaultModel: ChatModelSummary = {\n        id: \"unsloth/Qwen3.5-4B-GGUF\",\n        name: loadResp.display_name ?? \"Qwen3.5-4B-GGUF\",\n        isVision: loadResp.is_vision ?? false,\n        isLora: false,\n        isGguf: true,\n      };\n      if (!store.models.some((m) => m.id === \"unsloth/Qwen3.5-4B-GGUF\")) {\n        store.setModels([...store.models, defaultModel]);\n      }\n      useChatRuntimeStore.setState({\n        ggufContextLength: loadResp.context_length ?? 131072,\n        supportsReasoning: loadResp.supports_reasoning ?? false,\n        reasoningEnabled: loadResp.supports_reasoning ?? false,\n        supportsTools: loadResp.supports_tools ?? false,\n        toolsEnabled: false,\n        defaultChatTemplate: loadResp.chat_template ?? null,\n        chatTemplateOverride: null,\n      });\n      toast.success(\"Loaded Qwen3.5-4B (UD-Q4_K_XL)\", { id: toastId });\n      return true;\n    } catch {\n      toast.dismiss(toastId);\n      return false;\n    }\n  } catch {\n    toast.dismiss(toastId);\n    return false;\n  }\n}\n\nexport function createOpenAIStreamAdapter(): ChatModelAdapter {\n  return {\n    async *run({ messages, abortSignal, unstable_threadId }) {\n      const runtime = useChatRuntimeStore.getState();\n      const { params } = runtime;\n\n      // Wait for in-progress model load to finish before inferring\n      if (runtime.modelLoading) {\n        toast.info(\"Waiting for model to finish loading…\");\n        await waitForModelReady(abortSignal);\n      }\n\n      if (!useChatRuntimeStore.getState().params.checkpoint) {\n        // Auto-load the smallest downloaded model\n        const loaded = await autoLoadSmallestModel();\n        if (!loaded) {\n          toast.error(\"No model loaded\", {\n            description: \"Pick a model in the top bar, then retry.\",\n          });\n          throw new Error(\"Load a model first.\");\n        }\n      }\n\n      const {\n        supportsTools,\n        toolsEnabled,\n        codeToolsEnabled,\n      } = runtime;\n\n      const outboundMessages = messages\n        .map(toOpenAIMessage)\n        .filter((message): message is NonNullable<typeof message> =>\n          Boolean(message),\n        );\n\n      if (params.systemPrompt.trim()) {\n        outboundMessages.unshift({\n          role: \"system\",\n          content: params.systemPrompt.trim(),\n        });\n      }\n      const imageBase64 = findLatestUserImageBase64(messages);\n      const audioBase64 = findLatestUserAudioBase64(messages);\n      // Clear pending audio from store after extracting (consumed on send)\n      if (audioBase64) {\n        const audioName = runtime.pendingAudioName;\n        if (audioName) {\n          const lastUserMsg = [...messages].reverse().find((m) => m.role === \"user\");\n          if (lastUserMsg) sentAudioNames.set(lastUserMsg.id, audioName);\n        }\n        runtime.clearPendingAudio();\n      }\n      const useAdapter = await resolveUseAdapter(unstable_threadId);\n\n      // ── Audio model path (non-streaming) ─────────────────────\n      const activeModel = runtime.models.find(\n        (m) => m.id === params.checkpoint,\n      );\n      if (activeModel?.isAudio && !activeModel?.hasAudioInput) {\n        const threadKey = unstable_threadId || \"__default\";\n        runtime.setThreadRunning(threadKey, true);\n        try {\n          yield {\n            content: [{ type: \"text\" as const, text: \"Generating audio...\" }],\n          };\n\n          const result = await generateAudio(\n            {\n              model: params.checkpoint,\n              messages: outboundMessages,\n              stream: false,\n              temperature: params.temperature,\n              top_p: params.topP,\n              max_tokens: params.maxTokens,\n              top_k: params.topK,\n              min_p: params.minP,\n              repetition_penalty: params.repetitionPenalty,\n              presence_penalty: params.presencePenalty,\n              ...(useAdapter === undefined ? {} : { use_adapter: useAdapter }),\n            },\n            abortSignal,\n          );\n\n          const audioUrl = `data:audio/wav;base64,${result.audio.data}`;\n          yield {\n            content: [\n              {\n                type: \"text\" as const,\n                text: `<audio-player src=\"${audioUrl}\" />`,\n              },\n            ],\n          };\n        } catch (err) {\n          if (!abortSignal.aborted) {\n            toast.error(\"Audio generation failed\", {\n              description:\n                err instanceof Error ? err.message : \"Unknown error\",\n            });\n          }\n          throw err;\n        } finally {\n          runtime.setThreadRunning(threadKey, false);\n        }\n        return;\n      }\n\n      const threadKey = unstable_threadId || \"__default\";\n      let waitingFirstChunk = true;\n      let firstTokenSettled = false;\n      const streamStartTime = Date.now();\n      let firstTokenTime: number | undefined;\n      let totalChunks = 0;\n      let resolveFirstToken: (() => void) | null = null;\n      let rejectFirstToken: ((err: unknown) => void) | null = null;\n      const firstTokenPromise = new Promise<void>((resolve, reject) => {\n        resolveFirstToken = resolve;\n        rejectFirstToken = reject;\n      });\n      // Avoid unhandled rejections if toast.promise never attached.\n      void firstTokenPromise.catch(() => {});\n\n      function settleFirstTokenOk(): void {\n        if (firstTokenSettled) return;\n        firstTokenSettled = true;\n        resolveFirstToken?.();\n      }\n\n      function settleFirstTokenErr(err: unknown): void {\n        if (firstTokenSettled) return;\n        firstTokenSettled = true;\n        rejectFirstToken?.(err);\n      }\n\n      const warmupDelayMs = 450;\n      const warmupTimer = setTimeout(() => {\n        if (!waitingFirstChunk) return;\n        if (abortSignal.aborted) return;\n        runtime.setGeneratingStatus(\"waiting\");\n      }, warmupDelayMs);\n      runtime.setThreadRunning(threadKey, true);\n      let cumulativeText = \"\";\n      let reasoningStartAt: number | null = null;\n      let reasoningDuration = 0;\n      // Tool call content parts — accumulated and yielded cumulatively.\n      // result is set directly on the tool-call part when tool_end arrives.\n      const toolCallParts: ToolCallMessagePart[] = [];\n\n      try {\n        const { supportsReasoning, reasoningEnabled } = runtime;\n        const stream = streamChatCompletions(\n          {\n            model: params.checkpoint,\n            messages: outboundMessages,\n            stream: true,\n            temperature: params.temperature,\n            top_p: params.topP,\n            max_tokens: params.maxTokens,\n            top_k: params.topK,\n            min_p: params.minP,\n            repetition_penalty: params.repetitionPenalty,\n            presence_penalty: params.presencePenalty,\n            image_base64: imageBase64,\n            audio_base64: audioBase64,\n            ...(useAdapter === undefined ? {} : { use_adapter: useAdapter }),\n            ...(supportsReasoning ? { enable_thinking: reasoningEnabled } : {}),\n            ...(supportsTools && (toolsEnabled || codeToolsEnabled)\n              ? {\n                  enable_tools: true,\n                  enabled_tools: [\n                    ...(toolsEnabled ? [\"web_search\"] : []),\n                    ...(codeToolsEnabled ? [\"python\", \"terminal\"] : []),\n                  ],\n                  auto_heal_tool_calls: useChatRuntimeStore.getState().autoHealToolCalls,\n                  max_tool_calls_per_message: useChatRuntimeStore.getState().maxToolCallsPerMessage,\n                  tool_call_timeout: (() => {\n                    const mins = useChatRuntimeStore.getState().toolCallTimeout;\n                    return mins >= 9999 ? 9999 : mins * 60;\n                  })(),\n                  session_id: unstable_threadId || undefined,\n                }\n              : {}),\n          },\n          abortSignal,\n        );\n\n        for await (const chunk of stream) {\n          // Handle tool status events\n          const toolStatusText = (chunk as unknown as { _toolStatus?: string })._toolStatus;\n          if (toolStatusText !== undefined) {\n            runtime.setToolStatus(toolStatusText || null);\n            continue;\n          }\n\n          // Emit tool-call content parts for assistant-ui.\n          // On tool_start: add a new tool-call part (renders in \"running\" state).\n          // On tool_end: set result on the existing part (transitions to \"complete\").\n          const toolEvent = (chunk as unknown as { _toolEvent?: Record<string, unknown> })._toolEvent;\n          if (toolEvent !== undefined) {\n            if (toolEvent.type === \"tool_start\") {\n              const id = (toolEvent.tool_call_id as string) || `${toolEvent.tool_name}_${Date.now()}`;\n              const toolArgs = (toolEvent.arguments ?? {}) as ToolCallMessagePart[\"args\"];\n              toolCallParts.push({\n                type: \"tool-call\" as const,\n                toolCallId: id,\n                toolName: toolEvent.tool_name as string,\n                argsText: JSON.stringify(toolArgs),\n                args: toolArgs,\n              });\n            } else if (toolEvent.type === \"tool_end\") {\n              const id = (toolEvent.tool_call_id as string) ||\n                toolCallParts[toolCallParts.length - 1]?.toolCallId || \"\";\n              const idx = toolCallParts.findIndex((p) => p.toolCallId === id);\n              if (idx !== -1) {\n                toolCallParts[idx] = { ...toolCallParts[idx], result: toolEvent.result as string };\n              }\n            }\n            // Yield cumulative state so tool UI updates (tools first, text after)\n            const textParts = parseAssistantContent(cumulativeText);\n            yield {\n              content: [...toolCallParts, ...textParts],\n              metadata: {\n                timing: buildTiming(streamStartTime, totalChunks, firstTokenTime),\n                custom: { reasoningDuration },\n              },\n            };\n            continue;\n          }\n\n          totalChunks += 1;\n          const delta = chunk.choices?.[0]?.delta?.content;\n          if (!delta) {\n            continue;\n          }\n          if (waitingFirstChunk) {\n            waitingFirstChunk = false;\n            firstTokenTime = Date.now() - streamStartTime;\n            settleFirstTokenOk();\n            runtime.setGeneratingStatus(null);\n          }\n\n          cumulativeText += delta;\n          const parts = parseAssistantContent(cumulativeText);\n\n          if (parts.some((part) => part.type === \"reasoning\") && !reasoningStartAt) {\n            reasoningStartAt = Date.now();\n          }\n          if (hasClosedThinkTag(cumulativeText) && reasoningStartAt && !reasoningDuration) {\n            reasoningDuration = Math.round((Date.now() - reasoningStartAt) / 1000);\n          }\n\n          if (parts.length > 0 || toolCallParts.length > 0) {\n            yield {\n              content: [...toolCallParts, ...parts],\n              metadata: {\n                timing: buildTiming(\n                  streamStartTime,\n                  totalChunks,\n                  firstTokenTime,\n                ),\n                custom: { reasoningDuration },\n              },\n            };\n          }\n        }\n        settleFirstTokenOk();\n\n        // Extract source parts from completed web_search tool calls\n        const sourceParts = toolCallParts.flatMap((tc) => {\n          if (tc.toolName !== \"web_search\" || !tc.result) return [];\n          return parseSourcesFromResult(typeof tc.result === \"string\" ? tc.result : \"\");\n        });\n\n        yield {\n          content: [\n            ...toolCallParts,\n            ...parseAssistantContent(cumulativeText),\n            ...sourceParts,\n          ],\n          metadata: {\n            timing: buildTiming(\n              streamStartTime,\n              totalChunks,\n              firstTokenTime,\n              Date.now() - streamStartTime,\n              estimateTokenCount(cumulativeText),\n              toolCallParts.length,\n            ),\n            custom: { reasoningDuration },\n          },\n        };\n      } catch (err) {\n        settleFirstTokenErr(err instanceof Error ? err : new Error(\"Generation failed\"));\n        if (!abortSignal.aborted) {\n          toast.error(\"Generation failed\", {\n            description: err instanceof Error ? err.message : \"Unknown error\",\n          });\n        }\n        throw err;\n      } finally {\n        runtime.setGeneratingStatus(null);\n        runtime.setToolStatus(null);\n        clearTimeout(warmupTimer);\n        if (waitingFirstChunk) {\n          if (!firstTokenSettled) {\n            if (abortSignal.aborted) {\n              settleFirstTokenErr(new Error(\"Cancelled\"));\n            } else {\n              settleFirstTokenErr(new Error(\"No tokens received\"));\n            }\n          } else {\n            settleFirstTokenOk();\n          }\n        }\n        runtime.setThreadRunning(threadKey, false);\n      }\n    },\n  };\n}\n"
  },
  {
    "path": "studio/frontend/src/features/chat/api/chat-api.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { authFetch } from \"@/features/auth\";\nimport type {\n  AudioGenerationResponse,\n  GgufVariantsResponse,\n  InferenceStatusResponse,\n  ListLorasResponse,\n  ListModelsResponse,\n  LoadModelRequest,\n  LoadModelResponse,\n  OpenAIChatChunk,\n  OpenAIChatCompletionsRequest,\n  UnloadModelRequest,\n  ValidateModelResponse,\n} from \"../types/api\";\n\nfunction parseErrorText(status: number, body: unknown): string {\n  if (\n    body &&\n    typeof body === \"object\" &&\n    \"detail\" in body &&\n    typeof body.detail === \"string\"\n  ) {\n    return body.detail;\n  }\n  if (\n    body &&\n    typeof body === \"object\" &&\n    \"message\" in body &&\n    typeof body.message === \"string\"\n  ) {\n    return body.message;\n  }\n  return `Request failed (${status})`;\n}\n\nasync function parseJsonOrThrow<T>(response: Response): Promise<T> {\n  const body = await response.json().catch(() => null);\n  if (!response.ok) {\n    throw new Error(parseErrorText(response.status, body));\n  }\n  return body as T;\n}\n\nexport async function listModels(): Promise<ListModelsResponse> {\n  const response = await authFetch(\"/api/models/list\");\n  return parseJsonOrThrow<ListModelsResponse>(response);\n}\n\nexport async function listLoras(outputsDir?: string): Promise<ListLorasResponse> {\n  const query = outputsDir\n    ? `?${new URLSearchParams({ outputs_dir: outputsDir }).toString()}`\n    : \"\";\n  const response = await authFetch(`/api/models/loras${query}`);\n  return parseJsonOrThrow<ListLorasResponse>(response);\n}\n\nexport async function getInferenceStatus(): Promise<InferenceStatusResponse> {\n  const response = await authFetch(\"/api/inference/status\");\n  return parseJsonOrThrow<InferenceStatusResponse>(response);\n}\n\nexport async function loadModel(\n  payload: LoadModelRequest,\n): Promise<LoadModelResponse> {\n  const response = await authFetch(\"/api/inference/load\", {\n    method: \"POST\",\n    headers: { \"Content-Type\": \"application/json\" },\n    body: JSON.stringify(payload),\n  });\n  return parseJsonOrThrow<LoadModelResponse>(response);\n}\n\nexport async function validateModel(\n  payload: LoadModelRequest,\n): Promise<ValidateModelResponse> {\n  const response = await authFetch(\"/api/inference/validate\", {\n    method: \"POST\",\n    headers: { \"Content-Type\": \"application/json\" },\n    body: JSON.stringify({\n      model_path: payload.model_path,\n      hf_token: payload.hf_token,\n      gguf_variant: payload.gguf_variant ?? null,\n    }),\n  });\n  return parseJsonOrThrow<ValidateModelResponse>(response);\n}\n\nexport async function unloadModel(payload: UnloadModelRequest): Promise<void> {\n  const response = await authFetch(\"/api/inference/unload\", {\n    method: \"POST\",\n    headers: { \"Content-Type\": \"application/json\" },\n    body: JSON.stringify(payload),\n  });\n  await parseJsonOrThrow<unknown>(response);\n}\n\nexport interface CachedGgufRepo {\n  repo_id: string;\n  size_bytes: number;\n  cache_path: string;\n}\n\nexport async function getGgufDownloadProgress(\n  repoId: string,\n  variant: string,\n  expectedBytes: number,\n): Promise<{ downloaded_bytes: number; expected_bytes: number; progress: number }> {\n  const params = new URLSearchParams({\n    repo_id: repoId,\n    variant,\n    expected_bytes: String(expectedBytes),\n  });\n  const response = await authFetch(`/api/models/gguf-download-progress?${params}`);\n  return parseJsonOrThrow(response);\n}\n\nexport async function getDownloadProgress(\n  repoId: string,\n): Promise<{ downloaded_bytes: number; expected_bytes: number; progress: number }> {\n  const params = new URLSearchParams({ repo_id: repoId });\n  const response = await authFetch(`/api/models/download-progress?${params}`);\n  return parseJsonOrThrow(response);\n}\n\nexport async function listCachedGguf(): Promise<CachedGgufRepo[]> {\n  const response = await authFetch(\"/api/models/cached-gguf\");\n  const data = await parseJsonOrThrow<{ cached: CachedGgufRepo[] }>(response);\n  return data.cached;\n}\n\nexport interface CachedModelRepo {\n  repo_id: string;\n  size_bytes: number;\n}\n\nexport async function listCachedModels(): Promise<CachedModelRepo[]> {\n  const response = await authFetch(\"/api/models/cached-models\");\n  const data = await parseJsonOrThrow<{ cached: CachedModelRepo[] }>(response);\n  return data.cached;\n}\n\nexport async function deleteCachedModel(repoId: string, variant?: string): Promise<void> {\n  const payload: Record<string, string> = { repo_id: repoId };\n  if (variant) payload.variant = variant;\n  const response = await authFetch(\"/api/models/delete-cached\", {\n    method: \"DELETE\",\n    headers: { \"Content-Type\": \"application/json\" },\n    body: JSON.stringify(payload),\n  });\n  await parseJsonOrThrow<unknown>(response);\n}\n\nexport async function listGgufVariants(\n  repoId: string,\n  hfToken?: string,\n): Promise<GgufVariantsResponse> {\n  const params = new URLSearchParams({ repo_id: repoId });\n  if (hfToken) params.set(\"hf_token\", hfToken);\n  const response = await authFetch(`/api/models/gguf-variants?${params}`);\n  return parseJsonOrThrow<GgufVariantsResponse>(response);\n}\n\nfunction parseSseEvent(rawEvent: string): string[] {\n  const dataLines: string[] = [];\n  for (const line of rawEvent.split(/\\r?\\n/)) {\n    if (line.startsWith(\"data:\")) {\n      dataLines.push(line.slice(5).trimStart());\n    }\n  }\n  return dataLines;\n}\n\nexport async function* streamChatCompletions(\n  payload: OpenAIChatCompletionsRequest,\n  signal: AbortSignal,\n): AsyncGenerator<OpenAIChatChunk> {\n  const response = await authFetch(\"/v1/chat/completions\", {\n    method: \"POST\",\n    headers: { \"Content-Type\": \"application/json\" },\n    body: JSON.stringify(payload),\n    signal,\n  });\n\n  if (!response.ok) {\n    const body = await response.json().catch(() => null);\n    throw new Error(parseErrorText(response.status, body));\n  }\n\n  if (!response.body) {\n    throw new Error(\"Stream response missing body\");\n  }\n\n  const reader = response.body.getReader();\n  const decoder = new TextDecoder();\n  let buffer = \"\";\n\n  while (true) {\n    const { done, value } = await reader.read();\n    if (done) {\n      break;\n    }\n\n    buffer += decoder.decode(value, { stream: true });\n\n    let separatorIndex = buffer.search(/\\r?\\n\\r?\\n/);\n    while (separatorIndex >= 0) {\n      const rawEvent = buffer.slice(0, separatorIndex);\n      const separatorLength = buffer[separatorIndex] === \"\\r\" ? 4 : 2;\n      buffer = buffer.slice(separatorIndex + separatorLength);\n\n      const dataLines = parseSseEvent(rawEvent);\n      if (dataLines.length === 0) {\n        separatorIndex = buffer.search(/\\r?\\n\\r?\\n/);\n        continue;\n      }\n\n      const dataText = dataLines.join(\"\\n\");\n      if (dataText === \"[DONE]\") {\n        return;\n      }\n\n      const parsed = JSON.parse(dataText) as\n        | OpenAIChatChunk\n        | { type?: string; content?: string; error?: { message?: string } };\n      if (\"error\" in parsed && parsed.error) {\n        throw new Error(parsed.error.message || \"Stream error\");\n      }\n      // Tool status events are custom SSE payloads, not OpenAI chunks\n      if (\"type\" in parsed && parsed.type === \"tool_status\") {\n        yield { _toolStatus: parsed.content ?? \"\" } as unknown as OpenAIChatChunk;\n        separatorIndex = buffer.search(/\\r?\\n\\r?\\n/);\n        continue;\n      }\n      // Tool start/end events carry full input/output for the tool outputs panel\n      if (\"type\" in parsed && (parsed.type === \"tool_start\" || parsed.type === \"tool_end\")) {\n        yield { _toolEvent: parsed } as unknown as OpenAIChatChunk;\n        separatorIndex = buffer.search(/\\r?\\n\\r?\\n/);\n        continue;\n      }\n      yield parsed as OpenAIChatChunk;\n      separatorIndex = buffer.search(/\\r?\\n\\r?\\n/);\n    }\n  }\n}\n\nexport async function generateAudio(\n  payload: OpenAIChatCompletionsRequest,\n  signal: AbortSignal,\n): Promise<AudioGenerationResponse> {\n  const response = await authFetch(\"/api/inference/chat/completions\", {\n    method: \"POST\",\n    headers: { \"Content-Type\": \"application/json\" },\n    body: JSON.stringify({ ...payload, stream: false }),\n    signal,\n  });\n\n  if (!response.ok) {\n    const body = await response.json().catch(() => null);\n    throw new Error(parseErrorText(response.status, body));\n  }\n\n  return (await response.json()) as AudioGenerationResponse;\n}\n"
  },
  {
    "path": "studio/frontend/src/features/chat/chat-page.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport {\n  type LoraModelOption,\n  type ModelOption,\n  ModelSelector,\n} from \"@/components/assistant-ui/model-selector\";\nimport { Thread } from \"@/components/assistant-ui/thread\";\nimport { Button } from \"@/components/ui/button\";\nimport { SidebarProvider, SidebarTrigger, useSidebar } from \"@/components/ui/sidebar\";\nimport {\n  Sheet,\n  SheetContent,\n  SheetDescription,\n  SheetHeader,\n  SheetTitle,\n} from \"@/components/ui/sheet\";\nimport { Tooltip, TooltipContent, TooltipTrigger } from \"@/components/ui/tooltip\";\nimport { cn } from \"@/lib/utils\";\nimport {\n  ColumnInsertIcon,\n  PencilEdit02Icon,\n  Settings04Icon,\n} from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport {\n  type CSSProperties,\n  type ReactElement,\n  type ReactNode,\n  memo,\n  useCallback,\n  useEffect,\n  useMemo,\n  useRef,\n  useState,\n} from \"react\";\nimport { toast } from \"sonner\";\nimport { GuidedTour, useGuidedTourController } from \"@/features/tour\";\nimport { ChatSettingsPanel } from \"./chat-settings-sheet\";\nimport { ModelLoadInlineStatus } from \"./components/model-load-status\";\nimport { db } from \"./db\";\nimport { useChatModelRuntime } from \"./hooks/use-chat-model-runtime\";\nimport {\n  clearTrainingCompareHandoff,\n  getTrainingCompareHandoff,\n} from \"./lib/training-compare-handoff\";\nimport { ChatRuntimeProvider } from \"./runtime-provider\";\nimport { useChatRuntimeStore } from \"./stores/chat-runtime-store\";\nimport {\n  type CompareHandle,\n  CompareHandlesProvider,\n  RegisterCompareHandle,\n  SharedComposer,\n} from \"./shared-composer\";\nimport { ThreadSidebar } from \"./thread-sidebar\";\nimport type { ChatView, MessageRecord } from \"./types\";\nimport { buildChatTourSteps } from \"./tour\";\n\ntype LoraCandidate = {\n  id: string;\n  baseModel: string;\n  updatedAt?: number;\n};\n\nfunction normalizeModelRef(value: string | null | undefined): string {\n  return value?.trim().toLowerCase() ?? \"\";\n}\n\nfunction pickBestLoraForBase(\n  loras: LoraCandidate[],\n  baseModel: string | null,\n): LoraCandidate | null {\n  if (loras.length === 0) return null;\n  const sorted = [...loras].sort(\n    (a, b) => (b.updatedAt ?? -1) - (a.updatedAt ?? -1),\n  );\n  const normalizedBase = normalizeModelRef(baseModel);\n  if (!normalizedBase) return sorted[0];\n\n  const exact = sorted.find(\n    (lora) => normalizeModelRef(lora.baseModel) === normalizedBase,\n  );\n  if (exact) return exact;\n\n  const partial = sorted.find((lora) => {\n    const normalizedLoraBase = normalizeModelRef(lora.baseModel);\n    if (!normalizedLoraBase) return false;\n    return (\n      normalizedLoraBase.includes(normalizedBase) ||\n      normalizedBase.includes(normalizedLoraBase)\n    );\n  });\n  return partial ?? sorted[0];\n}\n\nfunction messageHasImage(message: MessageRecord): boolean {\n  const contentParts = Array.isArray(message.content) ? message.content : [];\n  if (contentParts.some((part) => part.type === \"image\")) {\n    return true;\n  }\n  const attachments = Array.isArray(message.attachments) ? message.attachments : [];\n  for (const attachment of attachments) {\n    const parts = Array.isArray(attachment.content) ? attachment.content : [];\n    for (const part of parts as Array<{ type?: string }>) {\n      if (part?.type === \"image\") {\n        return true;\n      }\n    }\n  }\n  return false;\n}\n\nconst SingleContent = memo(function SingleContent({\n  threadId,\n  newThreadNonce,\n}: { threadId?: string; newThreadNonce?: string }): ReactElement {\n  return (\n    <ChatRuntimeProvider\n      modelType=\"base\"\n      initialThreadId={threadId}\n      newThreadNonce={newThreadNonce}\n    >\n      <div className=\"min-h-0 flex-1\">\n        <Thread />\n      </div>\n    </ChatRuntimeProvider>\n  );\n});\n\ntype CompareModelSelection = {\n  id: string;\n  isLora: boolean;\n  ggufVariant?: string;\n};\n\n/**\n * Detect if this is a LoRA base-vs-fine-tuned compare.\n * Returns true when the loaded checkpoint is a LoRA — in that case\n * we use the fast simultaneous base/lora adapter-toggle path.\n */\nfunction useIsLoraCompare(): boolean {\n  return useChatRuntimeStore((s) => {\n    const cp = s.params.checkpoint;\n    return cp ? s.loras.some((l) => l.id === cp) : false;\n  });\n}\n\nconst CompareContent = memo(function CompareContent({\n  pairId,\n  models,\n  loraModels,\n}: { pairId: string; models: ModelOption[]; loraModels: LoraModelOption[] }): ReactElement {\n  const isLoraCompare = useIsLoraCompare();\n\n  return isLoraCompare\n    ? <LoraCompareContent pairId={pairId} />\n    : <GeneralCompareContent pairId={pairId} models={models} loraModels={loraModels} />;\n});\n\n/** Fast path: same model, adapter on/off, simultaneous generation. */\nconst LoraCompareContent = memo(function LoraCompareContent({\n  pairId,\n}: { pairId: string }): ReactElement {\n  const handlesRef = useRef<Record<string, CompareHandle>>({});\n  const [baseThreadId, setBaseThreadId] = useState<string>();\n  const [loraThreadId, setLoraThreadId] = useState<string>();\n\n  useEffect(() => {\n    let isActive = true;\n    db.threads\n      .where(\"pairId\")\n      .equals(pairId)\n      .toArray()\n      .then((threads) => {\n        if (!isActive) return;\n        setBaseThreadId(threads.find((t) => t.modelType === \"base\")?.id);\n        setLoraThreadId(threads.find((t) => t.modelType === \"lora\")?.id);\n      });\n    return () => { isActive = false; };\n  }, [pairId]);\n\n  return (\n    <CompareHandlesProvider handlesRef={handlesRef}>\n      <div className=\"flex min-h-0 flex-1 flex-col\">\n        <div\n          data-tour=\"chat-compare-view\"\n          className=\"grid min-h-0 flex-1 grid-cols-1 px-0 md:grid-cols-2\"\n        >\n          <div className=\"flex min-h-0 flex-col\">\n            <div className=\"px-3 py-1.5\">\n              <span className=\"text-[10px] font-semibold uppercase tracking-wider text-muted-foreground\">\n                Base Model\n              </span>\n            </div>\n            <div className=\"min-h-0 flex-1\">\n              <ChatRuntimeProvider modelType=\"base\" pairId={pairId} initialThreadId={baseThreadId}>\n                <RegisterCompareHandle name=\"base\" />\n                <Thread hideComposer={true} hideWelcome={true} />\n              </ChatRuntimeProvider>\n            </div>\n          </div>\n          <div className=\"flex min-h-0 flex-col border-t border-border/60 md:border-t-0 md:border-l\">\n            <div className=\"px-3 py-1.5 text-start md:text-end\">\n              <span className=\"text-[10px] font-semibold uppercase tracking-wider text-primary\">\n                Fine-tuned (LoRA)\n              </span>\n            </div>\n            <div className=\"min-h-0 flex-1\">\n              <ChatRuntimeProvider modelType=\"lora\" pairId={pairId} initialThreadId={loraThreadId}>\n                <RegisterCompareHandle name=\"lora\" />\n                <Thread hideComposer={true} hideWelcome={true} />\n              </ChatRuntimeProvider>\n            </div>\n          </div>\n        </div>\n        <div className=\"mx-auto w-full max-w-4xl px-4 py-4\">\n          <SharedComposer handlesRef={handlesRef} />\n        </div>\n      </div>\n    </CompareHandlesProvider>\n  );\n});\n\n/** General path: any two models, sequential load → generate. */\nconst GeneralCompareContent = memo(function GeneralCompareContent({\n  pairId,\n  models,\n  loraModels,\n}: { pairId: string; models: ModelOption[]; loraModels: LoraModelOption[] }): ReactElement {\n  const handlesRef = useRef<Record<string, CompareHandle>>({});\n  const [model1ThreadId, setModel1ThreadId] = useState<string>();\n  const [model2ThreadId, setModel2ThreadId] = useState<string>();\n\n  const globalCheckpoint = useChatRuntimeStore((s) => s.params.checkpoint);\n  const globalGgufVariant = useChatRuntimeStore((s) => s.activeGgufVariant);\n  const [model1, setModel1] = useState<CompareModelSelection>({\n    id: globalCheckpoint || \"\",\n    isLora: false,\n    ggufVariant: globalGgufVariant ?? undefined,\n  });\n  const [model2, setModel2] = useState<CompareModelSelection>({ id: \"\", isLora: false });\n\n  useEffect(() => {\n    let isActive = true;\n    db.threads\n      .where(\"pairId\")\n      .equals(pairId)\n      .toArray()\n      .then((threads) => {\n        if (!isActive) return;\n        setModel1ThreadId(\n          threads.find((t) => t.modelType === \"model1\" || t.modelType === \"base\")?.id,\n        );\n        setModel2ThreadId(\n          threads.find((t) => t.modelType === \"model2\" || t.modelType === \"lora\")?.id,\n        );\n      });\n    return () => { isActive = false; };\n  }, [pairId]);\n\n  return (\n    <CompareHandlesProvider handlesRef={handlesRef}>\n      <div className=\"flex min-h-0 flex-1 flex-col\">\n        <div\n          data-tour=\"chat-compare-view\"\n          className=\"grid min-h-0 flex-1 grid-cols-1 px-0 md:grid-cols-2\"\n        >\n          <div className=\"flex min-h-0 flex-col\">\n            <div className=\"flex items-center gap-2 px-3 py-1.5\">\n              <span className=\"text-[10px] font-semibold uppercase tracking-wider text-muted-foreground\">\n                Model 1\n              </span>\n              <ModelSelector\n                models={models}\n                loraModels={loraModels}\n                value={model1.id}\n                onValueChange={(id, meta) => setModel1({ id, isLora: meta.isLora, ggufVariant: meta.ggufVariant })}\n                variant=\"ghost\"\n                size=\"sm\"\n                className=\"max-w-[50%]\"\n              />\n            </div>\n            <div className=\"min-h-0 flex-1\">\n              <ChatRuntimeProvider\n                modelType=\"model1\"\n                pairId={pairId}\n                initialThreadId={model1ThreadId}\n              >\n                <RegisterCompareHandle name=\"model1\" />\n                <Thread hideComposer={true} hideWelcome={true} />\n              </ChatRuntimeProvider>\n            </div>\n          </div>\n          <div className=\"flex min-h-0 flex-col border-t border-border/60 md:border-t-0 md:border-l\">\n            <div className=\"flex items-center gap-2 px-3 py-1.5 md:justify-end\">\n              <span className=\"text-[10px] font-semibold uppercase tracking-wider text-primary\">\n                Model 2\n              </span>\n              <ModelSelector\n                models={models}\n                loraModels={loraModels}\n                value={model2.id}\n                onValueChange={(id, meta) => setModel2({ id, isLora: meta.isLora, ggufVariant: meta.ggufVariant })}\n                variant=\"ghost\"\n                size=\"sm\"\n                className=\"max-w-[50%]\"\n              />\n            </div>\n            <div className=\"min-h-0 flex-1\">\n              <ChatRuntimeProvider\n                modelType=\"model2\"\n                pairId={pairId}\n                initialThreadId={model2ThreadId}\n              >\n                <RegisterCompareHandle name=\"model2\" />\n                <Thread hideComposer={true} hideWelcome={true} />\n              </ChatRuntimeProvider>\n            </div>\n          </div>\n        </div>\n        <div className=\"mx-auto w-full max-w-4xl px-4 py-4\">\n          <SharedComposer handlesRef={handlesRef} model1={model1} model2={model2} />\n        </div>\n      </div>\n    </CompareHandlesProvider>\n  );\n});\n\nfunction InlineSidebar({\n  children,\n  side = \"left\",\n}: {\n  children: ReactNode;\n  side?: \"left\" | \"right\";\n}) {\n  const { state, isMobile, openMobile, setOpenMobile } = useSidebar();\n  const collapsed = state === \"collapsed\";\n\n  if (isMobile) {\n    return (\n      <Sheet open={openMobile} onOpenChange={setOpenMobile}>\n        <SheetContent side={side} className=\"w-[18rem] p-0\">\n          <SheetHeader className=\"sr-only\">\n            <SheetTitle>Chat sidebar</SheetTitle>\n            <SheetDescription>Chat threads and actions</SheetDescription>\n          </SheetHeader>\n          <div className=\"h-full overflow-auto\">{children}</div>\n        </SheetContent>\n      </Sheet>\n    );\n  }\n\n  return (\n    <div\n      className=\"group shrink-0 h-full pb-3.5\"\n      data-state={state}\n      data-collapsible={collapsed ? \"offcanvas\" : \"\"}\n      data-side={side}\n    >\n      <aside\n        data-sidebar=\"sidebar\"\n        className={cn(\n          \"bg-muted/70 text-sidebar-foreground h-full overflow-hidden rounded-2xl corner-squircle transition-[width] duration-200 ease-linear\",\n          !collapsed &&\n          side === \"right\" && \"border-l border-sidebar-border/70\",\n          collapsed ? \"w-0\" : \"w-(--sidebar-width)\",\n        )}\n      >\n        <div className=\"flex h-full w-(--sidebar-width) flex-col\">\n          {children}\n        </div>\n      </aside>\n    </div>\n  );\n}\n\nfunction TopBarActions({\n  onNewThread,\n  onNewCompare,\n  showCompare,\n}: { onNewThread: () => void; onNewCompare: () => void; showCompare: boolean }) {\n  const { state } = useSidebar();\n  if (state !== \"collapsed\") {\n    return null;\n  }\n  return (\n    <>\n      <Tooltip>\n        <TooltipTrigger asChild={true}>\n          <Button variant=\"ghost\" size=\"icon-sm\" onClick={onNewThread}>\n            <HugeiconsIcon icon={PencilEdit02Icon} strokeWidth={2} />\n          </Button>\n        </TooltipTrigger>\n        <TooltipContent side=\"bottom\">New Chat</TooltipContent>\n      </Tooltip>\n      {showCompare ? (\n        <Tooltip>\n          <TooltipTrigger asChild={true}>\n            <Button variant=\"ghost\" size=\"icon-sm\" onClick={onNewCompare}>\n              <HugeiconsIcon icon={ColumnInsertIcon} strokeWidth={2} />\n            </Button>\n          </TooltipTrigger>\n          <TooltipContent side=\"bottom\">Compare</TooltipContent>\n        </Tooltip>\n      ) : null}\n    </>\n  );\n}\n\nexport function ChatPage(): ReactElement {\n  const [view, setView] = useState<ChatView>({\n    mode: \"single\",\n    newThreadNonce: crypto.randomUUID(),\n  });\n  const [settingsOpen, setSettingsOpen] = useState(false);\n  const [modelSelectorOpen, setModelSelectorOpen] = useState(false);\n  const [modelSelectorLocked, setModelSelectorLocked] = useState(false);\n  const [sidebarOpen, setSidebarOpen] = useState(true);\n  const [viewBeforeCompare, setViewBeforeCompare] = useState<ChatView | null>(\n    null,\n  );\n  const inferenceParams = useChatRuntimeStore((state) => state.params);\n  const setInferenceParams = useChatRuntimeStore((state) => state.setParams);\n  const activeGgufVariant = useChatRuntimeStore((state) => state.activeGgufVariant);\n  const autoTitle = useChatRuntimeStore((state) => state.autoTitle);\n  const setAutoTitle = useChatRuntimeStore((state) => state.setAutoTitle);\n  const modelsFromStore = useChatRuntimeStore((state) => state.models);\n  const lorasFromStore = useChatRuntimeStore((state) => state.loras);\n  const modelsError = useChatRuntimeStore((state) => state.modelsError);\n  const activeThreadId = useChatRuntimeStore((state) => state.activeThreadId);\n  const {\n    refresh,\n    selectModel,\n    ejectModel,\n    cancelLoading,\n    loadingModel,\n    loadProgress,\n    loadToastDismissed,\n  } =\n    useChatModelRuntime();\n  const refreshRef = useRef(refresh);\n  const selectModelRef = useRef(selectModel);\n\n  useEffect(() => {\n    refreshRef.current = refresh;\n    selectModelRef.current = selectModel;\n  }, [refresh, selectModel]);\n  const canCompare = useMemo(() => {\n    return Boolean(inferenceParams.checkpoint);\n  }, [inferenceParams.checkpoint]);\n\n  const handleCheckpointChange = useCallback(\n    (value: string, meta?: { isLora: boolean; ggufVariant?: string; isDownloaded?: boolean; expectedBytes?: number }) => {\n      const store = useChatRuntimeStore.getState();\n      const currentCheckpoint = store.params.checkpoint;\n      const currentVariant = store.activeGgufVariant;\n      if (!value || (value === currentCheckpoint && (meta?.ggufVariant ?? null) === (currentVariant ?? null))) return;\n      void (async () => {\n        let showImageCompatibilityWarning = false;\n        if (view.mode === \"single\" && activeThreadId) {\n          const thread = await db.threads.get(activeThreadId);\n          if (thread?.modelId && thread.modelId !== value) {\n            const messages = await db.messages\n              .where(\"threadId\")\n              .equals(activeThreadId)\n              .toArray();\n            if (messages.length > 0) {\n              const hasImage = messages.some(messageHasImage);\n              const targetModel = modelsFromStore.find((model) => model.id === value);\n              showImageCompatibilityWarning =\n                hasImage && targetModel?.isVision === false;\n            }\n          }\n        }\n\n        if (showImageCompatibilityWarning) {\n          toast.warning(\"Selected model may not handle earlier images\", {\n            description:\n              \"This chat already includes images. Text-only models can ignore them or fail on follow-up replies.\",\n            duration: 6000,\n          });\n        }\n        await selectModel({\n          id: value,\n          isLora: meta?.isLora,\n          ggufVariant: meta?.ggufVariant,\n          isDownloaded: meta?.isDownloaded,\n          expectedBytes: meta?.expectedBytes,\n        });\n      })();\n    },\n    [activeThreadId, modelsFromStore, selectModel, view],\n  );\n  const handleEject = useCallback(() => {\n    void ejectModel();\n  }, [ejectModel]);\n  const handleNewThread = useCallback(\n    () => {\n      useChatRuntimeStore.getState().setActiveThreadId(null);\n      setView({ mode: \"single\", newThreadNonce: crypto.randomUUID() });\n    },\n    [],\n  );\n  const handleNewCompare = useCallback(\n    () => setView({ mode: \"compare\", pairId: crypto.randomUUID() }),\n    [],\n  );\n\n  const openModelSelector = useCallback(() => {\n    setModelSelectorLocked(true);\n    setModelSelectorOpen(true);\n  }, []);\n\n  const closeModelSelector = useCallback(() => {\n    setModelSelectorLocked(false);\n    setModelSelectorOpen(false);\n  }, []);\n\n  const handleModelSelectorOpenChange = useCallback(\n    (open: boolean) => {\n      if (!open && modelSelectorLocked) return;\n      setModelSelectorOpen(open);\n    },\n    [modelSelectorLocked],\n  );\n  const openSettings = useCallback(() => setSettingsOpen(true), []);\n  const closeSettings = useCallback(() => setSettingsOpen(false), []);\n  const openSidebar = useCallback(() => setSidebarOpen(true), []);\n\n  const enterCompare = useCallback(() => {\n    setViewBeforeCompare((prev) => prev ?? view);\n    setView({ mode: \"compare\", pairId: crypto.randomUUID() });\n  }, [view]);\n\n  const exitCompare = useCallback(() => {\n    if (!viewBeforeCompare) return;\n    setView(viewBeforeCompare);\n    setViewBeforeCompare(null);\n  }, [viewBeforeCompare]);\n\n  const handleThreadSelect = useCallback(\n    (nextView: ChatView) => {\n      setView(nextView);\n    },\n    [],\n  );\n\n  const models = useMemo<ModelOption[]>(\n    () =>\n      modelsFromStore.map((model) => ({\n        id: model.id,\n        name: model.name,\n        description: model.description,\n      })),\n    [modelsFromStore],\n  );\n\n  const loraModels = useMemo<LoraModelOption[]>(\n    () =>\n      lorasFromStore.map((lora) => ({\n        id: lora.id,\n        name: lora.name,\n        baseModel: lora.baseModel,\n        updatedAt: lora.updatedAt,\n        source: lora.source,\n        exportType: lora.exportType,\n      })),\n    [lorasFromStore],\n  );\n\n  useEffect(() => {\n    if (getTrainingCompareHandoff()) return;\n    void refresh();\n  }, [refresh]);\n\n  useEffect(() => {\n    const handoff = getTrainingCompareHandoff();\n    if (!handoff) return;\n    console.info(\"[chat-handoff] received\", handoff);\n    function clearHandoff(): void {\n      clearTrainingCompareHandoff();\n    }\n\n    let canceled = false;\n    void (async () => {\n      try {\n        console.info(\"[chat-handoff] refreshing models+loras\");\n        await refreshRef.current();\n        if (canceled) return;\n\n        const state = useChatRuntimeStore.getState();\n        const targetLora = pickBestLoraForBase(state.loras, handoff.baseModel);\n        if (targetLora) {\n          console.info(\"[chat-handoff] loading lora\", {\n            id: targetLora.id,\n            baseModel: targetLora.baseModel,\n          });\n          await selectModelRef.current({ id: targetLora.id, isLora: true });\n          if (canceled) return;\n          setView({ mode: \"compare\", pairId: crypto.randomUUID() });\n          clearHandoff();\n          console.info(\"[chat-handoff] loaded lora + opened compare\");\n          return;\n        }\n\n        if (\n          handoff.baseModel &&\n          state.models.some((model) => model.id === handoff.baseModel)\n        ) {\n          console.info(\"[chat-handoff] no lora match, loading base\", {\n            id: handoff.baseModel,\n          });\n          await selectModelRef.current({ id: handoff.baseModel, isLora: false });\n          if (canceled) return;\n        } else {\n          console.warn(\"[chat-handoff] no lora/base match found\", {\n            requestedBaseModel: handoff.baseModel,\n            loraCount: state.loras.length,\n            modelCount: state.models.length,\n          });\n        }\n        clearHandoff();\n        console.info(\"[chat-handoff] completed\");\n      } catch (error) {\n        console.error(\"[chat-handoff] failed\", error);\n        clearHandoff();\n      }\n    })();\n\n    return () => {\n      canceled = true;\n    };\n  }, []);\n\n  const tourSteps = useMemo(\n    () =>\n      buildChatTourSteps({\n        canCompare,\n        openModelSelector,\n        closeModelSelector,\n        openSettings,\n        closeSettings,\n        openSidebar,\n        enterCompare,\n        exitCompare,\n      }),\n    [\n      canCompare,\n      closeModelSelector,\n      closeSettings,\n      enterCompare,\n      exitCompare,\n      openModelSelector,\n      openSettings,\n      openSidebar,\n    ],\n  );\n\n  const tour = useGuidedTourController({\n    id: \"chat\",\n    steps: tourSteps,\n  });\n\n  useEffect(() => {\n    if (tour.open) return;\n    if (!modelSelectorLocked) return;\n    const timeoutId = window.setTimeout(() => {\n      setModelSelectorLocked(false);\n      setModelSelectorOpen(false);\n    }, 0);\n    return () => window.clearTimeout(timeoutId);\n  }, [modelSelectorLocked, tour.open]);\n\n  return (\n    <div className=\"h-[calc(100dvh-4rem)] bg-background overflow-hidden\">\n      <GuidedTour {...tour.tourProps} />\n      <SidebarProvider\n        defaultOpen={true}\n        open={sidebarOpen}\n        onOpenChange={setSidebarOpen}\n        className=\"!min-h-0 h-full w-full max-w-7xl mx-auto px-2 sm:px-4\"\n        style={\n          {\n            \"--sidebar-width\": \"14rem\",\n            \"--sidebar-width-icon\": \"3rem\",\n          } as CSSProperties\n        }\n      >\n        <InlineSidebar>\n          <ThreadSidebar\n            view={view}\n            onSelect={handleThreadSelect}\n            onNewThread={handleNewThread}\n            onNewCompare={handleNewCompare}\n            showCompare={canCompare}\n          />\n        </InlineSidebar>\n\n        <div className=\"flex min-h-0 min-w-0 flex-1 flex-col\">\n          <div className=\"flex h-11 shrink-0 items-center px-1.5 sm:px-2\">\n            <div className=\"flex items-center gap-1\">\n              <SidebarTrigger />\n              <TopBarActions\n                onNewThread={handleNewThread}\n                onNewCompare={handleNewCompare}\n                showCompare={canCompare}\n              />\n              <ModelSelector\n                models={models}\n                loraModels={loraModels}\n                value={inferenceParams.checkpoint}\n                activeGgufVariant={activeGgufVariant}\n                onValueChange={handleCheckpointChange}\n                onEject={handleEject}\n                variant=\"ghost\"\n                open={modelSelectorOpen}\n                onOpenChange={handleModelSelectorOpenChange}\n                triggerDataTour=\"chat-model-selector\"\n                contentDataTour=\"chat-model-selector-popover\"\n                className=\"max-w-[62vw] sm:max-w-none\"\n              />\n              {loadingModel && loadToastDismissed ? (\n                <ModelLoadInlineStatus\n                  label={\n                    loadProgress?.phase === \"starting\"\n                      ? \"Starting model…\"\n                      : loadingModel.isDownloaded\n                        ? \"Loading model…\"\n                        : \"Downloading model…\"\n                  }\n                  title={loadingModel.isDownloaded\n                    ? `Loading ${loadingModel.displayName} from cache.`\n                    : `Loading ${loadingModel.displayName}. This may include downloading.`}\n                  progressPercent={loadProgress?.percent}\n                  progressLabel={loadProgress?.label}\n                  onStop={cancelLoading}\n                />\n              ) : null}\n            </div>\n            {modelsError && (\n              <div className=\"ml-2 text-xs text-destructive truncate max-w-[28rem]\">\n                {modelsError}\n              </div>\n            )}\n            <div className=\"flex-1\" />\n            <button\n              type=\"button\"\n              onClick={() => setSettingsOpen((o) => !o)}\n              className=\"flex h-9 w-9 items-center justify-center rounded-md text-muted-foreground transition-colors hover:bg-accent hover:text-foreground\"\n              title=\"Inference settings\"\n              data-tour=\"chat-settings\"\n            >\n              <HugeiconsIcon icon={Settings04Icon} className=\"size-5\" />\n            </button>\n          </div>\n\n          {view.mode === \"single\" ? (\n            <SingleContent\n              key={view.threadId ?? view.newThreadNonce ?? \"new\"}\n              threadId={view.threadId}\n              newThreadNonce={view.newThreadNonce}\n            />\n          ) : (\n            <CompareContent key={view.pairId} pairId={view.pairId} models={models} loraModels={loraModels} />\n          )}\n        </div>\n\n        <ChatSettingsPanel\n          open={settingsOpen}\n          onOpenChange={setSettingsOpen}\n          params={inferenceParams}\n          onParamsChange={setInferenceParams}\n          autoTitle={autoTitle}\n          onAutoTitleChange={setAutoTitle}\n          onReloadModel={() => {\n            const state = useChatRuntimeStore.getState();\n            if (state.params.checkpoint) {\n              selectModel({\n                id: state.params.checkpoint,\n                ggufVariant: state.activeGgufVariant ?? undefined,\n                forceReload: true,\n                isDownloaded: true,\n                loadingDescription: \"Reloading with updated chat template.\",\n              });\n            }\n          }}\n        />\n      </SidebarProvider>\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/chat/chat-settings-sheet.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport {\n  Select,\n  SelectContent,\n  SelectItem,\n  SelectTrigger,\n  SelectValue,\n} from \"@/components/ui/select\";\nimport { Slider } from \"@/components/ui/slider\";\nimport { Textarea } from \"@/components/ui/textarea\";\nimport {\n  ArrowDown01Icon,\n  CodeIcon,\n  Delete02Icon,\n  FloppyDiskIcon,\n  PencilEdit01Icon,\n  Settings02Icon,\n  SlidersHorizontalIcon,\n} from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport { AnimatePresence, motion } from \"motion/react\";\nimport {\n  Sheet,\n  SheetContent,\n  SheetDescription,\n  SheetHeader,\n  SheetTitle,\n} from \"@/components/ui/sheet\";\nimport { useIsMobile } from \"@/hooks/use-mobile\";\nimport type { ReactNode } from \"react\";\nimport { useState } from \"react\";\nimport {\n  DEFAULT_INFERENCE_PARAMS,\n  type InferenceParams,\n} from \"./types/runtime\";\nimport { useChatRuntimeStore } from \"./stores/chat-runtime-store\";\nimport { Switch } from \"@/components/ui/switch\";\n\nexport const defaultInferenceParams = DEFAULT_INFERENCE_PARAMS;\nexport type { InferenceParams } from \"./types/runtime\";\n\nexport interface Preset {\n  name: string;\n  params: InferenceParams;\n}\n\nconst BUILTIN_PRESETS: Preset[] = [\n  { name: \"Default\", params: { ...defaultInferenceParams } },\n  {\n    name: \"Creative\",\n    params: {\n      ...defaultInferenceParams,\n      temperature: 1.5,\n      topP: 1.0,\n      topK: 0,\n      minP: 0.1,\n      repetitionPenalty: 1.0,\n    },\n  },\n  {\n    name: \"Precise\",\n    params: {\n      ...defaultInferenceParams,\n      temperature: 0.1,\n      topP: 0.95,\n      topK: 80,\n      minP: 0.01,\n      repetitionPenalty: 1.0,\n    },\n  },\n];\n\nfunction ParamSlider({\n  label,\n  value,\n  min,\n  max,\n  step,\n  onChange,\n  displayValue,\n}: {\n  label: string;\n  value: number;\n  min: number;\n  max: number;\n  step: number;\n  onChange: (v: number) => void;\n  displayValue?: string;\n}) {\n  return (\n    <div className=\"space-y-2\">\n      <div className=\"flex items-center justify-between\">\n        <span className=\"text-xs font-medium\">{label}</span>\n        <span className=\"text-xs tabular-nums text-muted-foreground\">\n          {displayValue ?? value}\n        </span>\n      </div>\n      <Slider\n        min={min}\n        max={max}\n        step={step}\n        value={[value]}\n        onValueChange={([v]) => onChange(v)}\n      />\n    </div>\n  );\n}\n\nfunction CollapsibleSection({\n  icon,\n  label,\n  children,\n  defaultOpen = false,\n}: {\n  icon: Parameters<typeof HugeiconsIcon>[0][\"icon\"];\n  label: string;\n  children?: ReactNode;\n  defaultOpen?: boolean;\n}) {\n  const [open, setOpen] = useState(defaultOpen);\n\n  return (\n    <div>\n      <button\n        type=\"button\"\n        onClick={() => setOpen(!open)}\n        className=\"flex w-full items-center corner-squircle gap-2.5 rounded-md px-2 py-2 text-sm transition-colors hover:bg-accent\"\n      >\n        <HugeiconsIcon icon={icon} className=\"size-4 text-muted-foreground\" />\n        <span className=\"flex-1 text-left font-medium\">{label}</span>\n        <motion.div\n          animate={{ rotate: open ? 180 : 0 }}\n          transition={{ duration: 0.15 }}\n        >\n          <HugeiconsIcon\n            icon={ArrowDown01Icon}\n            className=\"size-3.5 text-muted-foreground\"\n          />\n        </motion.div>\n      </button>\n      <AnimatePresence initial={false}>\n        {open && (\n          <motion.div\n            initial={{ height: 0, opacity: 0 }}\n            animate={{ height: \"auto\", opacity: 1 }}\n            exit={{ height: 0, opacity: 0 }}\n            transition={{ duration: 0.2, ease: \"easeInOut\" }}\n            className=\"overflow-hidden\"\n          >\n            <div className=\"px-2 pb-3 pt-1\">{children}</div>\n          </motion.div>\n        )}\n      </AnimatePresence>\n    </div>\n  );\n}\n\ninterface ChatSettingsPanelProps {\n  open: boolean;\n  onOpenChange?: (open: boolean) => void;\n  params: InferenceParams;\n  onParamsChange: (params: InferenceParams) => void;\n  autoTitle: boolean;\n  onAutoTitleChange: (enabled: boolean) => void;\n  onReloadModel?: () => void;\n}\n\nexport function ChatSettingsPanel({\n  open,\n  onOpenChange,\n  params,\n  onParamsChange,\n  autoTitle,\n  onAutoTitleChange,\n  onReloadModel,\n}: ChatSettingsPanelProps) {\n  const isMobile = useIsMobile();\n  const isGguf = useChatRuntimeStore((s) => s.activeGgufVariant) != null;\n  const ggufContextLength = useChatRuntimeStore((s) => s.ggufContextLength);\n  const kvCacheDtype = useChatRuntimeStore((s) => s.kvCacheDtype);\n  const setKvCacheDtype = useChatRuntimeStore((s) => s.setKvCacheDtype);\n  const [presets, setPresets] = useState<Preset[]>(BUILTIN_PRESETS);\n  const [activePreset, setActivePreset] = useState(\"Default\");\n  const isBuiltinPreset = BUILTIN_PRESETS.some((p) => p.name === activePreset);\n\n  function set<K extends keyof InferenceParams>(key: K) {\n    return (v: InferenceParams[K]) => onParamsChange({ ...params, [key]: v });\n  }\n\n  function applyPreset(name: string) {\n    const p = presets.find((pr) => pr.name === name);\n    if (p) {\n      onParamsChange({\n        ...p.params,\n        systemPrompt: params.systemPrompt,\n        checkpoint: params.checkpoint,\n        trustRemoteCode: params.trustRemoteCode,\n      });\n      setActivePreset(name);\n    }\n  }\n\n  function savePreset() {\n    const name = prompt(\"Preset name:\");\n    if (!name?.trim()) {\n      return;\n    }\n    const trimmed = name.trim();\n    setPresets((prev) => [\n      ...prev.filter((p) => p.name !== trimmed),\n      { name: trimmed, params: { ...params } },\n    ]);\n    setActivePreset(trimmed);\n  }\n\n  function deletePreset(name: string) {\n    if (BUILTIN_PRESETS.some((p) => p.name === name)) {\n      return;\n    }\n    setPresets((prev) => prev.filter((p) => p.name !== name));\n    if (activePreset === name) {\n      setActivePreset(\"Default\");\n    }\n  }\n\n  const settingsContent = (\n    <>\n      <div className=\"flex items-center gap-2 px-4 py-3\">\n        <HugeiconsIcon\n          icon={PencilEdit01Icon}\n          className=\"size-4 text-muted-foreground/70\"\n        />\n        <span className=\"flex-1 text-base font-semibold tracking-tight\">\n          Configuration\n        </span>\n      </div>\n\n      <div className=\"flex-1 overflow-y-auto px-1.5\">\n        {/* mt-4 matches the Playground sidebar gap (SidebarHeader py-3 + SidebarGroup pt-1) */}\n        <div className=\"mt-4 px-2 pb-3\">\n            <div className=\"flex items-center gap-2\">\n              <Select value={activePreset} onValueChange={applyPreset}>\n                <SelectTrigger className=\"h-8 flex-1 corner-squircle text-xs\">\n                  <SelectValue />\n                </SelectTrigger>\n                <SelectContent>\n                  {presets.map((p) => (\n                    <SelectItem key={p.name} value={p.name}>\n                      {p.name}\n                    </SelectItem>\n                  ))}\n                </SelectContent>\n              </Select>\n              <button\n                type=\"button\"\n                onClick={savePreset}\n                className=\"flex h-8 items-center gap-1.5 rounded-md border px-2.5 text-xs text-muted-foreground transition-colors hover:bg-accent\"\n                title=\"Save preset\"\n              >\n                <HugeiconsIcon icon={FloppyDiskIcon} className=\"size-3.5\" />\n                Save\n              </button>\n              <button\n                type=\"button\"\n                onClick={() => deletePreset(activePreset)}\n                disabled={isBuiltinPreset}\n                className=\"flex h-8 items-center gap-1.5 rounded-md border px-2.5 text-xs text-muted-foreground transition-colors hover:bg-accent disabled:cursor-not-allowed disabled:opacity-50\"\n                title={\n                  isBuiltinPreset\n                    ? \"Built-in presets cannot be deleted\"\n                    : \"Delete selected preset\"\n                }\n              >\n                <HugeiconsIcon icon={Delete02Icon} className=\"size-3.5\" />\n                Delete\n              </button>\n            </div>\n          </div>\n\n          <div className=\"px-2 pb-4\">\n            <label\n              htmlFor=\"system-prompt\"\n              className=\"mb-1.5 block text-xs font-medium\"\n            >\n              System Prompt\n            </label>\n            <Textarea\n              id=\"system-prompt\"\n              value={params.systemPrompt}\n              onChange={(e) => set(\"systemPrompt\")(e.target.value)}\n              placeholder=\"You are a helpful assistant...\"\n              className=\"min-h-20 text-xs corner-squircle\"\n              rows={3}\n            />\n          </div>\n\n          <CollapsibleSection\n            icon={SlidersHorizontalIcon}\n            label=\"Sampling\"\n            defaultOpen={true}\n          >\n            <div className=\"flex flex-col gap-5\">\n              <ParamSlider\n                label=\"Temperature\"\n                value={params.temperature}\n                min={0}\n                max={2}\n                step={0.1}\n                onChange={set(\"temperature\")}\n              />\n              <ParamSlider\n                label=\"Top P\"\n                value={params.topP}\n                min={0}\n                max={1}\n                step={0.05}\n                onChange={set(\"topP\")}\n                displayValue={params.topP === 1 ? \"Off\" : undefined}\n              />\n              <ParamSlider\n                label=\"Top K\"\n                value={params.topK}\n                min={0}\n                max={100}\n                step={1}\n                onChange={set(\"topK\")}\n                displayValue={params.topK === 0 ? \"Off\" : undefined}\n              />\n              <ParamSlider\n                label=\"Min P\"\n                value={params.minP}\n                min={0}\n                max={1}\n                step={0.01}\n                onChange={set(\"minP\")}\n              />\n              <ParamSlider\n                label=\"Repetition Penalty\"\n                value={params.repetitionPenalty}\n                min={1}\n                max={2}\n                step={0.05}\n                onChange={set(\"repetitionPenalty\")}\n                displayValue={params.repetitionPenalty === 1 ? \"Off\" : undefined}\n              />\n              <ParamSlider\n                label=\"Presence Penalty\"\n                value={params.presencePenalty}\n                min={0}\n                max={2}\n                step={0.1}\n                onChange={set(\"presencePenalty\")}\n                displayValue={params.presencePenalty === 0 ? \"Off\" : undefined}\n              />\n              {!isGguf && (\n                <ParamSlider\n                  label=\"Max Seq Length\"\n                  value={params.maxSeqLength}\n                  min={128}\n                  max={32768}\n                  step={128}\n                  onChange={set(\"maxSeqLength\")}\n                />\n              )}\n              <ParamSlider\n                label=\"Max Tokens\"\n                value={params.maxTokens}\n                min={64}\n                max={isGguf && ggufContextLength ? ggufContextLength : 32768}\n                step={64}\n                onChange={set(\"maxTokens\")}\n                displayValue={\n                  isGguf && ggufContextLength && params.maxTokens >= ggufContextLength\n                    ? \"Max\"\n                    : undefined\n                }\n              />\n            </div>\n          </CollapsibleSection>\n\n          <CollapsibleSection icon={Settings02Icon} label=\"Settings\" defaultOpen={true}>\n            <div className=\"flex flex-col gap-3 py-1\">\n              <div className=\"flex items-center justify-between gap-3\">\n                <div className=\"min-w-0\">\n                  <div className=\"text-xs font-medium\">Auto title</div>\n                  <div className=\"text-[11px] text-muted-foreground\">\n                    Generate short title after reply.\n                  </div>\n                </div>\n                <Switch\n                  checked={autoTitle}\n                  onCheckedChange={onAutoTitleChange}\n                />\n              </div>\n              <div className=\"flex items-center justify-between gap-3\">\n                <div className=\"min-w-0\">\n                  <div className=\"text-xs font-medium\">Trust remote code</div>\n                  <div className=\"text-[11px] text-muted-foreground\">\n                    Allow models with custom code (e.g. Nemotron). Only enable for repos you trust.\n                  </div>\n                </div>\n                <Switch\n                  checked={params.trustRemoteCode ?? false}\n                  onCheckedChange={set(\"trustRemoteCode\")}\n                />\n              </div>\n              {isGguf && (\n                <div className=\"flex items-center justify-between gap-3\">\n                  <div className=\"min-w-0\">\n                    <div className=\"text-xs font-medium\">KV Cache Dtype</div>\n                    <div className=\"text-[11px] text-muted-foreground\">\n                      Quantize KV cache to reduce VRAM. Reload to apply.\n                    </div>\n                  </div>\n                  <Select\n                    value={kvCacheDtype ?? \"f16\"}\n                    onValueChange={(v) => {\n                      setKvCacheDtype(v === \"f16\" ? null : v);\n                      onReloadModel?.();\n                    }}\n                  >\n                    <SelectTrigger className=\"h-7 w-[90px] text-xs\">\n                      <SelectValue />\n                    </SelectTrigger>\n                    <SelectContent>\n                      <SelectItem value=\"f16\">f16</SelectItem>\n                      <SelectItem value=\"bf16\">bf16</SelectItem>\n                      <SelectItem value=\"q8_0\">q8_0</SelectItem>\n                      <SelectItem value=\"q5_1\">q5_1</SelectItem>\n                      <SelectItem value=\"q4_1\">q4_1</SelectItem>\n                    </SelectContent>\n                  </Select>\n                </div>\n              )}\n              <AutoHealToolCallsToggle />\n              <MaxToolCallsSlider />\n              <ToolCallTimeoutSlider />\n            </div>\n          </CollapsibleSection>\n\n          <ChatTemplateSection onReloadModel={onReloadModel} />\n        </div>\n      </>\n  );\n\n  if (isMobile) {\n    return (\n      <Sheet open={open} onOpenChange={onOpenChange}>\n        <SheetContent side=\"right\" className=\"w-[18rem] p-0\">\n          <SheetHeader className=\"sr-only\">\n            <SheetTitle>Configuration</SheetTitle>\n            <SheetDescription>Chat inference settings</SheetDescription>\n          </SheetHeader>\n          <div className=\"flex h-full flex-col\">{settingsContent}</div>\n        </SheetContent>\n      </Sheet>\n    );\n  }\n\n  return (\n    <aside\n      className={`shrink-0 self-start h-[calc(100%-0.875rem)] overflow-hidden bg-muted/70 rounded-2xl corner-squircle transition-[width] duration-200 ease-linear ${open ? \"w-[17rem] border-l border-sidebar-border/70\" : \"w-0\"}`}\n    >\n      <div className=\"flex h-full w-[17rem] flex-col\">{settingsContent}</div>\n    </aside>\n  );\n}\n\nfunction MaxToolCallsSlider() {\n  const maxToolCalls = useChatRuntimeStore((s) => s.maxToolCallsPerMessage);\n  const setMaxToolCalls = useChatRuntimeStore((s) => s.setMaxToolCallsPerMessage);\n\n  // Slider range 0-41; 41 maps to 9999 (\"Max\")\n  const sliderValue = maxToolCalls >= 9999 ? 41 : Math.min(maxToolCalls, 40);\n\n  return (\n    <ParamSlider\n      label=\"Max Tool Calls Per Message\"\n      value={sliderValue}\n      min={0}\n      max={41}\n      step={1}\n      onChange={(v) => setMaxToolCalls(v >= 41 ? 9999 : v)}\n      displayValue={sliderValue >= 41 ? \"Max\" : sliderValue === 0 ? \"Off\" : undefined}\n    />\n  );\n}\n\nfunction ToolCallTimeoutSlider() {\n  const timeout = useChatRuntimeStore((s) => s.toolCallTimeout);\n  const setTimeout_ = useChatRuntimeStore((s) => s.setToolCallTimeout);\n\n  // Slider 1-31; 31 maps to 9999 (\"Max\")\n  const sliderValue = timeout >= 9999 ? 31 : Math.min(Math.max(timeout, 1), 30);\n\n  const displayValue =\n    sliderValue >= 31\n      ? \"Max\"\n      : sliderValue === 1\n        ? \"1 minute\"\n        : `${sliderValue} minutes`;\n\n  return (\n    <ParamSlider\n      label=\"Max Tool Call Duration\"\n      value={sliderValue}\n      min={1}\n      max={31}\n      step={1}\n      onChange={(v) => setTimeout_(v >= 31 ? 9999 : v)}\n      displayValue={displayValue}\n    />\n  );\n}\n\nfunction AutoHealToolCallsToggle() {\n  const autoHealToolCalls = useChatRuntimeStore((s) => s.autoHealToolCalls);\n  const setAutoHealToolCalls = useChatRuntimeStore((s) => s.setAutoHealToolCalls);\n\n  return (\n    <div className=\"flex items-center justify-between gap-3\">\n      <div className=\"min-w-0\">\n        <div className=\"text-xs font-medium\">Auto Heal Tool Calls 🦥</div>\n        <div className=\"text-[11px] text-muted-foreground\">\n          Fix malformed tool calls from the model automatically.\n        </div>\n      </div>\n      <Switch\n        checked={autoHealToolCalls}\n        onCheckedChange={setAutoHealToolCalls}\n      />\n    </div>\n  );\n}\n\nfunction ChatTemplateSection({\n  onReloadModel,\n}: {\n  onReloadModel?: () => void;\n}) {\n  const defaultTemplate = useChatRuntimeStore((s) => s.defaultChatTemplate);\n  const override = useChatRuntimeStore((s) => s.chatTemplateOverride);\n  const setOverride = useChatRuntimeStore((s) => s.setChatTemplateOverride);\n\n  if (!defaultTemplate) return null;\n\n  const displayValue = override ?? defaultTemplate;\n  const isModified = override !== null;\n\n  return (\n    <CollapsibleSection icon={CodeIcon} label=\"Chat Template\">\n      <div className=\"flex flex-col gap-2 py-1\">\n        <Textarea\n          value={displayValue}\n          onChange={(e) => setOverride(e.target.value)}\n          className=\"min-h-32 font-mono text-[10px] leading-relaxed corner-squircle\"\n          rows={6}\n          spellCheck={false}\n        />\n        <div className=\"flex flex-wrap gap-1.5\">\n          {isModified && (\n            <>\n              <button\n                type=\"button\"\n                onClick={() => {\n                  onReloadModel?.();\n                }}\n                className=\"rounded-md bg-primary px-2.5 py-1 text-[11px] font-medium text-primary-foreground transition-colors hover:bg-primary/90\"\n              >\n                Apply & Reload\n              </button>\n              <button\n                type=\"button\"\n                onClick={() => setOverride(null)}\n                className=\"rounded-md border px-2.5 py-1 text-[11px] font-medium text-muted-foreground transition-colors hover:bg-accent\"\n              >\n                Revert changes\n              </button>\n            </>\n          )}\n        </div>\n      </div>\n    </CollapsibleSection>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/chat/components/model-load-status.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Progress } from \"@/components/ui/progress\";\nimport { Spinner } from \"@/components/ui/spinner\";\nimport { Button } from \"@/components/ui/button\";\n\ntype ModelLoadDescriptionProps = {\n  title?: string | null;\n  message?: string | null;\n  progressPercent?: number | null;\n  progressLabel?: string | null;\n  onStop?: () => void;\n};\n\nfunction clampProgress(value: number): number {\n  return Math.max(0, Math.min(100, value));\n}\n\nexport function ModelLoadDescription({\n  title,\n  message,\n  progressPercent,\n  progressLabel,\n  onStop,\n}: ModelLoadDescriptionProps) {\n  const hasProgress = typeof progressPercent === \"number\";\n\n  return (\n    <div className=\"relative flex min-h-12 w-full items-stretch gap-2\">\n      <div className=\"flex h-full shrink-0 items-center self-center\">\n        <Spinner className=\"size-4 text-foreground\" />\n      </div>\n      <div className=\"min-w-0 flex-1 pr-5\">\n        {title ? <p className=\"text-foreground leading-5 font-semibold\">{title}</p> : null}\n        {hasProgress ? (\n          <div className=\"w-full pt-1\">\n            <div className=\"flex items-center justify-between text-[10px] font-medium tracking-[0.08em] text-muted-foreground/80\">\n              <span>{progressLabel}</span>\n              <span>{Math.round(clampProgress(progressPercent))}%</span>\n            </div>\n            <Progress value={clampProgress(progressPercent)} className=\"h-1 bg-foreground/[0.08]\" />\n          </div>\n        ) : message ? (\n          <p className=\"pt-1 text-xs leading-relaxed text-muted-foreground\">{message}</p>\n        ) : null}\n      </div>\n      {onStop ? (\n        <Button\n          type=\"button\"\n          size=\"xs\"\n          variant=\"ghost\"\n          aria-label=\"Stop model loading\"\n          className=\"h-auto self-stretch shrink-0 !rounded-none !border-0 bg-transparent px-1 text-[10px] text-muted-foreground hover:bg-transparent hover:text-destructive focus-visible:text-destructive\"\n          onClick={onStop}\n        >\n          Cancel\n        </Button>\n      ) : null}\n    </div>\n  );\n}\n\ntype ModelLoadInlineStatusProps = {\n  label: string;\n  title: string;\n  progressPercent?: number | null;\n  progressLabel?: string | null;\n  onStop?: () => void;\n};\n\nexport function ModelLoadInlineStatus({\n  label,\n  title,\n  progressPercent,\n  progressLabel,\n  onStop,\n}: ModelLoadInlineStatusProps) {\n  const hasProgress = typeof progressPercent === \"number\";\n\n  return (\n    <div className=\"flex min-w-[20rem] items-center gap-2.5 text-muted-foreground\" title={title}>\n      <div className=\"flex items-center gap-1.5 shrink-0\">\n        <Spinner className=\"size-3.5 shrink-0\" />\n        <span className=\"text-xs\">{label}</span>\n      </div>\n      {hasProgress ? (\n        <div className=\"flex min-w-0 flex-[1.35] items-center gap-2.5\">\n          <div className=\"min-w-[7rem] flex-1\">\n            <Progress value={clampProgress(progressPercent)} className=\"h-1 bg-foreground/[0.08]\" />\n          </div>\n          <div className=\"flex shrink-0 items-center gap-1 text-[10px] font-medium tracking-[0.08em] text-muted-foreground/80\">\n            <span>{progressLabel}</span>\n            <span>{Math.round(clampProgress(progressPercent))}%</span>\n          </div>\n        </div>\n      ) : null}\n      {onStop ? (\n        <Button\n          type=\"button\"\n          size=\"xs\"\n          variant=\"outline\"\n          className=\"shrink-0 text-[11px]\"\n          onClick={onStop}\n        >\n          Stop\n        </Button>\n      ) : null}\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/chat/db.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport Dexie, { type EntityTable, liveQuery } from \"dexie\";\nimport { useEffect, useState } from \"react\";\nimport type { MessageRecord, ThreadRecord } from \"./types\";\n\nconst db = new Dexie(\"unsloth-chat\") as Dexie & {\n  threads: EntityTable<ThreadRecord, \"id\">;\n  messages: EntityTable<MessageRecord, \"id\">;\n};\n\ndb.version(1).stores({\n  threads: \"id, modelType, pairId, archived, createdAt\",\n  messages: \"id, threadId, createdAt\",\n});\n\ndb.version(2)\n  .stores({\n    threads: \"id, modelType, pairId, archived, createdAt\",\n    messages: \"id, threadId, createdAt\",\n  })\n  .upgrade((tx) => tx.table(\"messages\").clear());\n\ndb.version(3)\n  .stores({\n    threads: \"id, modelType, pairId, archived, createdAt\",\n    messages: \"id, threadId, createdAt\",\n  })\n  .upgrade((tx) =>\n    tx\n      .table(\"threads\")\n      .toCollection()\n      .modify((thread) => {\n        if (!thread.modelId) thread.modelId = \"\";\n      }),\n  );\n\nexport { db };\n\nexport function useLiveQuery<T>(\n  querier: () => Promise<T>,\n  deps: unknown[] = [],\n): T | undefined {\n  const [value, setValue] = useState<T>();\n  useEffect(() => {\n    const sub = liveQuery(querier).subscribe({\n      next: setValue,\n      error: (err) => console.error(\"useLiveQuery:\", err),\n    });\n    return () => sub.unsubscribe();\n    // eslint-disable-next-line react-hooks/exhaustive-deps\n  }, [querier, ...deps]);\n  return value;\n}\n"
  },
  {
    "path": "studio/frontend/src/features/chat/hooks/use-chat-model-runtime.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { createElement, useCallback, useRef, useState } from \"react\";\nimport { toast } from \"sonner\";\nimport { ModelLoadDescription } from \"../components/model-load-status\";\nimport {\n  getDownloadProgress,\n  getGgufDownloadProgress,\n  getInferenceStatus,\n  listLoras,\n  listModels,\n  loadModel,\n  unloadModel,\n  validateModel,\n} from \"../api/chat-api\";\nimport { useChatRuntimeStore } from \"../stores/chat-runtime-store\";\nimport type { LoadModelResponse } from \"../types/api\";\nimport type {\n  ChatLoraSummary,\n  ChatModelSummary,\n  InferenceParams,\n} from \"../types/runtime\";\n\ntype SelectedModelInput = {\n  id: string;\n  isLora?: boolean;\n  ggufVariant?: string;\n  loadingDescription?: string;\n  isDownloaded?: boolean;\n  expectedBytes?: number;\n  forceReload?: boolean;\n};\n\nconst MODEL_LOAD_TOAST_CLASSNAMES = {\n  toast: \"items-start gap-2.5\",\n  content: \"gap-0.5 flex-1 min-w-0\",\n  title: \"leading-5\",\n  description: \"mt-0 w-full\",\n} as const;\n\nconst LORA_SUFFIX_RE = /_(\\d{9,})$/;\n\nfunction parseTrailingEpoch(input: string): number | undefined {\n  const match = input.match(LORA_SUFFIX_RE);\n  if (!match) {\n    return undefined;\n  }\n  const parsed = Number.parseInt(match[1], 10);\n  return Number.isFinite(parsed) ? parsed : undefined;\n}\n\nfunction stripTrailingEpoch(input: string): string {\n  const cleaned = input.replace(LORA_SUFFIX_RE, \"\").replace(/[_-]+$/, \"\").trim();\n  return cleaned || input;\n}\n\nfunction describeModel(model: {\n  is_lora?: boolean;\n  is_vision?: boolean;\n  is_gguf?: boolean;\n  is_audio?: boolean;\n  has_audio_input?: boolean;\n}): string | undefined {\n  const tags: string[] = [];\n  if (model.is_gguf) tags.push(\"GGUF\");\n  if (model.is_lora) tags.push(\"LoRA\");\n  if (model.is_vision) tags.push(\"Vision\");\n  if (model.is_audio) tags.push(\"Audio\");\n  if (model.has_audio_input) tags.push(\"Audio Input\");\n  if (!model.is_lora && !model.is_vision && !model.is_gguf && !model.is_audio && !model.has_audio_input)\n    tags.push(\"Base\");\n  return tags.join(\" · \");\n}\n\nfunction toChatModelSummary(model: {\n  id: string;\n  name?: string | null;\n  is_lora?: boolean;\n  is_vision?: boolean;\n  is_gguf?: boolean;\n  is_audio?: boolean;\n  audio_type?: string | null;\n  has_audio_input?: boolean;\n}): ChatModelSummary {\n  return {\n    id: model.id,\n    name: model.name || model.id,\n    description: describeModel(model),\n    isLora: Boolean(model.is_lora),\n    isVision: Boolean(model.is_vision),\n    isGguf: Boolean(model.is_gguf),\n    isAudio: Boolean(model.is_audio),\n    audioType: model.audio_type ?? null,\n    hasAudioInput: Boolean(model.has_audio_input),\n  };\n}\n\nfunction toLoraSummary(lora: {\n  display_name: string;\n  adapter_path: string;\n  base_model?: string | null;\n  source?: \"training\" | \"exported\" | null;\n  export_type?: \"lora\" | \"merged\" | \"gguf\" | null;\n}): ChatLoraSummary {\n  const idTail = lora.adapter_path.split(\"/\").filter(Boolean).at(-1) ?? \"\";\n  const updatedAt =\n    parseTrailingEpoch(lora.display_name) ?? parseTrailingEpoch(idTail);\n\n  return {\n    id: lora.adapter_path,\n    name: stripTrailingEpoch(lora.display_name),\n    baseModel: lora.base_model || \"Unknown base model\",\n    updatedAt,\n    source: lora.source ?? undefined,\n    exportType: lora.export_type ?? undefined,\n  };\n}\n\nfunction toFiniteNumber(value: unknown): number | undefined {\n  if (typeof value !== \"number\" || !Number.isFinite(value)) {\n    return undefined;\n  }\n  return value;\n}\n\nfunction mergeRecommendedInference(\n  current: InferenceParams,\n  response: LoadModelResponse,\n  modelId: string,\n): InferenceParams {\n  const inference = response.inference;\n  // GGUF: use actual context length from GGUF metadata, fallback to 131072\n  // Non-GGUF: 4096\n  const defaultMaxTokens = response.is_gguf\n    ? (response.context_length ?? 131072)\n    : 4096;\n  return {\n    ...current,\n    checkpoint: modelId,\n    maxTokens: defaultMaxTokens,\n    temperature:\n      toFiniteNumber(inference?.temperature) ?? current.temperature,\n    topP: toFiniteNumber(inference?.top_p) ?? current.topP,\n    topK: toFiniteNumber(inference?.top_k) ?? current.topK,\n    minP: toFiniteNumber(inference?.min_p) ?? current.minP,\n    presencePenalty:\n      toFiniteNumber(inference?.presence_penalty) ?? current.presencePenalty,\n    trustRemoteCode:\n      typeof inference?.trust_remote_code === \"boolean\"\n        ? inference.trust_remote_code\n        : current.trustRemoteCode,\n  };\n}\n\nexport function useChatModelRuntime() {\n  const params = useChatRuntimeStore((state) => state.params);\n  const models = useChatRuntimeStore((state) => state.models);\n  const loras = useChatRuntimeStore((state) => state.loras);\n  const setModels = useChatRuntimeStore((state) => state.setModels);\n  const setLoras = useChatRuntimeStore((state) => state.setLoras);\n  const setParams = useChatRuntimeStore((state) => state.setParams);\n  const setModelsError = useChatRuntimeStore((state) => state.setModelsError);\n  const setCheckpoint = useChatRuntimeStore((state) => state.setCheckpoint);\n  const clearCheckpoint = useChatRuntimeStore((state) => state.clearCheckpoint);\n\n  const [loadingModel, setLoadingModel] = useState<{\n    id: string;\n    displayName: string;\n    isDownloaded?: boolean;\n  } | null>(null);\n  const [loadToastDismissed, setLoadToastDismissed] = useState(false);\n  const [loadProgress, setLoadProgress] = useState<{\n    percent: number | null;\n    label: string | null;\n    phase: \"downloading\" | \"starting\";\n  } | null>(null);\n  const loadAbortRef = useRef<AbortController | null>(null);\n  const loadingModelRef = useRef<typeof loadingModel>(null);\n  const loadToastIdRef = useRef<string | number | null>(null);\n  const loadToastDismissedRef = useRef(false);\n\n  const setLoadToastDismissedState = useCallback((dismissed: boolean) => {\n    loadToastDismissedRef.current = dismissed;\n    setLoadToastDismissed(dismissed);\n  }, []);\n\n  const resetLoadingUi = useCallback(() => {\n    setLoadingModel(null);\n    setLoadProgress(null);\n    loadingModelRef.current = null;\n    loadAbortRef.current = null;\n    loadToastIdRef.current = null;\n    setLoadToastDismissedState(false);\n    useChatRuntimeStore.getState().setModelLoading(false);\n  }, [setLoadToastDismissedState]);\n\n  const renderLoadDescription = useCallback(\n    (\n      title: string,\n      message: string,\n      progressPercent?: number | null,\n      progressLabel?: string | null,\n      onStop?: () => void,\n    ) =>\n      createElement(ModelLoadDescription, {\n        title,\n        message,\n        progressPercent,\n        progressLabel,\n        onStop,\n      }),\n    [],\n  );\n\n  const refresh = useCallback(async () => {\n    setModelsError(null);\n    try {\n      const [listRes, statusRes, lorasRes] = await Promise.all([\n        listModels(),\n        getInferenceStatus(),\n        listLoras(),\n      ]);\n\n      setModels(listRes.models.map(toChatModelSummary));\n      setLoras(lorasRes.loras.map(toLoraSummary));\n\n      if (statusRes.active_model) {\n        setCheckpoint(statusRes.active_model, statusRes.gguf_variant);\n\n        // Apply inference defaults on reconnect (page refresh with model already loaded)\n        if (statusRes.inference) {\n          const currentParams = useChatRuntimeStore.getState().params;\n          setParams(\n            mergeRecommendedInference(currentParams, statusRes as any, statusRes.active_model),\n          );\n        }\n\n        // Restore reasoning/tools support flags and context length\n        const supportsReasoning = statusRes.supports_reasoning ?? false;\n        const supportsTools = statusRes.supports_tools ?? false;\n        useChatRuntimeStore.setState({\n          supportsReasoning,\n          supportsTools,\n          ggufContextLength: statusRes.is_gguf ? (statusRes.context_length ?? null) : null,\n        });\n\n        // Set reasoning default for Qwen3.5 small models\n        if (supportsReasoning) {\n          let reasoningDefault = true;\n          const mid = statusRes.active_model.toLowerCase();\n          if (mid.includes(\"qwen3.5\")) {\n            const sizeMatch = mid.match(/(\\d+\\.?\\d*)\\s*b/);\n            if (sizeMatch && parseFloat(sizeMatch[1]) < 9) {\n              reasoningDefault = false;\n            }\n          }\n          useChatRuntimeStore.getState().setReasoningEnabled(reasoningDefault);\n        }\n      }\n    } catch (error) {\n      const message =\n        error instanceof Error ? error.message : \"Failed to load models\";\n      setModelsError(message);\n      toast.error(\"Failed to refresh models\", {\n        description: message,\n      });\n    }\n  }, [setCheckpoint, setLoras, setModels, setModelsError, setParams]);\n\n  const cancelLoading = useCallback(() => {\n    const model = loadingModelRef.current;\n    if (!model) return;\n    loadAbortRef.current?.abort();\n    loadAbortRef.current = null;\n    loadingModelRef.current = null;\n    const tid = loadToastIdRef.current;\n    loadToastIdRef.current = null;\n    setLoadingModel(null);\n    setLoadProgress(null);\n    setLoadToastDismissedState(false);\n    clearCheckpoint();\n    if (tid != null) toast.dismiss(tid);\n    toast.info(\"Stopped loading model\", {\n      description: \"The current download may still finish in the background.\",\n    });\n    // Fire-and-forget: tell backend to stop, don't block UI\n    unloadModel({ model_path: model.id }).catch(() => {});\n  }, [clearCheckpoint, setLoadToastDismissedState]);\n\n  const selectModel = useCallback(\n    async (selection: string | SelectedModelInput) => {\n      const modelId = typeof selection === \"string\" ? selection : selection.id;\n      const ggufVariant =\n        typeof selection === \"string\" ? undefined : selection.ggufVariant;\n      const forceReload =\n        typeof selection === \"string\" ? false : selection.forceReload ?? false;\n      const currentVariant = useChatRuntimeStore.getState().activeGgufVariant;\n      if (!forceReload && (!modelId || (params.checkpoint === modelId && (ggufVariant ?? null) === (currentVariant ?? null)))) {\n        return;\n      }\n      // Prevent duplicate loads if already loading this model\n      if (loadingModelRef.current?.id === modelId) return;\n\n      const explicitIsLora =\n        typeof selection === \"string\" ? undefined : selection.isLora;\n      const extraLoadingDescription =\n        typeof selection === \"string\" ? undefined : selection.loadingDescription;\n      const isDownloaded =\n        typeof selection === \"string\" ? false : selection.isDownloaded ?? false;\n      const model = models.find((entry) => entry.id === modelId);\n      const lora = loras.find((entry) => entry.id === modelId);\n      const isLora =\n        explicitIsLora ?? model?.isLora ?? (lora ? true : false);\n      const displayName = model?.name || lora?.name || modelId;\n      const currentCheckpoint =\n        useChatRuntimeStore.getState().params.checkpoint;\n      const previousCheckpoint = currentCheckpoint;\n      const previousVariant =\n        useChatRuntimeStore.getState().activeGgufVariant ?? null;\n      const previousModel = previousCheckpoint\n        ? models.find((entry) => entry.id === previousCheckpoint)\n        : undefined;\n      const previousLora = previousCheckpoint\n        ? loras.find((entry) => entry.id === previousCheckpoint)\n        : undefined;\n      const previousIsLora =\n        previousModel?.isLora ?? (previousLora ? true : false);\n      const loadingDescription = [\n        currentCheckpoint ? \"Switching models.\" : null,\n        extraLoadingDescription ?? null,\n        isDownloaded ? \"Loading cached model into memory.\" : null,\n      ]\n        .filter(Boolean)\n        .join(\" \");\n      setModelsError(null);\n      setLoadToastDismissedState(false);\n      const loadInfo = { id: modelId, displayName, isDownloaded };\n      setLoadingModel(loadInfo);\n      useChatRuntimeStore.getState().setModelLoading(true);\n      setLoadProgress(\n        isDownloaded\n          ? { percent: null, label: null, phase: \"starting\" }\n          : { percent: 0, label: \"Preparing download\", phase: \"downloading\" },\n      );\n      loadingModelRef.current = loadInfo;\n      const abortCtrl = new AbortController();\n      loadAbortRef.current = abortCtrl;\n      try {\n        async function performLoad(): Promise<void> {\n          if (abortCtrl.signal.aborted) throw new Error(\"Cancelled\");\n          let previousWasUnloaded = false;\n          const currentCheckpoint =\n            useChatRuntimeStore.getState().params.checkpoint;\n          const paramsBeforeLoad = useChatRuntimeStore.getState().params;\n          const maxSeqLength = paramsBeforeLoad.maxSeqLength;\n          try {\n            // Lightweight pre-flight validation: avoid unloading a working model\n            // if the new identifier is clearly invalid (e.g. bad HF id / path).\n            await validateModel({\n              model_path: modelId,\n              hf_token: null,\n              max_seq_length: maxSeqLength,\n              load_in_4bit: true,\n              is_lora: isLora,\n              gguf_variant: ggufVariant ?? null,\n            });\n\n            if (currentCheckpoint) {\n              await unloadModel({ model_path: currentCheckpoint });\n              previousWasUnloaded = true;\n            }\n\n            const { chatTemplateOverride, kvCacheDtype } = useChatRuntimeStore.getState();\n            const loadResponse = await loadModel({\n              model_path: modelId,\n              hf_token: null,\n              max_seq_length: maxSeqLength,\n              load_in_4bit: true,\n              is_lora: isLora,\n              gguf_variant: ggufVariant ?? null,\n              trust_remote_code: paramsBeforeLoad.trustRemoteCode ?? false,\n              chat_template_override: chatTemplateOverride,\n              cache_type_kv: kvCacheDtype,\n            });\n\n            // If cancelled while loading, don't update UI to show\n            // the model as active -- it's being unloaded.\n            if (abortCtrl.signal.aborted) throw new Error(\"Cancelled\");\n\n            const currentParams = useChatRuntimeStore.getState().params;\n            setParams(\n              mergeRecommendedInference(currentParams, loadResponse, modelId),\n            );\n            // Qwen3.5 small models (0.8B, 2B, 4B, 9B) disable thinking by default\n            let reasoningDefault = loadResponse.supports_reasoning ?? false;\n            if (reasoningDefault) {\n              const mid = modelId.toLowerCase();\n              if (mid.includes(\"qwen3.5\")) {\n                const sizeMatch = mid.match(/(\\d+\\.?\\d*)\\s*b/);\n                if (sizeMatch && parseFloat(sizeMatch[1]) < 9) {\n                  reasoningDefault = false;\n                }\n              }\n            }\n            useChatRuntimeStore.setState({\n              ggufContextLength: loadResponse.is_gguf\n                ? (loadResponse.context_length ?? 131072)\n                : null,\n              supportsReasoning: loadResponse.supports_reasoning ?? false,\n              reasoningEnabled: reasoningDefault,\n              supportsTools: loadResponse.supports_tools ?? false,\n              toolsEnabled: false,\n              kvCacheDtype: loadResponse.cache_type_kv ?? null,\n              defaultChatTemplate: loadResponse.chat_template ?? null,\n              chatTemplateOverride: null,\n            });\n            // Qwen3/3.5: apply thinking-mode-specific params after load\n            if (modelId.toLowerCase().includes(\"qwen3\") && (loadResponse.supports_reasoning ?? false)) {\n              const store = useChatRuntimeStore.getState();\n              const p = reasoningDefault\n                ? { temperature: 0.6, topP: 0.95, topK: 20, minP: 0.0 }\n                : { temperature: 0.7, topP: 0.8, topK: 20, minP: 0.0 };\n              store.setParams({ ...store.params, ...p });\n            }\n            await refresh();\n          } catch (error) {\n            // Skip rollback if user cancelled -- model is already being unloaded.\n            if (abortCtrl.signal.aborted) throw error;\n            // If we unloaded a previous model and the new load failed, attempt a rollback.\n            if (previousWasUnloaded && previousCheckpoint) {\n              try {\n                await loadModel({\n                  model_path: previousCheckpoint,\n                  hf_token: null,\n                  max_seq_length: maxSeqLength,\n                  load_in_4bit: true,\n                  is_lora: previousIsLora,\n                  gguf_variant: previousVariant,\n                });\n                await refresh();\n              } catch {\n                // If rollback also fails, surface the original error.\n              }\n            }\n            throw error;\n          }\n        }\n\n        const toastTitle = isDownloaded ? \"Starting model…\" : \"Downloading model…\";\n        const toastId = toast(\n          null,\n          {\n            description: renderLoadDescription(\n              toastTitle,\n              loadingDescription,\n              isDownloaded ? null : 0,\n              isDownloaded ? null : \"Preparing download\",\n              cancelLoading,\n            ),\n            duration: Infinity,\n            closeButton: false,\n            classNames: MODEL_LOAD_TOAST_CLASSNAMES,\n            onDismiss: (dismissedToast) => {\n              if (loadToastIdRef.current !== dismissedToast.id) {\n                return;\n              }\n              setLoadToastDismissedState(true);\n            },\n          },\n        );\n        loadToastIdRef.current = toastId;\n\n        // Poll download progress for non-cached models (GGUF and non-GGUF)\n        let progressInterval: ReturnType<typeof setInterval> | null = null;\n        if (!isDownloaded) {\n          const expectedBytes =\n            typeof selection !== \"string\" ? selection.expectedBytes ?? 0 : 0;\n          let hasShownProgress = false;\n\n          const pollProgress = async () => {\n            if (abortCtrl.signal.aborted || !loadingModelRef.current) {\n              if (progressInterval) clearInterval(progressInterval);\n              return;\n            }\n            try {\n              const prog = ggufVariant && expectedBytes > 0\n                ? await getGgufDownloadProgress(modelId, ggufVariant, expectedBytes)\n                : await getDownloadProgress(modelId);\n\n              if (!loadingModelRef.current) return;\n\n              if (prog.progress > 0 && prog.progress < 1) {\n                hasShownProgress = true;\n                const dlGb = prog.downloaded_bytes / (1024 ** 3);\n                const totalGb = prog.expected_bytes / (1024 ** 3);\n                const pct = Math.round(prog.progress * 100);\n                const progressLabel = totalGb > 0\n                  ? `${dlGb.toFixed(1)} of ${totalGb.toFixed(1)} GB`\n                  : `${dlGb.toFixed(1)} GB downloaded`;\n                setLoadProgress({\n                  percent: pct,\n                  label: progressLabel,\n                  phase: \"downloading\",\n                });\n                if (loadToastDismissedRef.current) return;\n                toast(\n                  null,\n                  {\n                    id: toastId,\n                    description: renderLoadDescription(\n                      \"Downloading model…\",\n                      loadingDescription,\n                      pct,\n                      progressLabel,\n                      cancelLoading,\n                    ),\n                    duration: Infinity,\n                    closeButton: false,\n                    classNames: MODEL_LOAD_TOAST_CLASSNAMES,\n                    onDismiss: (dismissedToast) => {\n                      if (loadToastIdRef.current !== dismissedToast.id) return;\n                      setLoadToastDismissedState(true);\n                    },\n                  },\n                );\n              } else if (prog.downloaded_bytes > 0 && prog.expected_bytes === 0 && prog.progress === 0) {\n                hasShownProgress = true;\n                const dlGb = prog.downloaded_bytes / (1024 ** 3);\n                setLoadProgress({\n                  percent: null,\n                  label: `${dlGb.toFixed(1)} GB downloaded`,\n                  phase: \"downloading\",\n                });\n              } else if (prog.progress >= 1 && hasShownProgress) {\n                setLoadProgress({\n                  percent: 100,\n                  label: \"Download complete\",\n                  phase: \"starting\",\n                });\n                if (loadToastDismissedRef.current) {\n                  if (progressInterval) clearInterval(progressInterval);\n                  return;\n                }\n                toast(null, {\n                  id: toastId,\n                  description: renderLoadDescription(\n                    \"Starting model…\",\n                    \"Download complete. Loading the model into memory.\",\n                    100,\n                    \"Download complete\",\n                    cancelLoading,\n                  ),\n                  duration: Infinity,\n                  closeButton: false,\n                  classNames: MODEL_LOAD_TOAST_CLASSNAMES,\n                  onDismiss: (dismissedToast) => {\n                    if (loadToastIdRef.current !== dismissedToast.id) return;\n                    setLoadToastDismissedState(true);\n                  },\n                });\n                if (progressInterval) clearInterval(progressInterval);\n              }\n            } catch {\n              // Ignore polling errors\n            }\n          };\n\n          setTimeout(pollProgress, 500);\n          progressInterval = setInterval(pollProgress, 2000);\n        }\n\n        try {\n          await performLoad();\n          if (loadToastDismissedRef.current) {\n            toast.success(`${displayName} loaded`);\n          } else {\n            toast.success(`${displayName} loaded`, {\n              id: toastId,\n              description: undefined,\n              closeButton: false,\n              duration: 2000,\n            });\n          }\n        } catch (err) {\n          if (!abortCtrl.signal.aborted) {\n            const message =\n              err instanceof Error ? err.message : \"Failed to load model\";\n            if (loadToastDismissedRef.current) {\n              toast.error(message);\n            } else {\n              toast.error(message, {\n                id: toastId,\n                description: undefined,\n                closeButton: false,\n                duration: 5000,\n              });\n            }\n          }\n          throw err;\n        } finally {\n          if (progressInterval) clearInterval(progressInterval);\n          resetLoadingUi();\n        }\n      } catch (error) {\n        if (abortCtrl.signal.aborted) return; // User cancelled, nothing to report\n        resetLoadingUi();\n        const message =\n          error instanceof Error ? error.message : \"Failed to load model\";\n        setModelsError(message);\n      }\n    },\n    [\n      cancelLoading,\n      loras,\n      models,\n      params.checkpoint,\n      refresh,\n      renderLoadDescription,\n      resetLoadingUi,\n      setLoadToastDismissedState,\n      setModelsError,\n      setParams,\n    ],\n  );\n\n  const ejectModel = useCallback(async () => {\n    if (!params.checkpoint) {\n      return;\n    }\n    setModelsError(null);\n    try {\n      async function performUnload(): Promise<void> {\n        await unloadModel({ model_path: params.checkpoint });\n        clearCheckpoint();\n        await refresh();\n      }\n\n      await toast.promise(performUnload(), {\n        loading: \"Unloading model\",\n        success: { message: \"Model unloaded\", duration: 1200 },\n        error: (err) =>\n          err instanceof Error ? err.message : \"Failed to unload model\",\n        description: \"Releases VRAM and resets inference state.\",\n      });\n    } catch (error) {\n      const message =\n        error instanceof Error ? error.message : \"Failed to unload model\";\n      setModelsError(message);\n    }\n  }, [clearCheckpoint, params.checkpoint, refresh, setModelsError]);\n\n  return {\n    refresh,\n    selectModel,\n    ejectModel,\n    cancelLoading,\n    loadingModel,\n    loadProgress,\n    loadToastDismissed,\n  };\n}\n"
  },
  {
    "path": "studio/frontend/src/features/chat/index.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nexport { ChatPage } from \"./chat-page\";\nexport {\n  ChatSettingsPanel,\n  defaultInferenceParams,\n  type InferenceParams,\n  type Preset,\n} from \"./chat-settings-sheet\";\nexport { useChatRuntimeStore } from \"./stores/chat-runtime-store\";\nexport { useChatModelRuntime } from \"./hooks/use-chat-model-runtime\";\nexport { setTrainingCompareHandoff } from \"./lib/training-compare-handoff\";\n"
  },
  {
    "path": "studio/frontend/src/features/chat/runtime-provider.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport {\n  AssistantRuntimeProvider,\n  type AttachmentAdapter,\n  type CompleteAttachment,\n  CompositeAttachmentAdapter,\n  ExportedMessageRepository,\n  type ExportedMessageRepositoryItem,\n  type PendingAttachment,\n  RuntimeAdapterProvider,\n  Suggestions,\n  type ThreadHistoryAdapter,\n  type ThreadMessage,\n  WebSpeechDictationAdapter,\n  type unstable_RemoteThreadListAdapter,\n  useAui,\n  useAuiState,\n  useLocalRuntime,\n  unstable_useRemoteThreadListRuntime as useRemoteThreadListRuntime,\n} from \"@assistant-ui/react\";\nimport { createAssistantStream } from \"assistant-stream\";\nimport mammoth from \"mammoth\";\nimport { type ReactElement, type ReactNode, useEffect, useMemo } from \"react\";\nimport { extractText, getDocumentProxy } from \"unpdf\";\nimport { authFetch } from \"@/features/auth\";\nimport { createOpenAIStreamAdapter } from \"./api/chat-adapter\";\nimport { db } from \"./db\";\nimport { useChatRuntimeStore } from \"./stores/chat-runtime-store\";\nimport type { MessageRecord, ModelType } from \"./types\";\n\nconst DEFAULT_SUGGESTIONS = [\n  \"Draw an SVG of a cute sloth\",\n  \"Solve the integral of x²·sin(x) step by step\",\n  \"Write a Python function that finds the longest palindrome in a string\",\n  \"Format a comparison of 3 databases as a markdown table with pros and cons\",\n];\n\ntype TitleResponse = {\n  choices?: Array<{\n    message?: {\n      content?: string;\n    };\n  }>;\n};\n\nclass VisionImageAdapter implements AttachmentAdapter {\n  accept = \"image/jpeg,image/png,image/webp,image/gif\";\n\n  async add({ file }: { file: File }): Promise<PendingAttachment> {\n    const maxSize = 20 * 1024 * 1024;\n    if (file.size > maxSize) {\n      throw new Error(\"Image size exceeds 20MB limit\");\n    }\n\n    return {\n      id: crypto.randomUUID(),\n      type: \"image\",\n      name: file.name,\n      contentType: file.type,\n      file,\n      status: { type: \"requires-action\", reason: \"composer-send\" },\n    };\n  }\n\n  async send(attachment: PendingAttachment): Promise<CompleteAttachment> {\n    return {\n      id: attachment.id,\n      type: \"image\",\n      name: attachment.name,\n      contentType: attachment.contentType,\n      content: [\n        {\n          type: \"image\",\n          image: await this.fileToBase64DataURL(attachment.file),\n        },\n      ],\n      status: { type: \"complete\" },\n    };\n  }\n\n  async remove(): Promise<void> {\n    return Promise.resolve();\n  }\n\n  private async fileToBase64DataURL(file: File): Promise<string> {\n    return new Promise((resolve, reject) => {\n      const reader = new FileReader();\n      reader.onload = () => resolve(reader.result as string);\n      reader.onerror = () => reject(new Error(\"Failed to read image file\"));\n      reader.readAsDataURL(file);\n    });\n  }\n}\n\nclass PDFAttachmentAdapter implements AttachmentAdapter {\n  accept = \"application/pdf\";\n\n  add({ file }: { file: File }): Promise<PendingAttachment> {\n    return Promise.resolve({\n      id: crypto.randomUUID(),\n      type: \"document\",\n      name: file.name,\n      contentType: file.type,\n      file,\n      status: { type: \"requires-action\", reason: \"composer-send\" },\n    });\n  }\n\n  async send(attachment: PendingAttachment): Promise<CompleteAttachment> {\n    const buffer = new Uint8Array(await attachment.file.arrayBuffer());\n    const pdf = await getDocumentProxy(buffer);\n    const { text } = await extractText(pdf, { mergePages: true });\n    return {\n      id: attachment.id,\n      type: \"document\",\n      name: attachment.name,\n      contentType: attachment.contentType,\n      content: [{ type: \"text\", text: `[PDF: ${attachment.name}]\\n${text}` }],\n      status: { type: \"complete\" },\n    };\n  }\n\n  remove(): Promise<void> {\n    return Promise.resolve();\n  }\n}\n\nclass TextAttachmentAdapter implements AttachmentAdapter {\n  accept = \"text/plain,text/markdown,text/csv,text/xml,text/json,text/css\";\n\n  async add({ file }: { file: File }): Promise<PendingAttachment> {\n    return {\n      id: crypto.randomUUID(),\n      type: \"document\",\n      name: file.name,\n      contentType: file.type,\n      file,\n      status: { type: \"requires-action\", reason: \"composer-send\" },\n    };\n  }\n\n  async send(attachment: PendingAttachment): Promise<CompleteAttachment> {\n    const text = await attachment.file.text();\n    return {\n      id: attachment.id,\n      type: \"document\",\n      name: attachment.name,\n      contentType: attachment.contentType,\n      content: [\n        { type: \"text\", text: `<attachment name=${attachment.name}>\\n${text}\\n</attachment>` },\n      ],\n      status: { type: \"complete\" },\n    };\n  }\n\n  remove(): Promise<void> {\n    return Promise.resolve();\n  }\n}\n\nclass HtmlAttachmentAdapter implements AttachmentAdapter {\n  accept = \"text/html\";\n\n  async add({ file }: { file: File }): Promise<PendingAttachment> {\n    return {\n      id: crypto.randomUUID(),\n      type: \"document\",\n      name: file.name,\n      contentType: file.type,\n      file,\n      status: { type: \"requires-action\", reason: \"composer-send\" },\n    };\n  }\n\n  async send(attachment: PendingAttachment): Promise<CompleteAttachment> {\n    const html = await attachment.file.text();\n    // Strip HTML tags to extract readable text\n    const doc = new DOMParser().parseFromString(html, \"text/html\");\n    // Remove script and style elements\n    for (const el of doc.querySelectorAll(\"script, style\")) el.remove();\n    const text = (doc.body.textContent ?? \"\").replace(/\\s+/g, \" \").trim();\n    return {\n      id: attachment.id,\n      type: \"document\",\n      name: attachment.name,\n      contentType: attachment.contentType,\n      content: [\n        { type: \"text\", text: `[HTML: ${attachment.name}]\\n${text}` },\n      ],\n      status: { type: \"complete\" },\n    };\n  }\n\n  remove(): Promise<void> {\n    return Promise.resolve();\n  }\n}\n\nclass DocxAttachmentAdapter implements AttachmentAdapter {\n  accept =\n    \"application/vnd.openxmlformats-officedocument.wordprocessingml.document\";\n\n  add({ file }: { file: File }): Promise<PendingAttachment> {\n    return Promise.resolve({\n      id: crypto.randomUUID(),\n      type: \"document\",\n      name: file.name,\n      contentType: file.type,\n      file,\n      status: { type: \"requires-action\", reason: \"composer-send\" },\n    });\n  }\n\n  async send(attachment: PendingAttachment): Promise<CompleteAttachment> {\n    const arrayBuffer = await attachment.file.arrayBuffer();\n    const { value } = await mammoth.extractRawText({ arrayBuffer });\n    return {\n      id: attachment.id,\n      type: \"document\",\n      name: attachment.name,\n      contentType: attachment.contentType,\n      content: [{ type: \"text\", text: `[DOCX: ${attachment.name}]\\n${value}` }],\n      status: { type: \"complete\" },\n    };\n  }\n\n  remove(): Promise<void> {\n    return Promise.resolve();\n  }\n}\n\nfunction clip(input: string, maxLen: number): string {\n  const text = input.replace(/\\s+/g, \" \").trim();\n  if (text.length <= maxLen) return text;\n  return text.slice(0, maxLen).trimEnd();\n}\n\nfunction extractTextParts(m: ThreadMessage | undefined): string {\n  if (!m) return \"\";\n  const content = Array.isArray(m.content) ? m.content : [];\n  return content\n    .filter((p): p is Extract<typeof p, { type: \"text\" }> => p.type === \"text\")\n    .map((p) => p.text)\n    .join(\"\")\n    .trim();\n}\n\nasync function generateTitleWithModel(payload: {\n  userText: string;\n}): Promise<string | null> {\n  const params = useChatRuntimeStore.getState().params;\n  if (!params.checkpoint) return null;\n\n  const user = clip(payload.userText, 256);\n  const parts: string[] = [user];\n\n  function normalizeTitle(raw: string): string | null {\n    let title = raw.split(/\\r?\\n/, 1)[0] ?? \"\";\n    title = title.replace(/^\\s*title\\s*:\\s*/i, \"\");\n    title = title.replace(/[^\\x20-\\x7E]+/g, \" \");\n    title = title.replace(/[\"'`]+/g, \"\");\n    title = title.replace(/[.!?:;,]+/g, \" \");\n    title = title.replace(/\\s+/g, \" \").trim();\n\n    // Model echo fail-safe.\n    if (/\\b(user|base|lora|assistant)\\s*:/i.test(title)) {\n      return null;\n    }\n\n    const words = title.split(\" \").filter(Boolean).slice(0, 6);\n    const joined = words.join(\" \").trim();\n    if (!joined) return null;\n    return joined.length > 60 ? joined.slice(0, 60).trimEnd() : joined;\n  }\n\n  const response = await authFetch(\"/v1/chat/completions\", {\n    method: \"POST\",\n    headers: { \"Content-Type\": \"application/json\" },\n    body: JSON.stringify({\n      model: params.checkpoint,\n      stream: false,\n      temperature: 0.2,\n      top_p: 0.9,\n      max_tokens: 24,\n      top_k: 20,\n      repetition_penalty: 1.0,\n      messages: [\n        {\n          role: \"system\",\n          content:\n            \"Write 1 concise chat title for the user's message. Rules: 2-6 words, no quotes, no punctuation, ASCII only, do not echo input. Output title only.\",\n        },\n        { role: \"user\", content: parts.join(\"\\n\") },\n      ],\n    }),\n  });\n\n  const body = (await response.json().catch(() => null)) as TitleResponse | null;\n  if (!response.ok) return null;\n  const raw: string | undefined = body?.choices?.[0]?.message?.content;\n  if (!raw) return null;\n  return normalizeTitle(raw);\n}\n\nconst inflightTitleByKey = new Set<string>();\n\nfunction fallbackTitleFromUserText(userText: string): string {\n  const firstLine = (userText || \"\").split(/\\r?\\n/, 1)[0] ?? \"\";\n  const cleaned = firstLine.replace(/\\s+/g, \" \").trim();\n  const max = 48;\n  if (!cleaned) return \"New Chat\";\n  return cleaned.slice(0, max) + (cleaned.length > max ? \"...\" : \"\");\n}\n\nfunction cloneContent(content: ThreadMessage[\"content\"]): ThreadMessage[\"content\"] {\n  return Array.isArray(content)\n    ? JSON.parse(JSON.stringify(content))\n    : [];\n}\n\nfunction cloneAttachments(\n  attachments: readonly CompleteAttachment[] | undefined,\n): readonly CompleteAttachment[] {\n  if (!Array.isArray(attachments)) {\n    return [];\n  }\n  return JSON.parse(JSON.stringify(attachments));\n}\n\nfunction toThreadMessage(m: MessageRecord): ThreadMessage {\n  const content =\n    Array.isArray(m.content) && m.content.length > 0\n      ? cloneContent(m.content)\n      : [{ type: \"text\" as const, text: \"\" }];\n\n  if (m.role === \"user\") {\n    return {\n      id: m.id,\n      createdAt: new Date(m.createdAt),\n      role: \"user\" as const,\n      content: content as Extract<ThreadMessage, { role: \"user\" }>[\"content\"],\n      attachments: cloneAttachments(m.attachments),\n      metadata: { custom: {} },\n    };\n  }\n  return {\n    id: m.id,\n    createdAt: new Date(m.createdAt),\n    role: \"assistant\" as const,\n    content: content as Extract<ThreadMessage, { role: \"assistant\" }>[\"content\"],\n    status: { type: \"complete\" as const, reason: \"unknown\" as const },\n    metadata: {\n      custom: (m.metadata as Record<string, unknown>) ?? {},\n      steps: [],\n      unstable_annotations: [],\n      unstable_data: [],\n      unstable_state: null,\n    },\n  };\n}\n\nfunction createDexieAdapter(\n  modelType: ModelType,\n  pairId?: string,\n): unstable_RemoteThreadListAdapter {\n  return {\n    async fetch(remoteId: string) {\n      const thread = await db.threads.get(remoteId);\n      if (!thread) {\n        throw new Error(`Thread ${remoteId} not found`);\n      }\n      return {\n        remoteId: thread.id,\n        status: thread.archived ? \"archived\" : \"regular\",\n        title: thread.title,\n      };\n    },\n\n    async list() {\n      const threads = await db.threads\n        .where(\"modelType\")\n        .equals(modelType)\n        .reverse()\n        .sortBy(\"createdAt\");\n      return {\n        threads: threads.map((t) => ({\n          status: (t.archived ? \"archived\" : \"regular\") as\n            | \"archived\"\n            | \"regular\",\n          remoteId: t.id,\n          title: t.title,\n        })),\n      };\n    },\n\n    async initialize(threadId: string) {\n      const currentModelId =\n        useChatRuntimeStore.getState().params.checkpoint ?? \"\";\n      await db.threads.add({\n        id: threadId,\n        title: \"New Chat\",\n        modelType,\n        modelId: currentModelId,\n        pairId,\n        archived: false,\n        createdAt: Date.now(),\n      });\n      return { remoteId: threadId, externalId: undefined };\n    },\n\n    async rename(remoteId: string, newTitle: string) {\n      await db.threads.update(remoteId, { title: newTitle });\n    },\n\n    async archive(remoteId: string) {\n      await db.threads.update(remoteId, { archived: true });\n    },\n\n    async unarchive(remoteId: string) {\n      await db.threads.update(remoteId, { archived: false });\n    },\n\n    async delete(remoteId: string) {\n      await db.messages.where(\"threadId\").equals(remoteId).delete();\n      await db.threads.delete(remoteId);\n    },\n\n    async generateTitle(remoteId: string, messages: readonly ThreadMessage[]) {\n      const autoTitle = useChatRuntimeStore.getState().autoTitle;\n      const thread = await db.threads.get(remoteId);\n      const defaultTitle = \"New Chat\";\n\n      function streamTitle(title: string) {\n        return createAssistantStream((c) => {\n          c.appendText(title);\n          c.close();\n        });\n      }\n\n      async function persistTitle(title: string): Promise<void> {\n        await db.threads.update(remoteId, { title });\n        if (!pairId) return;\n        const paired = await db.threads\n          .where(\"pairId\")\n          .equals(pairId)\n          .filter((t) => t.id !== remoteId)\n          .first();\n        if (paired) await db.threads.update(paired.id, { title });\n      }\n\n      if (!thread) {\n        return streamTitle(defaultTitle);\n      }\n\n      // Only generate once per thread/pair.\n      if (thread.title && thread.title !== \"New Chat\") {\n        return streamTitle(thread.title);\n      }\n\n      const firstUser = messages.find((m) => m.role === \"user\");\n      const userText = extractTextParts(firstUser) || defaultTitle;\n\n      if (!autoTitle) {\n        const title = fallbackTitleFromUserText(userText);\n        await persistTitle(title);\n        return streamTitle(title);\n      }\n\n      const key = pairId ? `pair:${pairId}` : `thread:${remoteId}`;\n      if (inflightTitleByKey.has(key)) {\n        return streamTitle(thread.title || defaultTitle);\n      }\n\n      // Compare: wait until both threads done.\n      if (pairId) {\n        const paired = await db.threads\n          .where(\"pairId\")\n          .equals(pairId)\n          .filter((t) => t.id !== remoteId)\n          .first();\n\n        if (paired) {\n          const running = useChatRuntimeStore.getState().runningByThreadId;\n          if (running[paired.id]) {\n            setTimeout(() => {\n              void createDexieAdapter(modelType, pairId).generateTitle(remoteId, messages);\n            }, 600);\n            return streamTitle(thread.title || defaultTitle);\n          }\n        }\n      }\n\n      inflightTitleByKey.add(key);\n      try {\n        const title =\n          (await generateTitleWithModel({\n            userText,\n          })) ||\n          fallbackTitleFromUserText(userText);\n\n        await persistTitle(title);\n        return streamTitle(title);\n      } finally {\n        inflightTitleByKey.delete(key);\n      }\n    },\n  };\n}\n\nfunction ThreadHistoryProvider({\n  children,\n}: { children?: ReactNode }): ReactElement {\n  const aui = useAui();\n\n  const history = useMemo<ThreadHistoryAdapter>(\n    () => ({\n      async load() {\n        const { remoteId } = aui.threadListItem().getState();\n        if (!remoteId) {\n          return { messages: [] };\n        }\n        const roleOrder: Record<string, number> = {\n          system: 0,\n          user: 1,\n          assistant: 2,\n        };\n        const msgs = await db.messages.where(\"threadId\").equals(remoteId).toArray();\n        msgs.sort((a, b) => {\n          if (a.createdAt !== b.createdAt) return a.createdAt - b.createdAt;\n          const aOrder = roleOrder[a.role] ?? 99;\n          const bOrder = roleOrder[b.role] ?? 99;\n          if (aOrder !== bOrder) return aOrder - bOrder;\n          return a.id < b.id ? -1 : a.id > b.id ? 1 : 0;\n        });\n\n        return ExportedMessageRepository.fromArray(msgs.map(toThreadMessage));\n      },\n\n      async append({ message }: ExportedMessageRepositoryItem) {\n        const { remoteId } = await aui.threadListItem().initialize();\n        const content = cloneContent(message.content);\n        const attachments =\n          message.role === \"user\" ? cloneAttachments(message.attachments) : [];\n        const custom = message.metadata?.custom;\n        const existing = await db.messages.get(message.id);\n        const createdAt =\n          existing?.createdAt ??\n          message.createdAt?.getTime?.() ??\n          Date.now();\n        await db.messages.put({\n          id: message.id,\n          threadId: remoteId,\n          role: message.role,\n          content,\n          ...(attachments.length > 0 && { attachments }),\n          ...(custom && Object.keys(custom).length > 0 && { metadata: custom }),\n          createdAt,\n        });\n      },\n    }),\n    [aui],\n  );\n\n  const dictation = useMemo(\n    () =>\n      WebSpeechDictationAdapter.isSupported()\n        ? new WebSpeechDictationAdapter()\n        : undefined,\n    [],\n  );\n  const attachments = useMemo(\n    () =>\n      new CompositeAttachmentAdapter([\n        new VisionImageAdapter(),\n        new TextAttachmentAdapter(),\n        new HtmlAttachmentAdapter(),\n        new PDFAttachmentAdapter(),\n        new DocxAttachmentAdapter(),\n      ]),\n    [],\n  );\n  const adapters = useMemo(\n    () => ({ history, dictation, attachments }),\n    [history, dictation, attachments],\n  );\n\n  return (\n    <RuntimeAdapterProvider adapters={adapters}>\n      {children}\n    </RuntimeAdapterProvider>\n  );\n}\n\nconst chatAdapter = createOpenAIStreamAdapter();\n\nfunction useRuntimeHook(): ReturnType<typeof useLocalRuntime> {\n  return useLocalRuntime(chatAdapter);\n}\n\nfunction ThreadAutoSwitch({\n  threadId,\n}: { threadId: string }): ReactElement | null {\n  const aui = useAui();\n  const isLoading = useAuiState(({ threads }) => threads.isLoading);\n  const mainThreadId = useAuiState(({ threads }) => threads.mainThreadId);\n\n  useEffect(() => {\n    if (!isLoading && mainThreadId !== threadId) {\n      aui.threads().switchToThread(threadId);\n    }\n  }, [aui, isLoading, mainThreadId, threadId]);\n\n  return null;\n}\n\nfunction ThreadNewChatSwitch({\n  nonce,\n}: { nonce: string }): ReactElement | null {\n  const aui = useAui();\n  const isLoading = useAuiState(({ threads }) => threads.isLoading);\n\n  useEffect(() => {\n    if (!isLoading) {\n      aui.threads().switchToNewThread();\n    }\n  }, [aui, isLoading, nonce]);\n\n  return null;\n}\n\nfunction ActiveThreadSync({\n  enabled,\n}: { enabled: boolean }): ReactElement | null {\n  const mainThreadId = useAuiState(({ threads }) => threads.mainThreadId);\n  const setActiveThreadId = useChatRuntimeStore((state) => state.setActiveThreadId);\n\n  useEffect(() => {\n    if (!enabled) {\n      return;\n    }\n    setActiveThreadId(mainThreadId ?? null);\n  }, [enabled, mainThreadId, setActiveThreadId]);\n\n  return null;\n}\n\nexport function ChatRuntimeProvider({\n  children,\n  modelType = \"base\",\n  pairId,\n  initialThreadId,\n  newThreadNonce,\n}: {\n  children: ReactNode;\n  modelType?: ModelType;\n  pairId?: string;\n  initialThreadId?: string;\n  newThreadNonce?: string;\n}): ReactElement {\n  const runtime = useRemoteThreadListRuntime({\n    runtimeHook: useRuntimeHook,\n    adapter: {\n      ...createDexieAdapter(modelType, pairId),\n      unstable_Provider: ThreadHistoryProvider,\n    },\n  });\n\n  const aui = useAui({\n    suggestions: Suggestions(DEFAULT_SUGGESTIONS),\n  });\n\n  return (\n    <AssistantRuntimeProvider runtime={runtime} aui={aui}>\n      <ActiveThreadSync enabled={modelType === \"base\" && !pairId} />\n      {initialThreadId && <ThreadAutoSwitch threadId={initialThreadId} />}\n      {!initialThreadId && newThreadNonce && (\n        <ThreadNewChatSwitch nonce={newThreadNonce} />\n      )}\n      {children}\n    </AssistantRuntimeProvider>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/chat/shared-composer.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { TooltipIconButton } from \"@/components/assistant-ui/tooltip-icon-button\";\nimport { Button } from \"@/components/ui/button\";\nimport { AUDIO_ACCEPT, MAX_AUDIO_SIZE, fileToBase64 } from \"@/lib/audio-utils\";\nimport { useAui } from \"@assistant-ui/react\";\nimport { cn } from \"@/lib/utils\";\nimport { ArrowUpIcon, GlobeIcon, HeadphonesIcon, LightbulbIcon, LightbulbOffIcon, MicIcon, PlusIcon, SquareIcon, TerminalIcon, XIcon } from \"lucide-react\";\nimport { toast } from \"sonner\";\nimport { loadModel } from \"./api/chat-api\";\nimport { useChatRuntimeStore } from \"./stores/chat-runtime-store\";\nimport {\n  type KeyboardEvent,\n  type MutableRefObject,\n  type ReactElement,\n  type ReactNode,\n  createContext,\n  useCallback,\n  useContext,\n  useEffect,\n  useRef,\n  useState,\n} from \"react\";\n\nexport type CompareMessagePart =\n  | { type: \"text\"; text: string }\n  | { type: \"image\"; image: string }\n  | { type: \"audio\"; audio: string };\n\nexport interface CompareHandle {\n  append: (content: CompareMessagePart[]) => void;\n  /** Append a user message without triggering generation. */\n  appendMessage: (content: CompareMessagePart[]) => void;\n  /** Trigger generation on the current thread (after appendMessage). */\n  startRun: () => void;\n  cancel: () => void;\n  isRunning: () => boolean;\n  /** Returns a promise that resolves when the current or next run finishes. */\n  waitForRunEnd: () => Promise<void>;\n}\n\nconst IMAGE_ACCEPT = \"image/jpeg,image/png,image/webp,image/gif\";\nconst MAX_IMAGE_SIZE = 20 * 1024 * 1024;\n\nfunction fileToBase64DataURL(file: File): Promise<string> {\n  return new Promise((resolve, reject) => {\n    const reader = new FileReader();\n    reader.onload = () => resolve(reader.result as string);\n    reader.onerror = () => reject(new Error(\"Failed to read image file\"));\n    reader.readAsDataURL(file);\n  });\n}\n\nfunction useDictation(\n  setText: (value: string | ((prev: string) => string)) => void,\n) {\n  const [isDictating, setIsDictating] = useState(false);\n  const recognitionRef = useRef<SpeechRecognition | null>(null);\n\n  const start = useCallback(() => {\n    const SpeechRecognitionAPI =\n      typeof window !== \"undefined\" &&\n      (window.SpeechRecognition ?? (window as unknown as { webkitSpeechRecognition?: typeof SpeechRecognition }).webkitSpeechRecognition);\n    if (!SpeechRecognitionAPI) {\n      return;\n    }\n    const recognition = new SpeechRecognitionAPI() as SpeechRecognition;\n    recognition.continuous = true;\n    recognition.interimResults = true;\n    recognition.lang = \"en-US\";\n    recognition.onresult = (event: SpeechRecognitionEvent) => {\n      const last = event.resultIndex;\n      const result = event.results[last];\n      if (!result?.isFinal) return;\n      const transcript = result[0]?.transcript?.trim();\n      if (transcript) {\n        setText((prev) => (prev ? `${prev} ${transcript}` : transcript));\n      }\n    };\n    recognition.onerror = () => {\n      setIsDictating(false);\n    };\n    recognition.onend = () => {\n      setIsDictating(false);\n    };\n    recognition.start();\n    recognitionRef.current = recognition;\n    setIsDictating(true);\n  }, [setText]);\n\n  const stop = useCallback(() => {\n    if (recognitionRef.current) {\n      recognitionRef.current.stop();\n      recognitionRef.current = null;\n    }\n    setIsDictating(false);\n  }, []);\n\n  useEffect(() => {\n    return () => {\n      if (recognitionRef.current) {\n        recognitionRef.current.abort();\n      }\n    };\n  }, []);\n\n  const supported =\n    typeof window !== \"undefined\" &&\n    !!(window.SpeechRecognition ?? (window as unknown as { webkitSpeechRecognition?: unknown }).webkitSpeechRecognition);\n\n  return { isDictating, start, stop, supported };\n}\n\nexport type CompareHandles = MutableRefObject<Record<string, CompareHandle>>;\n\nconst CompareHandlesContext = createContext<CompareHandles | null>(null);\n\nexport function CompareHandlesProvider({\n  handlesRef,\n  children,\n}: {\n  handlesRef: CompareHandles;\n  children: ReactNode;\n}): ReactElement {\n  return (\n    <CompareHandlesContext.Provider value={handlesRef}>\n      {children}\n    </CompareHandlesContext.Provider>\n  );\n}\n\nexport function RegisterCompareHandle({\n  name,\n}: {\n  name: string;\n}): ReactElement | null {\n  const handlesRef = useContext(CompareHandlesContext);\n  const aui = useAui();\n\n  useEffect(() => {\n    if (!handlesRef) {\n      return;\n    }\n    const currentHandles = handlesRef.current;\n    currentHandles[name] = {\n      // fixes occasional reorder on reload.\n      append: (content) =>\n        aui.thread().append({ role: \"user\", content, createdAt: new Date() } as never),\n      appendMessage: (content) =>\n        aui.thread().append({ role: \"user\", content, createdAt: new Date(), startRun: false } as never),\n      startRun: () => {\n        const msgs = aui.thread().getState().messages;\n        const lastId = msgs.length > 0 ? msgs[msgs.length - 1].id : null;\n        aui.thread().startRun({ parentId: lastId });\n      },\n      cancel: () => aui.thread().cancelRun(),\n      isRunning: () => aui.thread().getState().isRunning,\n      waitForRunEnd: () =>\n        new Promise<void>((resolve) => {\n          let wasRunning = false;\n          const unsub = useChatRuntimeStore.subscribe((state) => {\n            const anyRunning = Object.keys(state.runningByThreadId).length > 0;\n            if (anyRunning) wasRunning = true;\n            if (wasRunning && !anyRunning) {\n              unsub();\n              resolve();\n            }\n          });\n        }),\n    };\n    return () => {\n      delete currentHandles[name];\n    };\n  }, [handlesRef, name, aui]);\n\n  return null;\n}\n\ntype PendingImage = { id: string; file: File };\n\nfunction PendingImageThumb({\n  file,\n  onRemove,\n}: {\n  file: File;\n  onRemove: () => void;\n}): ReactElement {\n  const [src, setSrc] = useState<string | null>(null);\n  useEffect(() => {\n    const url = URL.createObjectURL(file);\n    setSrc(url);\n    return () => URL.revokeObjectURL(url);\n  }, [file]);\n  if (!src) return <div className=\"size-14 animate-pulse rounded-[14px] bg-muted\" />;\n  return (\n    <div className=\"relative size-14 shrink-0 overflow-hidden rounded-[14px] border border-foreground/20 bg-muted\">\n      <img src={src} alt={file.name} className=\"h-full w-full object-cover\" />\n      <button\n        type=\"button\"\n        onClick={onRemove}\n        className=\"absolute top-1 right-1 flex size-5 items-center justify-center rounded-full bg-white text-muted-foreground shadow-sm hover:bg-destructive hover:text-destructive-foreground\"\n        aria-label=\"Remove attachment\"\n      >\n        <XIcon className=\"size-3\" />\n      </button>\n    </div>\n  );\n}\n\ntype CompareModelSelection = {\n  id: string;\n  isLora: boolean;\n  ggufVariant?: string;\n};\n\nexport function SharedComposer({\n  handlesRef,\n  model1,\n  model2,\n}: {\n  handlesRef: CompareHandles;\n  model1?: CompareModelSelection;\n  model2?: CompareModelSelection;\n}): ReactElement {\n  const [text, setText] = useState(\"\");\n  const [running, setRunning] = useState(false);\n  const [comparing, setComparing] = useState(false);\n  const [pendingImages, setPendingImages] = useState<PendingImage[]>([]);\n  const [pendingAudio, setPendingAudio] = useState<{ name: string; base64: string } | null>(null);\n  const [dragging, setDragging] = useState(false);\n  const textareaRef = useRef<HTMLTextAreaElement>(null);\n  const fileInputRef = useRef<HTMLInputElement>(null);\n  const audioInputRef = useRef<HTMLInputElement>(null);\n\n  const activeModel = useChatRuntimeStore((s) => {\n    const checkpoint = s.params.checkpoint;\n    return s.models.find((m) => m.id === checkpoint);\n  });\n  const supportsReasoning = useChatRuntimeStore((s) => s.supportsReasoning);\n  const reasoningEnabled = useChatRuntimeStore((s) => s.reasoningEnabled);\n  const setReasoningEnabled = useChatRuntimeStore((s) => s.setReasoningEnabled);\n  const supportsTools = useChatRuntimeStore((s) => s.supportsTools);\n  const toolsEnabled = useChatRuntimeStore((s) => s.toolsEnabled);\n  const setToolsEnabled = useChatRuntimeStore((s) => s.setToolsEnabled);\n  const codeToolsEnabled = useChatRuntimeStore((s) => s.codeToolsEnabled);\n  const setCodeToolsEnabled = useChatRuntimeStore((s) => s.setCodeToolsEnabled);\n  const setPendingAudioStore = useChatRuntimeStore((s) => s.setPendingAudio);\n  const clearPendingAudioStore = useChatRuntimeStore((s) => s.clearPendingAudio);\n\n  const { isDictating, start: startDictation, stop: stopDictation, supported: dictationSupported } = useDictation(\n    setText,\n  );\n\n  useEffect(() => {\n    const id = setInterval(() => {\n      const handles = handlesRef.current;\n      const any = Object.values(handles).some((h) => h.isRunning());\n      setRunning(any);\n    }, 200);\n    return () => clearInterval(id);\n  }, [handlesRef]);\n\n  const addFiles = useCallback((files: FileList | null) => {\n    if (!files?.length) return;\n    const next: PendingImage[] = [];\n    for (let i = 0; i < files.length; i++) {\n      const file = files[i];\n      if (!file) continue;\n      // Handle audio files\n      if (file.type.match(/^audio\\//i) && file.size <= MAX_AUDIO_SIZE) {\n        fileToBase64(file).then((base64) => {\n          setPendingAudio({ name: file.name, base64 });\n          setPendingAudioStore(base64, file.name);\n        });\n        continue;\n      }\n      // Handle image files\n      if (!file.type.match(/^image\\/(jpeg|png|webp|gif)$/i)) continue;\n      if (file.size > MAX_IMAGE_SIZE) continue;\n      next.push({ id: crypto.randomUUID(), file });\n    }\n    setPendingImages((prev) => [...prev, ...next]);\n  }, [setPendingAudioStore]);\n\n  const removePendingImage = useCallback((id: string) => {\n    setPendingImages((prev) => prev.filter((p) => p.id !== id));\n  }, []);\n\n  async function send() {\n    const msg = text.trim();\n    if (!msg && pendingImages.length === 0 && !pendingAudio) return;\n\n    const content: CompareMessagePart[] = [];\n    for (const { file } of pendingImages) {\n      try {\n        const image = await fileToBase64DataURL(file);\n        content.push({ type: \"image\", image });\n      } catch {\n        // skip failed image\n      }\n    }\n    if (pendingAudio) {\n      content.push({ type: \"audio\", audio: pendingAudio.base64 });\n    }\n    if (msg) {\n      content.push({ type: \"text\", text: msg });\n    }\n    if (content.length === 0) return;\n\n    setText(\"\");\n    setPendingImages([]);\n    setPendingAudio(null);\n    clearPendingAudioStore();\n    textareaRef.current?.focus();\n\n    // Generalized compare: load each model before dispatching to its side\n    const hasCompareHandles = Boolean(handlesRef.current[\"model1\"] || handlesRef.current[\"model2\"]);\n    const isGeneralizedCompare = hasCompareHandles && Boolean(model1?.id || model2?.id);\n    if (isGeneralizedCompare) {\n      const store = useChatRuntimeStore.getState();\n      const maxSeqLength = store.params.maxSeqLength;\n      const trustRemoteCode = store.params.trustRemoteCode ?? false;\n      const chatTemplateOverride = store.chatTemplateOverride;\n\n      function modelDisplayName(id: string): string {\n        const parts = id.split(\"/\");\n        return parts[parts.length - 1] || id;\n      }\n\n      // Helper: load a model and update store checkpoint\n      async function ensureModelLoaded(sel: CompareModelSelection): Promise<string> {\n        const resp = await loadModel({\n          model_path: sel.id,\n          hf_token: null,\n          max_seq_length: maxSeqLength,\n          load_in_4bit: true,\n          is_lora: sel.isLora,\n          gguf_variant: sel.ggufVariant ?? null,\n          trust_remote_code: trustRemoteCode,\n          chat_template_override: chatTemplateOverride,\n        });\n        useChatRuntimeStore.getState().setCheckpoint(\n          resp.model,\n          resp.is_gguf ? (sel.ggufVariant ?? undefined) : null,\n        );\n        return resp.status;\n      }\n\n      const handle1 = handlesRef.current[\"model1\"];\n      const handle2 = handlesRef.current[\"model2\"];\n\n      // Show user messages immediately on both sides\n      if (handle1) handle1.appendMessage(content);\n      if (handle2) handle2.appendMessage(content);\n\n      const name1 = model1?.id ? modelDisplayName(model1.id) : \"\";\n      const name2 = model2?.id ? modelDisplayName(model2.id) : \"\";\n      const toastId = toast(\"Comparing models…\", { duration: Infinity });\n\n      setComparing(true);\n      try {\n        // Side 1: load → generate → wait\n        if (handle1 && model1?.id) {\n          toast(\"Loading Model 1…\", { id: toastId, description: name1, duration: Infinity });\n          const status1 = await ensureModelLoaded(model1);\n          toast(\"Generating with Model 1…\", { id: toastId, description: `${name1} (${status1})`, duration: Infinity });\n          const done = handle1.waitForRunEnd();\n          handle1.startRun();\n          await done;\n        }\n\n        // Side 2: load → generate → wait\n        if (handle2 && model2?.id) {\n          const needsLoad = model2.id.toLowerCase() !== (model1?.id || \"\").toLowerCase()\n            || (model2.ggufVariant ?? \"\") !== (model1?.ggufVariant ?? \"\");\n          if (needsLoad) {\n            toast(\"Loading Model 2…\", { id: toastId, description: name2, duration: Infinity });\n          }\n          const status2 = await ensureModelLoaded(model2);\n          toast(\"Generating with Model 2…\", { id: toastId, description: `${name2} (${status2})`, duration: Infinity });\n          const done = handle2.waitForRunEnd();\n          handle2.startRun();\n          await done;\n        }\n\n        toast.success(\"Compare complete\", { id: toastId, duration: 2000 });\n      } catch (err) {\n        toast.error(\"Compare failed\", {\n          id: toastId,\n          description: err instanceof Error ? err.message : \"Unknown error\",\n          duration: 4000,\n        });\n      } finally {\n        setComparing(false);\n      }\n    } else {\n      // Original behavior: fire all handles simultaneously\n      for (const handle of Object.values(handlesRef.current)) {\n        handle.append(content);\n      }\n    }\n  }\n\n  function stop() {\n    if (isDictating) stopDictation();\n    for (const handle of Object.values(handlesRef.current)) {\n      handle.cancel();\n    }\n  }\n\n  const busy = running || comparing;\n\n  function onKeyDown(e: KeyboardEvent) {\n    if (e.key === \"Enter\" && !e.shiftKey) {\n      e.preventDefault();\n      if (!busy) {\n        send();\n      }\n    }\n  }\n\n  const canSend = (text.trim().length > 0 || pendingImages.length > 0 || pendingAudio !== null) && !busy;\n\n  return (\n    <div\n      className={`shadow-border ring-1 ring-border relative flex w-full flex-col rounded-2xl bg-background px-1 pt-2 transition-shadow outline-none ${dragging ? \"ring-ring bg-accent/50\" : \"\"}`}\n      onDragOver={(e) => {\n        e.preventDefault();\n        setDragging(true);\n      }}\n      onDragLeave={() => setDragging(false)}\n      onDrop={(e) => {\n        e.preventDefault();\n        setDragging(false);\n        addFiles(e.dataTransfer.files);\n      }}\n    >\n      {(pendingImages.length > 0 || pendingAudio) && (\n        <div className=\"mb-2 flex w-full flex-row flex-wrap items-center gap-2 px-1.5 pt-0.5 pb-1\">\n          {pendingImages.map(({ id, file }) => (\n            <PendingImageThumb\n              key={id}\n              file={file}\n              onRemove={() => removePendingImage(id)}\n            />\n          ))}\n          {pendingAudio && (\n            <div className=\"flex items-center gap-2 rounded-lg border border-foreground/20 bg-muted px-3 py-1.5 text-xs\">\n              <HeadphonesIcon className=\"size-3.5 text-muted-foreground\" />\n              <span className=\"max-w-48 truncate\">{pendingAudio.name}</span>\n              <button\n                type=\"button\"\n                onClick={() => { setPendingAudio(null); clearPendingAudioStore(); }}\n                className=\"flex size-4 items-center justify-center rounded-full hover:bg-destructive hover:text-destructive-foreground\"\n                aria-label=\"Remove audio\"\n              >\n                <XIcon className=\"size-3\" />\n              </button>\n            </div>\n          )}\n        </div>\n      )}\n      <textarea\n        ref={textareaRef}\n        value={text}\n        onChange={(e) => setText(e.target.value)}\n        onKeyDown={onKeyDown}\n        placeholder=\"Send to both models...\"\n        className=\"mb-1 max-h-32 min-h-14 w-full resize-none bg-transparent px-4 pt-2 pb-3 text-sm outline-none placeholder:text-muted-foreground\"\n        rows={1}\n      />\n      <div className=\"relative mx-2 mb-2 flex items-center justify-between\">\n        <div className=\"flex items-center gap-1\">\n          <input\n            ref={fileInputRef}\n            type=\"file\"\n            accept={IMAGE_ACCEPT}\n            multiple\n            className=\"hidden\"\n            onChange={(e) => {\n              addFiles(e.target.files);\n              e.target.value = \"\";\n            }}\n          />\n          <TooltipIconButton\n            tooltip=\"Add attachment\"\n            side=\"bottom\"\n            variant=\"ghost\"\n            size=\"icon\"\n            className=\"size-8 rounded-full text-muted-foreground hover:bg-muted-foreground/15\"\n            onClick={() => fileInputRef.current?.click()}\n            aria-label=\"Add attachment\"\n          >\n            <PlusIcon className=\"size-5 stroke-[1.5px]\" />\n          </TooltipIconButton>\n          {activeModel?.hasAudioInput && (\n            <>\n              <input\n                ref={audioInputRef}\n                type=\"file\"\n                accept={AUDIO_ACCEPT}\n                className=\"hidden\"\n                onChange={(e) => {\n                  addFiles(e.target.files);\n                  e.target.value = \"\";\n                }}\n              />\n              <TooltipIconButton\n                tooltip=\"Upload audio\"\n                side=\"bottom\"\n                variant=\"ghost\"\n                size=\"icon\"\n                className=\"size-8 rounded-full text-muted-foreground hover:bg-muted-foreground/15\"\n                onClick={() => audioInputRef.current?.click()}\n                aria-label=\"Upload audio\"\n              >\n                <HeadphonesIcon className=\"size-4 stroke-[1.5px]\" />\n              </TooltipIconButton>\n            </>\n          )}\n          {supportsReasoning && (\n            <button\n              type=\"button\"\n              onClick={() => {\n                const next = !reasoningEnabled;\n                setReasoningEnabled(next);\n                // Qwen3/3.5: adjust params for thinking on/off\n                const store = useChatRuntimeStore.getState();\n                const cp = store.params.checkpoint?.toLowerCase() ?? \"\";\n                if (cp.includes(\"qwen3\")) {\n                  const p = next\n                    ? { temperature: 0.6, topP: 0.95, topK: 20, minP: 0.0 }\n                    : { temperature: 0.7, topP: 0.8, topK: 20, minP: 0.0 };\n                  store.setParams({ ...store.params, ...p });\n                }\n              }}\n              className={cn(\n                \"flex items-center gap-0.5 rounded-full px-2 py-0.5 text-xs font-medium transition-colors\",\n                reasoningEnabled\n                  ? \"bg-primary/10 text-primary hover:bg-primary/20\"\n                  : \"bg-muted text-muted-foreground hover:bg-muted-foreground/15\",\n              )}\n              aria-label={reasoningEnabled ? \"Disable thinking\" : \"Enable thinking\"}\n            >\n              {reasoningEnabled ? (\n                <LightbulbIcon className=\"size-3\" />\n              ) : (\n                <LightbulbOffIcon className=\"size-3\" />\n              )}\n              <span>Think</span>\n            </button>\n          )}\n          {supportsTools && (\n            <button\n              type=\"button\"\n              onClick={() => setToolsEnabled(!toolsEnabled)}\n              className={cn(\n                \"flex items-center gap-1.5 rounded-full px-2.5 py-1 text-xs font-medium transition-colors\",\n                toolsEnabled\n                  ? \"bg-primary/10 text-primary hover:bg-primary/20\"\n                  : \"bg-muted text-muted-foreground hover:bg-muted-foreground/15\",\n              )}\n              aria-label={toolsEnabled ? \"Disable web search\" : \"Enable web search\"}\n            >\n              <GlobeIcon className=\"size-3.5\" />\n              <span>Search</span>\n            </button>\n          )}\n          {supportsTools && (\n            <button\n              type=\"button\"\n              onClick={() => setCodeToolsEnabled(!codeToolsEnabled)}\n              className={cn(\n                \"flex items-center gap-1.5 rounded-full px-2.5 py-1 text-xs font-medium transition-colors\",\n                codeToolsEnabled\n                  ? \"bg-primary/10 text-primary hover:bg-primary/20\"\n                  : \"bg-muted text-muted-foreground hover:bg-muted-foreground/15\",\n              )}\n              aria-label={codeToolsEnabled ? \"Disable code execution\" : \"Enable code execution\"}\n            >\n              <TerminalIcon className=\"size-3.5\" />\n              <span>Code</span>\n            </button>\n          )}\n        </div>\n        <div className=\"flex items-center gap-1\">\n          {dictationSupported && (\n            <>\n              {!isDictating ? (\n                <TooltipIconButton\n                  tooltip=\"Dictate\"\n                  side=\"bottom\"\n                  variant=\"ghost\"\n                  size=\"icon\"\n                  className=\"size-8 rounded-full text-muted-foreground hover:bg-muted-foreground/15\"\n                  onClick={startDictation}\n                  aria-label=\"Dictate\"\n                >\n                  <MicIcon className=\"size-4\" />\n                </TooltipIconButton>\n              ) : (\n                <TooltipIconButton\n                  tooltip=\"Stop dictation\"\n                  side=\"bottom\"\n                  variant=\"ghost\"\n                  size=\"icon\"\n                  className=\"size-8 rounded-full text-destructive\"\n                  onClick={stopDictation}\n                  aria-label=\"Stop dictation\"\n                >\n                  <SquareIcon className=\"size-3 animate-pulse fill-current\" />\n                </TooltipIconButton>\n              )}\n            </>\n          )}\n          {busy ? (\n            <Button\n              type=\"button\"\n              variant=\"default\"\n              size=\"icon\"\n              className=\"size-8 rounded-full\"\n              onClick={stop}\n            >\n              <SquareIcon className=\"size-3 fill-current\" />\n            </Button>\n          ) : (\n            <TooltipIconButton\n              tooltip=\"Send message\"\n              side=\"bottom\"\n              variant=\"default\"\n              size=\"icon\"\n              className=\"size-8 rounded-full\"\n              onClick={send}\n              disabled={!canSend}\n            >\n              <ArrowUpIcon className=\"size-4\" />\n            </TooltipIconButton>\n          )}\n        </div>\n      </div>\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/chat/stores/chat-runtime-store.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { create } from \"zustand\";\nimport {\n  DEFAULT_INFERENCE_PARAMS,\n  type ChatLoraSummary,\n  type ChatModelSummary,\n  type InferenceParams,\n} from \"../types/runtime\";\n\nconst AUTO_TITLE_KEY = \"unsloth_chat_auto_title\";\nconst AUTO_HEAL_TOOL_CALLS_KEY = \"unsloth_auto_heal_tool_calls\";\nconst MAX_TOOL_CALLS_KEY = \"unsloth_max_tool_calls_per_message\";\nconst TOOL_CALL_TIMEOUT_KEY = \"unsloth_tool_call_timeout\";\n\nfunction canUseStorage(): boolean {\n  return typeof window !== \"undefined\";\n}\n\nfunction loadBool(key: string, fallback: boolean): boolean {\n  if (!canUseStorage()) return fallback;\n  try {\n    const raw = localStorage.getItem(key);\n    if (raw === null) return fallback;\n    return raw === \"true\";\n  } catch {\n    return fallback;\n  }\n}\n\nfunction saveBool(key: string, value: boolean): void {\n  if (!canUseStorage()) return;\n  try {\n    localStorage.setItem(key, value ? \"true\" : \"false\");\n  } catch {\n    // ignore\n  }\n}\n\nfunction loadInt(key: string, fallback: number): number {\n  if (!canUseStorage()) return fallback;\n  try {\n    const raw = localStorage.getItem(key);\n    if (raw === null) return fallback;\n    const parsed = parseInt(raw, 10);\n    return Number.isNaN(parsed) ? fallback : parsed;\n  } catch {\n    return fallback;\n  }\n}\n\nfunction saveInt(key: string, value: number): void {\n  if (!canUseStorage()) return;\n  try {\n    localStorage.setItem(key, String(value));\n  } catch {\n    // ignore\n  }\n}\n\ntype ChatRuntimeStore = {\n  params: InferenceParams;\n  models: ChatModelSummary[];\n  loras: ChatLoraSummary[];\n  runningByThreadId: Record<string, boolean>;\n  autoTitle: boolean;\n  modelsError: string | null;\n  activeGgufVariant: string | null;\n  ggufContextLength: number | null;\n  supportsReasoning: boolean;\n  reasoningEnabled: boolean;\n  supportsTools: boolean;\n  toolsEnabled: boolean;\n  codeToolsEnabled: boolean;\n  toolStatus: string | null;\n  generatingStatus: string | null;\n  autoHealToolCalls: boolean;\n  maxToolCallsPerMessage: number;\n  toolCallTimeout: number;\n  kvCacheDtype: string | null;\n  defaultChatTemplate: string | null;\n  chatTemplateOverride: string | null;\n  activeThreadId: string | null;\n  pendingAudioBase64: string | null;\n  pendingAudioName: string | null;\n  modelLoading: boolean;\n  setModelLoading: (loading: boolean) => void;\n  setParams: (params: InferenceParams) => void;\n  setModels: (models: ChatModelSummary[]) => void;\n  setLoras: (loras: ChatLoraSummary[]) => void;\n  setThreadRunning: (threadId: string, running: boolean) => void;\n  setAutoTitle: (enabled: boolean) => void;\n  setModelsError: (error: string | null) => void;\n  setCheckpoint: (modelId: string, ggufVariant?: string | null) => void;\n  setActiveThreadId: (threadId: string | null) => void;\n  clearCheckpoint: () => void;\n  setReasoningEnabled: (enabled: boolean) => void;\n  setToolsEnabled: (enabled: boolean) => void;\n  setCodeToolsEnabled: (enabled: boolean) => void;\n  setToolStatus: (status: string | null) => void;\n  setGeneratingStatus: (status: string | null) => void;\n  setAutoHealToolCalls: (enabled: boolean) => void;\n  setMaxToolCallsPerMessage: (value: number) => void;\n  setToolCallTimeout: (value: number) => void;\n  setKvCacheDtype: (dtype: string | null) => void;\n  setChatTemplateOverride: (template: string | null) => void;\n  setPendingAudio: (base64: string, name: string) => void;\n  clearPendingAudio: () => void;\n};\n\nexport const useChatRuntimeStore = create<ChatRuntimeStore>((set) => ({\n  params: DEFAULT_INFERENCE_PARAMS,\n  models: [],\n  loras: [],\n  runningByThreadId: {},\n  autoTitle: loadBool(AUTO_TITLE_KEY, false),\n  modelsError: null,\n  activeGgufVariant: null,\n  ggufContextLength: null,\n  supportsReasoning: false,\n  reasoningEnabled: true,\n  supportsTools: false,\n  toolsEnabled: false,\n  codeToolsEnabled: false,\n  toolStatus: null,\n  generatingStatus: null,\n  autoHealToolCalls: loadBool(AUTO_HEAL_TOOL_CALLS_KEY, true),\n  maxToolCallsPerMessage: loadInt(MAX_TOOL_CALLS_KEY, 10),\n  toolCallTimeout: loadInt(TOOL_CALL_TIMEOUT_KEY, 5),\n  kvCacheDtype: null,\n  defaultChatTemplate: null,\n  chatTemplateOverride: null,\n  activeThreadId: null,\n  pendingAudioBase64: null,\n  pendingAudioName: null,\n  modelLoading: false,\n  setModelLoading: (loading) => set({ modelLoading: loading }),\n  setParams: (params) => set({ params }),\n  setModels: (models) => set({ models }),\n  setLoras: (loras) => set({ loras }),\n  setThreadRunning: (threadId, running) =>\n    set((state) => {\n      const next = { ...state.runningByThreadId };\n      if (running) {\n        next[threadId] = true;\n      } else {\n        delete next[threadId];\n      }\n      return { runningByThreadId: next };\n    }),\n  setAutoTitle: (autoTitle) =>\n    set(() => {\n      saveBool(AUTO_TITLE_KEY, autoTitle);\n      return { autoTitle };\n    }),\n  setModelsError: (modelsError) => set({ modelsError }),\n  setCheckpoint: (modelId, ggufVariant) =>\n    set((state) => ({\n      params: {\n        ...state.params,\n        checkpoint: modelId,\n      },\n      activeGgufVariant: ggufVariant ?? null,\n    })),\n  setActiveThreadId: (activeThreadId) => set({ activeThreadId }),\n  clearCheckpoint: () =>\n    set((state) => ({\n      params: {\n        ...state.params,\n        checkpoint: \"\",\n      },\n      activeGgufVariant: null,\n      ggufContextLength: null,\n      supportsReasoning: false,\n      reasoningEnabled: true,\n      supportsTools: false,\n      toolsEnabled: false,\n      codeToolsEnabled: false,\n      toolStatus: null,\n      kvCacheDtype: null,\n      defaultChatTemplate: null,\n      chatTemplateOverride: null,\n    })),\n  setReasoningEnabled: (reasoningEnabled) => set({ reasoningEnabled }),\n  setToolsEnabled: (toolsEnabled) => set({ toolsEnabled }),\n  setCodeToolsEnabled: (codeToolsEnabled) => set({ codeToolsEnabled }),\n  setToolStatus: (toolStatus) => set({ toolStatus }),\n  setGeneratingStatus: (generatingStatus) => set({ generatingStatus }),\n  setAutoHealToolCalls: (autoHealToolCalls) =>\n    set(() => {\n      saveBool(AUTO_HEAL_TOOL_CALLS_KEY, autoHealToolCalls);\n      return { autoHealToolCalls };\n    }),\n  setMaxToolCallsPerMessage: (maxToolCallsPerMessage) =>\n    set(() => {\n      saveInt(MAX_TOOL_CALLS_KEY, maxToolCallsPerMessage);\n      return { maxToolCallsPerMessage };\n    }),\n  setToolCallTimeout: (toolCallTimeout) =>\n    set(() => {\n      saveInt(TOOL_CALL_TIMEOUT_KEY, toolCallTimeout);\n      return { toolCallTimeout };\n    }),\n  setKvCacheDtype: (kvCacheDtype) => set({ kvCacheDtype }),\n  setChatTemplateOverride: (chatTemplateOverride) => set({ chatTemplateOverride }),\n  setPendingAudio: (base64, name) =>\n    set({ pendingAudioBase64: base64, pendingAudioName: name }),\n  clearPendingAudio: () =>\n    set({ pendingAudioBase64: null, pendingAudioName: null }),\n}));\n"
  },
  {
    "path": "studio/frontend/src/features/chat/thread-sidebar.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport {\n  SidebarContent,\n  SidebarGroup,\n  SidebarGroupContent,\n  SidebarGroupLabel,\n  SidebarHeader,\n  SidebarMenu,\n  SidebarMenuAction,\n  SidebarMenuButton,\n  SidebarMenuItem,\n} from \"@/components/ui/sidebar\";\nimport {\n  ColumnInsertIcon,\n  Delete02Icon,\n  PencilEdit02Icon,\n} from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport { db, useLiveQuery } from \"./db\";\nimport type { ChatView, ThreadRecord } from \"./types\";\n\ninterface SidebarItem {\n  type: \"single\" | \"compare\";\n  id: string;\n  title: string;\n  createdAt: number;\n}\n\nfunction groupThreads(threads: ThreadRecord[]): SidebarItem[] {\n  const items: SidebarItem[] = [];\n  const seenPairs = new Set<string>();\n\n  for (const t of threads) {\n    if (t.archived) {\n      continue;\n    }\n    if (t.pairId) {\n      if (seenPairs.has(t.pairId)) {\n        continue;\n      }\n      seenPairs.add(t.pairId);\n      items.push({\n        type: \"compare\",\n        id: t.pairId,\n        title: t.title,\n        createdAt: t.createdAt,\n      });\n    } else if (!t.pairId) {\n      items.push({\n        type: \"single\",\n        id: t.id,\n        title: t.title,\n        createdAt: t.createdAt,\n      });\n    }\n  }\n\n  return items.sort((a, b) => b.createdAt - a.createdAt);\n}\n\nexport function ThreadSidebar({\n  view,\n  onSelect,\n  onNewThread,\n  onNewCompare,\n  showCompare,\n}: {\n  view: ChatView;\n  onSelect: (view: ChatView) => void;\n  onNewThread: () => void;\n  onNewCompare: () => void;\n  showCompare: boolean;\n}) {\n  const allThreads = useLiveQuery(\n    () => db.threads.orderBy(\"createdAt\").reverse().toArray(),\n    [],\n  );\n  const items = groupThreads(allThreads ?? []);\n  const activeId = view.mode === \"single\" ? view.threadId : view.pairId;\n\n  function viewForItem(item: SidebarItem): ChatView {\n    return item.type === \"single\"\n      ? { mode: \"single\", threadId: item.id }\n      : { mode: \"compare\", pairId: item.id };\n  }\n\n  async function handleDelete(item: SidebarItem) {\n    if (item.type === \"single\") {\n      await db.messages.where(\"threadId\").equals(item.id).delete();\n      await db.threads.delete(item.id);\n    } else {\n      const paired = await db.threads.where(\"pairId\").equals(item.id).toArray();\n      for (const t of paired) {\n        await db.messages.where(\"threadId\").equals(t.id).delete();\n        await db.threads.delete(t.id);\n      }\n    }\n    if (activeId === item.id) {\n      onSelect({ mode: \"single\" });\n    }\n  }\n\n  return (\n    <>\n      <SidebarHeader className=\"px-4 py-3\">\n        <span className=\"text-base font-semibold tracking-tight\">Playground</span>\n      </SidebarHeader>\n      <SidebarContent>\n        <SidebarGroup className=\"px-4 pt-1\">\n          <SidebarGroupContent>\n            <SidebarMenu>\n              <SidebarMenuItem>\n                <SidebarMenuButton onClick={onNewThread}>\n                  <HugeiconsIcon icon={PencilEdit02Icon} />\n                  <span>New Chat</span>\n                </SidebarMenuButton>\n              </SidebarMenuItem>\n              {showCompare ? (\n                <SidebarMenuItem>\n                  <SidebarMenuButton data-tour=\"chat-compare\" onClick={onNewCompare}>\n                    <HugeiconsIcon icon={ColumnInsertIcon} />\n                    <span>Compare</span>\n                  </SidebarMenuButton>\n                </SidebarMenuItem>\n              ) : null}\n            </SidebarMenu>\n          </SidebarGroupContent>\n        </SidebarGroup>\n        <SidebarGroup className=\"flex-1 px-4\">\n          <SidebarGroupLabel className=\"text-xs font-medium text-muted-foreground/80\">Your Chats</SidebarGroupLabel>\n          <SidebarGroupContent>\n            <SidebarMenu>\n              {items.map((item) => (\n                <SidebarMenuItem key={item.id}>\n                  <SidebarMenuButton\n                    isActive={activeId === item.id}\n                    onClick={() => onSelect(viewForItem(item))}\n                  >\n                    <span>{item.title}</span>\n                  </SidebarMenuButton>\n                  <SidebarMenuAction\n                    showOnHover={true}\n                    onClick={() => handleDelete(item)}\n                    title=\"Delete\"\n                  >\n                    <HugeiconsIcon icon={Delete02Icon} />\n                  </SidebarMenuAction>\n                </SidebarMenuItem>\n              ))}\n            </SidebarMenu>\n            {items.length === 0 && (\n              <p className=\"px-2 py-6 text-center text-xs text-muted-foreground\">\n                No threads yet\n              </p>\n            )}\n          </SidebarGroupContent>\n        </SidebarGroup>\n      </SidebarContent>\n    </>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/chat/tour/index.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nexport { buildChatTourSteps } from \"./steps\";\n\n"
  },
  {
    "path": "studio/frontend/src/features/chat/tour/steps.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { TourStep } from \"@/features/tour\";\n\nexport function buildChatTourSteps({\n  canCompare,\n  openModelSelector,\n  closeModelSelector,\n  openSettings,\n  closeSettings,\n  openSidebar,\n  enterCompare,\n  exitCompare,\n}: {\n  canCompare: boolean;\n  openModelSelector: () => void;\n  closeModelSelector: () => void;\n  openSettings: () => void;\n  closeSettings: () => void;\n  openSidebar: () => void;\n  enterCompare: () => void;\n  exitCompare: () => void;\n}): TourStep[] {\n  const steps: TourStep[] = [\n    {\n      id: \"model\",\n      target: \"chat-model-selector\",\n      title: \"Pick a model\",\n      body: (\n        <>\n          This selects what’s loaded for inference. Hub = base models. Fine-tuned\n          = your LoRA adapters from Studio.\n        </>\n      ),\n    },\n    {\n      id: \"model-tabs\",\n      target: \"chat-model-selector-popover\",\n      title: \"Two tabs\",\n      body: (\n        <>\n          Hub: search Hugging Face models. Fine-tuned: adapters (LoRA) you’ve\n          trained locally. If results look off, compare base vs LoRA to see what\n          changed.\n        </>\n      ),\n      onEnter: openModelSelector,\n      onExit: closeModelSelector,\n    },\n    {\n      id: \"settings\",\n      target: \"chat-settings\",\n      title: \"Settings sidebar\",\n      body: (\n        <>\n          Sampling (temperature/top-p/top-k) + system prompt live here. If you\n          want more deterministic outputs, lower temperature first.\n        </>\n      ),\n      onEnter: openSettings,\n      onExit: closeSettings,\n    },\n  ];\n\n  if (canCompare) {\n    steps.push(\n      {\n        id: \"compare-btn\",\n        target: \"chat-compare\",\n        title: \"Compare mode\",\n        body: (\n          <>\n            Compare any two models side-by-side.\n            Pick a different model for each side and see how they respond to the same prompt.\n          </>\n        ),\n        onEnter: openSidebar,\n      },\n      {\n        id: \"compare-view\",\n        target: \"chat-compare-view\",\n        title: \"Side-by-side threads\",\n        body: (\n          <>\n            Same prompt, 2 threads. If LoRA is worse than base, it’s usually\n            data formatting, too many epochs, or a bad checkpoint choice.\n          </>\n        ),\n        onEnter: enterCompare,\n        onExit: exitCompare,\n      },\n    );\n  }\n\n  return steps;\n}\n"
  },
  {
    "path": "studio/frontend/src/features/chat/types/api.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nexport interface BackendModelDetails {\n  id: string;\n  name?: string | null;\n  is_vision?: boolean;\n  is_lora?: boolean;\n  is_gguf?: boolean;\n  is_audio?: boolean;\n  audio_type?: string | null;\n  has_audio_input?: boolean;\n}\n\nexport interface ListModelsResponse {\n  models: BackendModelDetails[];\n  default_models: string[];\n}\n\nexport interface BackendLoraInfo {\n  display_name: string;\n  adapter_path: string;\n  base_model?: string | null;\n  source?: \"training\" | \"exported\" | null;\n  export_type?: \"lora\" | \"merged\" | \"gguf\" | null;\n}\n\nexport interface ListLorasResponse {\n  loras: BackendLoraInfo[];\n  outputs_dir: string;\n}\n\nexport interface LoadModelRequest {\n  model_path: string;\n  hf_token: string | null;\n  max_seq_length: number;\n  load_in_4bit: boolean;\n  is_lora: boolean;\n  gguf_variant?: string | null;\n  /** Allow loading models with custom code (e.g. NVIDIA Nemotron). Only enable for repos you trust. */\n  trust_remote_code?: boolean;\n  chat_template_override?: string | null;\n  cache_type_kv?: string | null;\n}\n\nexport interface ValidateModelResponse {\n  valid: boolean;\n  message: string;\n  identifier?: string | null;\n  display_name?: string | null;\n  is_gguf?: boolean;\n  is_lora?: boolean;\n  is_vision?: boolean;\n}\n\nexport interface GgufVariantDetail {\n  filename: string;\n  quant: string;\n  size_bytes: number;\n  downloaded?: boolean;\n}\n\nexport interface GgufVariantsResponse {\n  repo_id: string;\n  variants: GgufVariantDetail[];\n  has_vision: boolean;\n  default_variant: string | null;\n}\n\nexport interface LoadModelResponse {\n  status: string;\n  model: string;\n  display_name: string;\n  is_vision: boolean;\n  is_lora: boolean;\n  is_gguf?: boolean;\n  is_audio?: boolean;\n  audio_type?: string | null;\n  has_audio_input?: boolean;\n  inference?: {\n    temperature?: number;\n    top_p?: number;\n    top_k?: number;\n    min_p?: number;\n    presence_penalty?: number;\n    trust_remote_code?: boolean;\n  };\n  context_length?: number | null;\n  supports_reasoning?: boolean;\n  supports_tools?: boolean;\n  cache_type_kv?: string | null;\n  chat_template?: string | null;\n}\n\nexport interface UnloadModelRequest {\n  model_path: string;\n}\n\nexport interface InferenceStatusResponse {\n  active_model: string | null;\n  is_vision: boolean;\n  is_gguf?: boolean;\n  gguf_variant?: string | null;\n  is_audio?: boolean;\n  audio_type?: string | null;\n  has_audio_input?: boolean;\n  loading: string[];\n  loaded: string[];\n  inference?: {\n    temperature?: number;\n    top_p?: number;\n    top_k?: number;\n    min_p?: number;\n    presence_penalty?: number;\n    trust_remote_code?: boolean;\n  };\n  supports_reasoning?: boolean;\n  supports_tools?: boolean;\n  context_length?: number | null;\n}\n\nexport interface AudioGenerationResponse {\n  id: string;\n  object: string;\n  model: string;\n  audio: {\n    data: string;\n    format: string;\n    sample_rate: number;\n  };\n  choices: Array<{\n    index: number;\n    message: { role: string; content: string };\n    finish_reason: string;\n  }>;\n}\n\nexport interface OpenAIChatMessage {\n  role: \"system\" | \"user\" | \"assistant\";\n  content: string;\n}\n\nexport interface OpenAIChatCompletionsRequest {\n  model: string;\n  messages: OpenAIChatMessage[];\n  stream: boolean;\n  temperature: number;\n  top_p: number;\n  max_tokens: number;\n  top_k: number;\n  min_p: number;\n  repetition_penalty: number;\n  presence_penalty: number;\n  image_base64?: string;\n  audio_base64?: string;\n  use_adapter?: boolean | string | null;\n  enable_thinking?: boolean | null;\n  enable_tools?: boolean | null;\n  enabled_tools?: string[];\n  auto_heal_tool_calls?: boolean;\n  max_tool_calls_per_message?: number;\n  tool_call_timeout?: number;\n  session_id?: string;\n}\n\nexport interface OpenAIChatDelta {\n  role?: string;\n  content?: string;\n}\n\nexport interface OpenAIChatChunkChoice {\n  delta?: OpenAIChatDelta;\n  finish_reason?: string | null;\n}\n\nexport interface OpenAIChatChunk {\n  choices?: OpenAIChatChunkChoice[];\n}\n"
  },
  {
    "path": "studio/frontend/src/features/chat/types/runtime.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nexport interface InferenceParams {\n  temperature: number;\n  topP: number;\n  topK: number;\n  minP: number;\n  repetitionPenalty: number;\n  presencePenalty: number;\n  maxSeqLength: number;\n  maxTokens: number;\n  systemPrompt: string;\n  checkpoint: string;\n  /** Allow loading models with custom code (e.g. NVIDIA Nemotron). Only enable for repos you trust. */\n  trustRemoteCode?: boolean;\n}\n\nexport const DEFAULT_INFERENCE_PARAMS: InferenceParams = {\n  temperature: 0.6,\n  topP: 0.95,\n  topK: 20,\n  minP: 0.01,\n  repetitionPenalty: 1.0,\n  presencePenalty: 0.0,\n  maxSeqLength: 4096,\n  maxTokens: 8192,\n  systemPrompt: \"\",\n  checkpoint: \"\",\n  trustRemoteCode: false,\n};\n\nexport interface ChatModelSummary {\n  id: string;\n  name: string;\n  description?: string;\n  isVision: boolean;\n  isLora: boolean;\n  isGguf?: boolean;\n  isAudio?: boolean;\n  audioType?: string | null;\n  hasAudioInput?: boolean;\n}\n\nexport interface ChatLoraSummary {\n  id: string;\n  name: string;\n  baseModel: string;\n  updatedAt?: number;\n  source?: \"training\" | \"exported\";\n  exportType?: \"lora\" | \"merged\" | \"gguf\";\n}\n"
  },
  {
    "path": "studio/frontend/src/features/chat/types.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nexport type ModelType = \"base\" | \"lora\" | \"model1\" | \"model2\";\n\nexport type ChatView =\n  | { mode: \"single\"; threadId?: string; newThreadNonce?: string }\n  | { mode: \"compare\"; pairId: string };\n\nexport interface ThreadRecord {\n  id: string;\n  title: string;\n  modelType: ModelType;\n  modelId?: string;\n  pairId?: string;\n  archived: boolean;\n  createdAt: number;\n}\n\nexport interface MessageRecord {\n  id: string;\n  threadId: string;\n  role: import(\"@assistant-ui/react\").ThreadMessage[\"role\"];\n  content: import(\"@assistant-ui/react\").ThreadMessage[\"content\"];\n  attachments?: import(\"@assistant-ui/react\").ThreadMessage[\"attachments\"];\n  metadata?: Record<string, unknown>;\n  createdAt: number;\n}\n"
  },
  {
    "path": "studio/frontend/src/features/chat/utils/parse-assistant-content.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { ChatModelRunResult } from \"@assistant-ui/react\";\n\ntype ContentPart = NonNullable<ChatModelRunResult[\"content\"]>[number];\n\nconst THINK_OPEN_TAG = \"<think>\";\nconst THINK_CLOSE_TAG = \"</think>\";\n\nfunction appendTextPart(parts: ContentPart[], text: string): void {\n  if (text) {\n    parts.push({ type: \"text\", text });\n  }\n}\n\nfunction appendReasoningPart(parts: ContentPart[], text: string): void {\n  if (text) {\n    parts.push({ type: \"reasoning\", text });\n  }\n}\n\nexport function parseAssistantContent(\n  raw: string,\n): ContentPart[] {\n  const parts: ContentPart[] = [];\n  if (!raw) {\n    return parts;\n  }\n\n  let cursor = 0;\n  while (cursor < raw.length) {\n    const openIndex = raw.indexOf(THINK_OPEN_TAG, cursor);\n    if (openIndex === -1) {\n      appendTextPart(parts, raw.slice(cursor));\n      break;\n    }\n\n    appendTextPart(parts, raw.slice(cursor, openIndex));\n\n    const reasoningStart = openIndex + THINK_OPEN_TAG.length;\n    const closeIndex = raw.indexOf(THINK_CLOSE_TAG, reasoningStart);\n    if (closeIndex === -1) {\n      appendReasoningPart(parts, raw.slice(reasoningStart));\n      break;\n    }\n\n    appendReasoningPart(parts, raw.slice(reasoningStart, closeIndex));\n    cursor = closeIndex + THINK_CLOSE_TAG.length;\n  }\n\n  return parts;\n}\n\nexport function hasClosedThinkTag(raw: string): boolean {\n  return raw.includes(THINK_CLOSE_TAG);\n}\n"
  },
  {
    "path": "studio/frontend/src/features/data-recipes/data/recipes-db.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { createEmptyRecipePayload } from \"@/features/recipe-studio\";\nimport { normalizeNonEmptyName } from \"@/utils\";\nimport Dexie, { type EntityTable, liveQuery } from \"dexie\";\nimport { useEffect, useState } from \"react\";\nimport type { RecipeRecord, SaveRecipeInput } from \"../types\";\n\nconst db = new Dexie(\"unsloth-data-recipes\") as Dexie & {\n  recipes: EntityTable<RecipeRecord, \"id\">;\n};\n\ndb.version(1).stores({\n  recipes: \"id, name, updatedAt, createdAt\",\n});\n\nconst recentRecipeCache = new Map<string, RecipeRecord>();\n\nexport function listRecipes(): Promise<RecipeRecord[]> {\n  return db.recipes.orderBy(\"updatedAt\").reverse().toArray();\n}\n\nexport function getRecipe(id: string): Promise<RecipeRecord | undefined> {\n  return db.recipes.get(id);\n}\n\nfunction writeRecipeCache(record: RecipeRecord): void {\n  recentRecipeCache.set(record.id, record);\n}\n\nexport function getCachedRecipe(id: string): RecipeRecord | null {\n  return recentRecipeCache.get(id) ?? null;\n}\n\nexport function primeRecipeCache(record: RecipeRecord): void {\n  writeRecipeCache(record);\n}\n\nexport async function saveRecipe(\n  input: SaveRecipeInput,\n): Promise<RecipeRecord> {\n  const now = Date.now();\n  const id = input.id ?? crypto.randomUUID();\n  const existing = input.id ? await db.recipes.get(input.id) : undefined;\n  const record: RecipeRecord = {\n    id,\n    name: normalizeNonEmptyName(input.name),\n    payload: input.payload,\n    createdAt: existing?.createdAt ?? now,\n    updatedAt: now,\n    learningRecipeId: input.learningRecipeId ?? existing?.learningRecipeId,\n    learningRecipeTitle:\n      input.learningRecipeTitle ?? existing?.learningRecipeTitle,\n  };\n  await db.recipes.put(record);\n  writeRecipeCache(record);\n  return record;\n}\n\nexport async function deleteRecipe(id: string): Promise<void> {\n  await db.recipes.delete(id);\n  recentRecipeCache.delete(id);\n}\n\nexport function createRecipeDraft(): Promise<RecipeRecord> {\n  return saveRecipe({\n    name: \"Unnamed\",\n    payload: createEmptyRecipePayload(),\n  });\n}\n\nexport function createRecipeFromLearningRecipe(input: {\n  templateId: string;\n  templateTitle: string;\n  payload: RecipeRecord[\"payload\"];\n}): Promise<RecipeRecord> {\n  return saveRecipe({\n    name: input.templateTitle,\n    payload: input.payload,\n    learningRecipeId: input.templateId,\n    learningRecipeTitle: input.templateTitle,\n  });\n}\n\nexport function useRecipes(): {\n  recipes: RecipeRecord[];\n  ready: boolean;\n} {\n  const [recipes, setRecipes] = useState<RecipeRecord[]>([]);\n  const [ready, setReady] = useState(false);\n\n  useEffect(() => {\n    const sub = liveQuery(() => listRecipes()).subscribe({\n      next: (value) => {\n        for (const recipe of value) {\n          writeRecipeCache(recipe);\n        }\n        setRecipes(value);\n        setReady(true);\n      },\n      error: (error) => {\n        console.error(\"data-recipes liveQuery:\", error);\n        setReady(true);\n      },\n    });\n    return () => sub.unsubscribe();\n  }, []);\n\n  return { recipes, ready };\n}\n"
  },
  {
    "path": "studio/frontend/src/features/data-recipes/index.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nexport { DataRecipesPage } from \"./pages/data-recipes-page\";\nexport { EditRecipePage } from \"./pages/edit-recipe-page\";\n"
  },
  {
    "path": "studio/frontend/src/features/data-recipes/learning-recipes/conversation.json",
    "content": "{\n  \"recipe\": {\n    \"model_providers\": [\n      {\n        \"name\": \"provider_1\",\n        \"endpoint\": \"https://openrouter.ai/api/v1\",\n        \"provider_type\": \"openai\",\n        \"extra_headers\": {},\n        \"extra_body\": {}\n      }\n    ],\n    \"mcp_providers\": [],\n    \"model_configs\": [\n      {\n        \"alias\": \"model_1\",\n        \"model\": \"mistralai/ministral-8b-2512\",\n        \"provider\": \"provider_1\",\n        \"inference_parameters\": {\n          \"temperature\": 0.7,\n          \"max_tokens\": 2048\n        }\n      }\n    ],\n    \"tool_configs\": [],\n    \"columns\": [\n      {\n        \"column_type\": \"sampler\",\n        \"name\": \"domain\",\n        \"drop\": true,\n        \"sampler_type\": \"category\",\n        \"params\": {\n          \"values\": [\n            \"Tech Support\",\n            \"Personal Finance\",\n            \"Learning\"\n          ]\n        }\n      },\n      {\n        \"column_type\": \"sampler\",\n        \"name\": \"topic\",\n        \"drop\": true,\n        \"sampler_type\": \"subcategory\",\n        \"params\": {\n          \"category\": \"domain\",\n          \"values\": {\n            \"Tech Support\": [\n              \"Wi-Fi keeps disconnecting\",\n              \"Laptop running very slow\",\n              \"Cannot install app update\"\n            ],\n            \"Personal Finance\": [\n              \"Monthly budget planning\",\n              \"Credit card debt payoff\",\n              \"Emergency fund setup\"\n            ],\n            \"Learning\": [\n              \"Exam study plan\",\n              \"Learn Python basics\",\n              \"Improve English writing\"\n            ]\n          }\n        }\n      },\n      {\n        \"column_type\": \"sampler\",\n        \"name\": \"conversation_length\",\n        \"drop\": true,\n        \"sampler_type\": \"category\",\n        \"params\": {\n          \"values\": [\n            \"4\",\n            \"6\"\n          ]\n        }\n      },\n      {\n        \"column_type\": \"llm-text\",\n        \"name\": \"user_goal\",\n        \"drop\": false,\n        \"model_alias\": \"model_1\",\n        \"prompt\": \"Write one user goal for a chat assistant.\\nDomain: {{ domain }}\\nTopic: {{ topic }}\\nConversation length target: {{ conversation_length }} messages total.\\nRules:\\n- 1 sentence.\\n- Specific and practical.\\n- Output only the goal text.\",\n        \"system_prompt\": \"You write realistic user goals for assistant conversations.\\n\",\n        \"with_trace\": \"none\"\n      },\n      {\n        \"column_type\": \"llm-structured\",\n        \"name\": \"output_format\",\n        \"drop\": false,\n        \"model_alias\": \"model_1\",\n        \"prompt\": \"Generate a realistic multi-turn conversation.\\nUser goal:\\n{{ user_goal }}\\nConstraints:\\n- Exactly {{ conversation_length }} messages total.\\n- Alternate roles strictly: user, assistant, user, assistant...\\n- First message must be user.\\n- Last message must be assistant.\\n- Keep responses grounded in {{ domain }} / {{ topic }}.\\n- End naturally with resolution or clear next step.\\n- No markdown, no extra keys.\",\n        \"output_format\": {\n          \"type\": \"object\",\n          \"properties\": {\n            \"conversation\": {\n              \"type\": \"array\",\n              \"minItems\": 4,\n              \"maxItems\": 6,\n              \"items\": {\n                \"type\": \"object\",\n                \"properties\": {\n                  \"role\": {\n                    \"type\": \"string\",\n                    \"enum\": [\n                      \"user\",\n                      \"assistant\"\n                    ]\n                  },\n                  \"content\": {\n                    \"type\": \"string\",\n                    \"minLength\": 1\n                  }\n                },\n                \"required\": [\n                  \"role\",\n                  \"content\"\n                ],\n                \"additionalProperties\": false\n              }\n            }\n          },\n          \"required\": [\n            \"conversation\"\n          ],\n          \"additionalProperties\": false\n        }\n      }\n    ],\n    \"processors\": []\n  },\n  \"run\": {\n    \"rows\": 5,\n    \"preview\": true,\n    \"output_formats\": [\n      \"jsonl\"\n    ]\n  },\n  \"ui\": {\n    \"nodes\": [\n      {\n        \"id\": \"provider_1\",\n        \"x\": -1056.848383841495,\n        \"y\": 519.6373927070263,\n        \"width\": 400\n      },\n      {\n        \"id\": \"model_1\",\n        \"x\": -543.7221365246206,\n        \"y\": 488.2975724283656,\n        \"width\": 400\n      },\n      {\n        \"id\": \"domain\",\n        \"x\": 0,\n        \"y\": 140,\n        \"width\": 400\n      },\n      {\n        \"id\": \"topic\",\n        \"x\": 0,\n        \"y\": 280,\n        \"width\": 400\n      },\n      {\n        \"id\": \"conversation_length\",\n        \"x\": 466.61510192672256,\n        \"y\": 139.68271861864798,\n        \"width\": 400\n      },\n      {\n        \"id\": \"user_goal\",\n        \"x\": 1.412158386197035,\n        \"y\": 508.77123580445596,\n        \"width\": 400\n      },\n      {\n        \"id\": \"output_format\",\n        \"x\": 1.1486983549970375,\n        \"y\": 754.4221089431811,\n        \"width\": 400\n      },\n      {\n        \"id\": \"note_1\",\n        \"x\": 210.01377182764494,\n        \"y\": -262.9440547613487,\n        \"width\": 400,\n        \"node_type\": \"markdown_note\",\n        \"name\": \"note_1\",\n        \"markdown\": \"###  Start with controlled chat context\\nThis recipe uses sampler columns to shape each conversation:\\n\\n- `domain`\\n- `topic`\\n- `conversation_length` (4 or 6 messages)\\n\\n**Why this helps**:\\n\\n- You get varied conversations without manual writing\\n- Each row stays grounded in a clear scenario\\n- You can scale quickly while keeping data quality consistent\",\n        \"note_color\": \"#FFE4E6\",\n        \"note_opacity\": \"35\"\n      },\n      {\n        \"id\": \"note_2\",\n        \"x\": 515.9369583007435,\n        \"y\": 454.3936030274385,\n        \"width\": 400,\n        \"node_type\": \"markdown_note\",\n        \"name\": \"note_2\",\n        \"markdown\": \"The **LLM Text** block (`user_goal`) creates one realistic user intent from sampler context.\\n\\n**It should be**:\\n\\n- **specific**\\n- **practical**\\n- **short**\\n\\nThis goal becomes the anchor for the full multi-turn conversation.\",\n        \"note_color\": \"#FFE4E6\",\n        \"note_opacity\": \"35\"\n      },\n      {\n        \"id\": \"note_3\",\n        \"x\": -12.952616065779239,\n        \"y\": 912.1316336111515,\n        \"width\": 400,\n        \"node_type\": \"markdown_note\",\n        \"name\": \"note_3\",\n        \"markdown\": \"The **LLM Structured** block (`output_format`) generates the conversation as strict JSON.\\n\\nIn this recipe, schema enforces:\\n\\n- `conversation` array\\n- message objects with `role` + `content`\\n- role enum: `user` / `assistant`\\n- no extra keys\\n\\nPrompt constraints also enforce:\\n\\n- exact length (`{{ conversation_length }}`)\\n- alternating roles\\n- first user message, last assistant message\\n- natural ending\\n\\nThis is key for training data: same shape, less cleanup.\",\n        \"note_color\": \"#FFE4E6\",\n        \"note_opacity\": \"35\"\n      },\n      {\n        \"id\": \"note_4\",\n        \"x\": -519.9585237323188,\n        \"y\": 81.84144119564277,\n        \"width\": 400,\n        \"node_type\": \"markdown_note\",\n        \"name\": \"note_4\",\n        \"markdown\": \"Sampler columns are useful during generation but usually noisy in final export.\\n\\nSet helper columns to `drop=true`, keep only core outputs such as:\\n\\n- `user_goal`\\n- `output_format`\\n\\nTip: Keep final schema close to your training format, not your generation scaffolding.\\n\",\n        \"note_color\": \"#FFE4E6\",\n        \"note_opacity\": \"35\"\n      }\n    ],\n    \"edges\": [\n      {\n        \"from\": \"domain\",\n        \"to\": \"topic\",\n        \"type\": \"canvas\",\n        \"source_handle\": \"data-out-bottom\",\n        \"target_handle\": \"data-in-top\"\n      },\n      {\n        \"from\": \"domain\",\n        \"to\": \"conversation_length\",\n        \"type\": \"canvas\",\n        \"source_handle\": \"data-out\",\n        \"target_handle\": \"data-in\"\n      },\n      {\n        \"from\": \"topic\",\n        \"to\": \"user_goal\",\n        \"type\": \"canvas\",\n        \"source_handle\": \"data-out-bottom\",\n        \"target_handle\": \"data-in-top\"\n      },\n      {\n        \"from\": \"user_goal\",\n        \"to\": \"output_format\",\n        \"type\": \"canvas\",\n        \"source_handle\": \"data-out-bottom\",\n        \"target_handle\": \"data-in-top\"\n      },\n      {\n        \"from\": \"provider_1\",\n        \"to\": \"model_1\",\n        \"type\": \"semantic\",\n        \"source_handle\": \"semantic-out\",\n        \"target_handle\": \"semantic-in\"\n      },\n      {\n        \"from\": \"model_1\",\n        \"to\": \"user_goal\",\n        \"type\": \"semantic\",\n        \"source_handle\": \"semantic-out\",\n        \"target_handle\": \"data-in\"\n      },\n      {\n        \"from\": \"model_1\",\n        \"to\": \"output_format\",\n        \"type\": \"semantic\",\n        \"source_handle\": \"semantic-out-bottom\",\n        \"target_handle\": \"data-in\"\n      }\n    ],\n    \"layout_direction\": \"LR\"\n  }\n}\n"
  },
  {
    "path": "studio/frontend/src/features/data-recipes/learning-recipes/index.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { RecipePayload } from \"@/features/recipe-studio\";\n\nconst structuredOutputsJinjaUrl = new URL(\n  \"./structured-outputs-jinja.json\",\n  import.meta.url,\n).href;\nconst pdfGroundedQaUrl = new URL(\"./pdf-grounded-qa.json\", import.meta.url)\n  .href;\nconst instructionFromAnswerUrl = new URL(\n  \"./instruction-from-answer.json\",\n  import.meta.url,\n).href;\nconst textToPythonUrl = new URL(\"./text-to-python.json\", import.meta.url).href;\nconst textToSqlUrl = new URL(\"./text-to-sql.json\", import.meta.url).href;\nconst ocrDocumentExtractionUrl = new URL(\n  \"./ocr-document-extraction.json\",\n  import.meta.url,\n).href;\n\nfunction isRecord(value: unknown): value is Record<string, unknown> {\n  return !!value && typeof value === \"object\" && !Array.isArray(value);\n}\n\nfunction toRecordArray(value: unknown): Record<string, unknown>[] {\n  if (!Array.isArray(value)) {\n    return [];\n  }\n  return value.filter((item): item is Record<string, unknown> =>\n    isRecord(item),\n  );\n}\n\nfunction coerceRecipePayload(value: unknown): RecipePayload {\n  if (!isRecord(value)) {\n    throw new Error(\"Template payload is invalid JSON object.\");\n  }\n\n  const recipeSource = isRecord(value.recipe) ? value.recipe : value;\n  if (!Array.isArray(recipeSource.columns)) {\n    throw new Error(\"Template payload must include recipe.columns.\");\n  }\n\n  if (isRecord(value.recipe) && isRecord(value.run) && isRecord(value.ui)) {\n    return value as unknown as RecipePayload;\n  }\n\n  const recipe: RecipePayload[\"recipe\"] = {\n    // biome-ignore lint/style/useNamingConvention: api schema\n    model_providers: toRecordArray(recipeSource.model_providers),\n    // biome-ignore lint/style/useNamingConvention: api schema\n    mcp_providers: toRecordArray(recipeSource.mcp_providers),\n    // biome-ignore lint/style/useNamingConvention: api schema\n    model_configs: toRecordArray(recipeSource.model_configs),\n    // biome-ignore lint/style/useNamingConvention: api schema\n    seed_config: isRecord(recipeSource.seed_config)\n      ? recipeSource.seed_config\n      : undefined,\n    // biome-ignore lint/style/useNamingConvention: api schema\n    tool_configs: toRecordArray(recipeSource.tool_configs),\n    columns: toRecordArray(recipeSource.columns),\n    processors: toRecordArray(recipeSource.processors),\n  };\n\n  return {\n    recipe,\n    run: {\n      rows: 5,\n      preview: true,\n      // biome-ignore lint/style/useNamingConvention: api schema\n      output_formats: [\"jsonl\"],\n    },\n    ui: {\n      nodes: [],\n      edges: [],\n    },\n  };\n}\n\nasync function loadPayloadFromUrl(url: string): Promise<RecipePayload> {\n  const response = await fetch(url);\n  if (!response.ok) {\n    throw new Error(`Failed to fetch template payload (${response.status})`);\n  }\n  const json = (await response.json()) as unknown;\n  return coerceRecipePayload(json);\n}\n\nexport type LearningRecipeDef = {\n  id: string;\n  title: string;\n  description: string;\n  loadPayload: () => Promise<RecipePayload>;\n};\n\nexport const LEARNING_RECIPES: LearningRecipeDef[] = [\n  {\n    id: \"structured-outputs-jinja\",\n    title: \"Structured Outputs + Jinja Expressions\",\n    description:\n      \"Support ticket triage with structured JSON outputs and Jinja conditionals.\",\n    loadPayload: () => loadPayloadFromUrl(structuredOutputsJinjaUrl),\n  },\n  {\n    id: \"pdf-grounded-qa\",\n    title: \"PDF Document QA\",\n    description: \"Build grounded question-answer examples from PDF chunks.\",\n    loadPayload: () => loadPayloadFromUrl(pdfGroundedQaUrl),\n  },\n  {\n    id: \"instruction-from-answer\",\n    title: \"Instruction from Answer\",\n    description:\n      \"Use seed answer columns to generate high-quality instruction targets.\",\n    loadPayload: () => loadPayloadFromUrl(instructionFromAnswerUrl),\n  },\n  {\n    id: \"text-to-python\",\n    title: \"Text to Python\",\n    description:\n      \"Generate instruction-to-code data with category sampling and LLM judging.\",\n    loadPayload: () => loadPayloadFromUrl(textToPythonUrl),\n  },\n  {\n    id: \"text-to-sql\",\n    title: \"Text to SQL\",\n    description:\n      \"Generate SQL tasks and runnable SQL outputs with prompt-driven generation.\",\n    loadPayload: () => loadPayloadFromUrl(textToSqlUrl),\n  },\n  {\n    id: \"ocr-document-extraction\",\n    title: \"OCR Document Extraction\",\n    description:\n      \"Use image context to generate OCR-style document extraction output.\",\n    loadPayload: () => loadPayloadFromUrl(ocrDocumentExtractionUrl),\n  },\n];\n"
  },
  {
    "path": "studio/frontend/src/features/data-recipes/learning-recipes/instruction-from-answer.json",
    "content": "{\n  \"recipe\": {\n    \"model_providers\": [\n      {\n        \"name\": \"openai_provider\",\n        \"endpoint\": \"\",\n        \"provider_type\": \"openai\",\n        \"extra_headers\": {},\n        \"extra_body\": {}\n      }\n    ],\n    \"mcp_providers\": [],\n    \"model_configs\": [\n      {\n        \"alias\": \"ministral\",\n        \"model\": \"\",\n        \"provider\": \"openai_provider\",\n        \"inference_parameters\": {\n          \"temperature\": 0.7,\n          \"max_tokens\": 2048\n        }\n      }\n    ],\n    \"seed_config\": {\n      \"source\": {\n        \"seed_type\": \"hf\",\n        \"path\": \"unsloth/alpaca-cleaned\",\n        \"endpoint\": \"https://huggingface.co\"\n      },\n      \"sampling_strategy\": \"ordered\",\n      \"selection_strategy\": {\n        \"start\": 1,\n        \"end\": 100\n      }\n    },\n    \"tool_configs\": [],\n    \"columns\": [\n      {\n        \"column_type\": \"llm-text\",\n        \"name\": \"generated_instruction\",\n        \"drop\": false,\n        \"model_alias\": \"ministral\",\n        \"prompt\": \"Based on this target answer:\\n{{ output }}\\n\\nWrite one high-quality plain text short and brief user instruction that this answer would satisfy.\\nReturn only the instruction.\",\n        \"with_trace\": \"none\",\n        \"extract_reasoning_content\": false\n      }\n    ],\n    \"processors\": [\n      {\n        \"processor_type\": \"drop_columns\",\n        \"name\": \"drop_seed_columns\",\n        \"column_names\": [\n          \"input\",\n          \"instruction\"\n        ]\n      }\n    ]\n  },\n  \"run\": {\n    \"rows\": 5,\n    \"preview\": true,\n    \"output_formats\": [\n      \"jsonl\"\n    ]\n  },\n  \"ui\": {\n    \"nodes\": [\n      {\n        \"id\": \"note_1\",\n        \"x\": -567.3566303099885,\n        \"y\": 38.88875727651093,\n        \"width\": 400,\n        \"node_type\": \"markdown_note\",\n        \"name\": \"note_1\",\n        \"markdown\": \"#### Hugging Face seed block\\nThis recipe uses a ** HuggingFace dataset ** as seed data.\\nYou select a Hugging Face dataset, load columns, then generate new fields from seed columns. Each column in the Hugging Face dataset becomes a valid variable that you can reference eg. `{{ topic }}`\\n\\n##### Setup:\\n\\n1. Search for a dataset and select one in the dropdown (example: `unsloth/alpaca-cleaned`)\\n2. Add token only if dataset is gated/private\\n3. Load columns + preview rows so variables are available in prompts\\n\\n##### Why this matters:\\n- Seed columns can drive generation quality\\n- You can reference seed values directly in prompts (for example `{{ output }}`)\",\n        \"note_color\": \"#DCFCE7\",\n        \"note_opacity\": \"35\"\n      },\n      {\n        \"id\": \"note_2\",\n        \"x\": -74.04047072330651,\n        \"y\": -265.3540670633283,\n        \"width\": 400,\n        \"node_type\": \"markdown_note\",\n        \"name\": \"note_2\",\n        \"markdown\": \"##### Drop columns behavior:\\n\\n- You can mark specific seed columns to **drop from final output**\\n- Those columns are still used during generation\\n- They are removed only from exported final dataset\\n\\n##### Example:\\n- Keep `generated_instruction` from llm-text block\\n- Drop original `instruction`, `input`, `output` from the hugginface dataset from final dataset\\n- Result: clean training output while still using source columns as generation context\\n\",\n        \"note_color\": \"#DCFCE7\",\n        \"note_opacity\": \"35\"\n      },\n      {\n        \"id\": \"seed\",\n        \"x\": -76.07288662013991,\n        \"y\": 143.39449780463954,\n        \"width\": 400\n      },\n      {\n        \"id\": \"openai_provider\",\n        \"x\": 461.00000000000006,\n        \"y\": -489.8750000000001,\n        \"width\": 400\n      },\n      {\n        \"id\": \"ministral\",\n        \"x\": 463.272022949692,\n        \"y\": -191.13601147484601,\n        \"width\": 400\n      },\n      {\n        \"id\": \"generated_instruction\",\n        \"x\": 464,\n        \"y\": 109.00000000000003,\n        \"width\": 400\n      }\n    ],\n    \"edges\": [\n      {\n        \"from\": \"seed\",\n        \"to\": \"generated_instruction\",\n        \"type\": \"canvas\",\n        \"source_handle\": \"data-out\",\n        \"target_handle\": \"data-in\"\n      },\n      {\n        \"from\": \"ministral\",\n        \"to\": \"generated_instruction\",\n        \"type\": \"semantic\",\n        \"source_handle\": \"semantic-out-bottom\",\n        \"target_handle\": \"data-in-top\"\n      },\n      {\n        \"from\": \"openai_provider\",\n        \"to\": \"ministral\",\n        \"type\": \"semantic\",\n        \"source_handle\": \"semantic-out-bottom\",\n        \"target_handle\": \"semantic-in-top\"\n      }\n    ],\n    \"layout_direction\": \"LR\",\n    \"seed_source_type\": \"hf\",\n    \"seed_columns\": [],\n    \"seed_drop_columns\": [],\n    \"seed_preview_rows\": [],\n    \"local_file_name\": \"\",\n    \"unstructured_file_name\": \"\",\n    \"unstructured_chunk_size\": \"1200\",\n    \"unstructured_chunk_overlap\": \"200\"\n  }\n}\n"
  },
  {
    "path": "studio/frontend/src/features/data-recipes/learning-recipes/ocr-document-extraction.json",
    "content": "{\n  \"recipe\": {\n    \"model_providers\": [\n      {\n        \"name\": \"provider_1\",\n        \"endpoint\": \"https://openrouter.ai/api/v1\",\n        \"provider_type\": \"openai\",\n        \"extra_headers\": {},\n        \"extra_body\": {}\n      }\n    ],\n    \"mcp_providers\": [],\n    \"model_configs\": [\n      {\n        \"alias\": \"provider_column\",\n        \"model\": \"google/gemini-2.0-flash-001\",\n        \"provider\": \"provider_1\",\n        \"inference_parameters\": {\n          \"temperature\": 0.2,\n          \"max_tokens\": 4096\n        }\n      }\n    ],\n    \"seed_config\": {\n      \"source\": {\n        \"seed_type\": \"hf\",\n        \"path\": \"datasets/ylecun/mnist/mnist/**/*.parquet\"\n      },\n      \"sampling_strategy\": \"ordered\",\n      \"selection_strategy\": null\n    },\n    \"tool_configs\": [],\n    \"columns\": [\n      {\n        \"column_type\": \"llm-text\",\n        \"name\": \"ocr_text\",\n        \"drop\": false,\n        \"model_alias\": \"provider_column\",\n        \"prompt\": \"Transcribe all text from this document image.\",\n        \"multi_modal_context\": [\n          {\n            \"modality\": \"image\",\n            \"column_name\": \"image\"\n          }\n        ]\n      }\n    ],\n    \"processors\": []\n  },\n  \"run\": {\n    \"rows\": 5,\n    \"preview\": true,\n    \"output_formats\": [\"jsonl\"]\n  },\n  \"ui\": {\n    \"nodes\": [\n      {\n        \"id\": \"note_1\",\n        \"x\": -180,\n        \"y\": 43,\n        \"width\": 400,\n        \"node_type\": \"markdown_note\",\n        \"name\": \"note_1\",\n        \"markdown\": \"This recipe uses **Gemini 2.0 Flash** via OpenRouter to transcribe document images into clean text.\\n\\nThe Seed block is prefilled with `ylecun/mnist` so you can run immediately. You can swap to any Hugging Face dataset that includes an `image` column.\\n\\nOutput: `ocr_text` column with the raw transcribed text per image.\",\n        \"note_color\": \"#DCFCE7\",\n        \"note_opacity\": \"35\"\n      },\n      {\n        \"id\": \"note_2\",\n        \"x\": 283,\n        \"y\": -333,\n        \"width\": 400,\n        \"node_type\": \"markdown_note\",\n        \"name\": \"note_2\",\n        \"markdown\": \"##### Setup\\n\\nAdd your OpenRouter API key to the **Model Provider** block — same as every other recipe.\\n\\nGemini 2.0 Flash is well-suited for OCR: fast, cheap, and strong on tables, receipts, forms, and multi-column layouts.\\n\\nWant a purpose-built OCR model? Swap the endpoint to a local vLLM server running `lightonai/LightOnOCR-2-1B` for maximum throughput.\",\n        \"note_color\": \"#DCFCE7\",\n        \"note_opacity\": \"35\"\n      },\n      {\n        \"id\": \"note_3\",\n        \"x\": 303,\n        \"y\": 299,\n        \"width\": 400,\n        \"node_type\": \"markdown_note\",\n        \"name\": \"note_3\",\n        \"markdown\": \"##### Seed: HF dataset with image column\\n\\nThis template starts with `ylecun/mnist` so first run works without seed setup.\\n\\nTo use your own data: open Seed → keep **HF dataset** selected → choose a dataset that contains an `image` column → click **Load**.\\n\\nThen open the LLM Text block and set **Image Context** to the `image` column so each row image is sent with the prompt.\\n\\nTip: datasets with embedded image columns are more reliable than URL-only image fields.\",\n        \"note_color\": \"#DCFCE7\",\n        \"note_opacity\": \"35\"\n      },\n      {\n        \"id\": \"seed\",\n        \"x\": 295,\n        \"y\": 108,\n        \"width\": 400\n      },\n      {\n        \"id\": \"provider_1\",\n        \"x\": 960,\n        \"y\": -465,\n        \"width\": 400\n      },\n      {\n        \"id\": \"provider_column\",\n        \"x\": 959,\n        \"y\": -180,\n        \"width\": 400\n      },\n      {\n        \"id\": \"ocr_text\",\n        \"x\": 960,\n        \"y\": 108,\n        \"width\": 400\n      }\n    ],\n    \"edges\": [\n      {\n        \"from\": \"seed\",\n        \"to\": \"ocr_text\",\n        \"type\": \"canvas\",\n        \"source_handle\": \"data-out\",\n        \"target_handle\": \"data-in\"\n      },\n      {\n        \"from\": \"provider_1\",\n        \"to\": \"provider_column\",\n        \"type\": \"semantic\",\n        \"source_handle\": \"semantic-out-bottom\",\n        \"target_handle\": \"semantic-in-top\"\n      },\n      {\n        \"from\": \"provider_column\",\n        \"to\": \"ocr_text\",\n        \"type\": \"semantic\",\n        \"source_handle\": \"semantic-out-bottom\",\n        \"target_handle\": \"data-in-top\"\n      }\n    ],\n    \"layout_direction\": \"LR\",\n    \"seed_source_type\": \"hf\",\n    \"seed_columns\": [],\n    \"seed_drop_columns\": [],\n    \"seed_preview_rows\": [],\n    \"local_file_name\": \"\",\n    \"unstructured_file_name\": \"\",\n    \"unstructured_chunk_size\": \"900\",\n    \"unstructured_chunk_overlap\": \"150\"\n  }\n}\n"
  },
  {
    "path": "studio/frontend/src/features/data-recipes/learning-recipes/pdf-grounded-qa.json",
    "content": "{\n  \"recipe\": {\n    \"model_providers\": [\n      {\n        \"name\": \"provider_1\",\n        \"endpoint\": \"\",\n        \"provider_type\": \"openai\",\n        \"extra_headers\": {},\n        \"extra_body\": {}\n      }\n    ],\n    \"mcp_providers\": [],\n    \"model_configs\": [\n      {\n        \"alias\": \"provider_column\",\n        \"model\": \"\",\n        \"provider\": \"provider_1\",\n        \"inference_parameters\": {\n          \"temperature\": 0.7\n        }\n      }\n    ],\n    \"seed_config\": {\n      \"source\": {\n        \"seed_type\": \"unstructured\",\n        \"path\": \"\",\n        \"chunk_size\": 1200,\n        \"chunk_overlap\": 200\n      },\n      \"sampling_strategy\": \"ordered\",\n      \"selection_strategy\": null\n    },\n    \"tool_configs\": [],\n    \"columns\": [\n      {\n        \"column_type\": \"llm-structured\",\n        \"name\": \"llm_structured_1\",\n        \"drop\": false,\n        \"model_alias\": \"provider_column\",\n        \"prompt\": \"Given ONLY this chunk: {{ chunk_text }} generate one answerable question, answer, and exact supporting quote from chunk. If not answerable, skip.\",\n        \"with_trace\": \"none\",\n        \"extract_reasoning_content\": false,\n        \"output_format\": {\n          \"type\": \"object\",\n          \"additionalProperties\": false,\n          \"required\": [\n            \"question\",\n            \"answer\",\n            \"evidence_quote\"\n          ],\n          \"properties\": {\n            \"question\": {\n              \"type\": \"string\"\n            },\n            \"answer\": {\n              \"type\": \"string\"\n            },\n            \"evidence_quote\": {\n              \"type\": \"string\"\n            }\n          }\n        }\n      }\n    ],\n    \"processors\": []\n  },\n  \"run\": {\n    \"rows\": 5,\n    \"preview\": true,\n    \"output_formats\": [\n      \"jsonl\"\n    ]\n  },\n  \"ui\": {\n    \"nodes\": [\n      {\n        \"id\": \"note_1\",\n        \"x\": 474.6120044693708,\n        \"y\": 1229.5810476890458,\n        \"width\": 400,\n        \"node_type\": \"markdown_note\",\n        \"name\": \"note_1\",\n        \"markdown\": \"This recipe uses **seed data** from external documents.\\nInstead of starting from empty generation, we load real source text first.\\n\\nIn this flow, the seed source is **Unstructured Documents**:\\n\\n- Upload: `.pdf`, `.docx`, `.txt`\\n- Text is extracted and split on client into chunks\\n- Each chunk becomes a row-like seed record (`chunk_text`) that you can reference in prompts with `{{ chunk_text }} `\",\n        \"note_color\": \"#F3E8FF\",\n        \"note_opacity\": \"35\"\n      },\n      {\n        \"id\": \"note_2\",\n        \"x\": 26.758540311329455,\n        \"y\": 963.3465578835235,\n        \"width\": 400,\n        \"node_type\": \"markdown_note\",\n        \"name\": \"note_2\",\n        \"markdown\": \"##### Chunking settings:\\n\\n- **Chunk size**: how much text per chunk\\n- **Chunk overlap**: shared text between neighboring chunks to preserve context\\n\\n##### Sampling settings:\\n\\n- **Ordered**: keep original document order\\n- **Shuffle**: randomize chunk order\\n- **Selection index / selection settings**: choose which part/subset of seed data to use\",\n        \"note_color\": \"#F3E8FF\",\n        \"note_opacity\": \"35\"\n      },\n      {\n        \"id\": \"note_3\",\n        \"x\": 473.62551435180245,\n        \"y\": 741.5352258256931,\n        \"width\": 400,\n        \"node_type\": \"markdown_note\",\n        \"name\": \"note_3\",\n        \"markdown\": \"- LLM prompt: `{{ chunk_text }}`\\n- Expression block: combine/format values using `{{ chunk_text }}`\\n- Processor templates: use `{{ chunk_text }}` during transforms\\n\\nTip:\\n- Start with medium chunk size + small overlap.\\n- Increase overlap only if answers lose context between chunks.\",\n        \"note_color\": \"#F3E8FF\",\n        \"note_opacity\": \"35\"\n      },\n      {\n        \"id\": \"seed\",\n        \"x\": 484.36210245413577,\n        \"y\": 1059.99180558796,\n        \"width\": 400\n      },\n      {\n        \"id\": \"provider_1\",\n        \"x\": 960,\n        \"y\": 622,\n        \"width\": 400\n      },\n      {\n        \"id\": \"provider_column\",\n        \"x\": 960,\n        \"y\": 816,\n        \"width\": 400\n      },\n      {\n        \"id\": \"llm_structured_1\",\n        \"x\": 960,\n        \"y\": 1077,\n        \"width\": 400\n      }\n    ],\n    \"edges\": [\n      {\n        \"from\": \"provider_1\",\n        \"to\": \"provider_column\",\n        \"type\": \"semantic\",\n        \"source_handle\": \"semantic-out-bottom\",\n        \"target_handle\": \"semantic-in-top\"\n      },\n      {\n        \"from\": \"provider_column\",\n        \"to\": \"llm_structured_1\",\n        \"type\": \"semantic\",\n        \"source_handle\": \"semantic-out-bottom\",\n        \"target_handle\": \"data-in-top\"\n      },\n      {\n        \"from\": \"llm_structured_1\",\n        \"to\": \"seed\",\n        \"type\": \"canvas\",\n        \"source_handle\": \"data-out-left\",\n        \"target_handle\": \"data-in-right\"\n      }\n    ],\n    \"layout_direction\": \"LR\",\n    \"seed_source_type\": \"unstructured\",\n    \"seed_columns\": [],\n    \"seed_drop_columns\": [],\n    \"seed_preview_rows\": [],\n    \"local_file_name\": \"\",\n    \"unstructured_file_name\": \"\",\n    \"unstructured_chunk_size\": \"1200\",\n    \"unstructured_chunk_overlap\": \"200\"\n  }\n}"
  },
  {
    "path": "studio/frontend/src/features/data-recipes/learning-recipes/structured-outputs-jinja.json",
    "content": "{\n  \"recipe\": {\n    \"model_providers\": [\n      {\n        \"name\": \"provider_column\",\n        \"endpoint\": \"\",\n        \"provider_type\": \"openai\",\n        \"extra_headers\": {},\n        \"extra_body\": {}\n      }\n    ],\n    \"mcp_providers\": [],\n    \"model_configs\": [\n      {\n        \"alias\": \"ministral\",\n        \"model\": \"\",\n        \"provider\": \"provider_column\",\n        \"inference_parameters\": {\n          \"temperature\": 0.7\n        }\n      }\n    ],\n    \"tool_configs\": [],\n    \"columns\": [\n      {\n        \"column_type\": \"sampler\",\n        \"name\": \"user\",\n        \"drop\": true,\n        \"sampler_type\": \"person_from_faker\",\n        \"params\": {}\n      },\n      {\n        \"column_type\": \"sampler\",\n        \"name\": \"platform\",\n        \"drop\": false,\n        \"sampler_type\": \"category\",\n        \"params\": {\n          \"values\": [\n            \"web\",\n            \"mobile\",\n            \"cli\"\n          ]\n        }\n      },\n      {\n        \"column_type\": \"sampler\",\n        \"name\": \"impact_scope\",\n        \"drop\": false,\n        \"sampler_type\": \"category\",\n        \"params\": {\n          \"values\": [\n            \"single_user\",\n            \"team\",\n            \"org_wide\"\n          ]\n        }\n      },\n      {\n        \"column_type\": \"expression\",\n        \"name\": \"user_first_name\",\n        \"drop\": false,\n        \"expr\": \"{{ user.first_name }}\",\n        \"dtype\": \"str\"\n      },\n      {\n        \"column_type\": \"expression\",\n        \"name\": \"user_full_name\",\n        \"drop\": false,\n        \"expr\": \"{{ user.first_name }} {{ user.last_name }}\",\n        \"dtype\": \"str\"\n      },\n      {\n        \"column_type\": \"llm-structured\",\n        \"name\": \"ticket\",\n        \"drop\": false,\n        \"model_alias\": \"ministral\",\n        \"prompt\": \"Create a realistic support ticket from {{ user_full_name }} using the {{ platform }} platform. Impact scope is {{ impact_scope }}.\\n\",\n        \"with_trace\": \"none\",\n        \"extract_reasoning_content\": false,\n        \"output_format\": {\n          \"type\": \"object\",\n          \"additionalProperties\": false,\n          \"required\": [\n            \"issue_title\",\n            \"issue_summary\",\n            \"category\",\n            \"priority\"\n          ],\n          \"properties\": {\n            \"issue_title\": {\n              \"type\": \"string\",\n              \"description\": \"Short title of issue\"\n            },\n            \"issue_summary\": {\n              \"type\": \"string\",\n              \"description\": \"1-2 sentence summary\"\n            },\n            \"category\": {\n              \"type\": \"string\",\n              \"enum\": [\n                \"account\",\n                \"billing\",\n                \"api\",\n                \"infra\"\n              ],\n              \"description\": \"Issue category\"\n            },\n            \"priority\": {\n              \"type\": \"string\",\n              \"enum\": [\n                \"P1\",\n                \"P2\",\n                \"P3\"\n              ],\n              \"description\": \"Urgency level\"\n            }\n          }\n        }\n      },\n      {\n        \"column_type\": \"expression\",\n        \"name\": \"sla_target\",\n        \"drop\": false,\n        \"expr\": \"{% if impact_scope == 'org_wide' %}15m\\n{% elif impact_scope == 'team' %}1h\\n{% else %}4h\\n{% endif %}\",\n        \"dtype\": \"str\"\n      },\n      {\n        \"column_type\": \"llm-structured\",\n        \"name\": \"agent_reply\",\n        \"drop\": false,\n        \"model_alias\": \"ministral\",\n        \"prompt\": \"Write a concise support reply for ticket '{{ ticket.issue_title }}'. Category: {{ ticket.category }}. Priority: {{ ticket.priority }}. SLA target: {{ sla_target }}. {% if ticket.priority == 'P1' %}Tone must be urgent and action-first.{% else %}Tone must be calm and instructional.{% endif %}\",\n        \"with_trace\": \"none\",\n        \"extract_reasoning_content\": false,\n        \"output_format\": {\n          \"type\": \"object\",\n          \"additionalProperties\": false,\n          \"required\": [\n            \"response\",\n            \"next_action\"\n          ],\n          \"properties\": {\n            \"response\": {\n              \"type\": \"string\",\n              \"description\": \"Support response to user\"\n            },\n            \"next_action\": {\n              \"type\": \"string\",\n              \"enum\": [\n                \"ask_logs\",\n                \"reset_credentials\",\n                \"escalate\",\n                \"provide_steps\"\n              ],\n              \"description\": \"Primary next action\"\n            }\n          }\n        }\n      }\n    ],\n    \"processors\": []\n  },\n  \"run\": {\n    \"rows\": 5,\n    \"preview\": true,\n    \"output_formats\": [\n      \"jsonl\"\n    ]\n  },\n  \"ui\": {\n    \"nodes\": [\n      {\n        \"id\": \"note_1\",\n        \"x\": 990.3973509933774,\n        \"y\": 1487.5768211920529,\n        \"width\": 782,\n        \"node_type\": \"markdown_note\",\n        \"name\": \"note_1\",\n        \"markdown\": \"## Expression columns \\nAre like lightweight spreadsheet formulas.\\nUse them when you want to transform existing columns quickly, without calling an LLM.\\n\\n### What you can do:\\n\\n- Use values from other columns: `{{ first_name }} {{ last_name }}`\\n- Clean/format text: `{{ city | upper }}`, `{{ product_name | trim }}`\\n- Conditional logic:\\n  - `{% if order_total >= 100 %}VIP{% elif order_total >= 50 %}Standard{% else %}Starter{% endif %}`\\n- Simple math:\\n  - `{{ quantity * unit_price }}`\\n  - `{{ (subtotal - discount) | round(2) }}`\\n\\n### Good rule:\\n- If the value can be computed from existing data, use Expression first.\\n- Use LLM only when you need true language generation.\",\n        \"note_color\": \"#CFFAFE\",\n        \"note_opacity\": \"35\"\n      },\n      {\n        \"id\": \"note_2\",\n        \"x\": 3217.6543046357615,\n        \"y\": 2081.596026490066,\n        \"width\": 400,\n        \"node_type\": \"markdown_note\",\n        \"name\": \"note_2\",\n        \"markdown\": \"### LLM Structured block\\nGenerates JSON that matches your Output Format schema.\\nThink of Output Format as a contract for what the model must return.\\n\\n#### Prompt tips:\\n\\n- Reference existing columns with Jinja: `{{ column_name }}`\\n- You can reference nested values too: `{{ customer.first_name }}`\\n- Be explicit about what each field should contain.\\n\\n#### Example prompt pattern:\\n\\n```text\\nCreate a support ticket summary.\\nCustomer: {{ customer_name }}\\nIssue text: {{ issue_text }}\\n\\nReturn data for:\\n- priority\\n- short_title\\n- resolution_steps\\n```\",\n        \"note_color\": \"#CFFAFE\",\n        \"note_opacity\": \"35\"\n      },\n      {\n        \"id\": \"note_3\",\n        \"x\": 2294.4516556291387,\n        \"y\": 2399.6099337748346,\n        \"width\": 638,\n        \"node_type\": \"markdown_note\",\n        \"name\": \"note_3\",\n        \"markdown\": \"## Example output format shape (concept):\\n\\n```json\\n{\\n  \\\"type\\\": \\\"object\\\",\\n  \\\"properties\\\": {\\n    \\\"priority\\\": { \\\"type\\\": \\\"string\\\" },\\n    \\\"short_title\\\": { \\\"type\\\": \\\"string\\\" },\\n    \\\"resolution_steps\\\": { \\\"type\\\": \\\"array\\\", \\\"items\\\": { \\\"type\\\": \\\"string\\\" } }\\n  },\\n  \\\"required\\\": [\\\"priority\\\", \\\"short_title\\\", \\\"resolution_steps\\\"]\\n}\\n```\",\n        \"note_color\": \"#CFFAFE\",\n        \"note_opacity\": \"35\"\n      },\n      {\n        \"id\": \"note_4\",\n        \"x\": 2544.684105960265,\n        \"y\": 1126.5490066225163,\n        \"width\": 399,\n        \"node_type\": \"markdown_note\",\n        \"name\": \"note_4\",\n        \"markdown\": \"### Model provider & Config\\nEvery LLM block needs a model alias.\\nThat alias comes from a Model Config.\\nModel Config points to a Model Provider.\\n\\n#### Minimum setup:\\n\\n1. Create **Model Provider**\\n   - Set endpoint/provider type\\n   - Prefer env var auth (`api_key_env`) over hardcoded keys\\n\\n2. Create **Model Config**\\n   - Set alias (example: `model_1`)\\n   - Set model id\\n   - Link to provider\\n   - Tune params (temperature, max_tokens, etc.)\\n\\n3. In each LLM block\\n   - Set `model_alias` to that alias\\n\\nIf alias/provider link is missing, validation/run will fail.\",\n        \"note_color\": \"#CFFAFE\",\n        \"note_opacity\": \"35\"\n      },\n      {\n        \"id\": \"provider_column\",\n        \"x\": 2542,\n        \"y\": 1696,\n        \"width\": 400\n      },\n      {\n        \"id\": \"ministral\",\n        \"x\": 2542,\n        \"y\": 1890,\n        \"width\": 400\n      },\n      {\n        \"id\": \"user\",\n        \"x\": 191,\n        \"y\": 2423,\n        \"width\": 400\n      },\n      {\n        \"id\": \"platform\",\n        \"x\": 858.0384105960266,\n        \"y\": 2286.5,\n        \"width\": 400\n      },\n      {\n        \"id\": \"impact_scope\",\n        \"x\": 1342,\n        \"y\": 2286.5,\n        \"width\": 400\n      },\n      {\n        \"id\": \"user_first_name\",\n        \"x\": 1822,\n        \"y\": 2505,\n        \"width\": 400\n      },\n      {\n        \"id\": \"user_full_name\",\n        \"x\": 1822,\n        \"y\": 1959,\n        \"width\": 400\n      },\n      {\n        \"id\": \"ticket\",\n        \"x\": 2302,\n        \"y\": 2286.5,\n        \"width\": 400\n      },\n      {\n        \"id\": \"sla_target\",\n        \"x\": 1822,\n        \"y\": 2232,\n        \"width\": 400\n      },\n      {\n        \"id\": \"agent_reply\",\n        \"x\": 2782,\n        \"y\": 2151,\n        \"width\": 400\n      }\n    ],\n    \"edges\": [\n      {\n        \"from\": \"platform\",\n        \"to\": \"impact_scope\",\n        \"type\": \"canvas\",\n        \"source_handle\": \"data-out\",\n        \"target_handle\": \"data-in\"\n      },\n      {\n        \"from\": \"user\",\n        \"to\": \"user_first_name\",\n        \"type\": \"canvas\",\n        \"source_handle\": \"data-out\",\n        \"target_handle\": \"data-in\"\n      },\n      {\n        \"from\": \"user_full_name\",\n        \"to\": \"ticket\",\n        \"type\": \"canvas\",\n        \"source_handle\": \"data-out\",\n        \"target_handle\": \"data-in\"\n      },\n      {\n        \"from\": \"user\",\n        \"to\": \"platform\",\n        \"type\": \"canvas\",\n        \"source_handle\": \"data-out\",\n        \"target_handle\": \"data-in\"\n      },\n      {\n        \"from\": \"user\",\n        \"to\": \"user_full_name\",\n        \"type\": \"canvas\",\n        \"source_handle\": \"data-out\",\n        \"target_handle\": \"data-in\"\n      },\n      {\n        \"from\": \"user_first_name\",\n        \"to\": \"ticket\",\n        \"type\": \"canvas\",\n        \"source_handle\": \"data-out\",\n        \"target_handle\": \"data-in\"\n      },\n      {\n        \"from\": \"impact_scope\",\n        \"to\": \"sla_target\",\n        \"type\": \"canvas\",\n        \"source_handle\": \"data-out\",\n        \"target_handle\": \"data-in\"\n      },\n      {\n        \"from\": \"sla_target\",\n        \"to\": \"ticket\",\n        \"type\": \"canvas\",\n        \"source_handle\": \"data-out\",\n        \"target_handle\": \"data-in\"\n      },\n      {\n        \"from\": \"ticket\",\n        \"to\": \"agent_reply\",\n        \"type\": \"canvas\",\n        \"source_handle\": \"data-out\",\n        \"target_handle\": \"data-in\"\n      },\n      {\n        \"from\": \"provider_column\",\n        \"to\": \"ministral\",\n        \"type\": \"semantic\",\n        \"source_handle\": \"semantic-out-bottom\",\n        \"target_handle\": \"semantic-in-top\"\n      },\n      {\n        \"from\": \"ministral\",\n        \"to\": \"ticket\",\n        \"type\": \"semantic\",\n        \"source_handle\": \"semantic-out-bottom\",\n        \"target_handle\": \"data-in-top\"\n      },\n      {\n        \"from\": \"ministral\",\n        \"to\": \"agent_reply\",\n        \"type\": \"semantic\",\n        \"source_handle\": \"semantic-out-bottom\",\n        \"target_handle\": \"data-in-top\"\n      }\n    ],\n    \"layout_direction\": \"LR\"\n  }\n}"
  },
  {
    "path": "studio/frontend/src/features/data-recipes/learning-recipes/text-to-python.json",
    "content": "{\n  \"recipe\": {\n    \"model_providers\": [\n      {\n        \"name\": \"openai-compatible\",\n        \"endpoint\": \"\",\n        \"provider_type\": \"openai\",\n        \"extra_headers\": {},\n        \"extra_body\": {}\n      }\n    ],\n    \"mcp_providers\": [],\n    \"model_configs\": [\n      {\n        \"alias\": \"coding-model\",\n        \"model\": \"\",\n        \"provider\": \"openai-compatible\",\n        \"inference_parameters\": {\n          \"temperature\": 0.7\n        }\n      }\n    ],\n    \"tool_configs\": [],\n    \"columns\": [\n      {\n        \"column_type\": \"sampler\",\n        \"name\": \"domain\",\n        \"drop\": false,\n        \"sampler_type\": \"category\",\n        \"params\": {\n          \"values\": [\n            \"Data Processing\",\n            \"Web API\",\n            \"Automation\"\n          ]\n        }\n      },\n      {\n        \"column_type\": \"sampler\",\n        \"name\": \"task_type\",\n        \"drop\": false,\n        \"sampler_type\": \"subcategory\",\n        \"params\": {\n          \"category\": \"domain\",\n          \"values\": {\n            \"Data Processing\": [\n              \"CSV cleaning\",\n              \"JSON transform\",\n              \"deduplicate rows\"\n            ],\n            \"Web API\": [\n              \"GET endpoint\",\n              \"POST validation\",\n              \"pagination helper\"\n            ],\n            \"Automation\": [\n              \"file organizer\",\n              \"log parser\",\n              \"daily report script\"\n            ]\n          }\n        }\n      },\n      {\n        \"column_type\": \"llm-text\",\n        \"name\": \"instruction\",\n        \"drop\": false,\n        \"model_alias\": \"coding-model\",\n        \"prompt\": \"Write one clear Python coding instruction.\\nDomain: {{ domain }}\\nTask type: {{ task_type }}\\n\\nKeep it practical and specific.\\nReturn only the instruction without any code.\",\n        \"with_trace\": \"none\",\n        \"extract_reasoning_content\": false\n      },\n      {\n        \"column_type\": \"llm-code\",\n        \"name\": \"code_implementation\",\n        \"drop\": false,\n        \"model_alias\": \"coding-model\",\n        \"prompt\": \"Write Python code for:\\n{{ instruction }}\\n\\nRequirements:\\n- runnable script or function\\n- include needed imports\\n- short comments only where useful\\n- no markdown fences\",\n        \"with_trace\": \"none\",\n        \"extract_reasoning_content\": false,\n        \"code_lang\": \"python\"\n      },\n      {\n        \"column_type\": \"llm-judge\",\n        \"name\": \"code_judge_result\",\n        \"drop\": false,\n        \"model_alias\": \"coding-model\",\n        \"prompt\": \"Evaluate generated Python code against the instruction.\\n\\nInstruction:\\n{{ instruction }}\\n\\nCode:\\n{{ code_implementation }}\",\n        \"with_trace\": \"none\",\n        \"extract_reasoning_content\": false,\n        \"scores\": [\n          {\n            \"name\": \"Correctness\",\n            \"description\": \"Follows instruction and is executable\",\n            \"options\": {\n              \"0\": \"bad\",\n              \"1\": \"partial\",\n              \"2\": \"good\",\n              \"3\": \"excellent\"\n            }\n          }\n        ]\n      }\n    ],\n    \"processors\": []\n  },\n  \"run\": {\n    \"rows\": 5,\n    \"preview\": true,\n    \"output_formats\": [\n      \"jsonl\"\n    ]\n  },\n  \"ui\": {\n    \"nodes\": [\n      {\n        \"id\": \"note_1\",\n        \"x\": 1526,\n        \"y\": 1790.75,\n        \"width\": 568,\n        \"node_type\": \"markdown_note\",\n        \"name\": \"note_1\",\n        \"markdown\": \"The **LLM Code** block is where Python code is generated from your instruction/prompt.\\n\\n##### How it works in this recipe:\\n\\n- You provide a clear prompt (often using Jinja references from earlier columns)\\n- The model returns a response\\n- The block extracts code content directly for the output column\\n\\n##### Current status:\\n\\n- We are **not** running Python lint/syntax validation in this recipe yet (Soon)\\n- Validation support is planned and will be added\\n\\n##### What this means:\\n\\n- You may get mostly correct code, but some rows can still have syntax/style issues\\n- Keep prompts specific and constrained to reduce bad outputs\\n\\n##### Tip:\\n\\n- Ask for one self-contained function/script\\n- Ask for required imports\\n- Ask for no markdown fences if you want cleaner extraction\\n\",\n        \"note_color\": \"#FEF3C7\",\n        \"note_opacity\": \"35\"\n      },\n      {\n        \"id\": \"note_2\",\n        \"x\": 2597.376821192053,\n        \"y\": 1233.2039735099338,\n        \"width\": 471,\n        \"node_type\": \"markdown_note\",\n        \"name\": \"note_2\",\n        \"markdown\": \"The **LLM Judge** block evaluates generated outputs with rubric-style scores.\\n\\n##### Important:\\n\\n- A judge can have **one or many scores**\\n- Each score has:\\n  - a name (for example: `Correctness`)\\n  - a description\\n  - options (value + meaning)\\n\\n##### Example multi-score setup:\\n\\n- Correctness\\n- Readability\\n- Efficiency\\n\\n##### Why use multiple scores:\\n\\n- You get richer quality signals than a single pass/fail\\n- Easier filtering and weighting later in training data prep\\n\\n##### Practical pattern:\\n\\n1. Generate code with LLM Code\\n2. Judge with 2-4 focused scores\\n3. Keep high-quality rows based on score thresholds\\n\",\n        \"note_color\": \"#FEF3C7\",\n        \"note_opacity\": \"35\"\n      },\n      {\n        \"id\": \"openai-compatible\",\n        \"x\": 1627.1046357615896,\n        \"y\": 921.0301324503313,\n        \"width\": 400\n      },\n      {\n        \"id\": \"coding-model\",\n        \"x\": 1627.1046357615894,\n        \"y\": 1138.910927152318,\n        \"width\": 400\n      },\n      {\n        \"id\": \"domain\",\n        \"x\": 84,\n        \"y\": 1600.5,\n        \"width\": 400\n      },\n      {\n        \"id\": \"task_type\",\n        \"x\": 648,\n        \"y\": 1600.5,\n        \"width\": 400\n      },\n      {\n        \"id\": \"instruction\",\n        \"x\": 1128,\n        \"y\": 1567,\n        \"width\": 400\n      },\n      {\n        \"id\": \"code_implementation\",\n        \"x\": 1627.1046357615894,\n        \"y\": 1531.4728476821192,\n        \"width\": 400\n      },\n      {\n        \"id\": \"code_judge_result\",\n        \"x\": 2124.617218543046,\n        \"y\": 1567.076490066225,\n        \"width\": 400\n      }\n    ],\n    \"edges\": [\n      {\n        \"from\": \"domain\",\n        \"to\": \"task_type\",\n        \"type\": \"canvas\",\n        \"source_handle\": \"data-out\",\n        \"target_handle\": \"data-in\"\n      },\n      {\n        \"from\": \"task_type\",\n        \"to\": \"instruction\",\n        \"type\": \"canvas\",\n        \"source_handle\": \"data-out\",\n        \"target_handle\": \"data-in\"\n      },\n      {\n        \"from\": \"openai-compatible\",\n        \"to\": \"coding-model\",\n        \"type\": \"semantic\",\n        \"source_handle\": \"semantic-out-bottom\",\n        \"target_handle\": \"semantic-in-top\"\n      },\n      {\n        \"from\": \"instruction\",\n        \"to\": \"code_implementation\",\n        \"type\": \"canvas\",\n        \"source_handle\": \"data-out\",\n        \"target_handle\": \"data-in\"\n      },\n      {\n        \"from\": \"coding-model\",\n        \"to\": \"instruction\",\n        \"type\": \"semantic\",\n        \"source_handle\": \"semantic-out-bottom\",\n        \"target_handle\": \"data-in-top\"\n      },\n      {\n        \"from\": \"coding-model\",\n        \"to\": \"code_implementation\",\n        \"type\": \"semantic\",\n        \"source_handle\": \"semantic-out-bottom\",\n        \"target_handle\": \"data-in-top\"\n      },\n      {\n        \"from\": \"code_implementation\",\n        \"to\": \"code_judge_result\",\n        \"type\": \"canvas\",\n        \"source_handle\": \"data-out\",\n        \"target_handle\": \"data-in\"\n      },\n      {\n        \"from\": \"coding-model\",\n        \"to\": \"code_judge_result\",\n        \"type\": \"semantic\",\n        \"source_handle\": \"semantic-out-bottom\",\n        \"target_handle\": \"data-in-top\"\n      }\n    ],\n    \"layout_direction\": \"LR\"\n  }\n}"
  },
  {
    "path": "studio/frontend/src/features/data-recipes/learning-recipes/text-to-sql.json",
    "content": "{\n  \"recipe\": {\n    \"model_providers\": [\n      {\n        \"name\": \"vllm\",\n        \"endpoint\": \"\",\n        \"provider_type\": \"openai\",\n        \"extra_headers\": {},\n        \"extra_body\": {}\n      }\n    ],\n    \"mcp_providers\": [],\n    \"model_configs\": [\n      {\n        \"alias\": \"sql-pro\",\n        \"model\": \"\",\n        \"provider\": \"vllm\",\n        \"inference_parameters\": {\n          \"temperature\": 0.7\n        }\n      }\n    ],\n    \"tool_configs\": [],\n    \"columns\": [\n      {\n        \"column_type\": \"sampler\",\n        \"name\": \"domain\",\n        \"drop\": true,\n        \"sampler_type\": \"category\",\n        \"params\": {\n          \"values\": [\n            \"Ecommerce\",\n            \"Customer Support\",\n            \"Finance\"\n          ]\n        }\n      },\n      {\n        \"column_type\": \"sampler\",\n        \"name\": \"topic\",\n        \"drop\": true,\n        \"sampler_type\": \"subcategory\",\n        \"params\": {\n          \"category\": \"domain\",\n          \"values\": {\n            \"Ecommerce\": [\n              \"Orders and Revenue\",\n              \"Returns and Refunds\",\n              \"Product Performance\"\n            ],\n            \"Customer Support\": [\n              \"Ticket Resolution\",\n              \"SLA Compliance\",\n              \"Agent Productivity\"\n            ],\n            \"Finance\": [\n              \"Invoices and Payments\",\n              \"Subscription Churn\",\n              \"Monthly Cashflow\"\n            ]\n          }\n        }\n      },\n      {\n        \"column_type\": \"sampler\",\n        \"name\": \"sql_task_type\",\n        \"drop\": true,\n        \"sampler_type\": \"category\",\n        \"params\": {\n          \"values\": [\n            \"Filtering\",\n            \"Aggregation\",\n            \"Join Analysis\",\n            \"Trend Reporting\"\n          ]\n        }\n      },\n      {\n        \"column_type\": \"sampler\",\n        \"name\": \"instruction_phrase\",\n        \"drop\": true,\n        \"sampler_type\": \"category\",\n        \"params\": {\n          \"values\": [\n            \"Write a SQL query that\",\n            \"Create a SQL statement to\",\n            \"Develop a SQL query to\"\n          ]\n        }\n      },\n      {\n        \"column_type\": \"llm-text\",\n        \"name\": \"sql_prompt\",\n        \"drop\": false,\n        \"model_alias\": \"sql-pro\",\n        \"prompt\": \"Generate one natural-language SQL task.\\nContext:\\n- Domain: {{ domain }}\\n- Topic: {{ topic }}\\n- Task type: {{ sql_task_type }}\\nRules:\\n- Must start exactly with: \\\"{{ instruction_phrase }}\\\"\\n- Make it specific and practical.\\n- Mention expected business outcome.\\n- Keep it 1-2 sentences.\\n- Do not include SQL code.\\n- Output only the instruction text.\",\n        \"system_prompt\": \"You create clear, realistic business SQL tasks for training data.\\n\",\n        \"with_trace\": \"none\",\n        \"extract_reasoning_content\": false\n      },\n      {\n        \"column_type\": \"llm-code\",\n        \"name\": \"sql\",\n        \"drop\": false,\n        \"model_alias\": \"sql-pro\",\n        \"prompt\": \"Write SQL for this instruction:\\n{{ sql_prompt }}\\nReturn ONE SQL script with this exact structure:\\n-- SCHEMA\\n[CREATE TABLE statements]\\n[INSERT statements with sample rows]\\n-- QUERY\\n[final SELECT query solving the instruction]\\nRules:\\n- Use 2-3 tables max.\\n- Use realistic snake_case names.\\n- Include 5-8 rows of sample data per table.\\n- Query must match task type \\\"{{ sql_task_type }}\\\".\\n- Use only tables/columns you created.\\n- No markdown fences.\\n- No explanation text outside SQL comments shown above.\",\n        \"system_prompt\": \"You are an expert SQL engineer. Produce correct, runnable SQL only.\\n\",\n        \"with_trace\": \"none\",\n        \"extract_reasoning_content\": false,\n        \"code_lang\": \"sql:ansi\"\n      },\n      {\n        \"column_type\": \"validation\",\n        \"name\": \"sql-validator\",\n        \"drop\": false,\n        \"target_columns\": [\n          \"sql\"\n        ],\n        \"validator_type\": \"code\",\n        \"validator_params\": {\n          \"code_lang\": \"sql:ansi\"\n        },\n        \"batch_size\": 10\n      }\n    ],\n    \"processors\": []\n  },\n  \"run\": {\n    \"rows\": 5,\n    \"preview\": true,\n    \"output_formats\": [\n      \"jsonl\"\n    ]\n  },\n  \"ui\": {\n    \"nodes\": [\n      {\n        \"id\": \"note_1\",\n        \"x\": 338,\n        \"y\": 1020,\n        \"width\": 600,\n        \"node_type\": \"markdown_note\",\n        \"name\": \"note_1\",\n        \"markdown\": \"##### This recipe starts with **sampler columns** to create controlled SQL task context:\\n\\n- `domain`\\n- `topic` (subcategory from `domain`)\\n- `sql_task_type`\\n- `instruction_phrase`\\n\\n##### Why this is useful:\\n\\n- You get diverse tasks without writing every prompt by hand\\n- You can steer business context + task pattern in a predictable way\\n- LLM prompts become cleaner because context is already structured\",\n        \"note_color\": \"#DBEAFE\",\n        \"note_opacity\": \"35\"\n      },\n      {\n        \"id\": \"note_2\",\n        \"x\": 1675.8410596026492,\n        \"y\": 1644.2185430463576,\n        \"width\": 400,\n        \"node_type\": \"markdown_note\",\n        \"name\": \"note_2\",\n        \"markdown\": \"The **LLM Text** block (`sql_prompt`) turns sampler context into one clean natural-language SQL task.\\n\\n##### Prompt pattern in this recipe:\\n\\n- references prior columns with Jinja (`{{ domain }}`, `{{ topic }}`, etc.)\\n- enforces start phrase with `{{ instruction_phrase }}`\\n- returns instruction text only (no SQL yet)\\n\\n##### Tip:\\n\\n- Keep this instruction block concise and specific\\n- Save implementation details for the next SQL generation block\",\n        \"note_color\": \"#DBEAFE\",\n        \"note_opacity\": \"35\"\n      },\n      {\n        \"id\": \"note_3\",\n        \"x\": 2198.980132450331,\n        \"y\": 1723.1456953642385,\n        \"width\": 400,\n        \"node_type\": \"markdown_note\",\n        \"name\": \"note_3\",\n        \"markdown\": \"The **LLM Code** block (`sql`) generates SQL script from `{{ sql_prompt }}`.\\n\\n##### In this recipe it returns:\\n\\n- schema section (`CREATE TABLE`)\\n- sample seed rows (`INSERT`)\\n- final query (`SELECT`)\\n\",\n        \"note_color\": \"#DBEAFE\",\n        \"note_opacity\": \"35\"\n      },\n      {\n        \"id\": \"note_4\",\n        \"x\": 1264,\n        \"y\": 1037,\n        \"width\": 400,\n        \"node_type\": \"markdown_note\",\n        \"name\": \"note_4\",\n        \"markdown\": \"Sampler columns are useful during generation, but often noisy in final output.\\n\\nSet helper columns to **drop=true** (like in this recipe), keep only output columns you want to export.\\n\\n#### Final keep we have set here:\\n\\n- `sql_prompt`\\n- `sql`\\n\\n\",\n        \"note_color\": \"#DBEAFE\",\n        \"note_opacity\": \"35\"\n      },\n      {\n        \"id\": \"vllm\",\n        \"x\": 1939.5364238410598,\n        \"y\": 781.25,\n        \"width\": 400\n      },\n      {\n        \"id\": \"sql-pro\",\n        \"x\": 1939.5364238410593,\n        \"y\": 975.25,\n        \"width\": 400\n      },\n      {\n        \"id\": \"domain\",\n        \"x\": 680,\n        \"y\": 1495,\n        \"width\": 400\n      },\n      {\n        \"id\": \"topic\",\n        \"x\": 1160,\n        \"y\": 1413,\n        \"width\": 400\n      },\n      {\n        \"id\": \"sql_task_type\",\n        \"x\": 1160,\n        \"y\": 1577,\n        \"width\": 400\n      },\n      {\n        \"id\": \"instruction_phrase\",\n        \"x\": 100,\n        \"y\": 1495,\n        \"width\": 400\n      },\n      {\n        \"id\": \"sql_prompt\",\n        \"x\": 1672.6490066225165,\n        \"y\": 1457.6854304635763,\n        \"width\": 400\n      },\n      {\n        \"id\": \"sql\",\n        \"x\": 2194.9006622516554,\n        \"y\": 1457.110927152318,\n        \"width\": 400\n      },\n      {\n        \"id\": \"sql-validator\",\n        \"x\": 2682.5827814569534,\n        \"y\": 1491.0413907284767,\n        \"width\": 400\n      }\n    ],\n    \"edges\": [\n      {\n        \"from\": \"domain\",\n        \"to\": \"topic\",\n        \"type\": \"canvas\",\n        \"source_handle\": \"data-out\",\n        \"target_handle\": \"data-in\"\n      },\n      {\n        \"from\": \"domain\",\n        \"to\": \"sql_task_type\",\n        \"type\": \"canvas\",\n        \"source_handle\": \"data-out\",\n        \"target_handle\": \"data-in\"\n      },\n      {\n        \"from\": \"instruction_phrase\",\n        \"to\": \"domain\",\n        \"type\": \"canvas\",\n        \"source_handle\": \"data-out\",\n        \"target_handle\": \"data-in\"\n      },\n      {\n        \"from\": \"topic\",\n        \"to\": \"sql_prompt\",\n        \"type\": \"canvas\",\n        \"source_handle\": \"data-out\",\n        \"target_handle\": \"data-in\"\n      },\n      {\n        \"from\": \"sql_prompt\",\n        \"to\": \"sql\",\n        \"type\": \"canvas\",\n        \"source_handle\": \"data-out\",\n        \"target_handle\": \"data-in\"\n      },\n      {\n        \"from\": \"vllm\",\n        \"to\": \"sql-pro\",\n        \"type\": \"semantic\",\n        \"source_handle\": \"semantic-out-bottom\",\n        \"target_handle\": \"semantic-in-top\"\n      },\n      {\n        \"from\": \"sql-pro\",\n        \"to\": \"sql\",\n        \"type\": \"semantic\",\n        \"source_handle\": \"semantic-out-bottom\",\n        \"target_handle\": \"data-in-top\"\n      },\n      {\n        \"from\": \"sql\",\n        \"to\": \"sql-validator\",\n        \"type\": \"semantic\",\n        \"source_handle\": \"data-out\",\n        \"target_handle\": \"data-in\"\n      },\n      {\n        \"from\": \"sql-pro\",\n        \"to\": \"sql_prompt\",\n        \"type\": \"semantic\",\n        \"source_handle\": \"semantic-out-bottom\",\n        \"target_handle\": \"data-in-top\"\n      },\n      {\n        \"from\": \"sql_prompt\",\n        \"to\": \"sql_task_type\",\n        \"type\": \"canvas\",\n        \"source_handle\": \"data-out-left\",\n        \"target_handle\": \"data-in-right\"\n      }\n    ],\n    \"layout_direction\": \"LR\"\n  }\n}"
  },
  {
    "path": "studio/frontend/src/features/data-recipes/pages/data-recipes-page.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Badge } from \"@/components/ui/badge\";\nimport { Button } from \"@/components/ui/button\";\nimport {\n  Dialog,\n  DialogContent,\n  DialogDescription,\n  DialogHeader,\n  DialogTitle,\n} from \"@/components/ui/dialog\";\nimport {\n  DropdownMenu,\n  DropdownMenuContent,\n  DropdownMenuItem,\n  DropdownMenuTrigger,\n} from \"@/components/ui/dropdown-menu\";\nimport {\n  Empty,\n  EmptyContent,\n  EmptyDescription,\n  EmptyHeader,\n  EmptyMedia,\n  EmptyTitle,\n} from \"@/components/ui/empty\";\nimport { ShineBorder } from \"@/components/ui/shine-border\";\nimport { toastError } from \"@/shared/toast\";\nimport {\n  Album02Icon,\n  ArrowDown01Icon,\n  CodeIcon,\n  CookBookIcon,\n  Database02Icon,\n  Delete02Icon,\n  DocumentAttachmentIcon,\n  FunctionIcon,\n  Plant01Icon,\n  PlusSignIcon,\n} from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport { useNavigate } from \"@tanstack/react-router\";\nimport type { ReactElement } from \"react\";\nimport { useEffect, useState } from \"react\";\nimport {\n  createRecipeDraft,\n  createRecipeFromLearningRecipe,\n  deleteRecipe,\n  primeRecipeCache,\n  useRecipes,\n} from \"../data/recipes-db\";\nimport { LEARNING_RECIPES } from \"../learning-recipes\";\n\nconst OPEN_LEARNING_RECIPES_ON_ARRIVAL_KEY =\n  \"data-recipes:open-learning-recipes\";\n\ntype TemplateCard = {\n  title: string;\n  description: string;\n  icon: typeof CookBookIcon;\n  difficulty: \"Easy\" | \"Starter\" | \"Intermediate\" | \"Advanced\";\n  learningBadges: string[];\n  surfaceClassName: string;\n  shineColor: string[];\n  learningRecipeId?: string;\n};\n\nconst TEMPLATE_CARDS: TemplateCard[] = [\n  {\n    title: \"Instruction from Answer\",\n    description:\n      \"Start from seed answer fields and generate matching user instructions for SFT pairs.\",\n    icon: Plant01Icon,\n    difficulty: \"Easy\",\n    learningBadges: [\"Seed Dataset\", \"LLM Text\", \"Prompting\"],\n    surfaceClassName:\n      \"from-emerald-500/15 via-green-500/5 to-transparent dark:from-emerald-400/30 dark:via-green-400/14 dark:to-emerald-950/16\",\n    shineColor: [\n      \"rgb(16 185 129 / 0.45)\",\n      \"rgb(34 197 94 / 0.4)\",\n      \"rgb(52 211 153 / 0.45)\",\n    ],\n    learningRecipeId: \"instruction-from-answer\",\n  },\n  {\n    title: \"PDF Document QA\",\n    description:\n      \"Unstructured PDF chunks transformed into grounded question-answer training pairs.\",\n    icon: DocumentAttachmentIcon,\n    difficulty: \"Easy\",\n    learningBadges: [\"Unstructured\", \"LLM Text\"],\n    surfaceClassName:\n      \"from-violet-500/15 via-fuchsia-500/5 to-transparent dark:from-violet-400/30 dark:via-fuchsia-400/14 dark:to-violet-950/16\",\n    shineColor: [\n      \"rgb(139 92 246 / 0.45)\",\n      \"rgb(217 70 239 / 0.4)\",\n      \"rgb(168 85 247 / 0.45)\",\n    ],\n    learningRecipeId: \"pdf-grounded-qa\",\n  },\n  {\n    title: \"OCR Document Extraction\",\n    description:\n      \"Use image context from seed data to generate OCR-style extraction outputs.\",\n    icon: Album02Icon,\n    difficulty: \"Starter\",\n    learningBadges: [\"Vision\", \"LLM Text\", \"Image Context\"],\n    surfaceClassName:\n      \"from-lime-500/15 via-emerald-500/5 to-transparent dark:from-lime-400/30 dark:via-emerald-400/14 dark:to-lime-950/16\",\n    shineColor: [\n      \"rgb(132 204 22 / 0.45)\",\n      \"rgb(16 185 129 / 0.4)\",\n      \"rgb(74 222 128 / 0.45)\",\n    ],\n    learningRecipeId: \"ocr-document-extraction\",\n  },\n  {\n    title: \"Text to Python\",\n    description:\n      \"Instruction-to-code pairs for training models that generate clean Python implementations.\",\n    icon: CodeIcon,\n    difficulty: \"Intermediate\",\n    learningBadges: [\"LLM Judge\", \"LLM Code\", \"Subcategory\", \"Category\"],\n    surfaceClassName:\n      \"from-amber-500/15 via-orange-500/5 to-transparent dark:from-amber-400/30 dark:via-orange-400/14 dark:to-amber-950/16\",\n    shineColor: [\n      \"rgb(245 158 11 / 0.45)\",\n      \"rgb(249 115 22 / 0.4)\",\n      \"rgb(251 146 60 / 0.45)\",\n    ],\n    learningRecipeId: \"text-to-python\",\n  },\n  {\n    title: \"Text to SQL\",\n    description:\n      \"Natural language to SQL pairs, including schema-aware query construction patterns.\",\n    icon: Database02Icon,\n    difficulty: \"Intermediate\",\n    learningBadges: [\"LLM Code\", \"Prompting\", \"Drop Columns\"],\n    surfaceClassName:\n      \"from-blue-500/15 via-indigo-500/5 to-transparent dark:from-blue-400/30 dark:via-indigo-400/14 dark:to-blue-950/16\",\n    shineColor: [\n      \"rgb(59 130 246 / 0.45)\",\n      \"rgb(99 102 241 / 0.4)\",\n      \"rgb(96 165 250 / 0.45)\",\n    ],\n    learningRecipeId: \"text-to-sql\",\n  },\n  {\n    title: \"Structured Outputs + Jinja Expressions\",\n    description:\n      \"Support ticket triage dataset with structured JSON outputs and Jinja if/else refs.\",\n    icon: FunctionIcon,\n    difficulty: \"Advanced\",\n    learningBadges: [\"Structured LLM\", \"Expression\", \"Jinja\"],\n    surfaceClassName:\n      \"from-cyan-500/15 via-sky-500/5 to-transparent dark:from-cyan-400/30 dark:via-sky-400/14 dark:to-cyan-950/16\",\n    shineColor: [\n      \"rgb(6 182 212 / 0.45)\",\n      \"rgb(56 189 248 / 0.4)\",\n      \"rgb(34 211 238 / 0.45)\",\n    ],\n    learningRecipeId: \"structured-outputs-jinja\",\n  },\n];\n\nconst LEARNING_RECIPE_BY_ID = new Map(\n  LEARNING_RECIPES.map((recipe) => [recipe.id, recipe]),\n);\n\nfunction formatRelativeTime(value: number): string {\n  const now = Date.now();\n  const diffMs = Math.max(0, now - value);\n  const minute = 60 * 1000;\n  const hour = 60 * minute;\n  const day = 24 * hour;\n  const week = 7 * day;\n\n  if (diffMs < minute) {\n    return \"just now\";\n  }\n  if (diffMs < hour) {\n    const minutes = Math.floor(diffMs / minute);\n    return `${minutes} minute${minutes === 1 ? \"\" : \"s\"} ago`;\n  }\n  if (diffMs < day) {\n    const hours = Math.floor(diffMs / hour);\n    return `${hours} hour${hours === 1 ? \"\" : \"s\"} ago`;\n  }\n  if (diffMs < week) {\n    const days = Math.floor(diffMs / day);\n    return `${days} day${days === 1 ? \"\" : \"s\"} ago`;\n  }\n  const weeks = Math.floor(diffMs / week);\n  return `${weeks} week${weeks === 1 ? \"\" : \"s\"} ago`;\n}\n\nfunction LearningRecipeCards({\n  onSelect,\n  loadingTemplateId,\n}: {\n  onSelect: (template: TemplateCard) => void;\n  loadingTemplateId: string | null;\n}): ReactElement {\n  return (\n    <div className=\"grid w-full gap-4 sm:grid-cols-2 xl:grid-cols-3\">\n      {TEMPLATE_CARDS.map((template) => {\n        const learningRecipe = template.learningRecipeId\n          ? LEARNING_RECIPE_BY_ID.get(template.learningRecipeId)\n          : undefined;\n        const isReady = Boolean(learningRecipe);\n        const isLoading =\n          template.learningRecipeId !== undefined &&\n          loadingTemplateId === template.learningRecipeId;\n        const isDisabled = !isReady || isLoading || Boolean(loadingTemplateId);\n        const visibleLearningBadges = template.learningBadges.slice(0, 4);\n        const extraLearningBadgeCount = Math.max(\n          0,\n          template.learningBadges.length - 4,\n        );\n        return (\n          <button\n            key={template.title}\n            type=\"button\"\n            disabled={isDisabled}\n            onClick={() => onSelect(template)}\n            className={`group shadow-border relative overflow-hidden rounded-2xl bg-gradient-to-br text-left transition-transform ${template.surfaceClassName} enabled:cursor-pointer enabled:hover:-translate-y-0.5 enabled:hover:shadow-md disabled:cursor-not-allowed disabled:opacity-70`}\n          >\n            <ShineBorder\n              borderWidth={1.2}\n              duration={13}\n              shineColor={template.shineColor}\n            />\n            <div className=\"relative flex h-full min-h-40 flex-col justify-between gap-3 p-4\">\n              <Badge\n                className=\"absolute right-3 top-3\"\n                variant={\n                  template.difficulty === \"Advanced\" ? \"secondary\" : \"outline\"\n                }\n              >\n                {template.difficulty}\n              </Badge>\n              <div className=\"inline-flex size-10 items-center justify-center rounded-xl border border-foreground/10 bg-background/80\">\n                <HugeiconsIcon\n                  icon={template.icon}\n                  className=\"size-5 text-foreground/90\"\n                />\n              </div>\n              <div className=\"space-y-1\">\n                <p className=\"line-clamp-2 text-sm font-semibold leading-tight text-foreground\">\n                  {template.title}\n                </p>\n                <p className=\"line-clamp-2 text-xs text-muted-foreground\">\n                  {template.description}\n                </p>\n              </div>\n              <div className=\"flex items-center gap-1 overflow-hidden whitespace-nowrap\">\n                {isLoading ? (\n                  <Badge variant=\"outline\">Loading...</Badge>\n                ) : (\n                  <>\n                    {visibleLearningBadges.map((badge) => (\n                      <Badge\n                        key={`${template.title}-${badge}`}\n                        variant=\"outline\"\n                        className=\"h-5 shrink-0 px-1.5 text-[10px]\"\n                      >\n                        {badge}\n                      </Badge>\n                    ))}\n                    {extraLearningBadgeCount > 0 ? (\n                      <Badge\n                        variant=\"outline\"\n                        className=\"h-5 shrink-0 px-1.5 text-[10px]\"\n                      >\n                        +{extraLearningBadgeCount}\n                      </Badge>\n                    ) : null}\n                    {isReady ? null : (\n                      <Badge\n                        variant=\"secondary\"\n                        className=\"h-5 shrink-0 px-1.5 text-[10px]\"\n                      >\n                        Soon\n                      </Badge>\n                    )}\n                  </>\n                )}\n              </div>\n            </div>\n          </button>\n        );\n      })}\n    </div>\n  );\n}\n\nexport function DataRecipesPage(): ReactElement {\n  const navigate = useNavigate();\n  const { recipes, ready } = useRecipes();\n  const [creatingRecipe, setCreatingRecipe] = useState(false);\n  const [learningDialogOpen, setLearningDialogOpen] = useState(false);\n  const [loadingTemplateId, setLoadingTemplateId] = useState<string | null>(\n    null,\n  );\n\n  useEffect(() => {\n    if (sessionStorage.getItem(OPEN_LEARNING_RECIPES_ON_ARRIVAL_KEY) !== \"1\") {\n      return;\n    }\n    sessionStorage.removeItem(OPEN_LEARNING_RECIPES_ON_ARRIVAL_KEY);\n    setLearningDialogOpen(true);\n  }, []);\n\n  async function openNewRecipe(): Promise<void> {\n    if (creatingRecipe || loadingTemplateId) {\n      return;\n    }\n    setCreatingRecipe(true);\n    try {\n      const recipe = await createRecipeDraft();\n      primeRecipeCache(recipe);\n      await navigate({\n        to: \"/data-recipes/$recipeId\",\n        params: { recipeId: recipe.id },\n      });\n    } finally {\n      setCreatingRecipe(false);\n    }\n  }\n\n  async function openLearningRecipe(template: TemplateCard): Promise<void> {\n    if (creatingRecipe || loadingTemplateId) {\n      return;\n    }\n    if (!template.learningRecipeId) {\n      toastError(\"Learning recipe not ready yet.\");\n      return;\n    }\n    const recipeTemplate = LEARNING_RECIPE_BY_ID.get(template.learningRecipeId);\n    if (!recipeTemplate) {\n      toastError(\"Learning recipe not found.\");\n      return;\n    }\n\n    setLoadingTemplateId(template.learningRecipeId);\n    try {\n      const payload = await recipeTemplate.loadPayload();\n      const recipe = await createRecipeFromLearningRecipe({\n        templateId: recipeTemplate.id,\n        templateTitle: recipeTemplate.title,\n        payload,\n      });\n      primeRecipeCache(recipe);\n      setLearningDialogOpen(false);\n      await navigate({\n        to: \"/data-recipes/$recipeId\",\n        params: { recipeId: recipe.id },\n      });\n    } catch (error) {\n      toastError(\n        \"Failed to start learning recipe.\",\n        error instanceof Error ? error.message : undefined,\n      );\n    } finally {\n      setLoadingTemplateId(null);\n    }\n  }\n\n  function openRecipe(recipe: (typeof recipes)[number]): void {\n    primeRecipeCache(recipe);\n    navigate({\n      to: \"/data-recipes/$recipeId\",\n      params: { recipeId: recipe.id },\n    }).catch(() => undefined);\n  }\n\n  async function handleDeleteRecipe(recipeId: string): Promise<void> {\n    await deleteRecipe(recipeId);\n  }\n\n  const isBusy = creatingRecipe || Boolean(loadingTemplateId);\n\n  return (\n    <div className=\"min-h-screen bg-background\">\n      <main className=\"mx-auto w-full max-w-7xl px-6 py-8\">\n        <div className=\"flex items-center justify-between gap-4\">\n          <div>\n            <h1 className=\"text-2xl font-semibold tracking-tight\">\n              Data Recipes\n            </h1>\n            <p className=\"mt-1 text-sm text-muted-foreground\">\n              Create and manage local recipe workflows.\n            </p>\n          </div>\n          <DropdownMenu>\n            <DropdownMenuTrigger asChild={true}>\n              <Button type=\"button\" disabled={isBusy}>\n                <HugeiconsIcon icon={PlusSignIcon} className=\"size-4\" />\n                New Recipe\n                <HugeiconsIcon icon={ArrowDown01Icon} className=\"size-4\" />\n              </Button>\n            </DropdownMenuTrigger>\n            <DropdownMenuContent align=\"end\">\n              <DropdownMenuItem\n                onSelect={() => {\n                  openNewRecipe().catch(() => undefined);\n                }}\n              >\n                <HugeiconsIcon icon={PlusSignIcon} className=\"size-4\" />\n                Start Empty\n              </DropdownMenuItem>\n              <DropdownMenuItem\n                onSelect={() => {\n                  setLearningDialogOpen(true);\n                }}\n              >\n                <HugeiconsIcon icon={CookBookIcon} className=\"size-4\" />\n                Start from Learning Recipe\n              </DropdownMenuItem>\n            </DropdownMenuContent>\n          </DropdownMenu>\n        </div>\n\n        {!ready ? (\n          <div className=\"mt-8 rounded-2xl border border-border/70 bg-card px-6 py-10 text-center\">\n            <p className=\"text-sm font-medium text-foreground\">\n              Loading recipes\n            </p>\n            <p className=\"mt-1 text-xs text-muted-foreground\">\n              Fetching your saved recipes and learning templates.\n            </p>\n          </div>\n        ) : recipes.length === 0 ? (\n          <Empty className=\"mt-8 border border-dashed border-border/70\">\n            <EmptyHeader>\n              <EmptyMedia variant=\"icon\">\n                <HugeiconsIcon icon={CookBookIcon} className=\"size-5\" />\n              </EmptyMedia>\n              <EmptyTitle>No recipes yet</EmptyTitle>\n              <EmptyDescription>\n                Browse Learning Recipes below to understand how recipe workflows\n                work.\n              </EmptyDescription>\n            </EmptyHeader>\n            <EmptyContent className=\"max-w-6xl items-stretch\">\n              {/*<Button*/}\n              {/*  type=\"button\"*/}\n              {/*  variant=\"secondary\"*/}\n              {/*  className=\"mx-auto\"*/}\n              {/*  onClick={() => setLearningDialogOpen(true)}*/}\n              {/*  disabled={isBusy}*/}\n              {/*>*/}\n              {/*  <HugeiconsIcon icon={CookBookIcon} className=\"size-4\" />*/}\n              {/*  Start Tutorial*/}\n              {/*</Button>*/}\n              <LearningRecipeCards\n                onSelect={(template) => {\n                  openLearningRecipe(template).catch(() => undefined);\n                }}\n                loadingTemplateId={loadingTemplateId}\n              />\n            </EmptyContent>\n          </Empty>\n        ) : (\n          <div className=\"mt-8 space-y-2\">\n            {recipes.map((recipe) => (\n              <div\n                key={recipe.id}\n                className=\"flex items-center gap-3 rounded-xl border bg-card px-4 py-3\"\n              >\n                <button\n                  type=\"button\"\n                  className=\"flex min-w-0 flex-1 items-center gap-3 text-left\"\n                  onClick={() => openRecipe(recipe)}\n                >\n                  <div className=\"flex size-9 shrink-0 items-center justify-center rounded-lg border border-border/70 bg-muted/20\">\n                    <HugeiconsIcon\n                      icon={CookBookIcon}\n                      className=\"size-4 text-muted-foreground\"\n                    />\n                  </div>\n                  <div className=\"min-w-0\">\n                    <div className=\"flex items-center gap-2\">\n                      <p className=\"truncate text-sm font-medium\">\n                        {recipe.name}\n                      </p>\n                      {recipe.learningRecipeId ? (\n                        <Badge variant=\"outline\">Learning Recipe</Badge>\n                      ) : null}\n                    </div>\n                    <p className=\"text-xs text-muted-foreground\">\n                      Last updated {formatRelativeTime(recipe.updatedAt)} |\n                      Created {formatRelativeTime(recipe.createdAt)}\n                    </p>\n                  </div>\n                </button>\n                <Button\n                  type=\"button\"\n                  variant=\"ghost\"\n                  size=\"icon\"\n                  className=\"size-8\"\n                  onClick={() => {\n                    handleDeleteRecipe(recipe.id).catch(() => undefined);\n                  }}\n                  aria-label={`Delete ${recipe.name}`}\n                >\n                  <HugeiconsIcon icon={Delete02Icon} className=\"size-4\" />\n                </Button>\n              </div>\n            ))}\n          </div>\n        )}\n      </main>\n\n      <Dialog open={learningDialogOpen} onOpenChange={setLearningDialogOpen}>\n        <DialogContent\n          className=\"sm:max-w-5xl\"\n          overlayClassName=\"bg-background/45 supports-backdrop-filter:backdrop-blur-[1px]\"\n        >\n          <DialogHeader>\n            <DialogTitle>Learning Recipes</DialogTitle>\n            <DialogDescription>\n              Start from a prebuilt recipe to learn patterns, then edit and run.\n            </DialogDescription>\n          </DialogHeader>\n          <LearningRecipeCards\n            onSelect={(template) => {\n              openLearningRecipe(template).catch(() => undefined);\n            }}\n            loadingTemplateId={loadingTemplateId}\n          />\n        </DialogContent>\n      </Dialog>\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/data-recipes/pages/edit-recipe-page.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Button } from \"@/components/ui/button\";\nimport { RecipeStudioPage, type RecipePayload } from \"@/features/recipe-studio\";\nimport { useNavigate } from \"@tanstack/react-router\";\nimport type { ReactElement } from \"react\";\nimport { useCallback, useEffect, useState } from \"react\";\nimport { getCachedRecipe, getRecipe, primeRecipeCache, saveRecipe } from \"../data/recipes-db\";\nimport type { RecipeRecord } from \"../types\";\n\ntype EditRecipePageProps = {\n  recipeId: string;\n};\n\ntype LoadState =\n  | { status: \"loading\" }\n  | { status: \"missing\" }\n  | { status: \"ready\"; record: RecipeRecord };\n\nfunction RecipeLoadState({\n  title,\n  description,\n  onBack,\n}: {\n  title: string;\n  description: string;\n  onBack: () => void;\n}): ReactElement {\n  return (\n    <div className=\"min-h-screen bg-background\">\n      <main className=\"mx-auto flex min-h-[70vh] w-full max-w-4xl items-center justify-center px-6 py-8\">\n        <div className=\"w-full rounded-2xl border bg-card p-8 text-center\">\n          <h1 className=\"text-lg font-semibold\">{title}</h1>\n          <p className=\"mt-2 text-sm text-muted-foreground\">{description}</p>\n          <Button type=\"button\" variant=\"outline\" className=\"mt-5\" onClick={onBack}>\n            Back to Recipes\n          </Button>\n        </div>\n      </main>\n    </div>\n  );\n}\n\nexport function EditRecipePage({ recipeId }: EditRecipePageProps): ReactElement {\n  const navigate = useNavigate();\n  const [loadState, setLoadState] = useState<LoadState>(() => {\n    const cachedRecipe = getCachedRecipe(recipeId);\n    if (cachedRecipe) {\n      return { status: \"ready\", record: cachedRecipe };\n    }\n    return { status: \"loading\" };\n  });\n\n  useEffect(() => {\n    let active = true;\n    const cachedRecipe = getCachedRecipe(recipeId);\n    if (cachedRecipe) {\n      setLoadState({ status: \"ready\", record: cachedRecipe });\n    } else {\n      setLoadState({ status: \"loading\" });\n    }\n\n    void getRecipe(recipeId).then((record) => {\n      if (!active) {\n        return;\n      }\n      if (!record) {\n        setLoadState({ status: \"missing\" });\n        return;\n      }\n      primeRecipeCache(record);\n      setLoadState({ status: \"ready\", record });\n    });\n    return () => {\n      active = false;\n    };\n  }, [recipeId]);\n\n  const handlePersist = useCallback(\n    async (input: { id: string | null; name: string; payload: RecipePayload }) => {\n      const record = await saveRecipe({\n        id: input.id ?? recipeId,\n        name: input.name,\n        payload: input.payload,\n      });\n      primeRecipeCache(record);\n      return { id: record.id, updatedAt: record.updatedAt };\n    },\n    [recipeId],\n  );\n\n  if (loadState.status === \"loading\") {\n    return (\n      <RecipeLoadState\n        title=\"Loading recipe...\"\n        description=\"Please wait while we load your recipe.\"\n        onBack={() => void navigate({ to: \"/data-recipes\" })}\n      />\n    );\n  }\n\n  if (loadState.status === \"missing\") {\n    return (\n      <RecipeLoadState\n        title=\"Recipe not found\"\n        description=\"This recipe may have been deleted.\"\n        onBack={() => void navigate({ to: \"/data-recipes\" })}\n      />\n    );\n  }\n\n  return (\n    <RecipeStudioPage\n      key={loadState.record.id}\n      recipeId={loadState.record.id}\n      initialRecipeName={loadState.record.name}\n      initialPayload={loadState.record.payload}\n      initialSavedAt={loadState.record.updatedAt}\n      onPersistRecipe={handlePersist}\n    />\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/data-recipes/types.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { RecipePayload } from \"@/features/recipe-studio\";\n\nexport type RecipeRecord = {\n  id: string;\n  name: string;\n  payload: RecipePayload;\n  createdAt: number;\n  updatedAt: number;\n  learningRecipeId?: string;\n  learningRecipeTitle?: string;\n};\n\nexport type SaveRecipeInput = {\n  id?: string | null;\n  name: string;\n  payload: RecipePayload;\n  learningRecipeId?: string;\n  learningRecipeTitle?: string;\n};\n"
  },
  {
    "path": "studio/frontend/src/features/export/anim.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nexport const collapseAnim = {\n  initial: { height: 0, opacity: 0 },\n  animate: { height: \"auto\" as const, opacity: 1 },\n  exit: { height: 0, opacity: 0 },\n  transition: { duration: 0.3, ease: [0.25, 0.1, 0.25, 1] as const },\n};\n"
  },
  {
    "path": "studio/frontend/src/features/export/api/export-api.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { authFetch } from \"@/features/auth\";\n\nasync function readError(response: Response): Promise<string> {\n  try {\n    const payload = (await response.json()) as { detail?: string; message?: string };\n    return payload.detail || payload.message || `Request failed (${response.status})`;\n  } catch {\n    return `Request failed (${response.status})`;\n  }\n}\n\nasync function parseJson<T>(response: Response): Promise<T> {\n  if (!response.ok) {\n    throw new Error(await readError(response));\n  }\n  return (await response.json()) as T;\n}\n\nexport interface CheckpointInfo {\n  display_name: string;\n  path: string;\n  loss?: number | null;\n}\n\nexport interface ModelCheckpoints {\n  name: string;\n  checkpoints: CheckpointInfo[];\n  base_model?: string | null;\n  peft_type?: string | null;\n  lora_rank?: number | null;\n}\n\nexport interface CheckpointListResponse {\n  outputs_dir: string;\n  models: ModelCheckpoints[];\n}\n\nexport interface ExportOperationResponse {\n  success: boolean;\n  message: string;\n  details?: Record<string, unknown> | null;\n}\n\nexport async function fetchCheckpoints(): Promise<CheckpointListResponse> {\n  const response = await authFetch(\"/api/models/checkpoints\");\n  return parseJson<CheckpointListResponse>(response);\n}\n\nexport async function loadCheckpoint(params: {\n  checkpoint_path: string;\n  max_seq_length?: number;\n  load_in_4bit?: boolean;\n  /** Allow loading models with custom code. Only enable for checkpoints you trust. */\n  trust_remote_code?: boolean;\n}): Promise<ExportOperationResponse> {\n  const response = await authFetch(\"/api/export/load-checkpoint\", {\n    method: \"POST\",\n    headers: { \"Content-Type\": \"application/json\" },\n    body: JSON.stringify(params),\n  });\n  return parseJson<ExportOperationResponse>(response);\n}\n\nexport async function exportMerged(params: {\n  save_directory: string;\n  format_type?: string;\n  push_to_hub?: boolean;\n  repo_id?: string | null;\n  hf_token?: string | null;\n  private?: boolean;\n}): Promise<ExportOperationResponse> {\n  const response = await authFetch(\"/api/export/export/merged\", {\n    method: \"POST\",\n    headers: { \"Content-Type\": \"application/json\" },\n    body: JSON.stringify(params),\n  });\n  return parseJson<ExportOperationResponse>(response);\n}\n\nexport async function exportBase(params: {\n  save_directory: string;\n  push_to_hub?: boolean;\n  repo_id?: string | null;\n  hf_token?: string | null;\n  private?: boolean;\n  base_model_id?: string | null;\n}): Promise<ExportOperationResponse> {\n  const response = await authFetch(\"/api/export/export/base\", {\n    method: \"POST\",\n    headers: { \"Content-Type\": \"application/json\" },\n    body: JSON.stringify(params),\n  });\n  return parseJson<ExportOperationResponse>(response);\n}\n\nexport async function exportGGUF(params: {\n  save_directory: string;\n  quantization_method: string;\n  push_to_hub?: boolean;\n  repo_id?: string | null;\n  hf_token?: string | null;\n}): Promise<ExportOperationResponse> {\n  const response = await authFetch(\"/api/export/export/gguf\", {\n    method: \"POST\",\n    headers: { \"Content-Type\": \"application/json\" },\n    body: JSON.stringify(params),\n  });\n  return parseJson<ExportOperationResponse>(response);\n}\n\nexport async function exportLoRA(params: {\n  save_directory: string;\n  push_to_hub?: boolean;\n  repo_id?: string | null;\n  hf_token?: string | null;\n  private?: boolean;\n}): Promise<ExportOperationResponse> {\n  const response = await authFetch(\"/api/export/export/lora\", {\n    method: \"POST\",\n    headers: { \"Content-Type\": \"application/json\" },\n    body: JSON.stringify(params),\n  });\n  return parseJson<ExportOperationResponse>(response);\n}\n\nexport async function cleanupExport(): Promise<ExportOperationResponse> {\n  const response = await authFetch(\"/api/export/cleanup\", { method: \"POST\" });\n  return parseJson<ExportOperationResponse>(response);\n}\n"
  },
  {
    "path": "studio/frontend/src/features/export/components/export-dialog.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Button } from \"@/components/ui/button\";\nimport {\n  Dialog,\n  DialogContent,\n  DialogDescription,\n  DialogFooter,\n  DialogHeader,\n  DialogTitle,\n} from \"@/components/ui/dialog\";\nimport { Input } from \"@/components/ui/input\";\nimport {\n  InputGroup,\n  InputGroupAddon,\n  InputGroupInput,\n} from \"@/components/ui/input-group\";\nimport { Spinner } from \"@/components/ui/spinner\";\nimport { Switch } from \"@/components/ui/switch\";\nimport { AlertCircleIcon, ArrowRight01Icon, CheckmarkCircle02Icon, Key01Icon } from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport { AnimatePresence, motion } from \"motion/react\";\nimport { collapseAnim } from \"../anim\";\nimport { EXPORT_METHODS, type ExportMethod } from \"../constants\";\n\ntype Destination = \"local\" | \"hub\";\n\ninterface ExportDialogProps {\n  open: boolean;\n  onOpenChange: (open: boolean) => void;\n  checkpoint: string | null;\n  exportMethod: ExportMethod | null;\n  quantLevels: string[];\n  estimatedSize: string;\n  baseModelName: string;\n  isAdapter: boolean;\n  destination: Destination;\n  onDestinationChange: (v: Destination) => void;\n  hfUsername: string;\n  onHfUsernameChange: (v: string) => void;\n  modelName: string;\n  onModelNameChange: (v: string) => void;\n  hfToken: string;\n  onHfTokenChange: (v: string) => void;\n  privateRepo: boolean;\n  onPrivateRepoChange: (v: boolean) => void;\n  onExport: () => void;\n  exporting: boolean;\n  exportError: string | null;\n  exportSuccess: boolean;\n}\n\nexport function ExportDialog({\n  open,\n  onOpenChange,\n  checkpoint,\n  exportMethod,\n  quantLevels,\n  estimatedSize: _estimatedSize,\n  baseModelName,\n  isAdapter,\n  destination,\n  onDestinationChange,\n  hfUsername,\n  onHfUsernameChange,\n  modelName,\n  onModelNameChange,\n  hfToken,\n  onHfTokenChange,\n  privateRepo,\n  onPrivateRepoChange,\n  onExport,\n  exporting,\n  exportError,\n  exportSuccess,\n}: ExportDialogProps) {\n  return (\n    <Dialog\n      open={open}\n      onOpenChange={(v) => {\n        if (exporting) return;\n        onOpenChange(v);\n      }}\n    >\n      <DialogContent className=\"sm:max-w-lg\" onInteractOutside={(e) => { if (exporting) e.preventDefault(); }}>\n        {exportSuccess ? (\n          <>\n            <div className=\"flex flex-col items-center gap-3 py-6\">\n              <div className=\"flex size-12 items-center justify-center rounded-full bg-emerald-500/10\">\n                <HugeiconsIcon icon={CheckmarkCircle02Icon} className=\"size-6 text-emerald-500\" />\n              </div>\n              <div className=\"text-center\">\n                <h3 className=\"text-lg font-semibold\">Export Complete</h3>\n                <p className=\"mt-1 text-sm text-muted-foreground\">\n                  {destination === \"hub\"\n                    ? \"Model successfully pushed to Hugging Face Hub.\"\n                    : \"Model saved locally.\"}\n                </p>\n              </div>\n            </div>\n            <DialogFooter>\n              <Button onClick={() => onOpenChange(false)}>Done</Button>\n            </DialogFooter>\n          </>\n        ) : (\n          <>\n            <DialogHeader>\n              <DialogTitle>Export Model</DialogTitle>\n              <DialogDescription>\n                Choose where to save your exported model.\n              </DialogDescription>\n            </DialogHeader>\n\n            <div className=\"flex gap-2\">\n              <Button\n                variant={destination === \"local\" ? \"dark\" : \"outline\"}\n                onClick={() => onDestinationChange(\"local\")}\n                disabled={exporting}\n                className=\"flex-1\"\n              >\n                Save Locally\n              </Button>\n              <Button\n                variant={destination === \"hub\" ? \"dark\" : \"outline\"}\n                onClick={() => onDestinationChange(\"hub\")}\n                disabled={exporting}\n                className=\"flex-1\"\n              >\n                Push to Hub\n              </Button>\n            </div>\n\n            <AnimatePresence>\n              {destination === \"hub\" && (\n                <motion.div {...collapseAnim} className=\"overflow-hidden\">\n                  <div className=\"flex flex-col gap-4 px-0.5\">\n                    <div className=\"grid grid-cols-2 gap-3\">\n                      <div className=\"flex flex-col gap-1.5\">\n                        <label className=\"text-xs font-medium text-muted-foreground\">\n                          Username / Org\n                        </label>\n                        <Input\n                          placeholder=\"your-username\"\n                          value={hfUsername}\n                          onChange={(e) => onHfUsernameChange(e.target.value)}\n                          disabled={exporting}\n                        />\n                      </div>\n                      <div className=\"flex flex-col gap-1.5\">\n                        <label className=\"text-xs font-medium text-muted-foreground\">\n                          Model Name\n                        </label>\n                        <Input\n                          placeholder=\"my-model-gguf\"\n                          value={modelName}\n                          onChange={(e) => onModelNameChange(e.target.value)}\n                          disabled={exporting}\n                        />\n                      </div>\n                    </div>\n\n                    <div className=\"flex flex-col gap-1.5\">\n                      <div className=\"flex items-center justify-between\">\n                        <label className=\"text-xs font-medium text-muted-foreground\">\n                          HF Write Token\n                        </label>\n                        <a\n                          href=\"https://huggingface.co/settings/tokens\"\n                          target=\"_blank\"\n                          rel=\"noopener noreferrer\"\n                          className=\"flex items-center gap-1 text-[11px] text-emerald-600 hover:text-emerald-700 transition-colors\"\n                        >\n                          Get token\n                          <HugeiconsIcon\n                            icon={ArrowRight01Icon}\n                            className=\"size-3\"\n                          />\n                        </a>\n                      </div>\n                      <InputGroup>\n                        <InputGroupAddon>\n                          <HugeiconsIcon icon={Key01Icon} className=\"size-4\" />\n                        </InputGroupAddon>\n                        <InputGroupInput\n                          type=\"password\"\n                          autoComplete=\"new-password\"\n                          name=\"hf-token\"\n                          placeholder=\"hf_...\"\n                          value={hfToken}\n                          onChange={(e) => onHfTokenChange(e.target.value)}\n                          disabled={exporting}\n                        />\n                      </InputGroup>\n                      <p className=\"text-[11px] text-muted-foreground/70\">\n                        Leave empty if already logged in via CLI.\n                      </p>\n                    </div>\n\n                    <div className=\"flex items-center gap-3\">\n                      <Switch\n                        id=\"private-repo\"\n                        size=\"sm\"\n                        checked={privateRepo}\n                        onCheckedChange={onPrivateRepoChange}\n                        disabled={exporting}\n                      />\n                      <label\n                        htmlFor=\"private-repo\"\n                        className=\"text-xs font-medium cursor-pointer\"\n                      >\n                        Private Repository\n                      </label>\n                    </div>\n                  </div>\n                </motion.div>\n              )}\n            </AnimatePresence>\n\n            {/* Error banner */}\n            {exportError && (\n              <div className=\"flex items-start gap-2 rounded-lg bg-destructive/10 p-3 text-sm text-destructive\">\n                <HugeiconsIcon icon={AlertCircleIcon} className=\"size-4 mt-0.5 shrink-0\" />\n                <span>{exportError}</span>\n              </div>\n            )}\n\n            {/* Summary */}\n            <div className=\"rounded-xl bg-muted/50 p-3 text-xs text-muted-foreground flex flex-col gap-1\">\n              <div className=\"flex justify-between\">\n                <span>Base Model</span>\n                <span className=\"font-medium text-foreground\">{baseModelName}</span>\n              </div>\n              <div className=\"flex justify-between\">\n                <span>{isAdapter ? \"Checkpoint\" : \"Model\"}</span>\n                <span className=\"font-medium text-foreground\">{checkpoint}</span>\n              </div>\n              <div className=\"flex justify-between\">\n                <span>Export Method</span>\n                <span className=\"font-medium text-foreground\">\n                  {EXPORT_METHODS.find((m) => m.value === exportMethod)?.title}\n                </span>\n              </div>\n              {exportMethod === \"gguf\" && quantLevels.length > 0 && (\n                <div className=\"flex justify-between\">\n                  <span>Quantizations</span>\n                  <span className=\"font-medium text-foreground\">\n                    {quantLevels.join(\", \")}\n                  </span>\n                </div>\n              )}\n              {/* TODO: unhide once estimated size comes from the backend API */}\n              {/* <div className=\"flex justify-between\">\n            <span>Est. size</span>\n            <span className=\"font-medium text-foreground\">{estimatedSize}</span>\n          </div> */}\n            </div>\n\n            <DialogFooter>\n              <Button\n                variant=\"outline\"\n                onClick={() => onOpenChange(false)}\n                disabled={exporting}\n              >\n                Cancel\n              </Button>\n              <Button onClick={onExport} disabled={exporting}>\n                {exporting ? (\n                  <span className=\"flex items-center gap-2\">\n                    <Spinner className=\"size-4\" />\n                    Exporting…\n                  </span>\n                ) : (\n                  \"Start Export\"\n                )}\n              </Button>\n            </DialogFooter>\n          </>\n        )}\n      </DialogContent>\n    </Dialog>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/export/components/method-picker.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Badge } from \"@/components/ui/badge\";\nimport {\n  Tooltip,\n  TooltipContent,\n  TooltipTrigger,\n} from \"@/components/ui/tooltip\";\nimport { cn } from \"@/lib/utils\";\nimport {\n  CheckmarkCircle01Icon,\n  InformationCircleIcon,\n} from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport { EXPORT_METHODS, type ExportMethod } from \"../constants\";\n\ninterface MethodPickerProps {\n  value: ExportMethod | null;\n  onChange: (v: ExportMethod) => void;\n}\n\nexport function MethodPicker({ value, onChange }: MethodPickerProps) {\n  return (\n    <div data-tour=\"export-method\" className=\"flex flex-col gap-3\">\n      <span className=\"flex items-center gap-1.5 text-xs font-medium text-muted-foreground\">\n        Export Method\n        <Tooltip>\n          <TooltipTrigger asChild={true}>\n            <button\n              type=\"button\"\n              className=\"text-foreground/70 hover:text-foreground\"\n            >\n              <HugeiconsIcon icon={InformationCircleIcon} className=\"size-3\" />\n            </button>\n          </TooltipTrigger>\n          <TooltipContent>\n            How your model is packaged for deployment.{\" \"}\n            <a\n              href=\"https://unsloth.ai/docs/basics/inference-and-deployment\"\n              target=\"_blank\"\n              rel=\"noopener noreferrer\"\n              className=\"text-primary underline\"\n            >\n              Read more\n            </a>\n          </TooltipContent>\n        </Tooltip>\n      </span>\n      <div className=\"grid grid-cols-3 gap-3\">\n        {EXPORT_METHODS.map((m) => {\n          const selected = value === m.value;\n          return (\n            <button\n              key={m.value}\n              type=\"button\"\n              onClick={() => onChange(m.value)}\n              className={cn(\n                \"flex items-start gap-3 rounded-xl p-4 text-left ring-1 transition-all\",\n                selected\n                  ? \"ring-2 ring-primary bg-primary/5\"\n                  : \"ring-border hover:-translate-y-0.5 hover:shadow-sm\",\n              )}\n            >\n              <div\n                className={cn(\n                  \"mt-0.5 flex size-5 shrink-0 items-center justify-center rounded-full border-2 transition-colors\",\n                  selected\n                    ? \"border-primary bg-primary\"\n                    : \"border-muted-foreground/30\",\n                )}\n              >\n                {selected && (\n                  <HugeiconsIcon\n                    icon={CheckmarkCircle01Icon}\n                    className=\"size-3 text-primary-foreground\"\n                  />\n                )}\n              </div>\n              <div className=\"flex flex-col gap-1\">\n                <div className=\"flex items-center gap-2\">\n                  <span className=\"text-sm font-medium\">{m.title}</span>\n                  <Tooltip>\n                    <TooltipTrigger asChild={true}>\n                      <button\n                        type=\"button\"\n                        className=\"shrink-0 text-foreground/50 hover:text-foreground cursor-help\"\n                        onClick={(e) => e.stopPropagation()}\n                        aria-label={`${m.title} info`}\n                      >\n                        <HugeiconsIcon\n                          icon={InformationCircleIcon}\n                          className=\"size-3\"\n                        />\n                      </button>\n                    </TooltipTrigger>\n                    <TooltipContent className=\"max-w-xs\">\n                      {m.tooltip}{\" \"}\n                      <a\n                        href={\n                          m.value === \"gguf\"\n                            ? \"https://unsloth.ai/docs/basics/inference-and-deployment/saving-to-gguf\"\n                            : \"https://unsloth.ai/docs/basics/inference-and-deployment\"\n                        }\n                        target=\"_blank\"\n                        rel=\"noopener noreferrer\"\n                        className=\"text-primary underline\"\n                      >\n                        Read more\n                      </a>\n                    </TooltipContent>\n                  </Tooltip>\n                  {m.badge && (\n                    <Badge\n                      variant=\"secondary\"\n                      className=\"text-[10px] px-1.5 py-0\"\n                    >\n                      {m.badge}\n                    </Badge>\n                  )}\n                </div>\n                <span className=\"text-xs text-muted-foreground\">\n                  {m.description}\n                </span>\n              </div>\n            </button>\n          );\n        })}\n      </div>\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/export/components/quant-picker.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport {\n  Tooltip,\n  TooltipContent,\n  TooltipTrigger,\n} from \"@/components/ui/tooltip\";\nimport { cn } from \"@/lib/utils\";\nimport {\n  CheckmarkCircle01Icon,\n  InformationCircleIcon,\n  LayersIcon,\n} from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport { QUANT_OPTIONS } from \"../constants\";\n\ninterface QuantPickerProps {\n  value: string[];\n  onChange: (v: string[]) => void;\n}\n\nexport function QuantPicker({ value, onChange }: QuantPickerProps) {\n  const toggle = (qv: string) => {\n    onChange(\n      value.includes(qv) ? value.filter((q) => q !== qv) : [...value, qv],\n    );\n  };\n\n  return (\n    <div className=\"flex flex-col gap-3\">\n      <div className=\"flex items-center gap-2\">\n        <HugeiconsIcon\n          icon={LayersIcon}\n          className=\"size-4 text-muted-foreground\"\n        />\n        <span className=\"text-xs font-medium text-muted-foreground\">\n          Quantization Levels\n        </span>\n        <Tooltip>\n          <TooltipTrigger asChild={true}>\n            <button\n              type=\"button\"\n              className=\"text-foreground/70 hover:text-foreground\"\n            >\n              <HugeiconsIcon icon={InformationCircleIcon} className=\"size-3\" />\n            </button>\n          </TooltipTrigger>\n          <TooltipContent className=\"max-w-xs\">\n            Lower quantization (Q2, Q3) = smaller files but reduced quality.\n            Q4–Q5 is a good balance.{\" \"}\n            <a\n              href=\"https://unsloth.ai/docs/basics/inference-and-deployment/saving-to-gguf\"\n              target=\"_blank\"\n              rel=\"noopener noreferrer\"\n              className=\"text-primary underline\"\n            >\n              Read more\n            </a>\n          </TooltipContent>\n        </Tooltip>\n        <span className=\"text-[11px] text-muted-foreground/70\">\n          — select one or more\n        </span>\n      </div>\n      <div className=\"flex flex-wrap gap-2 py-1 pl-1\">\n        {QUANT_OPTIONS.map((q) => {\n          const active = value.includes(q.value);\n          return (\n            <button\n              key={q.value}\n              type=\"button\"\n              onClick={() => toggle(q.value)}\n              className={cn(\n                \"inline-flex items-center gap-1.5 rounded-full px-3 py-1.5 text-xs font-medium ring-1 transition-all\",\n                active\n                  ? \"ring-primary bg-primary/10 text-foreground\"\n                  : \"ring-border text-muted-foreground hover:text-foreground hover:ring-foreground/20\",\n              )}\n            >\n              {active && (\n                <HugeiconsIcon\n                  icon={CheckmarkCircle01Icon}\n                  className=\"size-3 text-primary\"\n                />\n              )}\n              {q.label}\n              <span className=\"text-[10px] opacity-60\">{q.size}</span>\n              {q.recommended && !active && (\n                <span className=\"rounded-full bg-emerald-100 px-1.5 py-0 text-[9px] font-semibold text-emerald-700 dark:bg-emerald-900 dark:text-emerald-300\">\n                  rec\n                </span>\n              )}\n            </button>\n          );\n        })}\n      </div>\n      {value.length > 0 && (\n        <div className=\"flex items-center gap-3\">\n          <span className=\"text-[11px] text-muted-foreground\">\n            {value.length} selected\n          </span>\n          <button\n            type=\"button\"\n            onClick={() => onChange([])}\n            className=\"text-[11px] text-muted-foreground/70 hover:text-foreground transition-colors\"\n          >\n            Clear all\n          </button>\n        </div>\n      )}\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/export/constants.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { TrainingMethod } from \"@/types/training\";\n\nexport type ExportMethod = \"merged\" | \"lora\" | \"gguf\";\n\nexport const EXPORT_METHODS: {\n  value: ExportMethod;\n  title: string;\n  description: string;\n  tooltip: string;\n  badge?: string;\n}[] = [\n  {\n    value: \"merged\",\n    title: \"Merged Model\",\n    description: \"Full 16-bit model ready for inference.\",\n    tooltip:\n      \"Merges adapter weights into the base model. Best for direct deployment with vLLM or TGI.\",\n  },\n  {\n    value: \"lora\",\n    title: \"LoRA Only\",\n    description: \"Lightweight adapter files (~100 MB). Needs base model.\",\n    tooltip:\n      \"Exports only the trained adapter. Pair with the base model at inference time to save storage.\",\n  },\n  {\n    value: \"gguf\",\n    title: \"GGUF / Llama.cpp\",\n    description: \"Quantized formats for local AI runners.\",\n    tooltip:\n      \"Converts to GGUF for llama.cpp, Ollama, and other local runners. Pick a quantization level below.\",\n  },\n];\n\nexport const QUANT_OPTIONS = [\n  { value: \"q3_k_m\", label: \"Q3_K_M\", size: \"~3.5 GB\" },\n  { value: \"q4_0\", label: \"Q4_0\", size: \"~4.1 GB\" },\n  { value: \"q4_k_m\", label: \"Q4_K_M\", size: \"~4.8 GB\", recommended: true },\n  { value: \"q5_0\", label: \"Q5_0\", size: \"~5.0 GB\" },\n  { value: \"q5_k_m\", label: \"Q5_K_M\", size: \"~5.6 GB\" },\n  { value: \"q8_0\", label: \"Q8_0\", size: \"~8.2 GB\" },\n  { value: \"f16\", label: \"F16\", size: \"~14.2 GB\" },\n  { value: \"f32\", label: \"F32\", size: \"~28.4 GB\" },\n];\n\nexport function getEstimatedSize(\n  method: ExportMethod | null,\n  quantLevels: string[],\n) {\n  const sizeOf = (v: string) =>\n    QUANT_OPTIONS.find((q) => q.value === v)?.size ?? \"—\";\n  if (method === \"gguf\" && quantLevels.length > 0) {\n    if (quantLevels.length === 1) {\n      return sizeOf(quantLevels[0]);\n    }\n    const total = quantLevels\n      .map((q) => Number.parseFloat(sizeOf(q).replace(/[^0-9.]/g, \"\")))\n      .reduce((a, b) => a + b, 0);\n    return `~${total.toFixed(1)} GB (${quantLevels.length} files)`;\n  }\n  if (method === \"merged\") {\n    return \"~14.2 GB\";\n  }\n  if (method === \"lora\") {\n    return \"~100 MB\";\n  }\n  return \"—\";\n}\n\nexport const METHOD_LABELS: Record<TrainingMethod, string> = {\n  qlora: \"QLoRA\",\n  lora: \"LoRA\",\n  full: \"Full Fine-tune\",\n};\n\nexport const GUIDE_STEPS = [\n  \"Select a training checkpoint to export from\",\n  \"Choose an export method based on your use case\",\n  \"Pick quantization levels if using GGUF\",\n  \"Click Export and choose your destination\",\n  \"Test your model and compare outputs in Chat\",\n];\n"
  },
  {
    "path": "studio/frontend/src/features/export/export-page.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { SectionCard } from \"@/components/section-card\";\nimport { Button } from \"@/components/ui/button\";\nimport {\n  Select,\n  SelectContent,\n  SelectItem,\n  SelectTrigger,\n  SelectValue,\n} from \"@/components/ui/select\";\nimport { Separator } from \"@/components/ui/separator\";\nimport { Spinner } from \"@/components/ui/spinner\";\nimport {\n  Tooltip,\n  TooltipContent,\n  TooltipTrigger,\n} from \"@/components/ui/tooltip\";\nimport { useTrainingConfigStore } from \"@/features/training\";\nimport { AlertCircleIcon, InformationCircleIcon, PackageIcon } from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport { AnimatePresence, motion } from \"motion/react\";\nimport { useCallback, useEffect, useMemo, useState } from \"react\";\nimport { useShallow } from \"zustand/react/shallow\";\nimport { collapseAnim } from \"./anim\";\nimport type { ModelCheckpoints } from \"./api/export-api\";\nimport {\n  cleanupExport,\n  exportBase,\n  exportGGUF,\n  exportLoRA,\n  exportMerged,\n  fetchCheckpoints,\n  loadCheckpoint,\n} from \"./api/export-api\";\nimport { ExportDialog } from \"./components/export-dialog\";\nimport { MethodPicker } from \"./components/method-picker\";\nimport { QuantPicker } from \"./components/quant-picker\";\nimport {\n  type ExportMethod,\n  GUIDE_STEPS,\n  getEstimatedSize,\n} from \"./constants\";\nimport { GuidedTour, useGuidedTourController } from \"@/features/tour\";\nimport { exportTourSteps } from \"./tour\";\n\nexport function ExportPage() {\n  const { hfToken, setHfToken } = useTrainingConfigStore(\n    useShallow((s) => ({\n      hfToken: s.hfToken,\n      setHfToken: s.setHfToken,\n    })),\n  );\n\n  // ---- API-driven checkpoint state ----\n  const [models, setModels] = useState<ModelCheckpoints[]>([]);\n  const [loadingCheckpoints, setLoadingCheckpoints] = useState(true);\n  const [checkpointError, setCheckpointError] = useState<string | null>(null);\n\n  const [selectedModelIdx, setSelectedModelIdx] = useState<string | null>(null);\n  const [checkpoint, setCheckpoint] = useState<string | null>(null);\n\n  const [exportMethod, setExportMethod] = useState<ExportMethod | null>(null);\n  const [quantLevels, setQuantLevels] = useState<string[]>([]);\n  const [dialogOpen, setDialogOpen] = useState(false);\n\n  const [destination, setDestination] = useState<\"local\" | \"hub\">(\"local\");\n  const [hfUsername, setHfUsername] = useState(\"\");\n  const [modelName, setModelName] = useState(\"\");\n  const [privateRepo, setPrivateRepo] = useState(false);\n\n  const [exporting, setExporting] = useState(false);\n  const [exportError, setExportError] = useState<string | null>(null);\n  const [exportSuccess, setExportSuccess] = useState(false);\n\n  const tour = useGuidedTourController({\n    id: \"export\",\n    steps: exportTourSteps,\n  });\n\n  // ---- Fetch checkpoints on mount ----\n  useEffect(() => {\n    let cancelled = false;\n    setLoadingCheckpoints(true);\n    setCheckpointError(null);\n    fetchCheckpoints()\n      .then((data) => {\n        if (!cancelled) {\n          setModels(data.models);\n        }\n      })\n      .catch((err) => {\n        if (!cancelled) {\n          setCheckpointError(\n            err instanceof Error ? err.message : \"Failed to load checkpoints\",\n          );\n        }\n      })\n      .finally(() => {\n        if (!cancelled) setLoadingCheckpoints(false);\n      });\n    return () => {\n      cancelled = true;\n    };\n  }, []);\n\n  // ---- Derived state ----\n  const selectedModelData = useMemo(\n    () =>\n      selectedModelIdx != null\n        ? models.find((m) => m.name === selectedModelIdx) ?? null\n        : null,\n    [models, selectedModelIdx],\n  );\n\n  const checkpointsForModel = useMemo(\n    () => selectedModelData?.checkpoints ?? [],\n    [selectedModelData],\n  );\n\n  // Derive training info from selected model's API metadata\n  const baseModelName = selectedModelData?.base_model ?? \"—\";\n  const isAdapter = !!selectedModelData?.peft_type;\n  const loraRank = selectedModelData?.lora_rank ?? null;\n  const trainingMethodLabel = selectedModelData?.peft_type\n    ? \"LoRA / QLoRA\"\n    : \"Full Fine-tune\";\n\n  // Reset checkpoint when the selected model changes\n  useEffect(() => {\n    setCheckpoint(null);\n  }, [selectedModelIdx]);\n\n  const handleMethodChange = (method: ExportMethod) => {\n    setExportMethod(method);\n    if (method !== \"gguf\") {\n      setQuantLevels([]);\n    }\n  };\n\n  const estimatedSize = getEstimatedSize(exportMethod, quantLevels);\n  const canExport =\n    checkpoint &&\n    exportMethod &&\n    (exportMethod !== \"gguf\" || quantLevels.length > 0);\n\n  // ---- Export handler ----\n  const handleExport = useCallback(async () => {\n    if (!checkpoint) return;\n\n    const selectedCp = checkpointsForModel.find(\n      (cp) => cp.display_name === checkpoint,\n    );\n    if (!selectedCp) return;\n\n    setExporting(true);\n    setExportError(null);\n    setExportSuccess(false);\n\n    // For GGUF, use a flat folder like \"exports/gemma-3-4b-it-finetune-gguf\"\n    // For other formats, nest under training-run/checkpoint\n    const saveDir =\n      exportMethod === \"gguf\"\n        ? `${baseModelName.split(\"/\").pop() ?? selectedModelIdx ?? \"model\"}-finetune-gguf`\n        : `${selectedModelIdx ?? \"model\"}/${checkpoint}`;\n    const pushToHub = destination === \"hub\";\n    const repoId = pushToHub && hfUsername && modelName\n      ? `${hfUsername}/${modelName}`\n      : undefined;\n    const token = pushToHub && hfToken ? hfToken : undefined;\n\n    try {\n      // 1. Load checkpoint\n      await loadCheckpoint({ checkpoint_path: selectedCp.path });\n\n      // 2. Run export based on method\n      if (exportMethod === \"merged\") {\n        if (isAdapter) {\n          await exportMerged({\n            save_directory: saveDir,\n            push_to_hub: pushToHub,\n            repo_id: repoId,\n            hf_token: token,\n            private: privateRepo,\n          });\n        } else {\n          await exportBase({\n            save_directory: saveDir,\n            push_to_hub: pushToHub,\n            repo_id: repoId,\n            hf_token: token,\n            private: privateRepo,\n            base_model_id: selectedModelData?.base_model,\n          });\n        }\n      } else if (exportMethod === \"gguf\") {\n        for (const quant of quantLevels) {\n          await exportGGUF({\n            save_directory: saveDir,\n            quantization_method: quant,\n            push_to_hub: pushToHub,\n            repo_id: repoId,\n            hf_token: token,\n          });\n        }\n      } else if (exportMethod === \"lora\") {\n        await exportLoRA({\n          save_directory: saveDir,\n          push_to_hub: pushToHub,\n          repo_id: repoId,\n          hf_token: token,\n          private: privateRepo,\n        });\n      }\n\n      setExportSuccess(true);\n    } catch (err) {\n      setExportError(\n        err instanceof Error ? err.message : \"Export failed\",\n      );\n    } finally {\n      try {\n        await cleanupExport();\n      } catch {\n        // cleanup is best-effort\n      }\n      setExporting(false);\n    }\n  }, [\n    checkpoint,\n    checkpointsForModel,\n    selectedModelIdx,\n    selectedModelData,\n    exportMethod,\n    isAdapter,\n    quantLevels,\n    destination,\n    hfUsername,\n    modelName,\n    hfToken,\n    privateRepo,\n  ]);\n\n  // ---- Render ----\n  return (\n    <div className=\"min-h-screen bg-background\">\n      <main className=\"mx-auto max-w-7xl px-4 py-4 sm:px-6\">\n        <GuidedTour {...tour.tourProps} />\n\n        <div className=\"mb-8 flex flex-col gap-0.5\">\n          <h1 className=\"text-2xl font-semibold tracking-tight\">\n            Export Model\n          </h1>\n          <p className=\"text-sm text-muted-foreground\">\n            Export your fine-tuned model for deployment\n          </p>\n        </div>\n\n        <SectionCard\n          icon={<HugeiconsIcon icon={PackageIcon} className=\"size-5\" />}\n          title=\"Export Configuration\"\n          description=\"Select checkpoint, method, and quantization\"\n          accent=\"emerald\"\n          featured={true}\n          className=\"shadow-border ring-1 ring-border\"\n        >\n          {/* Loading / error states */}\n          {loadingCheckpoints && (\n            <div className=\"flex items-center gap-2 py-6 justify-center text-sm text-muted-foreground\">\n              <Spinner className=\"size-4\" />\n              Loading checkpoints…\n            </div>\n          )}\n\n          {checkpointError && (\n            <div className=\"flex items-center gap-2 py-6 justify-center text-sm text-destructive\">\n              <HugeiconsIcon icon={AlertCircleIcon} className=\"size-4\" />\n              {checkpointError}\n            </div>\n          )}\n\n          {!loadingCheckpoints && !checkpointError && (\n            <>\n              {/* Top row: Dropdowns + metadata | Guide */}\n              <div className=\"grid grid-cols-1 gap-6 md:grid-cols-2 md:gap-8\">\n                <div className=\"flex flex-col gap-4\">\n                  {/* Training run dropdown */}\n                  <div data-tour=\"export-training-run\" className=\"flex flex-col gap-2\">\n                    <label className=\"flex items-center gap-1.5 text-xs font-medium text-muted-foreground\">\n                      Training Run\n                      <Tooltip>\n                        <TooltipTrigger asChild={true}>\n                          <button\n                            type=\"button\"\n                            className=\"text-foreground/70 hover:text-foreground\"\n                          >\n                            <HugeiconsIcon\n                              icon={InformationCircleIcon}\n                              className=\"size-3\"\n                            />\n                          </button>\n                        </TooltipTrigger>\n                        <TooltipContent>\n                          Select the training run that produced the checkpoints\n                          you want to export.\n                        </TooltipContent>\n                      </Tooltip>\n                    </label>\n                    <Select\n                      value={selectedModelIdx ?? \"\"}\n                      onValueChange={setSelectedModelIdx}\n                    >\n                      <SelectTrigger className=\"w-full\">\n                        <SelectValue\n                          placeholder={\n                            models.length === 0\n                              ? \"No training runs found\"\n                              : \"Select a training run…\"\n                          }\n                        />\n                      </SelectTrigger>\n                      <SelectContent>\n                        {models.map((m) => {\n                          const tsMatch = m.name.match(/_(\\d{10,})$/);\n                          const displayName = tsMatch ? m.name.slice(0, tsMatch.index) : m.name;\n                          const timeStr = tsMatch\n                            ? new Date(Number(tsMatch[1]) * 1000).toLocaleString(undefined, {\n                                dateStyle: \"medium\",\n                                timeStyle: \"short\",\n                              })\n                            : null;\n                          return (\n                            <SelectItem key={m.name} value={m.name}>\n                              <span className=\"flex items-center gap-2\">\n                                {displayName}\n                                <span className=\"text-muted-foreground text-xs\">\n                                  {m.checkpoints.length} checkpoint\n                                  {m.checkpoints.length !== 1 ? \"s\" : \"\"}\n                                </span>\n                                {timeStr && (\n                                  <span className=\"text-muted-foreground text-xs\">\n                                    · {timeStr}\n                                  </span>\n                                )}\n                              </span>\n                            </SelectItem>\n                          );\n                        })}\n                      </SelectContent>\n                    </Select>\n                  </div>\n\n                  {/* Checkpoint dropdown */}\n                  <div data-tour=\"export-checkpoint\" className=\"flex flex-col gap-2\">\n                    <label className=\"flex items-center gap-1.5 text-xs font-medium text-muted-foreground\">\n                      Checkpoint\n                      <Tooltip>\n                        <TooltipTrigger asChild={true}>\n                          <button\n                            type=\"button\"\n                            className=\"text-foreground/70 hover:text-foreground\"\n                          >\n                            <HugeiconsIcon\n                              icon={InformationCircleIcon}\n                              className=\"size-3\"\n                            />\n                          </button>\n                        </TooltipTrigger>\n                        <TooltipContent>\n                          Choose a saved checkpoint to export. Lower loss\n                          generally means better quality.{\" \"}\n                          <a\n                            href=\"https://unsloth.ai/docs/basics/inference-and-deployment\"\n                            target=\"_blank\"\n                            rel=\"noopener noreferrer\"\n                            className=\"text-primary underline\"\n                          >\n                            Read more\n                          </a>\n                        </TooltipContent>\n                      </Tooltip>\n                    </label>\n                    <Select\n                      value={checkpoint ?? \"\"}\n                      onValueChange={setCheckpoint}\n                      disabled={!selectedModelIdx}\n                    >\n                      <SelectTrigger className=\"w-full\">\n                        <SelectValue\n                          placeholder={\n                            !selectedModelIdx\n                              ? \"Select a training run first\"\n                              : checkpointsForModel.length === 0\n                                ? \"No checkpoints found\"\n                                : \"Select a checkpoint…\"\n                          }\n                        />\n                      </SelectTrigger>\n                      <SelectContent>\n                        {checkpointsForModel.map((cp) => (\n                          <SelectItem key={cp.path} value={cp.display_name}>\n                            <span className=\"flex items-center gap-2\">\n                              {cp.display_name}\n                              {cp.loss != null && (\n                                <span className=\"text-muted-foreground text-xs\">\n                                  loss: {cp.loss.toFixed(4)}\n                                </span>\n                              )}\n                            </span>\n                          </SelectItem>\n                        ))}\n                      </SelectContent>\n                    </Select>\n                  </div>\n\n                  <div className=\"rounded-xl bg-muted/50 p-3 flex flex-col gap-2\">\n                    <span className=\"text-[11px] font-medium text-muted-foreground uppercase tracking-wider\">\n                      Training Info\n                    </span>\n                    <div className=\"grid grid-cols-1 gap-x-6 gap-y-1.5 text-xs sm:grid-cols-2\">\n                      <div className=\"flex justify-between\">\n                        <span className=\"text-muted-foreground\">Base Model</span>\n                        <span className=\"font-medium\">{baseModelName}</span>\n                      </div>\n                      <div className=\"flex justify-between\">\n                        <span className=\"text-muted-foreground\">Method</span>\n                        <span className=\"font-medium\">\n                          {trainingMethodLabel}\n                        </span>\n                      </div>\n                      <div className=\"flex justify-between\">\n                        <span className=\"text-muted-foreground\">Checkpoints</span>\n                        <span className=\"font-medium\">\n                          {checkpointsForModel.length}\n                        </span>\n                      </div>\n                      {isAdapter && (\n                        <div className=\"flex justify-between\">\n                          <span className=\"text-muted-foreground\">LoRA Rank</span>\n                          <span className=\"font-medium\">{loraRank}</span>\n                        </div>\n                      )}\n                    </div>\n                  </div>\n                </div>\n\n                <div className=\"flex flex-col gap-2.5\">\n                  <span className=\"text-xs font-medium text-muted-foreground\">\n                    Quick Guide\n                  </span>\n                  <ol className=\"flex flex-col gap-3\">\n                    {GUIDE_STEPS.map((step, i) => (\n                      <li\n                        key={step}\n                        className=\"flex items-start gap-2 text-xs text-muted-foreground\"\n                      >\n                        <span className=\"flex size-5 shrink-0 items-center justify-center rounded-full bg-muted text-[10px] font-semibold\">\n                          {i + 1}\n                        </span>\n                        {step}\n                      </li>\n                    ))}\n                  </ol>\n                </div>\n              </div>\n\n              <MethodPicker value={exportMethod} onChange={handleMethodChange} />\n\n              <AnimatePresence>\n                {exportMethod === \"gguf\" && (\n                  <motion.div {...collapseAnim} className=\"overflow-hidden\">\n                    <QuantPicker value={quantLevels} onChange={setQuantLevels} />\n                  </motion.div>\n                )}\n              </AnimatePresence>\n\n              <Separator />\n              <div className=\"flex items-center justify-end\">\n                {/* TODO: unhide once estimated size comes from the backend API */}\n                {/* <div className=\"flex items-center gap-1.5 text-xs text-muted-foreground\">\n                  <HugeiconsIcon\n                    icon={InformationCircleIcon}\n                    className=\"size-3.5\"\n                  />\n                  <span>Est. size: {estimatedSize} · Free disk space: 120 GB</span>\n                </div> */}\n                <Button\n                  data-tour=\"export-cta\"\n                  disabled={!canExport}\n                  onClick={() => { setExportSuccess(false); setExportError(null); setDialogOpen(true); }}\n                >\n                  Export Model\n                </Button>\n              </div>\n            </>\n          )}\n        </SectionCard>\n      </main>\n\n      <ExportDialog\n        open={dialogOpen}\n        onOpenChange={setDialogOpen}\n        checkpoint={checkpoint}\n        exportMethod={exportMethod}\n        quantLevels={quantLevels}\n        estimatedSize={estimatedSize}\n        baseModelName={baseModelName}\n        isAdapter={isAdapter}\n        destination={destination}\n        onDestinationChange={setDestination}\n        hfUsername={hfUsername}\n        onHfUsernameChange={setHfUsername}\n        modelName={modelName}\n        onModelNameChange={setModelName}\n        hfToken={hfToken}\n        onHfTokenChange={setHfToken}\n        privateRepo={privateRepo}\n        onPrivateRepoChange={setPrivateRepo}\n        onExport={handleExport}\n        exporting={exporting}\n        exportError={exportError}\n        exportSuccess={exportSuccess}\n      />\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/export/index.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nexport { ExportPage } from \"./export-page\";\n"
  },
  {
    "path": "studio/frontend/src/features/export/tour/index.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nexport { exportTourSteps } from \"./steps\";\n\n"
  },
  {
    "path": "studio/frontend/src/features/export/tour/steps.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { TourStep } from \"@/features/tour\";\n\nexport const exportTourSteps: TourStep[] = [\n  {\n    id: \"training-run\",\n    target: \"export-training-run\",\n    title: \"Pick training run\",\n    body: (\n      <>\n        Start by selecting the training run. Each run groups the checkpoints\n        produced by that specific fine-tuning job.\n      </>\n    ),\n  },\n  {\n    id: \"checkpoint\",\n    target: \"export-checkpoint\",\n    title: \"Pick checkpoint\",\n    body: (\n      <>\n        Pick which checkpoint to export. If you trained multiple checkpoints,\n        it’s worth exporting 1-2 candidates and testing in Chat.\n      </>\n    ),\n  },\n  {\n    id: \"method\",\n    target: \"export-method\",\n    title: \"Export method\",\n    body: (\n      <>\n        Choose the packaging. GGUF is for llama.cpp-style runtimes (pick a\n        quant). Safetensors is for HF/Transformers-style usage. If you’re unsure,\n        start with safetensors.\n      </>\n    ),\n  },\n  {\n    id: \"cta\",\n    target: \"export-cta\",\n    title: \"Export\",\n    body: (\n      <>\n        Export to local or push to HF Hub. After export, test in Chat and compare\n        against base to confirm behavior is what you expect.\n      </>\n    ),\n  },\n];\n"
  },
  {
    "path": "studio/frontend/src/features/onboarding/components/splash-screen.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Button } from \"@/components/ui/button\";\nimport { Card } from \"@/components/ui/card\";\nimport { motion } from \"motion/react\";\n\ninterface SplashScreenProps {\n  onStartOnboarding: () => void;\n  onGoToStudio: () => void;\n}\n\nexport function SplashScreen({\n  onStartOnboarding,\n  onGoToStudio,\n}: SplashScreenProps) {\n  return (\n    <div className=\"fixed inset-0 z-50 flex items-center justify-center bg-gradient-to-b from-background via-background to-primary/5 p-6\">\n      <Card className=\"w-full max-w-md px-8 py-8 shadow-border ring-1 ring-border\">\n        {/* Mascot */}\n        <div className=\"flex justify-center\">\n          <motion.img\n            src=\"/Sloth emojis/Sloth loca pc.png\"\n            alt=\"Sloth mascot\"\n            className=\"size-30\"\n            initial={{ opacity: 0, y: 40, scale: 0.95 }}\n            animate={{ opacity: 1, y: 0, scale: 1 }}\n            transition={{\n              type: \"spring\",\n              duration: 0.7,\n              bounce: 0.3,\n              delay: 0.1,\n            }}\n          />\n        </div>\n\n        {/* Brand text */}\n        <motion.div\n          className=\"mt-4 flex flex-col items-center gap-1\"\n          initial={{ opacity: 0, y: 10 }}\n          animate={{ opacity: 1, y: 0 }}\n          transition={{\n            duration: 0.4,\n            ease: [0.165, 0.84, 0.44, 1],\n            delay: 0.4,\n          }}\n        >\n          <h1 className=\"text-2xl font-semibold tracking-tight\">\n            Unsloth Studio\n          </h1>\n          <p className=\"text-sm text-muted-foreground\">Train and run LLMs locally</p>\n        </motion.div>\n\n        {/* Buttons */}\n        <motion.div\n          className=\"mt-8 flex flex-col gap-3\"\n          initial={{ opacity: 0, y: 10 }}\n          animate={{ opacity: 1, y: 0 }}\n          transition={{\n            duration: 0.4,\n            ease: [0.165, 0.84, 0.44, 1],\n            delay: 0.8,\n          }}\n        >\n          <Button size=\"lg\" onClick={onStartOnboarding}>\n            Start Onboarding\n          </Button>\n          <Button size=\"lg\" variant=\"outline\" onClick={onGoToStudio}>\n            Skip Onboarding\n          </Button>\n        </motion.div>\n      </Card>\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/onboarding/components/steps/dataset-step.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Badge } from \"@/components/ui/badge\";\nimport { Button } from \"@/components/ui/button\";\nimport {\n  Combobox,\n  ComboboxContent,\n  ComboboxEmpty,\n  ComboboxInput,\n  ComboboxItem,\n  ComboboxList,\n} from \"@/components/ui/combobox\";\nimport {\n  Field,\n  FieldDescription,\n  FieldGroup,\n  FieldLabel,\n} from \"@/components/ui/field\";\nimport {\n  InputGroup,\n  InputGroupAddon,\n  InputGroupInput,\n} from \"@/components/ui/input-group\";\nimport {\n  Select,\n  SelectContent,\n  SelectItem,\n  SelectTrigger,\n  SelectValue,\n} from \"@/components/ui/select\";\nimport { Spinner } from \"@/components/ui/spinner\";\nimport {\n  Tooltip,\n  TooltipContent,\n  TooltipTrigger,\n} from \"@/components/ui/tooltip\";\nimport {\n  useDebouncedValue,\n  useHfDatasetSearch,\n  useHfTokenValidation,\n  useInfiniteScroll,\n} from \"@/hooks\";\nimport { cn } from \"@/lib/utils\";\nimport {\n  HfDatasetSubsetSplitSelectors,\n  useTrainingConfigStore,\n} from \"@/features/training\";\nimport type { DatasetFormat } from \"@/types/training\";\nimport {\n  InformationCircleIcon,\n  Key01Icon,\n  Search01Icon,\n  SparklesIcon,\n  Upload04Icon,\n} from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport { useMemo, useRef, useState } from \"react\";\nimport { useShallow } from \"zustand/react/shallow\";\n\nconst FORMAT_OPTIONS: { value: DatasetFormat; label: string }[] = [\n  { value: \"auto\", label: \"Auto Detect\" },\n  { value: \"alpaca\", label: \"Alpaca\" },\n  { value: \"chatml\", label: \"ChatML\" },\n  { value: \"sharegpt\", label: \"ShareGPT\" },\n];\n\nexport function DatasetStep() {\n  const {\n    hfToken,\n    setHfToken,\n    datasetSource,\n    selectHfDataset,\n    selectLocalDataset,\n    datasetFormat,\n    setDatasetFormat,\n    dataset,\n    setDataset,\n    datasetSubset,\n    setDatasetSubset,\n    datasetSplit,\n    setDatasetSplit,\n    datasetEvalSplit,\n    setDatasetEvalSplit,\n    uploadedFile,\n    setUploadedFile,\n    modelType,\n  } = useTrainingConfigStore(\n    useShallow((s) => ({\n      hfToken: s.hfToken,\n      setHfToken: s.setHfToken,\n      datasetSource: s.datasetSource,\n      selectHfDataset: s.selectHfDataset,\n      selectLocalDataset: s.selectLocalDataset,\n      datasetFormat: s.datasetFormat,\n      setDatasetFormat: s.setDatasetFormat,\n      dataset: s.dataset,\n      setDataset: s.setDataset,\n      datasetSubset: s.datasetSubset,\n      setDatasetSubset: s.setDatasetSubset,\n      datasetSplit: s.datasetSplit,\n      setDatasetSplit: s.setDatasetSplit,\n      datasetEvalSplit: s.datasetEvalSplit,\n      setDatasetEvalSplit: s.setDatasetEvalSplit,\n      uploadedFile: s.uploadedFile,\n      setUploadedFile: s.setUploadedFile,\n      modelType: s.modelType,\n    })),\n  );\n\n  const [inputValue, setInputValue] = useState(\"\");\n  const selectingRef = useRef(false);\n  const debouncedQuery = useDebouncedValue(inputValue);\n  const {\n    results: hfResults,\n    isLoading,\n    isLoadingMore,\n    fetchMore,\n    error: hfSearchError,\n  } = useHfDatasetSearch(debouncedQuery, {\n    modelType,\n    accessToken: hfToken || undefined,\n  });\n\n  const { error: tokenValidationError, isChecking: isCheckingToken } =\n    useHfTokenValidation(hfToken);\n\n  const resultIds = useMemo(() => hfResults.map((r) => r.id), [hfResults]);\n\n  const comboboxAnchorRef = useRef<HTMLDivElement>(null);\n  const { scrollRef, sentinelRef } = useInfiniteScroll(\n    fetchMore,\n    hfResults.length,\n  );\n\n  const handleFileUpload = () => {\n    setUploadedFile(\"my_dataset.jsonl\");\n  };\n\n  return (\n    <FieldGroup>\n      <Field>\n        <FieldLabel>Source</FieldLabel>\n        <div className=\"flex gap-2\">\n          <Button\n            variant={datasetSource === \"huggingface\" ? \"dark\" : \"outline\"}\n            onClick={() =>\n              selectHfDataset(datasetSource === \"huggingface\" ? dataset : null)\n            }\n            className=\"flex-1\"\n          >\n            <img\n              src=\"/huggingface.svg\"\n              alt=\"\"\n              className=\"size-4 invert\"\n              data-icon=\"inline-start\"\n            />\n            Hugging Face\n          </Button>\n          <Button\n            variant={datasetSource === \"upload\" ? \"dark\" : \"outline\"}\n            onClick={() =>\n              selectLocalDataset(\n                datasetSource === \"upload\" ? uploadedFile : null,\n              )\n            }\n            className=\"flex-1\"\n          >\n            <HugeiconsIcon icon={Upload04Icon} data-icon=\"inline-start\" />\n            Upload\n          </Button>\n        </div>\n      </Field>\n\n      {datasetSource === \"huggingface\" ? (\n        <>\n          <Field>\n            <FieldLabel>\n              Hugging Face Token{\" \"}\n              <span className=\"text-muted-foreground font-normal\">\n                (Optional)\n              </span>\n            </FieldLabel>\n            <FieldDescription>\n              Required for gated or private datasets.{\" \"}\n              <a\n                href=\"https://huggingface.co/settings/tokens\"\n                target=\"_blank\"\n                rel=\"noopener noreferrer\"\n                className=\"text-primary hover:underline\"\n              >\n                Get token\n              </a>\n            </FieldDescription>\n            <InputGroup>\n              <InputGroupAddon>\n                <HugeiconsIcon icon={Key01Icon} className=\"size-4\" />\n              </InputGroupAddon>\n              <InputGroupInput\n                type=\"password\"\n                autoComplete=\"new-password\"\n                name=\"hf-token\"\n                placeholder=\"hf_...\"\n                value={hfToken}\n                onChange={(e) => setHfToken(e.target.value)}\n              />\n            </InputGroup>\n            {(tokenValidationError ?? hfSearchError) && (\n              <p className=\"text-xs text-destructive\">\n                {tokenValidationError ?? hfSearchError}\n                {\" — \"}\n                <a\n                  href=\"https://huggingface.co/settings/tokens\"\n                  target=\"_blank\"\n                  rel=\"noopener noreferrer\"\n                  className=\"underline\"\n                >\n                  Get or update token\n                </a>\n              </p>\n            )}\n            {isCheckingToken && (\n              <p className=\"text-xs text-muted-foreground\">Checking token…</p>\n            )}\n          </Field>\n\n          <Field>\n            <FieldLabel>Search datasets</FieldLabel>\n            <div ref={comboboxAnchorRef}>\n              <Combobox\n                items={resultIds}\n                filteredItems={resultIds}\n                filter={null}\n                value={dataset}\n                onValueChange={(id) => {\n                  selectingRef.current = true;\n                  setDataset(id);\n                }}\n                onInputValueChange={(val) => {\n                  if (selectingRef.current) {\n                    selectingRef.current = false;\n                    return;\n                  }\n                  setInputValue(val);\n                }}\n                itemToStringValue={(id) => id}\n                autoHighlight={true}\n              >\n                <ComboboxInput\n                  placeholder=\"Search datasets...\"\n                  className=\"w-full\"\n                >\n                  <InputGroupAddon>\n                    <HugeiconsIcon icon={Search01Icon} className=\"size-4\" />\n                  </InputGroupAddon>\n                </ComboboxInput>\n                <ComboboxContent anchor={comboboxAnchorRef}>\n                  {isLoading ? (\n                    <div className=\"flex items-center justify-center py-4 gap-2 text-xs text-muted-foreground\">\n                      <Spinner className=\"size-4\" /> Searching...\n                    </div>\n                  ) : (\n                    <ComboboxEmpty>No datasets found</ComboboxEmpty>\n                  )}\n                  <div\n                    ref={scrollRef}\n                    className=\"max-h-64 overflow-y-auto overscroll-contain [scrollbar-width:thin]\"\n                  >\n                    <ComboboxList className=\"p-1 !max-h-none !overflow-visible\">\n                      {(id: string) => {\n                        return (\n                          <ComboboxItem key={id} value={id} className=\"gap-2\">\n                            <Tooltip>\n                              <TooltipTrigger asChild={true}>\n                                <span className=\"block min-w-0 flex-1 truncate\">\n                                  {id}\n                                </span>\n                              </TooltipTrigger>\n                              <TooltipContent\n                                side=\"left\"\n                                className=\"max-w-xs break-all\"\n                              >\n                                {id}\n                              </TooltipContent>\n                            </Tooltip>\n                          </ComboboxItem>\n                        );\n                      }}\n                    </ComboboxList>\n                    <div ref={sentinelRef} className=\"h-px\" />\n                    {isLoadingMore && (\n                      <div className=\"flex items-center justify-center py-2\">\n                        <Spinner className=\"size-3.5 text-muted-foreground\" />\n                      </div>\n                    )}\n                  </div>\n                </ComboboxContent>\n              </Combobox>\n            </div>\n          </Field>\n\n          <HfDatasetSubsetSplitSelectors\n            variant=\"wizard\"\n            enabled={datasetSource === \"huggingface\"}\n            datasetName={dataset}\n            accessToken={hfToken || undefined}\n            datasetSubset={datasetSubset}\n            setDatasetSubset={setDatasetSubset}\n            datasetSplit={datasetSplit}\n            setDatasetSplit={setDatasetSplit}\n            datasetEvalSplit={datasetEvalSplit}\n            setDatasetEvalSplit={setDatasetEvalSplit}\n          />\n        </>\n      ) : (\n        <>\n          <Field>\n            <FieldLabel>Upload Dataset</FieldLabel>\n            <FieldDescription>\n              Supports JSONL, JSON, CSV formats\n            </FieldDescription>\n            <button\n              type=\"button\"\n              className={cn(\n                \"border-2 border-dashed rounded-xl p-8 text-center transition-colors cursor-pointer hover:border-primary/50 hover:bg-muted/50\",\n                uploadedFile && \"border-primary/50 bg-primary/5\",\n              )}\n              onClick={handleFileUpload}\n            >\n              {uploadedFile ? (\n                <div className=\"flex flex-col items-center gap-2\">\n                  <Badge variant=\"secondary\" className=\"text-sm\">\n                    {uploadedFile}\n                  </Badge>\n                  <span className=\"text-xs text-muted-foreground\">\n                    Click to replace\n                  </span>\n                </div>\n              ) : (\n                <div className=\"flex flex-col items-center gap-2\">\n                  <HugeiconsIcon\n                    icon={Upload04Icon}\n                    className=\"size-8 text-muted-foreground\"\n                  />\n                  <span className=\"text-sm text-muted-foreground\">\n                    Click to upload or drag and drop\n                  </span>\n                </div>\n              )}\n            </button>\n          </Field>\n        </>\n      )}\n\n      <Field>\n        <div className=\"flex items-center justify-between\">\n          <FieldLabel className=\"flex items-center gap-1.5\">\n            Format\n            <Tooltip>\n              <TooltipTrigger asChild={true}>\n                <button\n                  type=\"button\"\n                  className=\"text-muted-foreground/50 hover:text-muted-foreground\"\n                >\n                  <HugeiconsIcon\n                    icon={InformationCircleIcon}\n                    className=\"size-3.5\"\n                  />\n                </button>\n              </TooltipTrigger>\n              <TooltipContent className=\"max-w-xs\">\n                Auto will try to identify and convert your dataset to a\n                supported format.{\" \"}\n                <a\n                  href=\"https://unsloth.ai/docs/get-started/fine-tuning-llms-guide/datasets-guide\"\n                  target=\"_blank\"\n                  rel=\"noopener noreferrer\"\n                  className=\"text-primary underline\"\n                >\n                  Read more\n                </a>\n              </TooltipContent>\n            </Tooltip>\n          </FieldLabel>\n          <Select\n            value={datasetFormat}\n            onValueChange={(v) => setDatasetFormat(v as DatasetFormat)}\n          >\n            <SelectTrigger className=\"w-40\">\n              <SelectValue />\n            </SelectTrigger>\n            <SelectContent>\n              {FORMAT_OPTIONS.map((opt) => (\n                <SelectItem key={opt.value} value={opt.value}>\n                  {opt.value === \"auto\" && (\n                    <HugeiconsIcon\n                      icon={SparklesIcon}\n                      className=\"mr-1.5 inline size-3.5 align-text-bottom\"\n                    />\n                  )}\n                  {opt.label}\n                </SelectItem>\n              ))}\n            </SelectContent>\n          </Select>\n        </div>\n      </Field>\n    </FieldGroup>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/onboarding/components/steps/hyperparameters-step.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport {\n  FieldGroup,\n  FieldLabel,\n  FieldLegend,\n  FieldSet,\n} from \"@/components/ui/field\";\nimport {\n  Select,\n  SelectContent,\n  SelectItem,\n  SelectTrigger,\n  SelectValue,\n} from \"@/components/ui/select\";\nimport { Separator } from \"@/components/ui/separator\";\nimport { Slider } from \"@/components/ui/slider\";\nimport {\n  Tooltip,\n  TooltipContent,\n  TooltipTrigger,\n} from \"@/components/ui/tooltip\";\nimport { CONTEXT_LENGTHS } from \"@/config/training\";\nimport { useMaxStepsEpochsToggle, useTrainingConfigStore } from \"@/features/training\";\nimport { InformationCircleIcon } from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport { useMemo } from \"react\";\nimport { useShallow } from \"zustand/react/shallow\";\n\n/** Format a number in scientific notation like 2e-4, 5e-3, etc. */\nfunction formatLR(value: number): string {\n  if (value <= 0) return \"0\";\n  const exp = Math.floor(Math.log10(value));\n  const mantissa = value / 10 ** exp;\n  const rounded = Math.round(mantissa * 10) / 10;\n  if (rounded === 10) return `1e${exp + 1}`;\n  if (rounded === Math.round(rounded)) return `${Math.round(rounded)}e${exp}`;\n  return `${rounded}e${exp}`;\n}\n\n/**\n * Step learning rate up in a scientific-notation-friendly sequence:\n * 1e-4 -> 2e-4 -> 3e-4 -> ... -> 9e-4 -> 1e-3 -> 2e-3 -> ...\n */\nfunction stepLR(value: number, direction: 1 | -1): number {\n  if (value <= 0) return 1e-5;\n  const exp = Math.floor(Math.log10(value) + 1e-9);\n  const mantissa = Math.round(value / 10 ** exp);\n  let newMantissa = mantissa + direction;\n  let newExp = exp;\n  if (newMantissa > 9) {\n    newMantissa = 1;\n    newExp = exp + 1;\n  } else if (newMantissa < 1) {\n    newMantissa = 9;\n    newExp = exp - 1;\n  }\n  return newMantissa * 10 ** newExp;\n}\n\nexport function HyperparametersStep() {\n  const {\n    trainingMethod,\n    maxSteps,\n    setMaxSteps,\n    epochs,\n    setEpochs,\n    saveSteps,\n    setSaveSteps,\n    contextLength,\n    setContextLength,\n    learningRate,\n    setLearningRate,\n    loraRank,\n    setLoraRank,\n    loraAlpha,\n    setLoraAlpha,\n    loraDropout,\n    setLoraDropout,\n    maxPositionEmbeddings,\n  } = useTrainingConfigStore(\n    useShallow((s) => ({\n      trainingMethod: s.trainingMethod,\n      maxSteps: s.maxSteps,\n      setMaxSteps: s.setMaxSteps,\n      epochs: s.epochs,\n      setEpochs: s.setEpochs,\n      saveSteps: s.saveSteps,\n      setSaveSteps: s.setSaveSteps,\n      contextLength: s.contextLength,\n      setContextLength: s.setContextLength,\n      learningRate: s.learningRate,\n      setLearningRate: s.setLearningRate,\n      loraRank: s.loraRank,\n      setLoraRank: s.setLoraRank,\n      loraAlpha: s.loraAlpha,\n      setLoraAlpha: s.setLoraAlpha,\n      loraDropout: s.loraDropout,\n      setLoraDropout: s.setLoraDropout,\n      maxPositionEmbeddings: s.maxPositionEmbeddings,\n    })),\n  );\n\n  const showLoraParams =\n    trainingMethod === \"lora\" || trainingMethod === \"qlora\";\n  const { useEpochs, toggleUseEpochs } = useMaxStepsEpochsToggle({\n    maxSteps,\n    epochs,\n    saveSteps,\n    setMaxSteps,\n    setEpochs,\n    setSaveSteps,\n  });\n\n  const maxStepsSliderMax = Math.max(500, maxSteps, 30);\n  const epochsSliderMax = Math.max(10, epochs, 1);\n\n  // Use model's max_position_embeddings to cap context length options.\n  // Fall back to 65536 (64K) if not available.\n  const maxCtx = maxPositionEmbeddings ?? 65536;\n  const contextLengthOptions = useMemo(\n    () => CONTEXT_LENGTHS.filter((len) => len <= maxCtx),\n    [maxCtx],\n  );\n\n  return (\n    <FieldGroup>\n      <FieldSet>\n        <FieldLegend variant=\"label\">Choose your training parameters</FieldLegend>\n        <div className=\"flex flex-col gap-4\">\n          <div\n            key={useEpochs ? \"epochs\" : \"steps\"}\n            className=\"flex flex-col gap-2 animate-in fade-in-1 slide-in-from-bottom-1 duration-200\"\n          >\n            <div className=\"flex items-center justify-between\">\n              <FieldLabel className=\"flex items-center gap-1.5 !text-sm text-muted-foreground\">\n                {useEpochs ? \"Epochs\" : \"Max Steps\"}\n                <Tooltip>\n                  <TooltipTrigger asChild={true}>\n                    <button\n                      type=\"button\"\n                      className=\"text-muted-foreground/50 hover:text-muted-foreground\"\n                    >\n                      <HugeiconsIcon\n                        icon={InformationCircleIcon}\n                        className=\"size-3.5\"\n                      />\n                    </button>\n                  </TooltipTrigger>\n                  <TooltipContent>\n                    {useEpochs\n                      ? \"Number of full passes over the dataset.\"\n                      : \"Override total optimizer steps.\"}{\" \"}\n                    <a\n                      href=\"https://unsloth.ai/docs/get-started/fine-tuning-llms-guide/lora-hyperparameters-guide\"\n                      target=\"_blank\"\n                      rel=\"noopener noreferrer\"\n                      className=\"text-primary underline\"\n                    >\n                      Read more\n                    </a>\n                  </TooltipContent>\n                </Tooltip>\n              </FieldLabel>\n              <div className=\"flex items-center gap-3\">\n                <button\n                  type=\"button\"\n                  onClick={toggleUseEpochs}\n                  className=\"text-xs text-primary underline cursor-pointer\"\n                >\n                  {useEpochs ? \"Use Max Steps\" : \"Use Epochs\"}\n                </button>\n                <Slider\n                  value={[\n                    useEpochs\n                      ? Math.min(epochsSliderMax, Math.max(1, epochs))\n                      : Math.min(maxStepsSliderMax, Math.max(1, maxSteps)),\n                  ]}\n                  onValueChange={([v]) =>\n                    useEpochs ? setEpochs(v) : setMaxSteps(v)\n                  }\n                  min={1}\n                  max={useEpochs ? epochsSliderMax : maxStepsSliderMax}\n                  step={1}\n                  className=\"w-40\"\n                />\n                <input\n                  type=\"number\"\n                  value={useEpochs ? epochs : maxSteps}\n                  onChange={(e) => {\n                    const raw = e.target.value;\n                    if (raw === \"\") return;\n\n                    const value = Number(raw);\n                    if (!Number.isFinite(value) || value < 1) return;\n\n                    if (useEpochs) {\n                      setEpochs(value);\n                    } else {\n                      setMaxSteps(value);\n                    }\n                  }}\n                  min={1}\n                  max={useEpochs ? epochsSliderMax : maxStepsSliderMax}\n                  step={1}\n                  className=\"w-16 text-right font-mono text-xs font-medium bg-muted/50 border border-border rounded-lg px-1.5 py-0.5 focus:outline-none focus:ring-1 focus:ring-primary/30 [&::-webkit-inner-spin-button]:appearance-none\"\n                />\n              </div>\n            </div>\n          </div>\n\n          <div className=\"flex items-center justify-between\">\n            <FieldLabel className=\"flex items-center gap-1.5 !text-sm text-muted-foreground\">\n              Context Length\n              <Tooltip>\n                <TooltipTrigger asChild={true}>\n                  <button\n                    type=\"button\"\n                    className=\"text-muted-foreground/50 hover:text-muted-foreground\"\n                  >\n                    <HugeiconsIcon\n                      icon={InformationCircleIcon}\n                      className=\"size-3.5\"\n                    />\n                  </button>\n                </TooltipTrigger>\n                <TooltipContent>\n                  Maximum number of tokens per training sample.{\" \"}\n                  <a\n                    href=\"https://unsloth.ai/docs/get-started/fine-tuning-llms-guide/lora-hyperparameters-guide\"\n                    target=\"_blank\"\n                    rel=\"noopener noreferrer\"\n                    className=\"text-primary underline\"\n                  >\n                    Read more\n                  </a>\n                </TooltipContent>\n              </Tooltip>\n            </FieldLabel>\n            <Select\n              value={String(contextLength)}\n              onValueChange={(v) => setContextLength(Number(v))}\n            >\n              <SelectTrigger className=\"w-32 font-mono\">\n                <SelectValue />\n              </SelectTrigger>\n              <SelectContent>\n                {contextLengthOptions.map((len) => (\n                  <SelectItem key={len} value={String(len)}>\n                    {len.toLocaleString()}\n                  </SelectItem>\n                ))}\n              </SelectContent>\n            </Select>\n          </div>\n\n          <div className=\"flex items-center justify-between\">\n            <FieldLabel className=\"flex items-center gap-1.5 !text-sm text-muted-foreground\">\n              Learning Rate\n              <Tooltip>\n                <TooltipTrigger asChild={true}>\n                  <button\n                    type=\"button\"\n                    className=\"text-muted-foreground/50 hover:text-muted-foreground\"\n                  >\n                    <HugeiconsIcon\n                      icon={InformationCircleIcon}\n                      className=\"size-3.5\"\n                    />\n                  </button>\n                </TooltipTrigger>\n                <TooltipContent>\n                  Step size for weight updates. Lower = slower but more stable.{\" \"}\n                  <a\n                    href=\"https://unsloth.ai/docs/get-started/fine-tuning-llms-guide/lora-hyperparameters-guide\"\n                    target=\"_blank\"\n                    rel=\"noopener noreferrer\"\n                    className=\"text-primary underline\"\n                  >\n                    Read more\n                  </a>\n                </TooltipContent>\n              </Tooltip>\n            </FieldLabel>\n            <div className=\"flex items-center gap-1\">\n              <button\n                type=\"button\"\n                className=\"flex size-7 items-center justify-center rounded-md border border-border text-muted-foreground hover:bg-muted cursor-pointer\"\n                onClick={() => setLearningRate(stepLR(learningRate, -1))}\n              >\n                -\n              </button>\n              <span className=\"w-16 text-center font-mono text-sm\">\n                {formatLR(learningRate)}\n              </span>\n              <button\n                type=\"button\"\n                className=\"flex size-7 items-center justify-center rounded-md border border-border text-muted-foreground hover:bg-muted cursor-pointer\"\n                onClick={() => setLearningRate(stepLR(learningRate, 1))}\n              >\n                +\n              </button>\n            </div>\n          </div>\n\n        </div>\n      </FieldSet>\n\n      {showLoraParams && (\n        <>\n          <Separator />\n          <FieldSet>\n            <FieldLegend variant=\"label\">LoRA Parameters</FieldLegend>\n            <div className=\"flex flex-col gap-4\">\n              <div className=\"flex items-center justify-between\">\n                <FieldLabel className=\"flex items-center gap-1.5 !text-sm text-muted-foreground\">\n                  Rank\n                  <Tooltip>\n                    <TooltipTrigger asChild={true}>\n                      <button\n                        type=\"button\"\n                        className=\"text-muted-foreground/50 hover:text-muted-foreground\"\n                      >\n                        <HugeiconsIcon\n                          icon={InformationCircleIcon}\n                          className=\"size-3.5\"\n                        />\n                      </button>\n                    </TooltipTrigger>\n                    <TooltipContent>\n                      Dimension of the low-rank matrices. Higher = more\n                      capacity.{\" \"}\n                      <a\n                        href=\"https://unsloth.ai/docs/get-started/fine-tuning-llms-guide/lora-hyperparameters-guide\"\n                        target=\"_blank\"\n                        rel=\"noopener noreferrer\"\n                        className=\"text-primary underline\"\n                      >\n                        Read more\n                      </a>\n                    </TooltipContent>\n                  </Tooltip>\n                </FieldLabel>\n                <div className=\"flex items-center gap-3\">\n                  <Slider\n                    value={[loraRank]}\n                    onValueChange={([v]) => setLoraRank(v)}\n                    min={4}\n                    max={128}\n                    step={4}\n                    className=\"w-40\"\n                  />\n                  <input\n                    type=\"number\"\n                    value={loraRank}\n                    onChange={(e) => setLoraRank(Number(e.target.value))}\n                    min={4}\n                    max={128}\n                    step={4}\n                    className=\"w-12 text-right font-mono text-xs font-medium bg-muted/50 border border-border rounded-lg px-1.5 py-0.5 focus:outline-none focus:ring-1 focus:ring-primary/30 [&::-webkit-inner-spin-button]:appearance-none\"\n                  />\n                </div>\n              </div>\n\n              <div className=\"flex items-center justify-between\">\n                <FieldLabel className=\"flex items-center gap-1.5 !text-sm text-muted-foreground\">\n                  Alpha\n                  <Tooltip>\n                    <TooltipTrigger asChild={true}>\n                      <button\n                        type=\"button\"\n                        className=\"text-muted-foreground/50 hover:text-muted-foreground\"\n                      >\n                        <HugeiconsIcon\n                          icon={InformationCircleIcon}\n                          className=\"size-3.5\"\n                        />\n                      </button>\n                    </TooltipTrigger>\n                    <TooltipContent>\n                      Scaling factor. Typically set to 2x the rank value.{\" \"}\n                      <a\n                        href=\"https://unsloth.ai/docs/get-started/fine-tuning-llms-guide/lora-hyperparameters-guide\"\n                        target=\"_blank\"\n                        rel=\"noopener noreferrer\"\n                        className=\"text-primary underline\"\n                      >\n                        Read more\n                      </a>\n                    </TooltipContent>\n                  </Tooltip>\n                </FieldLabel>\n                <div className=\"flex items-center gap-3\">\n                  <Slider\n                    value={[loraAlpha]}\n                    onValueChange={([v]) => setLoraAlpha(v)}\n                    min={8}\n                    max={256}\n                    step={8}\n                    className=\"w-40\"\n                  />\n                  <input\n                    type=\"number\"\n                    value={loraAlpha}\n                    onChange={(e) => setLoraAlpha(Number(e.target.value))}\n                    min={8}\n                    max={256}\n                    step={8}\n                    className=\"w-12 text-right font-mono text-xs font-medium bg-muted/50 border border-border rounded-lg px-1.5 py-0.5 focus:outline-none focus:ring-1 focus:ring-primary/30 [&::-webkit-inner-spin-button]:appearance-none\"\n                  />\n                </div>\n              </div>\n\n              <div className=\"flex items-center justify-between\">\n                <FieldLabel className=\"flex items-center gap-1.5 !text-sm text-muted-foreground\">\n                  Dropout\n                  <Tooltip>\n                    <TooltipTrigger asChild={true}>\n                      <button\n                        type=\"button\"\n                        className=\"text-muted-foreground/50 hover:text-muted-foreground\"\n                      >\n                        <HugeiconsIcon\n                          icon={InformationCircleIcon}\n                          className=\"size-3.5\"\n                        />\n                      </button>\n                    </TooltipTrigger>\n                    <TooltipContent>\n                      Probability of dropping neurons during training for\n                      regularization.{\" \"}\n                      <a\n                        href=\"https://unsloth.ai/docs/get-started/fine-tuning-llms-guide/lora-hyperparameters-guide\"\n                        target=\"_blank\"\n                        rel=\"noopener noreferrer\"\n                        className=\"text-primary underline\"\n                      >\n                        Read more\n                      </a>\n                    </TooltipContent>\n                  </Tooltip>\n                </FieldLabel>\n                <div className=\"flex items-center gap-3\">\n                  <Slider\n                    value={[loraDropout]}\n                    onValueChange={([v]) => setLoraDropout(v)}\n                    min={0}\n                    max={0.5}\n                    step={0.01}\n                    className=\"w-40\"\n                  />\n                  <input\n                    type=\"number\"\n                    value={loraDropout}\n                    onChange={(e) => setLoraDropout(Number(e.target.value))}\n                    min={0}\n                    max={0.5}\n                    step={0.01}\n                    className=\"w-12 text-right font-mono text-xs font-medium bg-muted/50 border border-border rounded-lg px-1.5 py-0.5 focus:outline-none focus:ring-1 focus:ring-primary/30 [&::-webkit-inner-spin-button]:appearance-none\"\n                  />\n                </div>\n              </div>\n            </div>\n          </FieldSet>\n        </>\n      )}\n    </FieldGroup>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/onboarding/components/steps/model-selection-step.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport {\n  Combobox,\n  ComboboxContent,\n  ComboboxEmpty,\n  ComboboxInput,\n  ComboboxItem,\n  ComboboxList,\n} from \"@/components/ui/combobox\";\nimport {\n  Field,\n  FieldDescription,\n  FieldGroup,\n  FieldLabel,\n} from \"@/components/ui/field\";\nimport {\n  InputGroup,\n  InputGroupAddon,\n  InputGroupInput,\n} from \"@/components/ui/input-group\";\nimport {\n  Select,\n  SelectContent,\n  SelectItem,\n  SelectTrigger,\n  SelectValue,\n} from \"@/components/ui/select\";\nimport { Spinner } from \"@/components/ui/spinner\";\nimport {\n  Tooltip,\n  TooltipContent,\n  TooltipTrigger,\n} from \"@/components/ui/tooltip\";\nimport { MODEL_TYPE_TO_HF_TASK, PRIORITY_TRAINING_MODELS, applyPriorityOrdering } from \"@/config/training\";\nimport {\n  useDebouncedValue,\n  useGpuInfo,\n  useHfModelSearch,\n  useHfTokenValidation,\n  useInfiniteScroll,\n} from \"@/hooks\";\nimport { formatCompact } from \"@/lib/utils\";\nimport {\n  type TrainingMethod as VramTrainingMethod,\n  type VramFitStatus,\n  buildModelVramMap,\n} from \"@/lib/vram\";\nimport { useTrainingConfigStore } from \"@/features/training\";\nimport type { TrainingMethod } from \"@/types/training\";\nimport {\n  InformationCircleIcon,\n  Key01Icon,\n  Search01Icon,\n} from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport { useEffect, useMemo, useRef, useState } from \"react\";\nimport { useShallow } from \"zustand/react/shallow\";\n\nexport function ModelSelectionStep() {\n  const gpu = useGpuInfo();\n  const {\n    modelType,\n    selectedModel,\n    setSelectedModel,\n    ensureModelDefaultsLoaded,\n    trainingMethod,\n    setTrainingMethod,\n    hfToken,\n    setHfToken,\n  } = useTrainingConfigStore(\n    useShallow((s) => ({\n      modelType: s.modelType,\n      selectedModel: s.selectedModel,\n      setSelectedModel: s.setSelectedModel,\n      ensureModelDefaultsLoaded: s.ensureModelDefaultsLoaded,\n      trainingMethod: s.trainingMethod,\n      setTrainingMethod: s.setTrainingMethod,\n      hfToken: s.hfToken,\n      setHfToken: s.setHfToken,\n    })),\n  );\n\n  const [inputValue, setInputValue] = useState(\"\");\n  const selectingRef = useRef(false);\n  const debouncedQuery = useDebouncedValue(inputValue);\n  const task = modelType ? MODEL_TYPE_TO_HF_TASK[modelType] : undefined;\n  const {\n    results: hfResults,\n    isLoading,\n    isLoadingMore,\n    fetchMore,\n    error: hfSearchError,\n  } = useHfModelSearch(debouncedQuery, {\n    task,\n    accessToken: hfToken || undefined,\n    excludeGguf: true,\n    priorityIds: PRIORITY_TRAINING_MODELS,\n  });\n\n  const { error: tokenValidationError, isChecking: isCheckingToken } =\n    useHfTokenValidation(hfToken);\n\n  const resultIds = useMemo(() => {\n    const ids = hfResults.map((r) => r.id);\n    return applyPriorityOrdering(ids);\n  }, [hfResults]);\n\n  // Match Studio behavior: only show exception signals (OOM/TIGHT) in training flows.\n  const vramMap = useMemo(() => {\n    const fitMap = buildModelVramMap(\n      hfResults,\n      trainingMethod as VramTrainingMethod,\n      gpu,\n    );\n    const map = new Map<string, { status: VramFitStatus | null; detail: string | null }>();\n    for (const r of hfResults) {\n      const fit = fitMap.get(r.id);\n      map.set(r.id, {\n        status: fit?.status ?? null,\n        detail: r.totalParams ? formatCompact(r.totalParams) : null,\n      });\n    }\n    return map;\n  }, [hfResults, gpu, trainingMethod]);\n\n  const comboboxAnchorRef = useRef<HTMLDivElement>(null);\n  const { scrollRef, sentinelRef } = useInfiniteScroll(\n    fetchMore,\n    hfResults.length,\n  );\n\n  useEffect(() => {\n    ensureModelDefaultsLoaded();\n  }, [selectedModel, ensureModelDefaultsLoaded]);\n\n  return (\n    <FieldGroup>\n      <Field>\n        <FieldLabel>\n          Hugging Face Token{\" \"}\n          <span className=\"text-muted-foreground font-normal\">(Optional)</span>\n        </FieldLabel>\n        <FieldDescription>\n          Required for gated or private models.{\" \"}\n          <a\n            href=\"https://huggingface.co/settings/tokens\"\n            target=\"_blank\"\n            rel=\"noopener noreferrer\"\n            className=\"text-primary hover:underline\"\n          >\n            Get token\n          </a>\n        </FieldDescription>\n        <InputGroup>\n          <InputGroupAddon>\n            <HugeiconsIcon icon={Key01Icon} className=\"size-4\" />\n          </InputGroupAddon>\n          <InputGroupInput\n            type=\"password\"\n            autoComplete=\"new-password\"\n            name=\"hf-token\"\n            placeholder=\"hf_...\"\n            value={hfToken}\n            onChange={(e) => setHfToken(e.target.value)}\n          />\n        </InputGroup>\n        {(tokenValidationError ?? hfSearchError) && (\n          <p className=\"text-xs text-destructive\">\n            {tokenValidationError ?? hfSearchError}\n            {\" — \"}\n            <a\n              href=\"https://huggingface.co/settings/tokens\"\n              target=\"_blank\"\n              rel=\"noopener noreferrer\"\n              className=\"underline\"\n            >\n              Get or update token\n            </a>\n          </p>\n        )}\n        {isCheckingToken && (\n          <p className=\"text-xs text-muted-foreground\">Checking token…</p>\n        )}\n      </Field>\n\n      <Field>\n        <FieldLabel className=\"flex items-center gap-1.5\">\n          Search models\n          <Tooltip>\n            <TooltipTrigger asChild={true}>\n              <button\n                type=\"button\"\n                className=\"text-muted-foreground/50 hover:text-muted-foreground\"\n              >\n                <HugeiconsIcon\n                  icon={InformationCircleIcon}\n                  className=\"size-3.5\"\n                />\n              </button>\n            </TooltipTrigger>\n            <TooltipContent>\n              Search Hugging Face models or pick from our recommended list.{\" \"}\n              <a\n                href=\"https://unsloth.ai/docs/get-started/fine-tuning-llms-guide/what-model-should-i-use\"\n                target=\"_blank\"\n                rel=\"noopener noreferrer\"\n                className=\"text-primary underline\"\n              >\n                Read more\n              </a>\n            </TooltipContent>\n          </Tooltip>\n        </FieldLabel>\n        <div ref={comboboxAnchorRef}>\n          <Combobox\n            items={resultIds}\n            filteredItems={resultIds}\n            filter={null}\n            value={selectedModel}\n            onValueChange={(id) => {\n              selectingRef.current = true;\n              setSelectedModel(id);\n            }}\n            onInputValueChange={(val) => {\n              if (selectingRef.current) {\n                selectingRef.current = false;\n                return;\n              }\n              setInputValue(val);\n            }}\n            itemToStringValue={(id) => id}\n            autoHighlight={true}\n          >\n            <ComboboxInput placeholder=\"Search models...\" className=\"w-full\">\n              <InputGroupAddon>\n                <HugeiconsIcon icon={Search01Icon} className=\"size-4\" />\n              </InputGroupAddon>\n            </ComboboxInput>\n            <ComboboxContent anchor={comboboxAnchorRef}>\n              {isLoading ? (\n                <div className=\"flex items-center justify-center py-4 gap-2 text-xs text-muted-foreground\">\n                  <Spinner className=\"size-4\" /> Searching…\n                </div>\n              ) : (\n                <ComboboxEmpty>No models found</ComboboxEmpty>\n              )}\n              <div\n                ref={scrollRef}\n                className=\"max-h-64 overflow-y-auto overscroll-contain [scrollbar-width:thin]\"\n              >\n                <ComboboxList className=\"p-1 !max-h-none !overflow-visible\">\n                  {(id: string) => {\n                    const entry = vramMap.get(id);\n                    const sizeLabel = entry?.detail ?? null;\n                    const fitStatus = entry?.status ?? null;\n                    const exceeds = fitStatus === \"exceeds\";\n                    return (\n                      <ComboboxItem\n                        key={id}\n                        value={id}\n                        className={`justify-between ${exceeds ? \"opacity-50\" : \"\"}`}\n                      >\n                        <Tooltip>\n                          <TooltipTrigger asChild={true}>\n                            <span\n                              className={`min-w-0 flex-1 truncate ${exceeds ? \"line-through decoration-muted-foreground/50\" : \"\"}`}\n                            >\n                              {id}\n                            </span>\n                          </TooltipTrigger>\n                          <TooltipContent\n                            side=\"left\"\n                            className=\"max-w-xs break-all\"\n                          >\n                            {id}\n                          </TooltipContent>\n                        </Tooltip>\n                        <span className=\"flex items-center gap-1.5 shrink-0\">\n                          {fitStatus === \"exceeds\" && (\n                            <span className=\"text-[9px] font-medium text-red-400\">\n                              OOM\n                            </span>\n                          )}\n                          {fitStatus === \"tight\" && (\n                            <span className=\"text-[9px] font-medium text-amber-400\">\n                              TIGHT\n                            </span>\n                          )}\n                          {sizeLabel ? (\n                            <span className=\"text-xs text-muted-foreground\">\n                              {sizeLabel}\n                            </span>\n                          ) : null}\n                        </span>\n                      </ComboboxItem>\n                    );\n                  }}\n                </ComboboxList>\n                <div ref={sentinelRef} className=\"h-px\" />\n                {isLoadingMore && (\n                  <div className=\"flex items-center justify-center py-2\">\n                    <Spinner className=\"size-3.5 text-muted-foreground\" />\n                  </div>\n                )}\n              </div>\n            </ComboboxContent>\n          </Combobox>\n        </div>\n      </Field>\n\n      {selectedModel && (\n        <Field>\n          <div className=\"flex items-center justify-between\">\n            <div>\n              <FieldLabel className=\"flex items-center gap-1.5\">\n                Training method\n                <Tooltip>\n                  <TooltipTrigger asChild={true}>\n                    <button\n                      type=\"button\"\n                      className=\"text-muted-foreground/50 hover:text-muted-foreground\"\n                    >\n                      <HugeiconsIcon\n                        icon={InformationCircleIcon}\n                        className=\"size-3.5\"\n                      />\n                    </button>\n                  </TooltipTrigger>\n                  <TooltipContent className=\"max-w-xs\">\n                    QLoRA uses 4-bit quantization for lowest VRAM. LoRA uses\n                    16-bit for better quality. Full fine-tune updates all\n                    weights.{\" \"}\n                    <a\n                      href=\"https://unsloth.ai/docs/get-started/fine-tuning-llms-guide/lora-hyperparameters-guide\"\n                      target=\"_blank\"\n                      rel=\"noopener noreferrer\"\n                      className=\"text-primary underline\"\n                    >\n                      Read more\n                    </a>\n                  </TooltipContent>\n                </Tooltip>\n              </FieldLabel>\n              <FieldDescription>\n                Choose how to fine-tune {selectedModel}\n              </FieldDescription>\n            </div>\n            <Select\n              value={trainingMethod}\n              onValueChange={(v) => setTrainingMethod(v as TrainingMethod)}\n            >\n              <SelectTrigger className=\"w-40\">\n                <SelectValue />\n              </SelectTrigger>\n              <SelectContent>\n                <SelectItem value=\"qlora\">QLoRA (4-bit)</SelectItem>\n                <SelectItem value=\"lora\">LoRA (16-bit)</SelectItem>\n                <SelectItem value=\"full\">Full Fine-tune</SelectItem>\n              </SelectContent>\n            </Select>\n          </div>\n        </Field>\n      )}\n    </FieldGroup>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/onboarding/components/steps/model-type-step.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Badge } from \"@/components/ui/badge\";\nimport { Card, CardContent } from \"@/components/ui/card\";\nimport { RadioGroup, RadioGroupItem } from \"@/components/ui/radio-group\";\nimport {\n  Tooltip,\n  TooltipContent,\n  TooltipTrigger,\n} from \"@/components/ui/tooltip\";\nimport { MODEL_TYPES } from \"@/config/training\";\nimport { cn } from \"@/lib/utils\";\nimport { useTrainingConfigStore } from \"@/features/training\";\nimport type { ModelType } from \"@/types/training\";\nimport {\n  BubbleChatIcon,\n  Database02Icon,\n  ImageIcon,\n  InformationCircleIcon,\n  TextIcon,\n  VoiceIcon,\n} from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport { type ReactElement, useState } from \"react\";\nimport { useShallow } from \"zustand/react/shallow\";\n\nconst TYPE_ICONS: Record<ModelType, typeof ImageIcon> = {\n  vision: ImageIcon,\n  audio: VoiceIcon,\n  embeddings: Database02Icon,\n  text: TextIcon,\n};\n\nconst TYPE_TOOLTIPS: Record<ModelType, string> = {\n  vision: \"Fine-tune models that understand images and text together\",\n  audio: \"Fine-tune text-to-speech and audio models\",\n  embeddings: \"Fine-tune models for semantic search and similarity\",\n  text: \"Fine-tune large language models for text generation\",\n};\n\nconst COMING_SOON: ModelType[] = [];\n\nexport function ModelTypeStep(): ReactElement {\n  const { modelType, setModelType } = useTrainingConfigStore(\n    useShallow((s) => ({\n      modelType: s.modelType,\n      setModelType: s.setModelType,\n    })),\n  );\n  const [chatOnlySelected, setChatOnlySelected] = useState(false);\n\n  return (\n    <div className=\"flex flex-col gap-6\">\n      <div>\n        <h2 className=\"text-lg font-semibold\">Welcome to Unsloth Studio</h2>\n        <p className=\"text-sm text-muted-foreground\">\n          Choose a path - fine-tune LLMs, vision, embedding, audio models or just chat.{\" \"}\n          <a\n            href=\"https://unsloth.ai/docs/new/studio/start\"\n            target=\"_blank\"\n            rel=\"noreferrer\"\n            className=\"text-primary underline\"\n          >\n            Get started with our guide\n          </a>\n        </p>\n      </div>\n      <RadioGroup\n        value={chatOnlySelected ? \"\" : (modelType ?? \"\")}\n        onValueChange={(v) => {\n          if (!COMING_SOON.includes(v as ModelType)) {\n            setChatOnlySelected(false);\n            sessionStorage.removeItem(\"unsloth_chat_only\");\n            setModelType(v as ModelType);\n          }\n        }}\n        className=\"grid grid-cols-2 gap-4\"\n      >\n        {MODEL_TYPES.map((type) => {\n          const Icon = TYPE_ICONS[type.value];\n          const isSelected = !chatOnlySelected && modelType === type.value;\n          const isDisabled = COMING_SOON.includes(type.value);\n          const inputId = `model-type-${type.value}`;\n\n          return (\n            <label\n              key={type.value}\n              htmlFor={inputId}\n              className={cn(\n                isDisabled ? \"cursor-not-allowed\" : \"cursor-pointer\",\n              )}\n            >\n              <Card\n                size=\"sm\"\n                className={cn(\n                  \"relative shadow-primary/30 transition-all duration-150 ease-out\",\n                  isDisabled && \"opacity-50 bg-muted/50\",\n                  !isDisabled &&\n                    \"hover:ring-primary/40 hover:-translate-y-0.5 hover:shadow-sm\",\n                  isSelected &&\n                    !isDisabled &&\n                    \"ring-2 ring-primary -translate-y-0.5 shadow-sm\",\n                )}\n              >\n                {isDisabled && (\n                  <Badge\n                    variant=\"secondary\"\n                    className=\"absolute top-2 right-2 text-[10px]\"\n                  >\n                    Coming Soon\n                  </Badge>\n                )}\n                <CardContent className=\"flex items-center gap-4 py-4\">\n                  <RadioGroupItem\n                    id={inputId}\n                    value={type.value}\n                    className=\"sr-only\"\n                    disabled={isDisabled}\n                  />\n                  <div\n                    className={cn(\n                      \"size-10 rounded-xl corner-squircle flex items-center justify-center shrink-0\",\n                      \"transition-all duration-100 ease-out\",\n                      isDisabled && \"bg-muted/50 text-muted-foreground/50\",\n                      !isDisabled &&\n                        isSelected &&\n                        \"bg-primary/10 text-primary scale-105\",\n                      !(isDisabled || isSelected) &&\n                        \"bg-muted text-muted-foreground\",\n                    )}\n                  >\n                    <HugeiconsIcon\n                      icon={Icon}\n                      className={cn(\n                        \"size-5 transition-transform duration-100 ease-out\",\n                        isSelected && !isDisabled && \"scale-110\",\n                      )}\n                      strokeWidth={isSelected && !isDisabled ? 2.5 : 2}\n                    />\n                  </div>\n                  <div className=\"flex flex-col gap-0.5 flex-1\">\n                    <div className=\"flex items-center gap-1.5\">\n                      <span\n                        className={cn(\n                          \"font-medium\",\n                          isDisabled && \"text-muted-foreground\",\n                        )}\n                      >\n                        {type.label}\n                      </span>\n                      <Tooltip>\n                        <TooltipTrigger asChild={true}>\n                          <button\n                            type=\"button\"\n                            className=\"text-muted-foreground/50 hover:text-muted-foreground\"\n                          >\n                            <HugeiconsIcon\n                              icon={InformationCircleIcon}\n                              className=\"size-3.5\"\n                            />\n                          </button>\n                        </TooltipTrigger>\n                        <TooltipContent>\n                          {TYPE_TOOLTIPS[type.value]}\n                        </TooltipContent>\n                      </Tooltip>\n                    </div>\n                    <span className=\"text-xs text-muted-foreground\">\n                      {type.description}\n                    </span>\n                  </div>\n                </CardContent>\n              </Card>\n            </label>\n          );\n        })}\n        <div\n          className=\"cursor-pointer\"\n          onClick={() => {\n            setChatOnlySelected(true);\n            setModelType(\"text\" as ModelType);\n            sessionStorage.setItem(\"unsloth_chat_only\", \"1\");\n          }}\n        >\n          <Card\n            size=\"sm\"\n            className={cn(\n              \"relative shadow-primary/30 transition-all duration-150 ease-out\",\n              \"hover:ring-primary/40 hover:-translate-y-0.5 hover:shadow-sm\",\n              chatOnlySelected && \"ring-2 ring-primary -translate-y-0.5 shadow-sm\",\n            )}\n          >\n            <CardContent className=\"flex items-center gap-4 py-4\">\n              {/* Invisible spacer matching RadioGroupItem (size-4 flex) in other cards */}\n              <div className=\"size-4 shrink-0\" aria-hidden=\"true\" />\n              <div\n                className={cn(\n                  \"size-10 rounded-xl corner-squircle flex items-center justify-center shrink-0\",\n                  \"transition-all duration-100 ease-out\",\n                  chatOnlySelected\n                    ? \"bg-primary/10 text-primary scale-105\"\n                    : \"bg-muted text-muted-foreground\",\n                )}\n              >\n                <HugeiconsIcon\n                  icon={BubbleChatIcon}\n                  className={cn(\n                    \"size-5 transition-transform duration-100 ease-out\",\n                    chatOnlySelected && \"scale-110\",\n                  )}\n                  strokeWidth={chatOnlySelected ? 2.5 : 2}\n                />\n              </div>\n              <div className=\"flex flex-col gap-0.5 flex-1\">\n                <div className=\"flex items-center gap-1.5\">\n                  <span className=\"font-medium\">Chat</span>\n                  <Tooltip>\n                    <TooltipTrigger asChild={true}>\n                      <button\n                        type=\"button\"\n                        className=\"text-muted-foreground/50 hover:text-muted-foreground\"\n                        onClick={(e) => e.stopPropagation()}\n                      >\n                        <HugeiconsIcon\n                          icon={InformationCircleIcon}\n                          className=\"size-3.5\"\n                        />\n                      </button>\n                    </TooltipTrigger>\n                    <TooltipContent>\n                      Chat with any model. Has tool calling, web search and more.\n                    </TooltipContent>\n                  </Tooltip>\n                </div>\n                <span className=\"text-xs text-muted-foreground\">\n                  Chat with LLMs & vision models + audio generation.\n                </span>\n              </div>\n            </CardContent>\n          </Card>\n        </div>\n      </RadioGroup>\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/onboarding/components/steps/summary-step.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Badge } from \"@/components/ui/badge\";\nimport { Card, CardContent, CardHeader, CardTitle } from \"@/components/ui/card\";\nimport { Separator } from \"@/components/ui/separator\";\nimport { useTrainingConfigStore } from \"@/features/training\";\nimport { useHardwareInfo } from \"@/hooks\";\nimport { isAdapterMethod } from \"@/types/training\";\nimport { ChipIcon, Database02Icon, GpuIcon, Settings04Icon } from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport { useShallow } from \"zustand/react/shallow\";\n\nfunction Row({\n  label,\n  value,\n  mono,\n  capitalize,\n  uppercase,\n}: {\n  label: string;\n  value: React.ReactNode;\n  mono?: boolean;\n  capitalize?: boolean;\n  uppercase?: boolean;\n}) {\n  return (\n    <div className=\"flex items-center justify-between\">\n      <span className=\"text-muted-foreground\">{label}</span>\n      <span\n        className={\n          mono\n            ? \"font-mono text-xs\"\n            : capitalize\n              ? \"capitalize\"\n              : uppercase\n                ? \"uppercase\"\n                : undefined\n        }\n      >\n        {value}\n      </span>\n    </div>\n  );\n}\n\nexport function SummaryStep() {\n  const hw = useHardwareInfo();\n  const {\n    modelType,\n    selectedModel,\n    trainingMethod,\n    datasetSource,\n    datasetFormat,\n    dataset,\n    datasetSubset,\n    datasetSplit,\n    uploadedFile,\n    epochs,\n    contextLength,\n    learningRate,\n    loraRank,\n    loraAlpha,\n    loraDropout,\n  } = useTrainingConfigStore(\n    useShallow(\n      ({\n        modelType,\n        selectedModel,\n        trainingMethod,\n        datasetSource,\n        datasetFormat,\n        dataset,\n        datasetSubset,\n        datasetSplit,\n        uploadedFile,\n        epochs,\n        contextLength,\n        learningRate,\n        loraRank,\n        loraAlpha,\n        loraDropout,\n      }) => ({\n        modelType,\n        selectedModel,\n        trainingMethod,\n        datasetSource,\n        datasetFormat,\n        dataset,\n        datasetSubset,\n        datasetSplit,\n        uploadedFile,\n        epochs,\n        contextLength,\n        learningRate,\n        loraRank,\n        loraAlpha,\n        loraDropout,\n      }),\n    ),\n  );\n\n  const showLoraParams = isAdapterMethod(trainingMethod);\n  const datasetName = datasetSource === \"upload\" ? uploadedFile : dataset;\n\n  return (\n    <div className=\"grid grid-cols-2 gap-3\">\n      <Card size=\"sm\" className=\"flex flex-col rounded-2xl\">\n        <CardHeader className=\"pb-2\">\n          <CardTitle className=\"text-sm text-muted-foreground\">\n            System\n          </CardTitle>\n        </CardHeader>\n        <CardContent className=\"flex flex-1 flex-col\">\n          <div className=\"flex items-start gap-3\">\n            <div className=\"mt-0.5 flex size-8 shrink-0 items-center justify-center rounded-lg bg-emerald-500/10\">\n              <HugeiconsIcon icon={GpuIcon} className=\"size-4 text-emerald-600\" />\n            </div>\n            <div className=\"flex flex-1 flex-col\">\n              <span className=\"text-xs text-muted-foreground\">GPU</span>\n              <div className=\"flex items-center gap-2\">\n                <span className=\"text-sm font-medium\">{hw.gpuName ?? \"---\"}</span>\n                <Badge variant=\"secondary\">{hw.vramTotalGb != null ? `${hw.vramTotalGb} GB` : \"---\"}</Badge>\n              </div>\n            </div>\n          </div>\n          <Separator className=\"my-2\" />\n          <div className=\"space-y-1 text-sm\">\n            <Row label=\"unsloth\" value={hw.unsloth ?? \"---\"} mono />\n            <Row label=\"torch\" value={hw.torch ?? \"---\"} mono />\n            <Row label=\"transformers\" value={hw.transformers ?? \"---\"} mono />\n          </div>\n        </CardContent>\n      </Card>\n\n      <Card size=\"sm\" className=\"flex flex-col rounded-2xl\">\n        <CardHeader className=\"pb-2\">\n          <CardTitle className=\"text-sm text-muted-foreground\">Model</CardTitle>\n        </CardHeader>\n        <CardContent className=\"flex flex-1 flex-col\">\n          <div className=\"flex items-start gap-3\">\n            <div className=\"mt-0.5 flex size-8 shrink-0 items-center justify-center rounded-lg bg-emerald-500/10\">\n              <HugeiconsIcon icon={ChipIcon} className=\"size-4 text-emerald-600\" />\n            </div>\n            <div className=\"flex flex-1 flex-col overflow-hidden\">\n              <span className=\"text-xs text-muted-foreground\">Model</span>\n              <span className=\"truncate text-sm font-medium\">{selectedModel ?? \"---\"}</span>\n            </div>\n          </div>\n          <Separator className=\"my-2\" />\n          <div className=\"space-y-1 text-sm\">\n            <Row label=\"Type\" value={modelType} capitalize />\n            <Row label=\"Method\" value={trainingMethod === \"qlora\" ? \"QLoRA\" : trainingMethod === \"lora\" ? \"LoRA\" : \"Full\"} />\n          </div>\n        </CardContent>\n      </Card>\n\n      <Card size=\"sm\" className=\"flex flex-col rounded-2xl\">\n        <CardHeader className=\"pb-2\">\n          <CardTitle className=\"text-sm text-muted-foreground\">\n            Dataset\n          </CardTitle>\n        </CardHeader>\n        <CardContent className=\"flex flex-1 flex-col\">\n          <div className=\"flex items-start gap-3\">\n            <div className=\"mt-0.5 flex size-8 shrink-0 items-center justify-center rounded-lg bg-indigo-500/10\">\n              <HugeiconsIcon icon={Database02Icon} className=\"size-4 text-indigo-600\" />\n            </div>\n            <div className=\"flex flex-1 flex-col overflow-hidden\">\n              <span className=\"text-xs text-muted-foreground\">Dataset</span>\n              <span className=\"truncate text-sm font-medium\">{datasetName ?? \"---\"}</span>\n            </div>\n          </div>\n          <Separator className=\"my-2\" />\n          <div className=\"space-y-1 text-sm\">\n            <Row label=\"Source\" value={datasetSource} capitalize />\n            {datasetSubset && (\n              <Row label=\"Subset\" value={datasetSubset} mono />\n            )}\n            {datasetSplit && (\n              <Row label=\"Split\" value={datasetSplit} mono />\n            )}\n            <Row label=\"Format\" value={datasetFormat} capitalize />\n          </div>\n        </CardContent>\n      </Card>\n\n      <Card size=\"sm\" className=\"flex flex-col rounded-2xl\">\n        <CardHeader className=\"pb-2\">\n          <CardTitle className=\"text-sm text-muted-foreground\">\n            Hyperparameters\n          </CardTitle>\n        </CardHeader>\n        <CardContent className=\"flex flex-1 flex-col\">\n          <div className=\"flex items-start gap-3\">\n            <div className=\"mt-0.5 flex size-8 shrink-0 items-center justify-center rounded-lg bg-orange-500/10\">\n              <HugeiconsIcon icon={Settings04Icon} className=\"size-4 text-orange-600\" />\n            </div>\n            <div className=\"flex flex-1 flex-col\">\n              <span className=\"text-xs text-muted-foreground\">Training</span>\n              <span className=\"text-sm font-medium\">\n                {trainingMethod === \"qlora\" ? \"QLoRA\" : trainingMethod === \"lora\" ? \"LoRA\" : \"Full\"}\n              </span>\n            </div>\n          </div>\n          <Separator className=\"my-2\" />\n          <div className=\"grid grid-cols-2 gap-x-6 gap-y-1 text-sm\">\n            <Row label=\"Epochs\" value={epochs} mono />\n            <Row label=\"Context\" value={contextLength.toLocaleString()} mono />\n            <Row label=\"LR\" value={learningRate.toExponential()} mono />\n            {showLoraParams && (\n              <>\n                <Row label=\"Rank\" value={loraRank} mono />\n                <Row label=\"Alpha\" value={loraAlpha} mono />\n                <Row label=\"Dropout\" value={loraDropout} mono />\n              </>\n            )}\n          </div>\n        </CardContent>\n      </Card>\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/onboarding/components/wizard-content.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { STEPS } from \"@/config/training\";\nimport { useTrainingConfigStore } from \"@/features/training\";\nimport type { StepNumber } from \"@/types/training\";\nimport { DatasetStep } from \"./steps/dataset-step\";\nimport { HyperparametersStep } from \"./steps/hyperparameters-step\";\nimport { ModelSelectionStep } from \"./steps/model-selection-step\";\nimport { ModelTypeStep } from \"./steps/model-type-step\";\nimport { SummaryStep } from \"./steps/summary-step\";\n\nconst STEP_COMPONENTS = {\n  1: ModelTypeStep,\n  2: ModelSelectionStep,\n  3: DatasetStep,\n  4: HyperparametersStep,\n  5: SummaryStep,\n} as const;\n\nconst STEP_MASCOTS: Record<StepNumber, string> = {\n  1: \"/Sloth emojis/large sloth wave.png\",\n  2: \"/Sloth emojis/sloth magnify final.png\",\n  3: \"/Sloth emojis/sloth huglove large.png\",\n  4: \"/Sloth emojis/large sloth glasses.png\",\n  5: \"/Sloth emojis/large sloth yay.png\",\n};\n\nexport function WizardContent() {\n  const currentStep = useTrainingConfigStore((s) => s.currentStep);\n  const stepConfig = STEPS[currentStep - 1];\n  const StepComponent = STEP_COMPONENTS[currentStep];\n  const mascotSrc = STEP_MASCOTS[currentStep];\n\n  return (\n    <main className=\"flex-1 flex flex-col overflow-y-auto\">\n      <header className=\"flex flex-wrap items-start gap-3 p-4 pb-3 sm:p-6 sm:pb-4\">\n        <img src={mascotSrc} alt=\"Unsloth mascot\" className=\"size-12 sm:size-14\" />\n        <div className=\"flex flex-col min-w-0\">\n          <h1 className=\"text-lg font-semibold sm:text-xl\">{stepConfig.title}</h1>\n          <p className=\"text-sm text-muted-foreground\">\n            {stepConfig.description}\n          </p>\n        </div>\n        <p className=\"ml-auto hidden shrink-0 text-xs text-muted-foreground uppercase tracking-wider md:block\">\n          Step {currentStep} of {STEPS.length}\n        </p>\n      </header>\n      <div className=\"flex-1 p-4 pt-1.5 sm:p-6 sm:pt-2\">\n        <StepComponent />\n      </div>\n    </main>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/onboarding/components/wizard-footer.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Button } from \"@/components/ui/button\";\nimport { STEPS } from \"@/config/training\";\nimport { markOnboardingDone } from \"@/features/auth\";\nimport { useTrainingConfigStore } from \"@/features/training\";\nimport { ArrowLeft02Icon, ArrowRight02Icon } from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport { useNavigate } from \"@tanstack/react-router\";\nimport { useShallow } from \"zustand/react/shallow\";\n\nexport function WizardFooter({ onBackToSplash }: { onBackToSplash: () => void }) {\n  const { currentStep, prevStep, nextStep, canProceed } = useTrainingConfigStore(\n    useShallow((s) => ({\n      currentStep: s.currentStep,\n      prevStep: s.prevStep,\n      nextStep: s.nextStep,\n      canProceed: s.canProceed(),\n    })),\n  );\n  const navigate = useNavigate();\n  const isFirst = currentStep === 1;\n  const isLast = currentStep === STEPS.length;\n\n  return (\n    <footer>\n      <div className=\"flex items-center justify-between p-6\">\n        <Button\n          variant=\"outline\"\n          className=\"px-4 !pl-4\"\n          onClick={isFirst ? onBackToSplash : prevStep}\n        >\n          <HugeiconsIcon icon={ArrowLeft02Icon} data-icon=\"inline-start\" />\n          Back\n        </Button>\n        <div className=\"flex items-center gap-2\">\n          {!isLast && (\n            <Button\n              variant=\"outline\"\n              className=\"px-4\"\n              onClick={() => {\n                markOnboardingDone();\n                navigate({ to: \"/studio\" });\n              }}\n            >\n              Skip\n            </Button>\n          )}\n          {isLast ? (\n            <Button\n              onClick={() => {\n                markOnboardingDone();\n                navigate({ to: \"/studio\" });\n              }}\n              disabled={!canProceed}\n              className=\"px-4 !pr-4\"\n            >\n              Go to Studio\n              <HugeiconsIcon icon={ArrowRight02Icon} data-icon=\"inline-end\" />\n            </Button>\n          ) : (\n            <Button\n              onClick={() => {\n                if (currentStep === 1 && sessionStorage.getItem(\"unsloth_chat_only\") === \"1\") {\n                  sessionStorage.removeItem(\"unsloth_chat_only\");\n                  markOnboardingDone();\n                  window.location.href = \"/chat\";\n                } else {\n                  nextStep();\n                }\n              }}\n              className=\"px-4 !pl-4\"\n              disabled={!canProceed}\n            >\n              Continue\n              <HugeiconsIcon icon={ArrowRight02Icon} data-icon=\"inline-end\" />\n            </Button>\n          )}\n        </div>\n      </div>\n    </footer>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/onboarding/components/wizard-layout.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Card } from \"@/components/ui/card\";\nimport { useNavigate } from \"@tanstack/react-router\";\nimport { motion } from \"motion/react\";\nimport { Suspense, lazy, useEffect, useRef, useState } from \"react\";\n\nimport type { ConfettiRef } from \"@/components/ui/confetti\";\nimport { STEPS } from \"@/config/training\";\nimport { isOnboardingDone, markOnboardingDone } from \"@/features/auth\";\nimport { useTrainingConfigStore } from \"@/features/training\";\nimport { SplashScreen } from \"./splash-screen\";\nimport { WizardContent } from \"./wizard-content\";\nimport { WizardFooter } from \"./wizard-footer\";\nimport { WizardSidebar } from \"./wizard-sidebar\";\n\nconst Confetti = lazy(() =>\n  import(\"@/components/ui/confetti\").then((m) => ({ default: m.Confetti })),\n);\n\nexport function WizardLayout() {\n  const navigate = useNavigate();\n  const [showSplash, setShowSplash] = useState(true);\n  const currentStep = useTrainingConfigStore((s) => s.currentStep);\n  const confettiRef = useRef<ConfettiRef>(null);\n  const hasFiredRef = useRef(false);\n  const isFinalStep = currentStep === STEPS.length;\n\n  // Only redirect on initial mount — not on re-renders after markOnboardingDone()\n  // which would override explicit /chat navigation from skip buttons.\n  const checkedRef = useRef(false);\n  useEffect(() => {\n    if (!checkedRef.current) {\n      checkedRef.current = true;\n      if (isOnboardingDone()) {\n        navigate({ to: \"/studio\" });\n      }\n    }\n  }, [navigate]);\n\n  useEffect(() => {\n    if (isFinalStep && !hasFiredRef.current) {\n      hasFiredRef.current = true;\n      confettiRef.current?.fire({\n        particleCount: 80,\n        angle: 60,\n        spread: 55,\n        origin: { x: 0, y: 0.6 },\n        colors: [\"#34b482\", \"#26ccff\", \"#a25afd\", \"#88ff5a\"],\n      });\n      confettiRef.current?.fire({\n        particleCount: 80,\n        angle: 120,\n        spread: 55,\n        origin: { x: 1, y: 0.6 },\n        colors: [\"#34b482\", \"#26ccff\", \"#a25afd\", \"#88ff5a\"],\n      });\n    }\n    if (!isFinalStep) {\n      hasFiredRef.current = false;\n    }\n  }, [isFinalStep]);\n\n  return (\n    <div className=\"relative min-h-screen flex items-center justify-center overflow-hidden bg-gradient-to-br from-primary/5 via-background to-primary/3 p-4 sm:p-6 md:p-8\">\n      {showSplash && (\n        <SplashScreen\n          onStartOnboarding={() => setShowSplash(false)}\n          onGoToStudio={() => {\n            markOnboardingDone();\n            window.location.href = \"/studio\";\n          }}\n        />\n      )}\n      <Suspense fallback={null}>\n        <Confetti\n          ref={confettiRef}\n          manualstart={true}\n          className=\"pointer-events-none fixed inset-0 z-50 size-full\"\n        />\n      </Suspense>\n      {!showSplash && (\n        <motion.div\n          className=\"w-full max-w-5xl\"\n          initial={{ opacity: 0, scale: 0.98, y: 10 }}\n          animate={{ opacity: 1, scale: 1, y: 0 }}\n          transition={{\n            duration: 0.4,\n            ease: [0.165, 0.84, 0.44, 1],\n          }}\n        >\n          <Card className=\"relative z-10 w-full !gap-0 !m-0 !p-0 flex min-h-[560px] flex-col overflow-hidden shadow-border ring-1 ring-border md:min-h-[620px] md:flex-row lg:h-[660px]\">\n            <WizardSidebar />\n            <div className=\"flex-1 flex flex-col\">\n              <WizardContent />\n              <WizardFooter onBackToSplash={() => setShowSplash(true)} />\n            </div>\n          </Card>\n        </motion.div>\n      )}\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/onboarding/components/wizard-sidebar.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Button } from \"@/components/ui/button\";\nimport { Progress } from \"@/components/ui/progress\";\nimport { STEPS } from \"@/config/training\";\nimport { markOnboardingDone } from \"@/features/auth\";\nimport { useTrainingConfigStore } from \"@/features/training\";\nimport { ArrowRight02Icon } from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport { WizardStepItem } from \"./wizard-step-item\";\n\nexport function WizardSidebar() {\n  const currentStep = useTrainingConfigStore((s) => s.currentStep);\n  const progress = ((currentStep - 1) / (STEPS.length - 1)) * 100;\n\n  return (\n    <aside className=\"w-full shrink-0 bg-muted/70 p-4 md:w-64 md:p-6\">\n      <div className=\"flex items-center gap-3 py-1 md:py-2\">\n        <img\n          src=\"https://unsloth.ai/cgi/image/unsloth_sticker_no_shadow_ldN4V4iydw00qSIIWDCUv.png?width=96&quality=80&format=auto\"\n          alt=\"Unsloth\"\n          className=\"size-12\"\n        />\n        <div className=\"flex flex-col\">\n          <span className=\"font-semibold text-lg leading-tight\">Unsloth</span>\n          <span className=\"text-xs text-muted-foreground\">Studio</span>\n        </div>\n      </div>\n      <div className=\"mt-3 md:mt-0\">\n        <Progress value={progress} className=\"h-1.5\" />\n      </div>\n      <p className=\"mt-2 text-xs text-muted-foreground md:hidden\">\n        Step {currentStep} of {STEPS.length}\n      </p>\n      <Button\n        size=\"sm\"\n        className=\"mt-2 w-full md:hidden\"\n        onClick={() => {\n          markOnboardingDone();\n          window.location.href = \"/chat\";\n        }}\n      >\n        Skip to Chat\n        <HugeiconsIcon icon={ArrowRight02Icon} data-icon=\"inline-end\" />\n      </Button>\n      <nav className=\"mt-3 hidden flex-col gap-1 md:flex\">\n        {STEPS.map((step) => (\n          <WizardStepItem key={step.number} step={step} />\n        ))}\n      </nav>\n      <Button\n        size=\"sm\"\n        className=\"mt-3 hidden w-full md:flex\"\n        onClick={() => {\n          markOnboardingDone();\n          window.location.href = \"/chat\";\n        }}\n      >\n        Skip to Chat\n        <HugeiconsIcon icon={ArrowRight02Icon} data-icon=\"inline-end\" />\n      </Button>\n    </aside>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/onboarding/components/wizard-step-item.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { cn } from \"@/lib/utils\";\nimport { useTrainingConfigStore } from \"@/features/training\";\nimport type { StepConfig, StepNumber } from \"@/types/training\";\nimport { useShallow } from \"zustand/react/shallow\";\n\ninterface WizardStepItemProps {\n  step: StepConfig;\n}\n\nexport function WizardStepItem({ step }: WizardStepItemProps) {\n  const { currentStep, setStep } = useTrainingConfigStore(\n    useShallow((s) => ({ currentStep: s.currentStep, setStep: s.setStep })),\n  );\n  const isActive = currentStep === step.number;\n  const isCompleted = currentStep > step.number;\n  const canClick = isCompleted;\n\n  return (\n    <button\n      type=\"button\"\n      onClick={() => canClick && setStep(step.number as StepNumber)}\n      disabled={!canClick}\n      className={cn(\n        \"flex items-start gap-3 text-left w-full py-2 transition-colors\",\n        canClick && \"cursor-pointer hover:opacity-80\",\n        !(canClick || isActive) && \"opacity-50\",\n      )}\n    >\n      <div\n        className={cn(\n          \"size-5 rounded-full flex items-center justify-center text-xs font-medium shrink-0 mt-0.5 transition-colors\",\n          isActive && \"bg-primary text-primary-foreground\",\n          isCompleted && \"bg-primary/20 text-primary\",\n          !(isActive || isCompleted) && \"bg-muted text-muted-foreground\",\n        )}\n      >\n        {isCompleted ? \"✓\" : step.number}\n      </div>\n      <div className=\"flex flex-col gap-1\">\n        <span\n          className={cn(\n            \"text-sm font-medium\",\n            isActive && \"text-foreground\",\n            !isActive && \"text-muted-foreground\",\n          )}\n        >\n          {step.title}\n        </span>\n        <span className=\"text-xs text-muted-foreground\">{step.subtitle}</span>\n      </div>\n    </button>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/onboarding/index.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nexport { WizardLayout } from \"./components/wizard-layout\";\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/api/index.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { authFetch } from \"@/features/auth\";\n\nconst DEFAULT_BASE = \"/api/data-recipe\";\n\nexport const DATA_DESIGNER_API_BASE =\n  import.meta.env.VITE_DATA_DESIGNER_API ?? DEFAULT_BASE;\n\nexport type JobCreateResponse = {\n  // biome-ignore lint/style/useNamingConvention: api schema\n  job_id: string;\n};\n\nexport type PublishRecipeJobRequest = {\n  repo_id: string;\n  description: string;\n  hf_token?: string | null;\n  private?: boolean;\n  artifact_path?: string | null;\n};\n\nexport type PublishRecipeJobResponse = {\n  success: boolean;\n  url: string;\n  message: string;\n};\n\nexport type JobStatusResponse = {\n  // biome-ignore lint/style/useNamingConvention: api schema\n  job_id: string;\n  status: string;\n  stage?: string | null;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  current_column?: string | null;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  completed_columns?: string[] | null;\n  batch?: {\n    idx?: number | null;\n    total?: number | null;\n  };\n  progress?: {\n    done?: number | null;\n    total?: number | null;\n    percent?: number | null;\n    // biome-ignore lint/style/useNamingConvention: api schema\n    eta_sec?: number | null;\n    rate?: number | null;\n    ok?: number | null;\n    failed?: number | null;\n  };\n  // biome-ignore lint/style/useNamingConvention: api schema\n  column_progress?: {\n    done?: number | null;\n    total?: number | null;\n    percent?: number | null;\n    // biome-ignore lint/style/useNamingConvention: api schema\n    eta_sec?: number | null;\n    rate?: number | null;\n    ok?: number | null;\n    failed?: number | null;\n  };\n  // biome-ignore lint/style/useNamingConvention: api schema\n  model_usage?: Record<string, unknown>;\n  rows?: number | null;\n  cols?: number | null;\n  error?: string | null;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  has_analysis?: boolean;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  dataset_rows?: number | null;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  artifact_path?: string | null;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  started_at?: number | null;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  finished_at?: number | null;\n};\n\nexport type JobDatasetResponse = {\n  dataset?: unknown[];\n  total?: number;\n  limit?: number;\n  offset?: number;\n};\n\nexport type JobEvent = {\n  event: string;\n  id: number | null;\n  payload: Record<string, unknown>;\n};\n\nexport type SeedInspectRequest = {\n  // biome-ignore lint/style/useNamingConvention: api schema\n  dataset_name: string;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  hf_token?: string;\n  subset?: string;\n  split?: string;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  preview_size?: number;\n};\n\nexport type SeedInspectUploadRequest = {\n  filename: string;\n  // base64 payload without data URL prefix\n  content_base64: string;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  preview_size?: number;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  seed_source_type?: \"local\" | \"unstructured\";\n  // biome-ignore lint/style/useNamingConvention: api schema\n  unstructured_chunk_size?: number;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  unstructured_chunk_overlap?: number;\n};\n\nexport type SeedInspectResponse = {\n  // biome-ignore lint/style/useNamingConvention: api schema\n  dataset_name: string;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  resolved_path: string;\n  columns: string[];\n  // biome-ignore lint/style/useNamingConvention: api schema\n  preview_rows: Record<string, unknown>[];\n  split?: string | null;\n  subset?: string | null;\n};\n\nexport type ValidateError = {\n  message: string;\n  path?: string | null;\n  code?: string | null;\n};\n\nexport type ValidateResponse = {\n  valid: boolean;\n  errors: ValidateError[];\n  // biome-ignore lint/style/useNamingConvention: api schema\n  raw_detail?: string | null;\n};\n\nexport type McpToolsListRequest = {\n  // biome-ignore lint/style/useNamingConvention: api schema\n  mcp_providers: Record<string, unknown>[];\n  // biome-ignore lint/style/useNamingConvention: api schema\n  timeout_sec?: number;\n};\n\nexport type McpToolsProviderResult = {\n  name: string;\n  tools: string[];\n  error?: string | null;\n};\n\nexport type McpToolsListResponse = {\n  providers: McpToolsProviderResult[];\n  // biome-ignore lint/style/useNamingConvention: api schema\n  duplicate_tools: Record<string, string[]>;\n};\n\nasync function parseErrorResponse(response: Response): Promise<string> {\n  const text = (await response.text()).trim();\n  if (!text) {\n    return \"Request failed.\";\n  }\n  try {\n    const parsed = JSON.parse(text) as {\n      detail?: string;\n      message?: string;\n      // biome-ignore lint/style/useNamingConvention: api schema\n      raw_detail?: string;\n    };\n    return (\n      parsed.detail ??\n      parsed.message ??\n      parsed.raw_detail ??\n      text\n    );\n  } catch {\n    return text;\n  }\n}\n\nasync function postJson<T>(path: string, payload: unknown): Promise<T> {\n  const response = await authFetch(`${DATA_DESIGNER_API_BASE}${path}`, {\n    method: \"POST\",\n    headers: {\n      \"Content-Type\": \"application/json\",\n    },\n    body: JSON.stringify(payload),\n  });\n\n  if (!response.ok) {\n    throw new Error(await parseErrorResponse(response));\n  }\n\n  return response.json();\n}\n\nasync function getJson<T>(path: string): Promise<T> {\n  const response = await authFetch(`${DATA_DESIGNER_API_BASE}${path}`);\n  if (!response.ok) {\n    throw new Error(await parseErrorResponse(response));\n  }\n  return response.json();\n}\n\nfunction parseJobEvent(rawEvent: string): JobEvent | null {\n  const lines = rawEvent.split(/\\r?\\n/);\n  let eventName = \"message\";\n  let id: number | null = null;\n  const dataLines: string[] = [];\n\n  for (const line of lines) {\n    if (!line) {\n      continue;\n    }\n    if (line.startsWith(\"event:\")) {\n      eventName = line.slice(6).trim() || \"message\";\n      continue;\n    }\n    if (line.startsWith(\"id:\")) {\n      const value = Number(line.slice(3).trim());\n      id = Number.isFinite(value) ? value : null;\n      continue;\n    }\n    if (line.startsWith(\"data:\")) {\n      dataLines.push(line.slice(5).trimStart());\n    }\n  }\n\n  if (dataLines.length === 0) {\n    return null;\n  }\n  let payload: Record<string, unknown>;\n  try {\n    payload = JSON.parse(dataLines.join(\"\\n\")) as Record<string, unknown>;\n  } catch {\n    return null;\n  }\n  return {\n    event: eventName,\n    id,\n    payload,\n  };\n}\n\nexport async function validateRecipe(\n  payload: unknown,\n): Promise<ValidateResponse> {\n  return postJson<ValidateResponse>(\"/validate\", payload);\n}\n\nexport async function createRecipeJob(payload: unknown): Promise<JobCreateResponse> {\n  return postJson<JobCreateResponse>(\"/jobs\", payload);\n}\n\nexport async function getRecipeJobStatus(jobId: string): Promise<JobStatusResponse> {\n  return getJson<JobStatusResponse>(`/jobs/${jobId}/status`);\n}\n\nexport async function getRecipeJobAnalysis(\n  jobId: string,\n): Promise<Record<string, unknown>> {\n  return getJson<Record<string, unknown>>(`/jobs/${jobId}/analysis`);\n}\n\nexport async function getRecipeJobDataset(\n  jobId: string,\n  options?: {\n    limit?: number;\n    offset?: number;\n  },\n): Promise<JobDatasetResponse> {\n  const limit = options?.limit ?? 20;\n  const offset = options?.offset ?? 0;\n  return getJson<JobDatasetResponse>(\n    `/jobs/${jobId}/dataset?limit=${limit}&offset=${offset}`,\n  );\n}\n\nexport async function cancelRecipeJob(jobId: string): Promise<JobStatusResponse> {\n  return postJson<JobStatusResponse>(`/jobs/${jobId}/cancel`, {});\n}\n\nexport async function publishRecipeJob(\n  jobId: string,\n  payload: PublishRecipeJobRequest,\n): Promise<PublishRecipeJobResponse> {\n  return postJson<PublishRecipeJobResponse>(`/jobs/${jobId}/publish`, payload);\n}\n\nexport async function inspectSeedDataset(\n  payload: SeedInspectRequest,\n): Promise<SeedInspectResponse> {\n  return postJson<SeedInspectResponse>(\"/seed/inspect\", payload);\n}\n\nexport async function inspectSeedUpload(\n  payload: SeedInspectUploadRequest,\n): Promise<SeedInspectResponse> {\n  return postJson<SeedInspectResponse>(\"/seed/inspect-upload\", payload);\n}\n\nexport async function listMcpTools(\n  payload: McpToolsListRequest,\n): Promise<McpToolsListResponse> {\n  return postJson<McpToolsListResponse>(\"/mcp/tools\", payload);\n}\n\nexport async function streamRecipeJobEvents(options: {\n  jobId: string;\n  signal: AbortSignal;\n  lastEventId?: number | null;\n  onOpen?: () => void;\n  onEvent: (event: JobEvent) => void;\n}): Promise<void> {\n  const headers = new Headers();\n  let query = \"\";\n  if (typeof options.lastEventId === \"number\") {\n    headers.set(\"Last-Event-ID\", String(options.lastEventId));\n    query = `?after=${options.lastEventId}`;\n  }\n\n  const response = await authFetch(\n    `${DATA_DESIGNER_API_BASE}/jobs/${options.jobId}/events${query}`,\n    {\n      method: \"GET\",\n      headers,\n      signal: options.signal,\n    },\n  );\n  if (!response.ok) {\n    throw new Error(await parseErrorResponse(response));\n  }\n  if (!response.body) {\n    throw new Error(\"Job stream unavailable.\");\n  }\n\n  options.onOpen?.();\n\n  const reader = response.body.getReader();\n  const decoder = new TextDecoder();\n  let buffer = \"\";\n\n  while (true) {\n    const { value, done } = await reader.read();\n    if (done) {\n      break;\n    }\n    buffer += decoder.decode(value, { stream: true });\n    let separatorIndex = buffer.search(/\\r?\\n\\r?\\n/);\n    while (separatorIndex >= 0) {\n      const rawEvent = buffer.slice(0, separatorIndex);\n      const separatorLength = buffer[separatorIndex] === \"\\r\" ? 4 : 2;\n      buffer = buffer.slice(separatorIndex + separatorLength);\n\n      if (rawEvent.startsWith(\"retry:\")) {\n        separatorIndex = buffer.search(/\\r?\\n\\r?\\n/);\n        continue;\n      }\n\n      const parsed = parseJobEvent(rawEvent);\n      if (parsed) {\n        options.onEvent(parsed);\n      }\n      separatorIndex = buffer.search(/\\r?\\n\\r?\\n/);\n    }\n  }\n}\n\n// NOTE: preview endpoints removed from harness.\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/blocks/definitions.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport {\n  BalanceScaleIcon,\n  Clock01Icon,\n  CodeIcon,\n  CodeSimpleIcon,\n  DiceFaces03Icon,\n  DocumentAttachmentIcon,\n  DocumentCodeIcon,\n  EqualSignIcon,\n  FingerPrintIcon,\n  FunctionIcon,\n  Plug01Icon,\n  Parabola02Icon,\n  PencilEdit02Icon,\n  Plant01Icon,\n  Shield02Icon,\n  Tag01Icon,\n  TagsIcon,\n  UserAccountIcon,\n} from \"@hugeicons/core-free-icons\";\nimport type {\n  LlmType,\n  NodeConfig,\n  SamplerType,\n  SeedSourceType,\n} from \"../types\";\nimport {\n  makeExpressionConfig,\n  makeLlmConfig,\n  makeMarkdownNoteConfig,\n  makeModelConfig,\n  makeModelProviderConfig,\n  makeToolProfileConfig,\n  makeSamplerConfig,\n  makeSeedConfig,\n  makeValidatorConfig,\n} from \"../utils\";\n\nexport type BlockKind =\n  | \"sampler\"\n  | \"llm\"\n  | \"validator\"\n  | \"expression\"\n  | \"seed\"\n  | \"note\";\nexport type BlockType =\n  | SamplerType\n  | LlmType\n  | \"validator_python\"\n  | \"validator_sql\"\n  | \"validator_oxc\"\n  | \"expression\"\n  | \"markdown_note\"\n  | \"seed\"\n  | \"seed_hf\"\n  | \"seed_local\"\n  | \"seed_unstructured\"\n  | \"model_provider\"\n  | \"model_config\"\n  | \"tool_config\";\n\nexport type SeedBlockType = \"seed_hf\" | \"seed_local\" | \"seed_unstructured\";\n\ntype IconType = typeof CodeIcon;\n\nexport type BlockGroup = {\n  kind: BlockKind;\n  title: string;\n  description: string;\n  icon: IconType;\n};\n\nexport type BlockDialogKey =\n  | \"seed\"\n  | \"markdown_note\"\n  | \"category\"\n  | \"subcategory\"\n  | \"uniform\"\n  | \"gaussian\"\n  | \"bernoulli\"\n  | \"datetime\"\n  | \"timedelta\"\n  | \"uuid\"\n  | \"person\"\n  | \"llm\"\n  | \"validator\"\n  | \"model_provider\"\n  | \"model_config\"\n  | \"tool_config\"\n  | \"expression\";\n\nexport type BlockDefinition = {\n  kind: BlockKind;\n  type: BlockType;\n  title: string;\n  description: string;\n  icon: IconType;\n  dialogKey: BlockDialogKey;\n  createConfig: (id: string, existing: NodeConfig[]) => NodeConfig;\n};\n\nexport const BLOCK_GROUPS: BlockGroup[] = [\n  {\n    kind: \"sampler\",\n    title: \"Generated fields\",\n    description: \"Create fields from lists, ranges, and reusable patterns.\",\n    icon: DiceFaces03Icon,\n  },\n  {\n    kind: \"seed\",\n    title: \"Source data\",\n    description: \"Start from an existing dataset or file.\",\n    icon: Plant01Icon,\n  },\n  {\n    kind: \"llm\",\n    title: \"AI generation\",\n    description: \"Generate content, connect models, and manage tools.\",\n    icon: PencilEdit02Icon,\n  },\n  {\n    kind: \"validator\",\n    title: \"Checks\",\n    description: \"Lint or filter generated code as it moves through the recipe.\",\n    icon: Shield02Icon,\n  },\n  {\n    kind: \"expression\",\n    title: \"Formulas\",\n    description: \"Build a field from other fields.\",\n    icon: FunctionIcon,\n  },\n  {\n    kind: \"note\",\n    title: \"Notes\",\n    description: \"Add markdown notes to document your flow.\",\n    icon: PencilEdit02Icon,\n  },\n];\n\nconst BLOCK_DEFINITIONS: BlockDefinition[] = [\n  {\n    kind: \"seed\",\n    type: \"seed_hf\",\n    title: \"Hugging Face dataset\",\n    description: \"Use rows from a Hugging Face dataset as source data.\",\n    icon: Plant01Icon,\n    dialogKey: \"seed\",\n    createConfig: (id, existing) => makeSeedConfig(id, existing, \"hf\"),\n  },\n  {\n    kind: \"seed\",\n    type: \"seed_local\",\n    title: \"CSV or JSON file\",\n    description: \"Upload CSV, JSON, or JSONL and use its rows as source data.\",\n    icon: DocumentCodeIcon,\n    dialogKey: \"seed\",\n    createConfig: (id, existing) => makeSeedConfig(id, existing, \"local\"),\n  },\n  {\n    kind: \"seed\",\n    type: \"seed_unstructured\",\n    title: \"Document file\",\n    description: \"Upload PDF, DOCX, or TXT and turn it into source rows.\",\n    icon: DocumentAttachmentIcon,\n    dialogKey: \"seed\",\n    createConfig: (id, existing) => makeSeedConfig(id, existing, \"unstructured\"),\n  },\n  {\n    kind: \"sampler\",\n    type: \"category\",\n    title: \"Category\",\n    description: \"Generate values from a list you define, with optional weights or rules.\",\n    icon: Tag01Icon,\n    dialogKey: \"category\",\n    createConfig: (id, existing) => makeSamplerConfig(id, \"category\", existing),\n  },\n  {\n    kind: \"sampler\",\n    type: \"subcategory\",\n    title: \"Subcategory\",\n    description: \"Generate values from groups you define for each category.\",\n    icon: TagsIcon,\n    dialogKey: \"subcategory\",\n    createConfig: (id, existing) => makeSamplerConfig(id, \"subcategory\", existing),\n  },\n  {\n    kind: \"sampler\",\n    type: \"uniform\",\n    title: \"Random number\",\n    description: \"Generate a number anywhere between a minimum and maximum.\",\n    icon: EqualSignIcon,\n    dialogKey: \"uniform\",\n    createConfig: (id, existing) => makeSamplerConfig(id, \"uniform\", existing),\n  },\n  {\n    kind: \"sampler\",\n    type: \"gaussian\",\n    title: \"Bell-curve number\",\n    description: \"Generate numbers around an average value.\",\n    icon: Parabola02Icon,\n    dialogKey: \"gaussian\",\n    createConfig: (id, existing) => makeSamplerConfig(id, \"gaussian\", existing),\n  },\n  {\n    kind: \"sampler\",\n    type: \"bernoulli\",\n    title: \"Yes/no value\",\n    description: \"Generate a binary result from a probability.\",\n    icon: EqualSignIcon,\n    dialogKey: \"bernoulli\",\n    createConfig: (id, existing) => makeSamplerConfig(id, \"bernoulli\", existing),\n  },\n  {\n    kind: \"sampler\",\n    type: \"datetime\",\n    title: \"Date and time\",\n    description: \"Generate timestamps inside a date range.\",\n    icon: Clock01Icon,\n    dialogKey: \"datetime\",\n    createConfig: (id, existing) => makeSamplerConfig(id, \"datetime\", existing),\n  },\n  {\n    kind: \"sampler\",\n    type: \"timedelta\",\n    title: \"Time offset\",\n    description: \"Generate a time difference from another date field.\",\n    icon: Clock01Icon,\n    dialogKey: \"timedelta\",\n    createConfig: (id, existing) => makeSamplerConfig(id, \"timedelta\", existing),\n  },\n  {\n    kind: \"sampler\",\n    type: \"uuid\",\n    title: \"Unique ID\",\n    description: \"Generate unique identifiers.\",\n    icon: FingerPrintIcon,\n    dialogKey: \"uuid\",\n    createConfig: (id, existing) => makeSamplerConfig(id, \"uuid\", existing),\n  },\n  {\n    kind: \"sampler\",\n    type: \"person\",\n    title: \"Synthetic person\",\n    description: \"Generate realistic person details.\",\n    icon: UserAccountIcon,\n    dialogKey: \"person\",\n    createConfig: (id, existing) => makeSamplerConfig(id, \"person\", existing),\n  },\n  {\n    kind: \"llm\",\n    type: \"text\",\n    title: \"AI text\",\n    description: \"Generate text from your prompt.\",\n    icon: PencilEdit02Icon,\n    dialogKey: \"llm\",\n    createConfig: (id, existing) => makeLlmConfig(id, \"text\", existing),\n  },\n  {\n    kind: \"llm\",\n    type: \"structured\",\n    title: \"AI structured data\",\n    description: \"Generate JSON that follows a response format.\",\n    icon: CodeIcon,\n    dialogKey: \"llm\",\n    createConfig: (id, existing) => makeLlmConfig(id, \"structured\", existing),\n  },\n  {\n    kind: \"llm\",\n    type: \"code\",\n    title: \"AI code\",\n    description: \"Generate code in the language you choose.\",\n    icon: CodeSimpleIcon,\n    dialogKey: \"llm\",\n    createConfig: (id, existing) => makeLlmConfig(id, \"code\", existing),\n  },\n  {\n    kind: \"llm\",\n    type: \"judge\",\n    title: \"AI scorer\",\n    description: \"Score outputs against your criteria.\",\n    icon: BalanceScaleIcon,\n    dialogKey: \"llm\",\n    createConfig: (id, existing) => makeLlmConfig(id, \"judge\", existing),\n  },\n  {\n    kind: \"llm\",\n    type: \"model_provider\",\n    title: \"Provider connection\",\n    description: \"Choose where model requests go and how to sign in.\",\n    icon: Shield02Icon,\n    dialogKey: \"model_provider\",\n    createConfig: (id, existing) => makeModelProviderConfig(id, existing),\n  },\n  {\n    kind: \"llm\",\n    type: \"model_config\",\n    title: \"Model preset\",\n    description: \"Pick a model and save reusable generation settings.\",\n    icon: Plant01Icon,\n    dialogKey: \"model_config\",\n    createConfig: (id, existing) => makeModelConfig(id, existing),\n  },\n  {\n    kind: \"llm\",\n    type: \"tool_config\",\n    title: \"Tool access\",\n    description: \"Choose which tools an AI step can use.\",\n    icon: Plug01Icon,\n    dialogKey: \"tool_config\",\n    createConfig: (id, existing) => makeToolProfileConfig(id, existing),\n  },\n  {\n    kind: \"validator\",\n    type: \"validator_python\",\n    title: \"Python check\",\n    description: \"Lint generated Python and filter out rows that fail.\",\n    icon: Shield02Icon,\n    dialogKey: \"validator\",\n    createConfig: (id, existing) =>\n      makeValidatorConfig(id, \"code\", \"python\", existing),\n  },\n  {\n    kind: \"validator\",\n    type: \"validator_sql\",\n    title: \"SQL check\",\n    description: \"Lint generated SQL and filter out rows that fail.\",\n    icon: Shield02Icon,\n    dialogKey: \"validator\",\n    createConfig: (id, existing) =>\n      makeValidatorConfig(id, \"code\", \"sql:sqlite\", existing),\n  },\n  {\n    kind: \"validator\",\n    type: \"validator_oxc\",\n    title: \"JS/TS check\",\n    description: \"Lint generated JavaScript or TypeScript and filter out rows that fail.\",\n    icon: Shield02Icon,\n    dialogKey: \"validator\",\n    createConfig: (id, existing) =>\n      makeValidatorConfig(id, \"oxc\", \"javascript\", existing),\n  },\n  {\n    kind: \"expression\",\n    type: \"expression\",\n    title: \"Formula\",\n    description: \"Build or transform a field using other fields.\",\n    icon: FunctionIcon,\n    dialogKey: \"expression\",\n    createConfig: (id, existing) => makeExpressionConfig(id, existing),\n  },\n  {\n    kind: \"note\",\n    type: \"markdown_note\",\n    title: \"Note\",\n    description: \"Add a note to the canvas. Notes do not affect the run.\",\n    icon: PencilEdit02Icon,\n    dialogKey: \"markdown_note\",\n    createConfig: (id, existing) => makeMarkdownNoteConfig(id, existing),\n  },\n];\n\nexport function getBlocksForKind(kind: BlockKind): BlockDefinition[] {\n  return BLOCK_DEFINITIONS.filter((block) => block.kind === kind);\n}\n\nexport function getBlockDefinition(\n  kind: BlockKind,\n  type: BlockType,\n): BlockDefinition | null {\n  return (\n    BLOCK_DEFINITIONS.find((block) => block.kind === kind && block.type === type) ??\n    null\n  );\n}\n\nexport function getBlockDefinitionForConfig(\n  config: NodeConfig | null,\n): BlockDefinition | null {\n  if (!config) {\n    return null;\n  }\n  if (config.kind === \"seed\") {\n    const seedType: Record<SeedSourceType, SeedBlockType> = {\n      hf: \"seed_hf\",\n      local: \"seed_local\",\n      unstructured: \"seed_unstructured\",\n    };\n    return getBlockDefinition(\"seed\", seedType[config.seed_source_type ?? \"hf\"]);\n  }\n  if (config.kind === \"sampler\") {\n    const samplerType =\n      config.sampler_type === \"person_from_faker\" ? \"person\" : config.sampler_type;\n    return getBlockDefinition(\"sampler\", samplerType);\n  }\n  if (config.kind === \"llm\") {\n    return getBlockDefinition(\"llm\", config.llm_type);\n  }\n  if (config.kind === \"validator\") {\n    if (config.validator_type === \"oxc\") {\n      return getBlockDefinition(\"validator\", \"validator_oxc\");\n    }\n    const isSql = config.code_lang.startsWith(\"sql:\");\n    return getBlockDefinition(\n      \"validator\",\n      isSql ? \"validator_sql\" : \"validator_python\",\n    );\n  }\n  if (config.kind === \"model_provider\") {\n    return getBlockDefinition(\"llm\", \"model_provider\");\n  }\n  if (config.kind === \"model_config\") {\n    return getBlockDefinition(\"llm\", \"model_config\");\n  }\n  if (config.kind === \"tool_config\") {\n    return getBlockDefinition(\"llm\", \"tool_config\");\n  }\n  if (config.kind === \"markdown_note\") {\n    return getBlockDefinition(\"note\", \"markdown_note\");\n  }\n  return getBlockDefinition(\"expression\", \"expression\");\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/blocks/registry.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nexport type {\n  BlockDefinition,\n  BlockDialogKey,\n  BlockGroup,\n  BlockKind,\n  BlockType,\n  SeedBlockType,\n} from \"./definitions\";\nexport {\n  BLOCK_GROUPS,\n  getBlockDefinition,\n  getBlockDefinitionForConfig,\n  getBlocksForKind,\n} from \"./definitions\";\nexport { renderBlockDialog } from \"./render-dialog\";\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/blocks/render-dialog.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { ReactElement } from \"react\";\nimport type { NodeConfig, SamplerConfig } from \"../types\";\nimport { getBlockDefinitionForConfig } from \"./definitions\";\nimport { ExpressionDialog } from \"../dialogs/expression/expression-dialog\";\nimport { LlmDialog } from \"../dialogs/llm/llm-dialog\";\nimport { ModelConfigDialog } from \"../dialogs/models/model-config-dialog\";\nimport { ModelProviderDialog } from \"../dialogs/models/model-provider-dialog\";\nimport { SeedDialog } from \"../dialogs/seed/seed-dialog\";\nimport { CategoryDialog } from \"../dialogs/samplers/category-dialog\";\nimport { DatetimeDialog } from \"../dialogs/samplers/datetime-dialog\";\nimport { BernoulliDialog } from \"../dialogs/samplers/bernoulli-dialog\";\nimport { GaussianDialog } from \"../dialogs/samplers/gaussian-dialog\";\nimport { PersonDialog } from \"../dialogs/samplers/person-dialog\";\nimport { SubcategoryDialog } from \"../dialogs/samplers/subcategory-dialog\";\nimport { TimedeltaDialog } from \"../dialogs/samplers/timedelta-dialog\";\nimport { UniformDialog } from \"../dialogs/samplers/uniform-dialog\";\nimport { UuidDialog } from \"../dialogs/samplers/uuid-dialog\";\nimport { MarkdownNoteDialog } from \"../dialogs/markdown-note/markdown-note-dialog\";\nimport { ToolProfileDialog } from \"../dialogs/tool-profile/tool-profile-dialog\";\nimport { ValidatorDialog } from \"../dialogs/validators/validator-dialog\";\n\nexport function renderBlockDialog(\n  config: NodeConfig | null,\n  open: boolean,\n  categoryOptions: SamplerConfig[],\n  modelConfigAliases: string[],\n  modelProviderOptions: string[],\n  toolProfileAliases: string[],\n  datetimeOptions: string[],\n  onUpdate: (id: string, patch: Partial<NodeConfig>) => void,\n): ReactElement | null {\n  const definition = getBlockDefinitionForConfig(config);\n  if (!definition || !config) {\n    return null;\n  }\n\n  const update = (patch: Partial<NodeConfig>) => onUpdate(config.id, patch);\n\n  switch (definition.dialogKey) {\n    case \"seed\":\n      return config.kind === \"seed\" ? (\n        <SeedDialog config={config} onUpdate={update} open={open} />\n      ) : null;\n    case \"category\":\n      return config.kind === \"sampler\" && config.sampler_type === \"category\" ? (\n        <CategoryDialog key={config.id} config={config} onUpdate={update} />\n      ) : null;\n    case \"subcategory\":\n      return config.kind === \"sampler\" && config.sampler_type === \"subcategory\" ? (\n        <SubcategoryDialog\n          config={config}\n          categoryOptions={categoryOptions}\n          onUpdate={update}\n        />\n      ) : null;\n    case \"uniform\":\n      return config.kind === \"sampler\" && config.sampler_type === \"uniform\" ? (\n        <UniformDialog config={config} onUpdate={update} />\n      ) : null;\n    case \"gaussian\":\n      return config.kind === \"sampler\" && config.sampler_type === \"gaussian\" ? (\n        <GaussianDialog config={config} onUpdate={update} />\n      ) : null;\n    case \"bernoulli\":\n      return config.kind === \"sampler\" && config.sampler_type === \"bernoulli\" ? (\n        <BernoulliDialog config={config} onUpdate={update} />\n      ) : null;\n    case \"datetime\":\n      return config.kind === \"sampler\" && config.sampler_type === \"datetime\" ? (\n        <DatetimeDialog config={config} onUpdate={update} />\n      ) : null;\n    case \"timedelta\":\n      return config.kind === \"sampler\" && config.sampler_type === \"timedelta\" ? (\n        <TimedeltaDialog\n          config={config}\n          datetimeOptions={datetimeOptions}\n          onUpdate={update}\n        />\n      ) : null;\n    case \"uuid\":\n      return config.kind === \"sampler\" && config.sampler_type === \"uuid\" ? (\n        <UuidDialog config={config} onUpdate={update} />\n      ) : null;\n    case \"person\":\n      return config.kind === \"sampler\" &&\n        (config.sampler_type === \"person\" ||\n          config.sampler_type === \"person_from_faker\") ? (\n        <PersonDialog config={config} onUpdate={update} />\n      ) : null;\n    case \"llm\":\n      return config.kind === \"llm\" ? (\n        <LlmDialog\n          config={config}\n          modelConfigAliases={modelConfigAliases}\n          modelProviderOptions={modelProviderOptions}\n          toolProfileAliases={toolProfileAliases}\n          onUpdate={update}\n        />\n      ) : null;\n    case \"model_provider\":\n      return config.kind === \"model_provider\" ? (\n        <ModelProviderDialog config={config} onUpdate={update} />\n      ) : null;\n    case \"model_config\":\n      return config.kind === \"model_config\" ? (\n        <ModelConfigDialog\n          config={config}\n          providerOptions={modelProviderOptions}\n          onUpdate={update}\n        />\n      ) : null;\n    case \"tool_config\":\n      return config.kind === \"tool_config\" ? (\n        <ToolProfileDialog config={config} onUpdate={update} />\n      ) : null;\n    case \"expression\":\n      return config.kind === \"expression\" ? (\n        <ExpressionDialog config={config} onUpdate={update} />\n      ) : null;\n    case \"validator\":\n      return config.kind === \"validator\" ? (\n        <ValidatorDialog config={config} onUpdate={update} />\n      ) : null;\n    case \"markdown_note\":\n      return config.kind === \"markdown_note\" ? (\n        <MarkdownNoteDialog config={config} onUpdate={update} />\n      ) : null;\n  }\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/components/block-sheet.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Button } from \"@/components/ui/button\";\nimport { Badge } from \"@/components/ui/badge\";\nimport { Input } from \"@/components/ui/input\";\nimport {\n  Sheet,\n  SheetContent,\n  SheetHeader,\n  SheetTitle,\n  SheetTrigger,\n} from \"@/components/ui/sheet\";\nimport {\n  ArrowLeft02Icon,\n  ArrowRight01Icon,\n  CodeIcon,\n  Copy02Icon,\n  type Database02Icon,\n  DragDropVerticalIcon,\n  DocumentAttachmentIcon,\n  PlusSignIcon,\n  Search01Icon,\n  Tick02Icon,\n  Upload01Icon,\n} from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport {\n  useCallback,\n  type DragEvent as ReactDragEvent,\n  type ReactElement,\n  useMemo,\n  useState,\n} from \"react\";\nimport { RECIPE_FLOATING_ICON_BUTTON_CLASS } from \"./recipe-floating-icon-button-class\";\nimport type { LlmType, SamplerType } from \"../types\";\nimport {\n  BLOCK_GROUPS,\n  getBlocksForKind,\n  type BlockType,\n  type SeedBlockType,\n} from \"../blocks/registry\";\nimport {\n  RECIPE_STUDIO_ONBOARDING_ICON_TONE,\n  RECIPE_STUDIO_ONBOARDING_SURFACE_TONE,\n} from \"../utils/ui-tones\";\n\ntype SheetView =\n  | \"root\"\n  | \"sampler\"\n  | \"seed\"\n  | \"llm\"\n  | \"validator\"\n  | \"expression\"\n  | \"note\"\n  | \"processor\";\ntype SheetKind =\n  | \"sampler\"\n  | \"seed\"\n  | \"llm\"\n  | \"validator\"\n  | \"expression\"\n  | \"note\";\ntype RootSheetView = Exclude<SheetView, \"root\">;\ntype RootGroup = {\n  kind: RootSheetView;\n  title: string;\n  description: string;\n  icon: typeof Database02Icon;\n};\n\ntype BlockSheetProps = {\n  container: HTMLDivElement | null;\n  sheetView: SheetView;\n  onViewChange: (sheetView: SheetView) => void;\n  open?: boolean;\n  onOpenChange?: (open: boolean) => void;\n  onAddSampler: (type: SamplerType) => void;\n  onAddSeed: (type: SeedBlockType) => void;\n  onAddLlm: (type: LlmType) => void;\n  onAddModelProvider: () => void;\n  onAddModelConfig: () => void;\n  onAddToolProfile: () => void;\n  onAddExpression: () => void;\n  onAddValidator: (\n    type: \"validator_python\" | \"validator_sql\" | \"validator_oxc\",\n  ) => void;\n  onAddMarkdownNote: () => void;\n  onOpenProcessors: () => void;\n  copied: boolean;\n  onCopy: () => void;\n  onImport: () => void;\n};\n\nexport const RECIPE_BLOCK_DND_MIME = \"application/x-recipe-studio-block\";\nexport type RecipeBlockDragPayload = {\n  kind: SheetKind;\n  type: BlockType;\n};\n\nfunction getSheetTitle(sheetView: SheetView): string {\n  if (sheetView === \"root\") {\n    return \"Add a step\";\n  }\n  if (sheetView === \"sampler\") {\n    return \"Generated fields\";\n  }\n  if (sheetView === \"seed\") {\n    return \"Source data\";\n  }\n  if (sheetView === \"expression\") {\n    return \"Formulas\";\n  }\n  if (sheetView === \"validator\") {\n    return \"Checks\";\n  }\n  if (sheetView === \"note\") {\n    return \"Notes\";\n  }\n  if (sheetView === \"processor\") {\n    return \"Processor blocks\";\n  }\n  return \"AI generation\";\n}\n\nconst VIEW_KIND: Record<SheetView, SheetKind | null> = {\n  root: null,\n  sampler: \"sampler\",\n  seed: \"seed\",\n  llm: \"llm\",\n  validator: \"validator\",\n  expression: \"expression\",\n  note: \"note\",\n  processor: null,\n};\n\nconst ROOT_GROUPS: RootGroup[] = [...BLOCK_GROUPS];\nconst ROOT_GROUPS_WITH_SEED_FIRST: RootGroup[] = [\n  ...ROOT_GROUPS.filter((group) => group.kind === \"seed\"),\n  ...ROOT_GROUPS.filter((group) => group.kind !== \"seed\"),\n];\nconst SEARCHABLE_KINDS: SheetKind[] = [\n  \"sampler\",\n  \"seed\",\n  \"llm\",\n  \"validator\",\n  \"expression\",\n  \"note\",\n];\nconst PROCESSOR_TITLE = \"Final dataset shape\";\nconst PROCESSOR_DESCRIPTION = \"Rename, reorder, or reshape the final dataset.\";\nconst SHOW_PROCESSOR_IN_BLOCK_SHEET = false;\nconst LLM_SETUP_TYPES = new Set<BlockType>([\n  \"model_provider\",\n  \"model_config\",\n  \"tool_config\",\n]);\n\nfunction BlockSheetButton({\n  icon,\n  title,\n  description,\n  onClick,\n  isActive = false,\n  draggable = false,\n  onDragStart,\n  trailing = \"chevron\",\n  disabled = false,\n  badge,\n}: {\n  icon: typeof Database02Icon;\n  title: string;\n  description: string;\n  onClick: () => void;\n  isActive?: boolean;\n  draggable?: boolean;\n  onDragStart?: (event: ReactDragEvent<HTMLButtonElement>) => void;\n  trailing?: \"chevron\" | \"drag\" | \"none\";\n  disabled?: boolean;\n  badge?: string;\n}): ReactElement {\n  return (\n    <button\n      type=\"button\"\n      onClick={disabled ? undefined : onClick}\n      disabled={disabled}\n      draggable={disabled ? false : draggable}\n      onDragStart={disabled ? undefined : onDragStart}\n      className={`flex w-full items-center gap-3 border-l-2 bg-background px-3 py-3 text-left transition ${\n        disabled ? \"cursor-not-allowed opacity-60\" : \"hover:bg-muted/35\"\n      } ${\n        isActive\n          ? \"border-emerald-500\"\n          : disabled\n            ? \"border-transparent\"\n            : \"border-transparent hover:border-border/60\"\n      } ${draggable ? \"cursor-grab active:cursor-grabbing\" : \"\"}`}\n    >\n      <div className=\"flex size-9 items-center justify-center rounded-xl text-foreground/70\">\n        <HugeiconsIcon icon={icon} className=\"size-5\" />\n      </div>\n      <div className=\"min-w-0 flex-1\">\n        <div className=\"flex items-center gap-2\">\n          <p className=\"break-words text-sm font-semibold text-foreground\">\n            {title}\n          </p>\n          {badge ? (\n            <Badge variant=\"outline\" className=\"rounded-full text-[10px]\">\n              {badge}\n            </Badge>\n          ) : null}\n        </div>\n        <p className=\"break-words text-[11px] text-muted-foreground\">\n          {description}\n        </p>\n      </div>\n      {trailing === \"chevron\" ? (\n        <HugeiconsIcon\n          icon={ArrowRight01Icon}\n          className=\"size-3.5 text-muted-foreground\"\n        />\n      ) : trailing === \"drag\" ? (\n        <HugeiconsIcon\n          icon={DragDropVerticalIcon}\n          strokeWidth={3.5}\n          className=\"size-5 text-foreground\"\n        />\n      ) : null}\n    </button>\n  );\n}\n\nexport function BlockSheet({\n  container,\n  sheetView,\n  onViewChange,\n  open,\n  onOpenChange,\n  onAddSampler,\n  onAddSeed,\n  onAddLlm,\n  onAddModelProvider,\n  onAddModelConfig,\n  onAddToolProfile,\n  onAddExpression,\n  onAddValidator,\n  onAddMarkdownNote,\n  onOpenProcessors,\n  copied,\n  onCopy,\n  onImport,\n}: BlockSheetProps): ReactElement {\n  const sheetTitle = getSheetTitle(sheetView);\n  const [uncontrolledOpen, setUncontrolledOpen] = useState(false);\n  const [search, setSearch] = useState(\"\");\n  const expressionBlocks = useMemo(() => getBlocksForKind(\"expression\"), []);\n  const noteBlocks = useMemo(() => getBlocksForKind(\"note\"), []);\n  const seedBlocks = useMemo(() => getBlocksForKind(\"seed\"), []);\n  const isControlled = typeof open === \"boolean\";\n  const sheetOpen = isControlled ? (open as boolean) : uncontrolledOpen;\n  const normalizedSearch = search.trim().toLowerCase();\n  const hasSearch = normalizedSearch.length > 0;\n  const isProcessorView = sheetView === \"processor\";\n  const isRootView = sheetView === \"root\";\n  const isScopedBlockView = !isRootView && !isProcessorView;\n\n  const setSheetOpen = (nextOpen: boolean) => {\n    if (!isControlled) {\n      setUncontrolledOpen(nextOpen);\n    }\n    onOpenChange?.(nextOpen);\n  };\n  const matchesSearch = useCallback(\n    (title: string, description: string) =>\n      title.toLowerCase().includes(normalizedSearch) ||\n      description.toLowerCase().includes(normalizedSearch),\n    [normalizedSearch],\n  );\n\n  const searchableBlocks = useMemo(\n    () => SEARCHABLE_KINDS.flatMap((kind) => getBlocksForKind(kind)),\n    [],\n  );\n  const rootSearchBlocks = useMemo(() => {\n    if (!hasSearch) {\n      return [];\n    }\n    return searchableBlocks.filter((item) =>\n      matchesSearch(item.title, item.description),\n    );\n  }, [hasSearch, matchesSearch, searchableBlocks]);\n\n  const scopedBlocks = useMemo(() => {\n    if (!isScopedBlockView) {\n      return [];\n    }\n    const blocks = getBlocksForKind(VIEW_KIND[sheetView] ?? \"sampler\");\n    if (!hasSearch) {\n      return blocks;\n    }\n    return blocks.filter((item) => matchesSearch(item.title, item.description));\n  }, [hasSearch, isScopedBlockView, matchesSearch, sheetView]);\n  const llmCreateBlocks =\n    sheetView === \"llm\"\n      ? scopedBlocks.filter((item) => !LLM_SETUP_TYPES.has(item.type))\n      : [];\n  const llmSetupBlocks =\n    sheetView === \"llm\"\n      ? scopedBlocks.filter((item) => LLM_SETUP_TYPES.has(item.type))\n      : [];\n  const featuredSeedBlock =\n    sheetView === \"seed\" && !hasSearch\n      ? scopedBlocks.find((item) => item.type === \"seed_unstructured\") ?? null\n      : null;\n  const otherSeedBlocks =\n    sheetView === \"seed\" && !hasSearch\n      ? scopedBlocks.filter((item) => item.type !== \"seed_unstructured\")\n      : scopedBlocks;\n\n  const rootGroups = useMemo(() => {\n    if (!hasSearch) {\n      return ROOT_GROUPS_WITH_SEED_FIRST;\n    }\n    return ROOT_GROUPS.filter((group) => {\n      if (matchesSearch(group.title, group.description)) {\n        return true;\n      }\n      if (group.kind === \"processor\") {\n        return matchesSearch(PROCESSOR_TITLE, PROCESSOR_DESCRIPTION);\n      }\n      return getBlocksForKind(group.kind).some((item) =>\n        matchesSearch(item.title, item.description),\n      );\n    });\n  }, [hasSearch, matchesSearch]);\n  const showNoMatches =\n    (isRootView && hasSearch && rootSearchBlocks.length === 0) ||\n    (isScopedBlockView && scopedBlocks.length === 0) ||\n    (isProcessorView &&\n      hasSearch &&\n      !matchesSearch(PROCESSOR_TITLE, PROCESSOR_DESCRIPTION));\n\n  const buildDragStart =\n    (kind: SheetKind, type: BlockType) =>\n    (event: ReactDragEvent<HTMLButtonElement>) => {\n      const payload: RecipeBlockDragPayload = { kind, type };\n      const serialized = JSON.stringify(payload);\n      event.dataTransfer.setData(RECIPE_BLOCK_DND_MIME, serialized);\n      event.dataTransfer.setData(\"text/plain\", serialized);\n      event.dataTransfer.effectAllowed = \"copy\";\n    };\n  const getTrailing = (): \"drag\" => \"drag\";\n  const onBlockClick = (kind: SheetKind, type: BlockType) => {\n    setSheetOpen(false);\n    if (kind === \"sampler\") {\n      onAddSampler(type as SamplerType);\n      return;\n    }\n    if (kind === \"seed\") {\n      onAddSeed(type as SeedBlockType);\n      return;\n    }\n    if (kind === \"llm\") {\n      if (type === \"model_provider\") {\n        onAddModelProvider();\n        return;\n      }\n      if (type === \"model_config\") {\n        onAddModelConfig();\n        return;\n      }\n      if (type === \"tool_config\") {\n        onAddToolProfile();\n        return;\n      }\n      onAddLlm(type as LlmType);\n      return;\n    }\n    if (kind === \"validator\") {\n      onAddValidator(\n        type as \"validator_python\" | \"validator_sql\" | \"validator_oxc\",\n      );\n      return;\n    }\n    if (kind === \"expression\") {\n      onAddExpression();\n      return;\n    }\n    onAddMarkdownNote();\n  };\n\n  return (\n    <div className=\"flex flex-col items-end gap-2\">\n      <Sheet\n        open={sheetOpen}\n        onOpenChange={(nextOpen) => {\n          setSheetOpen(nextOpen);\n          if (nextOpen) {\n            onViewChange(\"root\");\n            setSearch(\"\");\n          }\n        }}\n      >\n        <SheetTrigger asChild={true}>\n          <Button\n            size=\"icon\"\n            className={RECIPE_FLOATING_ICON_BUTTON_CLASS}\n            variant=\"ghost\"\n            aria-label=\"Add a step\"\n            title=\"Add a step\"\n          >\n            <HugeiconsIcon\n              icon={PlusSignIcon}\n              className=\"size-5 text-muted-foreground group-hover:text-primary\"\n            />\n          </Button>\n        </SheetTrigger>\n        <SheetContent\n          side=\"right\"\n          container={container}\n          position=\"absolute\"\n          overlayPosition=\"absolute\"\n          className=\"absolute gap-0 p-0 shadow-none\"\n          overlayClassName=\"bg-transparent pointer-events-none backdrop-blur-none supports-backdrop-filter:backdrop-blur-none\"\n        >\n          <SheetHeader className=\"px-6 py-5\">\n            <div className=\"flex items-center gap-2\">\n              {sheetView !== \"root\" && (\n                <Button\n                  type=\"button\"\n                  variant=\"ghost\"\n                  size=\"icon-sm\"\n                  onClick={() => onViewChange(\"root\")}\n                  aria-label=\"Back to step groups\"\n                  title=\"Back to step groups\"\n                >\n                  <HugeiconsIcon icon={ArrowLeft02Icon} className=\"size-4\" />\n                </Button>\n              )}\n              <SheetTitle>{sheetTitle}</SheetTitle>\n            </div>\n            <div className=\"relative mt-3\">\n              <HugeiconsIcon\n                icon={Search01Icon}\n                className=\"pointer-events-none absolute left-2.5 top-1/2 size-4 -translate-y-1/2 text-muted-foreground\"\n              />\n              <Input\n                value={search}\n                onChange={(event) => setSearch(event.target.value)}\n                placeholder=\"Search steps...\"\n                className=\"corner-squircle h-9 pl-8\"\n                aria-label=\"Search steps\"\n              />\n            </div>\n          </SheetHeader>\n          <div className=\"flex-1 min-h-0 overflow-y-auto py-4\">\n            <div className=\"mt-4 flex flex-col gap-2\">\n              {isRootView && !hasSearch && (\n                <div className={`mx-3 mb-2 rounded-2xl border px-4 py-4 ${RECIPE_STUDIO_ONBOARDING_SURFACE_TONE}`}>\n                  <div className=\"flex items-start gap-3\">\n                    <div className={`mt-0.5 flex size-9 shrink-0 items-center justify-center rounded-xl ${RECIPE_STUDIO_ONBOARDING_ICON_TONE}`}>\n                      <HugeiconsIcon\n                        icon={DocumentAttachmentIcon}\n                        className=\"size-4\"\n                      />\n                    </div>\n                    <div className=\"min-w-0 flex-1 space-y-2\">\n                      <div>\n                        <p className=\"text-sm font-semibold text-foreground\">\n                          Need a place to start?\n                        </p>\n                        <p className=\"text-xs text-muted-foreground\">\n                          Open Source data first, then add generation and checks\n                          on top of it.\n                        </p>\n                      </div>\n                      <Button\n                        type=\"button\"\n                        size=\"sm\"\n                        variant=\"ghost\"\n                        className=\"corner-squircle justify-start px-0 text-primary hover:bg-transparent hover:text-primary/80\"\n                        onClick={() => onViewChange(\"seed\")}\n                      >\n                        Start with source data\n                      </Button>\n                    </div>\n                  </div>\n                </div>\n              )}\n              {isRootView &&\n                hasSearch &&\n                rootSearchBlocks.map((item) => (\n                  <BlockSheetButton\n                    key={`${item.kind}:${item.type}`}\n                    icon={item.icon}\n                    title={item.title}\n                    description={item.description}\n                    draggable={true}\n                    onDragStart={buildDragStart(item.kind, item.type)}\n                    trailing={getTrailing()}\n                    onClick={() => onBlockClick(item.kind, item.type)}\n                  />\n                ))}\n              {isRootView &&\n                !hasSearch &&\n                rootGroups.map((item) => (\n                  <BlockSheetButton\n                    key={item.kind}\n                    icon={item.icon}\n                    title={item.title}\n                    description={item.description}\n                    draggable={item.kind === \"expression\" || item.kind === \"note\"}\n                    onDragStart={\n                      item.kind === \"expression\" && expressionBlocks[0]\n                        ? buildDragStart(\"expression\", expressionBlocks[0].type)\n                        : item.kind === \"note\" && noteBlocks[0]\n                          ? buildDragStart(\"note\", noteBlocks[0].type)\n                          : undefined\n                    }\n                    trailing={\n                      item.kind === \"expression\" || item.kind === \"note\"\n                        ? \"drag\"\n                        : \"chevron\"\n                    }\n                    onClick={() => {\n                      if (item.kind === \"seed\" && seedBlocks.length === 1) {\n                        setSheetOpen(false);\n                        onAddSeed(seedBlocks[0].type as SeedBlockType);\n                        return;\n                      }\n                      if (item.kind === \"expression\" && expressionBlocks.length === 1) {\n                        setSheetOpen(false);\n                        onAddExpression();\n                        return;\n                      }\n                      if (item.kind === \"note\" && noteBlocks.length === 1) {\n                        setSheetOpen(false);\n                        onAddMarkdownNote();\n                        return;\n                      }\n                      onViewChange(item.kind);\n                    }}\n                  />\n                ))}\n              {SHOW_PROCESSOR_IN_BLOCK_SHEET && isProcessorView && (\n                (!hasSearch ||\n                  matchesSearch(PROCESSOR_TITLE, PROCESSOR_DESCRIPTION)) && (\n                  <BlockSheetButton\n                    icon={CodeIcon}\n                    title={PROCESSOR_TITLE}\n                    description={PROCESSOR_DESCRIPTION}\n                    onClick={onOpenProcessors}\n                  />\n                )\n              )}\n              {isScopedBlockView &&\n                sheetView === \"seed\" &&\n                featuredSeedBlock && (\n                  <div className=\"pb-2\">\n                    <div className=\"px-3 pb-2\">\n                      <p className=\"text-xs font-semibold uppercase tracking-wide text-muted-foreground\">\n                        Recommended first step\n                      </p>\n                      <p className=\"text-xs text-muted-foreground\">\n                        Best when you want to turn PDFs, DOCX files, or text\n                        files into source rows.\n                      </p>\n                    </div>\n                    <BlockSheetButton\n                      icon={featuredSeedBlock.icon}\n                      title={featuredSeedBlock.title}\n                      description={featuredSeedBlock.description}\n                      draggable={true}\n                      onDragStart={buildDragStart(\n                        featuredSeedBlock.kind,\n                        featuredSeedBlock.type,\n                      )}\n                      trailing={getTrailing()}\n                      badge=\"Start here\"\n                      onClick={() =>\n                        onBlockClick(\n                          featuredSeedBlock.kind,\n                          featuredSeedBlock.type,\n                        )\n                      }\n                    />\n                  </div>\n                )}\n              {isScopedBlockView &&\n                sheetView === \"seed\" &&\n                !hasSearch &&\n                otherSeedBlocks.length > 0 && (\n                  <div className=\"px-3 pt-2 pb-2\">\n                    <p className=\"text-xs font-semibold uppercase tracking-wide text-muted-foreground\">\n                      Other source options\n                    </p>\n                    <p className=\"text-xs text-muted-foreground\">\n                      Use a dataset or structured file when your source is\n                      already tabular.\n                    </p>\n                  </div>\n                )}\n              {isScopedBlockView &&\n                sheetView === \"llm\" &&\n                llmCreateBlocks.length > 0 && (\n                  <div className=\"px-3 pb-2\">\n                    <p className=\"text-xs font-semibold uppercase tracking-wide text-muted-foreground\">\n                      Create\n                    </p>\n                    <p className=\"text-xs text-muted-foreground\">\n                      Start with the kind of output you want to generate.\n                    </p>\n                  </div>\n                )}\n              {isScopedBlockView &&\n                sheetView === \"llm\" &&\n                llmCreateBlocks.map((item) => (\n                  <BlockSheetButton\n                    key={item.type}\n                    icon={item.icon}\n                    title={item.title}\n                    description={item.description}\n                    draggable={true}\n                    onDragStart={buildDragStart(item.kind, item.type)}\n                    trailing={getTrailing()}\n                    onClick={() => onBlockClick(item.kind, item.type)}\n                  />\n                ))}\n              {isScopedBlockView &&\n                sheetView === \"llm\" &&\n                llmSetupBlocks.length > 0 && (\n                  <div className=\"px-3 pt-4 pb-2\">\n                    <p className=\"text-xs font-semibold uppercase tracking-wide text-muted-foreground\">\n                      Setup\n                    </p>\n                    <p className=\"text-xs text-muted-foreground\">\n                      Add these only when you need a new model or tool setup.\n                    </p>\n                  </div>\n                )}\n              {isScopedBlockView &&\n                sheetView === \"llm\" &&\n                llmSetupBlocks.map((item) => (\n                  <BlockSheetButton\n                    key={item.type}\n                    icon={item.icon}\n                    title={item.title}\n                    description={item.description}\n                    draggable={true}\n                    onDragStart={buildDragStart(item.kind, item.type)}\n                    trailing={getTrailing()}\n                    onClick={() => onBlockClick(item.kind, item.type)}\n                  />\n                ))}\n              {isScopedBlockView &&\n                sheetView === \"seed\" &&\n                otherSeedBlocks.map((item) => (\n                  <BlockSheetButton\n                    key={item.type}\n                    icon={item.icon}\n                    title={item.title}\n                    description={item.description}\n                    draggable={true}\n                    onDragStart={buildDragStart(item.kind, item.type)}\n                    trailing={getTrailing()}\n                    onClick={() => onBlockClick(item.kind, item.type)}\n                  />\n                ))}\n              {isScopedBlockView &&\n                sheetView !== \"llm\" &&\n                sheetView !== \"seed\" &&\n                scopedBlocks.map(\n                  (item) => (\n                    <BlockSheetButton\n                      key={item.type}\n                      icon={item.icon}\n                      title={item.title}\n                      description={item.description}\n                      draggable={true}\n                      onDragStart={buildDragStart(item.kind, item.type)}\n                      trailing={getTrailing()}\n                      onClick={() => onBlockClick(item.kind, item.type)}\n                    />\n                  ),\n                )}\n              {SHOW_PROCESSOR_IN_BLOCK_SHEET && isRootView && !hasSearch && (\n                <div className=\"px-3 pt-3\">\n                  <button\n                    type=\"button\"\n                    onClick={() => {\n                      setSheetOpen(false);\n                      onOpenProcessors();\n                    }}\n                    className=\"flex w-full items-center justify-between gap-3 rounded-xl border border-border/60 px-3 py-3 text-left transition hover:bg-muted/25\"\n                  >\n                    <div className=\"min-w-0\">\n                      <p className=\"text-sm font-medium text-foreground\">\n                        Edit final dataset shape\n                      </p>\n                      <p className=\"break-words text-xs text-muted-foreground\">\n                        Rename, reorder, or reshape your final output.\n                      </p>\n                    </div>\n                    <HugeiconsIcon\n                      icon={CodeIcon}\n                      className=\"size-4 text-muted-foreground\"\n                    />\n                  </button>\n                </div>\n              )}\n              {showNoMatches && (\n                <p className=\"px-3 py-2 text-xs text-muted-foreground\">\n                  No matching steps.\n                </p>\n              )}\n            </div>\n          </div>\n        </SheetContent>\n      </Sheet>\n      <Button\n        type=\"button\"\n        variant=\"ghost\"\n        size=\"icon\"\n        className={RECIPE_FLOATING_ICON_BUTTON_CLASS}\n        onClick={onImport}\n        aria-label=\"Paste recipe JSON\"\n        title=\"Paste recipe JSON\"\n      >\n        <HugeiconsIcon\n          icon={Upload01Icon}\n          className=\"size-5 text-muted-foreground group-hover:text-primary\"\n        />\n      </Button>\n      <Button\n        type=\"button\"\n        variant=\"ghost\"\n        size=\"icon\"\n        className={RECIPE_FLOATING_ICON_BUTTON_CLASS}\n        onClick={onCopy}\n        aria-label={copied ? \"Recipe JSON copied\" : \"Copy recipe JSON\"}\n        title={copied ? \"Recipe JSON copied\" : \"Copy recipe JSON\"}\n      >\n        <HugeiconsIcon\n          icon={copied ? Tick02Icon : Copy02Icon}\n          className=\"size-5 text-muted-foreground group-hover:text-primary\"\n        />\n      </Button>\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/components/chip-input.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Button } from \"@/components/ui/button\";\nimport { Cancel01Icon } from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport {\n  type KeyboardEvent,\n  type ReactElement,\n  useEffect,\n  useId,\n  useMemo,\n  useRef,\n  useState,\n} from \"react\";\n\ntype ChipInputProps = {\n  values: string[];\n  onAdd: (value: string) => void;\n  onRemove: (index: number) => void;\n  placeholder?: string;\n  suggestions?: string[];\n};\n\nexport function ChipInput({\n  values,\n  onAdd,\n  onRemove,\n  placeholder = \"Type and press Enter\",\n  suggestions,\n}: ChipInputProps): ReactElement {\n  const [draft, setDraft] = useState(\"\");\n  const [isWrapped, setIsWrapped] = useState(false);\n  const containerRef = useRef<HTMLDivElement | null>(null);\n  const listId = useId();\n  const suggestionSet = useMemo(\n    () => new Set((suggestions ?? []).map((value) => value.trim())),\n    [suggestions],\n  );\n\n  useEffect(() => {\n    const element = containerRef.current;\n    if (!element) {\n      return;\n    }\n    const syncWrapped = () => {\n      setIsWrapped(element.clientHeight > 44);\n    };\n    syncWrapped();\n    const observer = new ResizeObserver(syncWrapped);\n    observer.observe(element);\n    return () => observer.disconnect();\n  }, [values.length, draft]);\n\n  function addValue(rawValue: string, allowAny: boolean): void {\n    const trimmed = rawValue.trim();\n    if (!trimmed) {\n      return;\n    }\n    if (!allowAny && !suggestionSet.has(trimmed)) {\n      return;\n    }\n    onAdd(trimmed);\n    setDraft(\"\");\n  }\n\n  const handleKeyDown = (event: KeyboardEvent<HTMLInputElement>) => {\n    if (event.key === \"Enter\") {\n      event.preventDefault();\n      addValue(draft, true);\n    }\n    if (event.key === \"Backspace\" && !draft && values.length > 0) {\n      onRemove(values.length - 1);\n    }\n  };\n\n  function handleChange(nextDraft: string): void {\n    setDraft(nextDraft);\n    if (suggestionSet.has(nextDraft.trim())) {\n      addValue(nextDraft, false);\n    }\n  }\n\n  return (\n    <div\n      ref={containerRef}\n      className={`bg-input/30 border-input focus-within:border-ring focus-within:ring-ring/50 flex min-h-9 flex-wrap items-center gap-1.5 border bg-clip-padding px-1.5 py-1.5 text-sm transition-colors focus-within:ring-[3px] ${isWrapped ? \"corner-squircle rounded-xl\" : \"rounded-4xl\"}`}\n    >\n      {values.map((value, index) => (\n        <span\n          key={`${value}-${index}`}\n          className=\"bg-muted-foreground/10 text-foreground flex h-[calc(--spacing(5.5))] w-fit items-center justify-center gap-1 rounded-4xl pr-0 pl-2 text-xs font-medium whitespace-nowrap\"\n        >\n          {value}\n          <Button\n            type=\"button\"\n            variant=\"ghost\"\n            size=\"icon-xs\"\n            className=\"-ml-1 opacity-50 hover:opacity-100\"\n            onClick={() => onRemove(index)}\n          >\n            <HugeiconsIcon\n              icon={Cancel01Icon}\n              strokeWidth={2}\n              className=\"pointer-events-none\"\n            />\n          </Button>\n        </span>\n      ))}\n      <input\n        className=\"nodrag min-w-16 flex-1 bg-transparent text-sm outline-none placeholder:text-muted-foreground\"\n        placeholder={values.length === 0 ? placeholder : \"\"}\n        value={draft}\n        list={suggestions && suggestions.length > 0 ? listId : undefined}\n        onChange={(event) => handleChange(event.target.value)}\n        onBlur={() => addValue(draft, false)}\n        onKeyDown={handleKeyDown}\n      />\n      {suggestions && suggestions.length > 0 && (\n        <datalist id={listId}>\n          {suggestions.map((value) => (\n            <option key={value} value={value} />\n          ))}\n        </datalist>\n      )}\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/components/controls/layout-controls.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { type ReactElement, useCallback } from \"react\";\nimport {\n  Panel,\n  useReactFlow,\n  useUpdateNodeInternals,\n} from \"@xyflow/react\";\nimport { Button } from \"@/components/ui/button\";\nimport { getFitNodeIdsIgnoringNotes } from \"../../utils/graph/fit-view\";\n\ntype LayoutControlsProps = {\n  direction: \"LR\" | \"TB\";\n  onLayout: () => void;\n  onToggleDirection: () => void;\n};\n\nexport function LayoutControls({\n  direction,\n  onLayout,\n  onToggleDirection,\n}: LayoutControlsProps): ReactElement {\n  const { fitView, getNodes } = useReactFlow();\n  const updateNodeInternals = useUpdateNodeInternals();\n\n  const refreshNodeInternals = useCallback(() => {\n    const nodeIds = getNodes().map((node) => node.id);\n    if (nodeIds.length > 0) {\n      updateNodeInternals(nodeIds);\n    }\n  }, [getNodes, updateNodeInternals]);\n\n  const handleLayout = useCallback(() => {\n    onLayout();\n    requestAnimationFrame(() => {\n      refreshNodeInternals();\n      requestAnimationFrame(() => {\n        fitView({\n          duration: 250,\n          nodes: getFitNodeIdsIgnoringNotes(getNodes()),\n        });\n      });\n    });\n  }, [fitView, getNodes, onLayout, refreshNodeInternals]);\n\n  const handleToggleDirection = useCallback(() => {\n    onToggleDirection();\n    requestAnimationFrame(() => {\n      onLayout();\n      requestAnimationFrame(() => {\n        refreshNodeInternals();\n        requestAnimationFrame(() => {\n          fitView({\n            duration: 250,\n            nodes: getFitNodeIdsIgnoringNotes(getNodes()),\n          });\n        });\n      });\n    });\n  }, [fitView, getNodes, onLayout, onToggleDirection, refreshNodeInternals]);\n\n  return (\n    <Panel position=\"top-left\" className=\"m-3 flex items-center gap-2\">\n      <Button size=\"sm\" className=\"corner-squircle\" variant=\"secondary\" onClick={handleLayout}>\n        Auto layout\n      </Button>\n      <Button size=\"sm\" className=\"corner-squircle\" variant=\"outline\" onClick={handleToggleDirection}>\n        {direction}\n      </Button>\n    </Panel>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/components/controls/run-validate-floating-controls.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { CookBookIcon, TestTube01Icon } from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport type { ReactElement } from \"react\";\nimport { Button } from \"@/components/ui/button\";\nimport type { RecipeExecutionKind } from \"../../execution-types\";\n\ntype RunValidateFloatingControlsProps = {\n  runBusy: boolean;\n  runDialogKind: RecipeExecutionKind;\n  validateLoading: boolean;\n  executionLocked: boolean;\n  onOpenRunDialog: (kind: RecipeExecutionKind) => void;\n  onValidate: () => void;\n};\n\nexport function RunValidateFloatingControls({\n  runBusy,\n  runDialogKind,\n  validateLoading,\n  executionLocked,\n  onOpenRunDialog,\n  onValidate,\n}: RunValidateFloatingControlsProps): ReactElement {\n  return (\n    <div className=\"pointer-events-none absolute inset-x-0 bottom-3 z-20 flex justify-center\">\n      <div className=\"pointer-events-auto flex items-center gap-2\">\n        <Button\n          type=\"button\"\n          className=\"corner-squircle h-11 px-5\"\n          onClick={() => onOpenRunDialog(runDialogKind)}\n          disabled={runBusy}\n        >\n          <HugeiconsIcon icon={CookBookIcon} className=\"size-4\" />\n          {runBusy ? \"Running...\" : \"Run\"}\n        </Button>\n        <Button\n          type=\"button\"\n          variant=\"outline\"\n          className=\"corner-squircle h-11 px-5\"\n          onClick={onValidate}\n          disabled={validateLoading || executionLocked}\n        >\n          <HugeiconsIcon icon={TestTube01Icon} className=\"size-4\" />\n          {validateLoading ? \"Checking...\" : \"Check\"}\n        </Button>\n      </div>\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/components/controls/viewport-controls.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { type ReactElement, useCallback } from \"react\";\nimport { Lock, LockOpen, Maximize2, Minus, Plus } from \"lucide-react\";\nimport { Panel, useReactFlow } from \"@xyflow/react\";\nimport { Button } from \"@/components/ui/button\";\nimport { getFitNodeIdsIgnoringNotes } from \"../../utils/graph/fit-view\";\nimport { RECIPE_FLOATING_ICON_BUTTON_CLASS } from \"../recipe-floating-icon-button-class\";\n\ntype ViewportControlsProps = {\n  interactive: boolean;\n  lockDisabled?: boolean;\n  onToggleInteractive: () => void;\n};\n\nexport function ViewportControls({\n  interactive,\n  lockDisabled = false,\n  onToggleInteractive,\n}: ViewportControlsProps): ReactElement {\n  const { zoomIn, zoomOut, fitView, getNodes } = useReactFlow();\n\n  const handleZoomIn = useCallback(() => {\n    zoomIn({ duration: 150 });\n  }, [zoomIn]);\n\n  const handleZoomOut = useCallback(() => {\n    zoomOut({ duration: 150 });\n  }, [zoomOut]);\n\n  const handleFitView = useCallback(() => {\n    fitView({\n      duration: 250,\n      nodes: getFitNodeIdsIgnoringNotes(getNodes()),\n    });\n  }, [fitView, getNodes]);\n\n  return (\n    <Panel position=\"bottom-left\" className=\"m-3 flex items-center gap-2\">\n      <Button\n        type=\"button\"\n        variant=\"ghost\"\n        size=\"icon\"\n        className={RECIPE_FLOATING_ICON_BUTTON_CLASS}\n        onClick={handleZoomIn}\n        aria-label=\"Zoom in\"\n      >\n        <Plus className=\"size-4\" />\n      </Button>\n      <Button\n        type=\"button\"\n        variant=\"ghost\"\n        size=\"icon\"\n        className={RECIPE_FLOATING_ICON_BUTTON_CLASS}\n        onClick={handleZoomOut}\n        aria-label=\"Zoom out\"\n      >\n        <Minus className=\"size-4\" />\n      </Button>\n      <Button\n        type=\"button\"\n        variant=\"ghost\"\n        size=\"icon\"\n        className={RECIPE_FLOATING_ICON_BUTTON_CLASS}\n        onClick={handleFitView}\n        aria-label=\"Fit view\"\n      >\n        <Maximize2 className=\"size-4\" />\n      </Button>\n      <Button\n        type=\"button\"\n        variant=\"ghost\"\n        size=\"icon\"\n        className={RECIPE_FLOATING_ICON_BUTTON_CLASS}\n        disabled={lockDisabled}\n        onClick={onToggleInteractive}\n        aria-label={interactive ? \"Lock interaction\" : \"Unlock interaction\"}\n      >\n        {interactive ? <LockOpen className=\"size-4\" /> : <Lock className=\"size-4\" />}\n      </Button>\n    </Panel>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/components/executions/execution-columns-tab.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { ReactElement } from \"react\";\nimport {\n  Table,\n  TableBody,\n  TableCell,\n  TableHead,\n  TableHeader,\n  TableRow,\n} from \"@/components/ui/table\";\nimport type { AnalysisColumnStat } from \"./executions-view-helpers\";\n\ntype ExecutionColumnsTabProps = {\n  analysisColumns: AnalysisColumnStat[];\n};\n\nexport function ExecutionColumnsTab({\n  analysisColumns,\n}: ExecutionColumnsTabProps): ReactElement {\n  return (\n    <div className=\"mt-3 rounded-xl border p-3\">\n      <p className=\"mb-2 text-sm font-semibold\">Column statistics</p>\n      {analysisColumns.length === 0 ? (\n        <p className=\"text-xs text-muted-foreground\">No column statistics yet.</p>\n      ) : (\n        <Table>\n          <TableHeader>\n            <TableRow>\n              <TableHead>Column</TableHead>\n              <TableHead>Type</TableHead>\n              <TableHead>Data type</TableHead>\n              <TableHead>Unique</TableHead>\n              <TableHead>Nulls</TableHead>\n              <TableHead>Input tok avg</TableHead>\n              <TableHead>Output tok avg</TableHead>\n            </TableRow>\n          </TableHeader>\n          <TableBody>\n            {analysisColumns.map((column) => (\n              <TableRow key={column.column_name}>\n                <TableCell>{column.column_name}</TableCell>\n                <TableCell>{column.column_type}</TableCell>\n                <TableCell>{column.simple_dtype}</TableCell>\n                <TableCell>{column.num_unique ?? \"--\"}</TableCell>\n                <TableCell>{column.num_null ?? \"--\"}</TableCell>\n                <TableCell>{column.input_tokens_mean ?? \"--\"}</TableCell>\n                <TableCell>{column.output_tokens_mean ?? \"--\"}</TableCell>\n              </TableRow>\n            ))}\n          </TableBody>\n        </Table>\n      )}\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/components/executions/execution-data-tab.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { ReactElement } from \"react\";\nimport type { ColumnDef } from \"@tanstack/react-table\";\nimport { Button } from \"@/components/ui/button\";\nimport { DataTable } from \"@/components/ui/data-table\";\nimport {\n  DropdownMenu,\n  DropdownMenuCheckboxItem,\n  DropdownMenuContent,\n  DropdownMenuLabel,\n  DropdownMenuTrigger,\n} from \"@/components/ui/dropdown-menu\";\nimport { cn } from \"@/lib/utils\";\nimport { isExecutionInProgress } from \"../../executions/execution-helpers\";\nimport type { RecipeExecutionRecord } from \"../../execution-types\";\nimport { hasExpandableTextCell } from \"./executions-view-helpers\";\n\ntype ExecutionDataTabProps = {\n  execution: RecipeExecutionRecord;\n  datasetColumnNames: string[];\n  hiddenDatasetColumns: string[];\n  canPageDataset: boolean;\n  currentDatasetPage: number;\n  totalPages: number;\n  tableColumns: ColumnDef<Record<string, unknown>>[];\n  datasetRowsForTable: Record<string, unknown>[];\n  visibleDatasetColumnNames: string[];\n  expandedDatasetRows: Record<string, boolean>;\n  selectedExecutionIdSafe: string | null;\n  onSetHiddenColumns: (updater: (current: string[]) => string[]) => void;\n  onPrevPage: () => void;\n  onNextPage: () => void;\n  onToggleRowExpanded: (rowId: string) => void;\n};\n\nexport function ExecutionDataTab({\n  execution,\n  datasetColumnNames,\n  hiddenDatasetColumns,\n  canPageDataset,\n  currentDatasetPage,\n  totalPages,\n  tableColumns,\n  datasetRowsForTable,\n  visibleDatasetColumnNames,\n  expandedDatasetRows,\n  selectedExecutionIdSafe,\n  onSetHiddenColumns,\n  onPrevPage,\n  onNextPage,\n  onToggleRowExpanded,\n}: ExecutionDataTabProps): ReactElement {\n  return (\n    <div className=\"mt-3\">\n      <div className=\"mb-2 flex flex-wrap items-center justify-between gap-2\">\n        <p className=\"text-sm font-semibold\">Dataset sample</p>\n        <div className=\"flex items-center gap-2 text-xs text-muted-foreground\">\n          {datasetColumnNames.length > 0 && (\n            <DropdownMenu>\n              <DropdownMenuTrigger asChild>\n                <Button type=\"button\" size=\"sm\" variant=\"outline\">\n                  Columns\n                </Button>\n              </DropdownMenuTrigger>\n              <DropdownMenuContent align=\"end\">\n                <DropdownMenuLabel>Visible columns</DropdownMenuLabel>\n                {datasetColumnNames.map((columnName) => (\n                  <DropdownMenuCheckboxItem\n                    key={columnName}\n                    checked={!hiddenDatasetColumns.includes(columnName)}\n                    onSelect={(event) => {\n                      event.preventDefault();\n                    }}\n                    onCheckedChange={(checked) => {\n                      onSetHiddenColumns((currentColumns) => {\n                        if (checked) {\n                          return currentColumns.filter((name) => name !== columnName);\n                        }\n                        return [...currentColumns, columnName];\n                      });\n                    }}\n                  >\n                    {columnName}\n                  </DropdownMenuCheckboxItem>\n                ))}\n              </DropdownMenuContent>\n            </DropdownMenu>\n          )}\n          {canPageDataset && (\n            <>\n              <span>\n                Page {currentDatasetPage}/{totalPages}\n              </span>\n              <Button\n                type=\"button\"\n                size=\"sm\"\n                variant=\"outline\"\n                disabled={\n                  isExecutionInProgress(execution.status) || currentDatasetPage <= 1\n                }\n                onClick={onPrevPage}\n              >\n                Prev\n              </Button>\n              <Button\n                type=\"button\"\n                size=\"sm\"\n                variant=\"outline\"\n                disabled={\n                  isExecutionInProgress(execution.status) ||\n                  currentDatasetPage >= totalPages\n                }\n                onClick={onNextPage}\n              >\n                Next\n              </Button>\n            </>\n          )}\n        </div>\n      </div>\n      {execution.dataset.length === 0 ? (\n        <p className=\"text-xs text-muted-foreground\">No rows returned.</p>\n      ) : tableColumns.length === 0 ? (\n        <p className=\"text-xs text-muted-foreground\">\n          All columns hidden. Use Columns to show at least one.\n        </p>\n      ) : (\n        <div className=\"max-h-[55vh] overflow-auto\">\n          <DataTable\n            columns={tableColumns}\n            data={datasetRowsForTable}\n            getRowClassName={(row, _rowIndex, rowId) => {\n              const canExpand = hasExpandableTextCell(row, visibleDatasetColumnNames);\n              if (!canExpand) {\n                return undefined;\n              }\n              return cn(\n                \"cursor-pointer\",\n                expandedDatasetRows[rowId] ? \"bg-primary/[0.05]\" : \"hover:bg-primary/[0.06]\",\n              );\n            }}\n            onRowClick={(row, _rowIndex, rowId) => {\n              const canExpand = hasExpandableTextCell(row, visibleDatasetColumnNames);\n              if (!canExpand || !selectedExecutionIdSafe) {\n                return;\n              }\n              onToggleRowExpanded(rowId);\n            }}\n          />\n        </div>\n      )}\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/components/executions/execution-overview-tab.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { ReactElement, RefObject, UIEvent } from \"react\";\nimport {\n  Database01Icon,\n  Database02Icon,\n  Flag02Icon,\n} from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport { Badge } from \"@/components/ui/badge\";\nimport { Button } from \"@/components/ui/button\";\nimport {\n  Table,\n  TableBody,\n  TableCell,\n  TableHead,\n  TableHeader,\n  TableRow,\n} from \"@/components/ui/table\";\nimport { isExecutionInProgress } from \"../../executions/execution-helpers\";\nimport type { RecipeExecutionRecord } from \"../../execution-types\";\nimport type { ModelUsageRow } from \"./executions-view-helpers\";\nimport { formatMetricValue } from \"./executions-view-helpers\";\n\ntype ExecutionOverviewTabProps = {\n  execution: RecipeExecutionRecord;\n  showSummaryCards: boolean;\n  recordsMetric: number | null;\n  totalMetric: number | null;\n  runDuration: string;\n  columnCount: number;\n  llmColumnCount: number;\n  nullRate: number | null;\n  sideEffects: string[];\n  lowUniquenessColumns: string[];\n  modelUsageRows: ModelUsageRow[];\n  terminalLines: string[];\n  terminalRef: RefObject<HTMLDivElement | null>;\n  onTerminalScroll: (event: UIEvent<HTMLDivElement>) => void;\n  canPublish: boolean;\n  onOpenPublish: () => void;\n};\n\nexport function ExecutionOverviewTab({\n  execution,\n  showSummaryCards,\n  recordsMetric,\n  totalMetric,\n  runDuration,\n  columnCount,\n  llmColumnCount,\n  nullRate,\n  sideEffects,\n  lowUniquenessColumns,\n  modelUsageRows,\n  terminalLines,\n  terminalRef,\n  onTerminalScroll,\n  canPublish,\n  onOpenPublish,\n}: ExecutionOverviewTabProps): ReactElement {\n  return (\n    <div className=\"mt-3 space-y-3\">\n      {showSummaryCards && (\n        <div className=\"space-y-3\">\n          {canPublish && (\n            <div className=\"flex flex-col gap-3 rounded-xl border border-border/60 bg-card/55 p-3 sm:flex-row sm:items-center sm:justify-between\">\n              <div className=\"space-y-1\">\n                <p className=\"text-sm font-medium text-foreground\">Next step</p>\n                <p className=\"text-xs text-muted-foreground\">\n                  This run is complete. Publish the generated dataset to Hugging Face.\n                </p>\n              </div>\n              <Button type=\"button\" variant=\"outline\" size=\"sm\" onClick={onOpenPublish}>\n                Publish to Hugging Face\n              </Button>\n            </div>\n          )}\n          <div className=\"grid gap-3 md:grid-cols-2\">\n            <div className=\"h-full rounded-xl border border-border/60 bg-card/55 p-3\">\n              <div className=\"mb-2 flex items-center justify-between\">\n                <p className=\"text-xs text-muted-foreground\">Run summary</p>\n                <HugeiconsIcon\n                  icon={Database01Icon}\n                  className=\"size-4 text-muted-foreground\"\n                />\n              </div>\n              <div className=\"space-y-1.5 text-xs\">\n                <p className=\"flex items-center justify-between gap-3\">\n                  <span className=\"text-muted-foreground\">Records</span>\n                  <span className=\"font-semibold\">\n                    {formatMetricValue(recordsMetric)} / {formatMetricValue(totalMetric)}\n                  </span>\n                </p>\n                <p className=\"flex items-center justify-between gap-3\">\n                  <span className=\"text-muted-foreground\">Duration</span>\n                  <span className=\"font-semibold\">{runDuration}</span>\n                </p>\n                <p className=\"flex items-center justify-between gap-3\">\n                  <span className=\"text-muted-foreground\">Columns analyzed</span>\n                  <span className=\"font-semibold\">{formatMetricValue(columnCount)}</span>\n                </p>\n                <p className=\"flex items-center justify-between gap-3\">\n                  <span className=\"text-muted-foreground\">Final stage</span>\n                  <span className=\"truncate font-semibold\">{execution.stage ?? \"--\"}</span>\n                </p>\n              </div>\n            </div>\n            <div className=\"h-full rounded-xl border border-border/60 bg-card/55 p-3\">\n              <div className=\"mb-2 flex items-center justify-between\">\n                <p className=\"text-xs text-muted-foreground\">Insights</p>\n                <HugeiconsIcon\n                  icon={Database02Icon}\n                  className=\"size-4 text-muted-foreground\"\n                />\n              </div>\n              <div className=\"space-y-1.5 text-xs\">\n                <p className=\"flex items-center justify-between gap-3\">\n                  <span className=\"text-muted-foreground\">LLM columns</span>\n                  <span className=\"font-semibold\">{formatMetricValue(llmColumnCount)}</span>\n                </p>\n                <p className=\"flex items-center justify-between gap-3\">\n                  <span className=\"text-muted-foreground\">Null rate</span>\n                  <span className=\"font-semibold\">{nullRate?.toFixed(1) ?? \"--\"}%</span>\n                </p>\n                <p className=\"flex items-center justify-between gap-3\">\n                  <span className=\"text-muted-foreground\">Side-effect columns</span>\n                  <span className=\"font-semibold\">{formatMetricValue(sideEffects.length)}</span>\n                </p>\n                {sideEffects.length > 0 && (\n                  <div className=\"pt-0.5\">\n                    <div className=\"flex flex-wrap gap-1.5\">\n                      {sideEffects.map((name) => (\n                        <Badge key={name} variant=\"outline\">\n                          {name}\n                        </Badge>\n                      ))}\n                    </div>\n                  </div>\n                )}\n                <p className=\"flex items-center justify-between gap-3\">\n                  <span className=\"text-muted-foreground\">Low uniqueness flags</span>\n                  <span className=\"font-semibold\">\n                    {formatMetricValue(lowUniquenessColumns.length)}\n                  </span>\n                </p>\n                {lowUniquenessColumns.length > 0 && (\n                  <div className=\"pt-0.5\">\n                    <div className=\"flex flex-wrap gap-1.5\">\n                      {lowUniquenessColumns.slice(0, 3).map((name) => (\n                        <Badge key={name} variant=\"outline\">\n                          {name}\n                        </Badge>\n                      ))}\n                      {lowUniquenessColumns.length > 3 && (\n                        <Badge variant=\"outline\">\n                          +{lowUniquenessColumns.length - 3} more\n                        </Badge>\n                      )}\n                    </div>\n                  </div>\n                )}\n              </div>\n            </div>\n          </div>\n          <div className=\"rounded-xl border border-border/60 bg-card/55 p-3\">\n            <div className=\"mb-2 flex items-center justify-between\">\n              <p className=\"text-xs text-muted-foreground\">Model usage</p>\n              <HugeiconsIcon icon={Flag02Icon} className=\"size-4 text-muted-foreground\" />\n            </div>\n            {modelUsageRows.length === 0 ? (\n              <p className=\"text-xs text-muted-foreground\">No model usage yet.</p>\n            ) : (\n              <div className=\"overflow-hidden rounded-lg border border-border/60 bg-card/50\">\n                <Table>\n                  <TableHeader>\n                    <TableRow>\n                      <TableHead>Model</TableHead>\n                      <TableHead className=\"text-right\">Input</TableHead>\n                      <TableHead className=\"text-right\">Output</TableHead>\n                    </TableRow>\n                  </TableHeader>\n                  <TableBody>\n                    {modelUsageRows.map((usage) => (\n                      <TableRow key={usage.model}>\n                        <TableCell className=\"max-w-[320px] truncate\">{usage.model}</TableCell>\n                        <TableCell className=\"text-right\">\n                          {formatMetricValue(usage.input)}\n                        </TableCell>\n                        <TableCell className=\"text-right\">\n                          {formatMetricValue(usage.output)}\n                        </TableCell>\n                      </TableRow>\n                    ))}\n                  </TableBody>\n                </Table>\n              </div>\n            )}\n          </div>\n        </div>\n      )}\n      <div className=\"overflow-hidden rounded-xl corner-squircle border\">\n        <div className=\"flex items-center justify-between border-b px-3 py-2\">\n          <p className=\"text-sm font-semibold\">Terminal output</p>\n          <p className=\"text-xs text-muted-foreground\">{terminalLines.length} lines</p>\n        </div>\n        <div\n          ref={terminalRef}\n          className=\"max-h-72 overflow-auto bg-zinc-900/80 px-3 py-2 font-mono text-xs text-zinc-200\"\n          onScroll={onTerminalScroll}\n        >\n          {terminalLines.length === 0 ? (\n            <p className=\"text-zinc-400\">\n              {isExecutionInProgress(execution.status)\n                ? \"Waiting for logs...\"\n                : \"No logs captured.\"}\n            </p>\n          ) : (\n            terminalLines.map((line, index) => (\n              <p\n                key={`${index}-${line.slice(0, 24)}`}\n                className=\"whitespace-pre-wrap break-words leading-relaxed\"\n              >\n                {line}\n              </p>\n            ))\n          )}\n        </div>\n      </div>\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/components/executions/execution-raw-tab.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { ReactElement } from \"react\";\n\ntype ExecutionRawTabProps = {\n  rawExecution: Record<string, unknown> | null;\n};\n\nexport function ExecutionRawTab({\n  rawExecution,\n}: ExecutionRawTabProps): ReactElement {\n  return (\n    <div className=\"mt-3 rounded-xl border p-3\">\n      <p className=\"mb-2 text-sm font-semibold\">Raw execution</p>\n      <pre className=\"max-h-96 overflow-auto rounded-md bg-muted/40 p-3 text-xs\">\n        {JSON.stringify(rawExecution, null, 2)}\n      </pre>\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/components/executions/execution-sidebar.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { ReactElement } from \"react\";\nimport { Badge } from \"@/components/ui/badge\";\nimport { cn } from \"@/lib/utils\";\nimport type { RecipeExecutionRecord } from \"../../execution-types\";\nimport {\n  executionLabel,\n  isExecutionInProgress,\n  normalizeRunName,\n} from \"../../executions/execution-helpers\";\nimport {\n  formatStatus,\n  formatTimestamp,\n  statusRightBorder,\n  statusTone,\n} from \"./executions-view-helpers\";\n\ntype ExecutionSidebarProps = {\n  executions: RecipeExecutionRecord[];\n  selectedExecutionId: string | null;\n  onSelectExecution: (id: string) => void;\n};\n\nexport function ExecutionSidebar({\n  executions,\n  selectedExecutionId,\n  onSelectExecution,\n}: ExecutionSidebarProps): ReactElement {\n  return (\n    <aside className=\"w-72 shrink-0 border-r border-border/60 bg-card/20\">\n      <div className=\"flex items-center justify-between  border-border/60 px-3 py-2\">\n        <p className=\"text-xs font-semibold uppercase text-muted-foreground\">\n          Runs\n        </p>\n      </div>\n      <div className=\"h-[calc(100%-45px)] space-y-2 overflow-auto p-2\">\n        {executions.length === 0 ? (\n          <div className=\"rounded-xl border border-dashed border-border/60 p-3 text-xs text-muted-foreground\">\n            No runs yet.\n          </div>\n        ) : (\n          executions.map((execution) => {\n            const title =\n              execution.kind === \"full\"\n                ? (normalizeRunName(execution.run_name) ??\n                  executionLabel(execution.kind))\n                : executionLabel(execution.kind);\n            return (\n              <button\n                key={execution.id}\n                type=\"button\"\n                onClick={() => onSelectExecution(execution.id)}\n                className={cn(\n                  \"w-full rounded-xl corner-squircle border border-r-2 border-border/60 bg-card/60 p-3 text-left transition-colors\",\n                  selectedExecutionId === execution.id\n                    ? \"border-primary/35 bg-primary/[0.045]\"\n                    : \"hover:bg-muted/25\",\n                  statusRightBorder(execution.status),\n                )}\n              >\n                <div className=\"mb-2 flex items-center justify-between gap-2\">\n                  <p className=\"truncate text-sm font-medium\">\n                    {title}\n                  </p>\n                  <Badge\n                    variant=\"outline\"\n                    className={cn(\"capitalize text-[11px]\", statusTone(execution.status))}\n                  >\n                    {formatStatus(execution.status)}\n                  </Badge>\n                </div>\n                <p className=\"text-xs text-muted-foreground\">{execution.rows} rows</p>\n                {isExecutionInProgress(execution.status) &&\n                  typeof execution.batch?.total === \"number\" &&\n                  execution.batch.total > 1 && (\n                    <p className=\"text-xs text-muted-foreground\">\n                      Batch {execution.batch.idx ?? \"--\"}/{execution.batch.total}\n                    </p>\n                  )}\n                <p className=\"text-xs text-muted-foreground\">\n                  {formatTimestamp(execution.createdAt)}\n                </p>\n              </button>\n            );\n          })\n        )}\n      </div>\n    </aside>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/components/executions/executions-view-helpers.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type {\n  RecipeExecutionAnalysis,\n  RecipeExecutionStatus,\n} from \"../../execution-types\";\nimport { isExecutionInProgress } from \"../../executions/execution-helpers\";\nimport { resolveImagePreview } from \"../../utils/image-preview\";\n\nexport type AnalysisColumnStat = {\n  column_name: string;\n  column_type: string;\n  simple_dtype: string;\n  num_unique: number | null;\n  num_null: number | null;\n  input_tokens_mean: number | null;\n  output_tokens_mean: number | null;\n};\n\nexport type ModelUsageRow = {\n  model: string;\n  input: number | null;\n  output: number | null;\n};\n\nexport const PREVIEW_DATASET_PAGE_SIZE = 20;\nexport const TERMINAL_STICKY_BOTTOM_THRESHOLD_PX = 24;\n\nexport function formatTimestamp(value: number): string {\n  return new Date(value).toLocaleString();\n}\n\nexport function formatCellValue(value: unknown): string {\n  if (value === null || value === undefined) {\n    return \"--\";\n  }\n  if (typeof value === \"string\") {\n    return value;\n  }\n  if (typeof value === \"number\" || typeof value === \"boolean\") {\n    return String(value);\n  }\n  try {\n    return JSON.stringify(value);\n  } catch {\n    return String(value);\n  }\n}\n\nexport function isExpandableCellValue(value: string): boolean {\n  return value.length > 180;\n}\n\nexport function truncateCellValue(value: string): string {\n  if (value.length <= 180) {\n    return value;\n  }\n  return `${value.slice(0, 180).trimEnd()}...`;\n}\n\nexport function hasExpandableTextCell(\n  row: Record<string, unknown>,\n  visibleColumnNames: string[],\n): boolean {\n  return visibleColumnNames.some((columnName) => {\n    if (resolveImagePreview(row[columnName])) {\n      return false;\n    }\n    return isExpandableCellValue(formatCellValue(row[columnName]));\n  });\n}\n\nfunction parseNumber(value: unknown): number | null {\n  return typeof value === \"number\" && Number.isFinite(value) ? value : null;\n}\n\nfunction parseString(value: unknown): string {\n  return typeof value === \"string\" && value.length > 0 ? value : \"--\";\n}\n\nexport function parseAnalysisColumns(\n  analysis: RecipeExecutionAnalysis | null,\n): AnalysisColumnStat[] {\n  const items = Array.isArray(analysis?.column_statistics)\n    ? analysis.column_statistics\n    : [];\n  return items\n    .map((item) => {\n      if (!item || typeof item !== \"object\" || Array.isArray(item)) {\n        return null;\n      }\n      const row = item as Record<string, unknown>;\n      return {\n        column_name: parseString(row.column_name),\n        column_type: parseString(row.column_type),\n        simple_dtype: parseString(row.simple_dtype),\n        num_unique: parseNumber(row.num_unique),\n        num_null: parseNumber(row.num_null),\n        input_tokens_mean: parseNumber(row.input_tokens_mean),\n        output_tokens_mean: parseNumber(row.output_tokens_mean),\n      };\n    })\n    .filter((item): item is AnalysisColumnStat => item !== null);\n}\n\nexport function statusTone(status: RecipeExecutionStatus): string {\n  if (status === \"completed\") {\n    return \"border-emerald-500/30 text-emerald-700 dark:text-emerald-300\";\n  }\n  if (status === \"error\" || status === \"cancelled\") {\n    return \"border-red-500/30 text-red-700 dark:text-red-300\";\n  }\n  if (isExecutionInProgress(status)) {\n    return \"border-amber-500/30 text-amber-700 dark:text-amber-300\";\n  }\n  return \"border-border/60 text-muted-foreground\";\n}\n\nexport function statusRightBorder(status: RecipeExecutionStatus): string {\n  if (status === \"completed\") {\n    return \"border-r-emerald-500/40\";\n  }\n  if (status === \"error\" || status === \"cancelled\") {\n    return \"border-r-red-500/40\";\n  }\n  if (isExecutionInProgress(status)) {\n    return \"border-r-amber-500/40\";\n  }\n  return \"border-r-border/50\";\n}\n\nexport function formatStatus(status: RecipeExecutionStatus): string {\n  if (status === \"cancelled\") {\n    return \"cancelled\";\n  }\n  return status;\n}\n\nexport function formatPercent(value: number | null | undefined): string {\n  if (typeof value !== \"number\" || Number.isNaN(value)) {\n    return \"--\";\n  }\n  return `${value.toFixed(1)}%`;\n}\n\nexport function formatDuration(startedAt: number, finishedAt: number | null): string {\n  if (!finishedAt || finishedAt <= startedAt) {\n    return \"--\";\n  }\n  const seconds = Math.round((finishedAt - startedAt) / 1000);\n  return `${seconds}s`;\n}\n\nexport function formatMetricValue(value: number | null | undefined): string {\n  if (typeof value !== \"number\" || Number.isNaN(value)) {\n    return \"--\";\n  }\n  return value.toLocaleString();\n}\n\nexport function parseModelUsageRows(\n  value: Record<string, unknown> | null,\n): ModelUsageRow[] {\n  if (!value) {\n    return [];\n  }\n  return Object.entries(value)\n    .map(([name, data]) => {\n      if (!data || typeof data !== \"object\" || Array.isArray(data)) {\n        return null;\n      }\n      const modelObj = data as Record<string, unknown>;\n      const tokens =\n        modelObj.tokens &&\n        typeof modelObj.tokens === \"object\" &&\n        !Array.isArray(modelObj.tokens)\n          ? (modelObj.tokens as Record<string, unknown>)\n          : null;\n      const modelName =\n        typeof modelObj.model === \"string\" && modelObj.model.length > 0\n          ? modelObj.model\n          : name;\n      return {\n        model: modelName,\n        input: parseNumber(tokens?.input),\n        output: parseNumber(tokens?.output),\n      };\n    })\n    .filter((item): item is ModelUsageRow => item !== null);\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/components/executions/executions-view.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { useEffect, useMemo, useRef, useState, type ReactElement } from \"react\";\nimport type { ColumnDef } from \"@tanstack/react-table\";\nimport {\n  CheckmarkCircle02Icon,\n  Flag02Icon,\n  Share08Icon,\n} from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport { publishRecipeJob } from \"../../api\";\nimport { Badge } from \"@/components/ui/badge\";\nimport { Button } from \"@/components/ui/button\";\nimport { Progress } from \"@/components/ui/progress\";\nimport { Tabs, TabsContent, TabsList, TabsTrigger } from \"@/components/ui/tabs\";\nimport { cn } from \"@/lib/utils\";\nimport { resolveImagePreview } from \"../../utils/image-preview\";\nimport type {\n  RecipeExecutionRecord,\n} from \"../../execution-types\";\nimport { isExecutionInProgress } from \"../../executions/execution-helpers\";\nimport { ExecutionColumnsTab } from \"./execution-columns-tab\";\nimport { ExecutionDataTab } from \"./execution-data-tab\";\nimport { ExecutionOverviewTab } from \"./execution-overview-tab\";\nimport { ExecutionRawTab } from \"./execution-raw-tab\";\nimport { ExecutionSidebar } from \"./execution-sidebar\";\nimport { PublishExecutionDialog } from \"./publish-execution-dialog\";\nimport {\n  PREVIEW_DATASET_PAGE_SIZE,\n  TERMINAL_STICKY_BOTTOM_THRESHOLD_PX,\n  formatCellValue,\n  formatDuration,\n  formatPercent,\n  hasExpandableTextCell,\n  parseAnalysisColumns,\n  parseModelUsageRows,\n  truncateCellValue,\n} from \"./executions-view-helpers\";\n\ntype ExecutionsViewProps = {\n  executions: RecipeExecutionRecord[];\n  selectedExecutionId: string | null;\n  currentSignature: string;\n  onSelectExecution: (id: string) => void;\n  onCancelExecution: (id: string) => void;\n  onLoadDatasetPage: (id: string, page: number) => void;\n};\n\nexport function ExecutionsView({\n  executions,\n  selectedExecutionId,\n  currentSignature,\n  onSelectExecution,\n  onCancelExecution,\n  onLoadDatasetPage,\n}: ExecutionsViewProps): ReactElement {\n  const formatEta = (value: number | null | undefined): string =>\n    typeof value === \"number\" && Number.isFinite(value)\n      ? `${value.toLocaleString()} s`\n      : \"--\";\n  const [detailTab, setDetailTab] = useState(\"overview\");\n  const [hiddenDatasetColumnsByExecution, setHiddenDatasetColumnsByExecution] = useState<\n    Record<string, string[]>\n  >({});\n  const [expandedDatasetRowsByExecution, setExpandedDatasetRowsByExecution] = useState<\n    Record<string, Record<string, boolean>>\n  >({});\n  const [previewDatasetPageByExecution, setPreviewDatasetPageByExecution] = useState<\n    Record<string, number>\n  >({});\n  const [publishDialogOpen, setPublishDialogOpen] = useState(false);\n  const terminalRef = useRef<HTMLDivElement | null>(null);\n  const shouldStickTerminalToBottomRef = useRef(true);\n  const selectedExecution = useMemo(\n    () =>\n      executions.find((execution) => execution.id === selectedExecutionId) ??\n      null,\n    [executions, selectedExecutionId],\n  );\n  const isStale = Boolean(\n    selectedExecution &&\n      selectedExecution.recipeSignature.length > 0 &&\n      selectedExecution.recipeSignature !== currentSignature,\n  );\n\n  const selectedExecutionIdSafe = selectedExecution?.id ?? null;\n  const hiddenDatasetColumns = useMemo(() => {\n    if (!selectedExecutionIdSafe) {\n      return [];\n    }\n    return hiddenDatasetColumnsByExecution[selectedExecutionIdSafe] ?? [];\n  }, [hiddenDatasetColumnsByExecution, selectedExecutionIdSafe]);\n  const expandedDatasetRows = useMemo(() => {\n    if (!selectedExecutionIdSafe) {\n      return {};\n    }\n    return expandedDatasetRowsByExecution[selectedExecutionIdSafe] ?? {};\n  }, [expandedDatasetRowsByExecution, selectedExecutionIdSafe]);\n\n  const datasetColumnNames = useMemo(() => {\n    if (!selectedExecution) {\n      return [];\n    }\n    const names = new Set<string>();\n    for (const row of selectedExecution.dataset) {\n      for (const key of Object.keys(row)) {\n        names.add(key);\n      }\n    }\n    return Array.from(names);\n  }, [selectedExecution]);\n\n  const visibleDatasetColumnNames = useMemo(\n    () =>\n      datasetColumnNames.filter(\n        (name) => !hiddenDatasetColumns.includes(name),\n      ),\n    [datasetColumnNames, hiddenDatasetColumns],\n  );\n\n  const tableColumns = useMemo<ColumnDef<Record<string, unknown>>[]>(() => {\n    if (!selectedExecution) {\n      return [];\n    }\n    return visibleDatasetColumnNames.map((name) => ({\n      accessorKey: name,\n      header: name,\n      cell: ({ getValue, row }) => {\n        const rawValue = getValue();\n        const imagePreview = resolveImagePreview(rawValue);\n        if (imagePreview?.kind === \"ready\") {\n          return (\n            <div className=\"max-w-[32rem]\">\n              <img\n                src={imagePreview.src}\n                alt={`${name} preview`}\n                loading=\"lazy\"\n                className=\"h-24 w-auto max-w-[260px] rounded-md border border-border/60 bg-muted/20 object-contain\"\n              />\n            </div>\n          );\n        }\n        if (imagePreview?.kind === \"too_large\") {\n          return (\n            <div className=\"max-w-[32rem]\">\n              <p className=\"text-xs text-muted-foreground\">\n                Image too large to preview\n              </p>\n            </div>\n          );\n        }\n        const value = formatCellValue(rawValue);\n        const rowExpanded = Boolean(expandedDatasetRows[row.id]);\n        const rowHasExpandableCell = hasExpandableTextCell(\n          row.original,\n          visibleDatasetColumnNames,\n        );\n        const showTruncated = rowHasExpandableCell && !rowExpanded;\n\n        return (\n          <div className=\"max-w-[32rem]\">\n            <p className=\"whitespace-pre-wrap break-all\">\n              {showTruncated ? truncateCellValue(value) : value}\n            </p>\n          </div>\n        );\n      },\n    }));\n  }, [expandedDatasetRows, selectedExecution, visibleDatasetColumnNames]);\n\n  const analysisColumns = useMemo(\n    () => parseAnalysisColumns(selectedExecution?.analysis ?? null),\n    [selectedExecution?.analysis],\n  );\n  const modelUsageRows = useMemo(\n    () => parseModelUsageRows(selectedExecution?.model_usage ?? null),\n    [selectedExecution?.model_usage],\n  );\n  const sideEffects = useMemo(() => {\n    const values = selectedExecution?.analysis?.side_effect_column_names;\n    return Array.isArray(values)\n      ? values.filter((value): value is string => typeof value === \"string\")\n      : [];\n  }, [selectedExecution?.analysis?.side_effect_column_names]);\n\n  const canCancel = Boolean(\n    selectedExecution?.jobId && isExecutionInProgress(selectedExecution.status),\n  );\n  const canPublish = Boolean(\n    selectedExecution &&\n      selectedExecution.kind === \"full\" &&\n      selectedExecution.status === \"completed\" &&\n      selectedExecution.jobId &&\n      selectedExecution.artifact_path,\n  );\n  const datasetPage = selectedExecution?.datasetPage ?? 1;\n  const datasetPageSize = selectedExecution?.datasetPageSize ?? 20;\n  const datasetTotal = selectedExecution?.datasetTotal ?? 0;\n  const previewPageRaw = selectedExecutionIdSafe\n    ? previewDatasetPageByExecution[selectedExecutionIdSafe] ?? 1\n    : 1;\n  const previewTotalPages = useMemo(() => {\n    if (!selectedExecution || selectedExecution.kind !== \"preview\") {\n      return 1;\n    }\n    return Math.max(\n      1,\n      Math.ceil(selectedExecution.dataset.length / PREVIEW_DATASET_PAGE_SIZE),\n    );\n  }, [selectedExecution]);\n  const previewPage = Math.min(previewPageRaw, previewTotalPages);\n  const totalPages =\n    selectedExecution?.kind === \"preview\"\n      ? previewTotalPages\n      : Math.max(1, Math.ceil(datasetTotal / datasetPageSize));\n  const canPageDataset =\n    selectedExecution?.kind === \"preview\" ||\n    (selectedExecution?.kind === \"full\" && Boolean(selectedExecution.jobId));\n  const datasetRowsForTable = useMemo(() => {\n    if (!selectedExecution) {\n      return [];\n    }\n    if (selectedExecution.kind !== \"preview\") {\n      return selectedExecution.dataset;\n    }\n    const start = (previewPage - 1) * PREVIEW_DATASET_PAGE_SIZE;\n    return selectedExecution.dataset.slice(start, start + PREVIEW_DATASET_PAGE_SIZE);\n  }, [previewPage, selectedExecution]);\n  const currentDatasetPage = selectedExecution?.kind === \"preview\" ? previewPage : datasetPage;\n  const recordsMetric = useMemo(() => {\n    if (!selectedExecution || selectedExecution.status !== \"completed\") {\n      return null;\n    }\n    if (typeof selectedExecution.analysis?.num_records === \"number\") {\n      return selectedExecution.analysis.num_records;\n    }\n    if (selectedExecution.datasetTotal > 0) {\n      return selectedExecution.datasetTotal;\n    }\n    if (selectedExecution.dataset.length > 0) {\n      return selectedExecution.dataset.length;\n    }\n    return null;\n  }, [selectedExecution]);\n  const totalMetric = useMemo(() => {\n    if (!selectedExecution || selectedExecution.status !== \"completed\") {\n      return null;\n    }\n    if (typeof selectedExecution.analysis?.target_num_records === \"number\") {\n      return selectedExecution.analysis.target_num_records;\n    }\n    return selectedExecution.rows > 0 ? selectedExecution.rows : null;\n  }, [selectedExecution]);\n  const columnCount = analysisColumns.length;\n  const llmColumnCount = useMemo(\n    () =>\n      analysisColumns.reduce(\n        (acc, column) => (column.column_type.startsWith(\"llm\") ? acc + 1 : acc),\n        0,\n      ),\n    [analysisColumns],\n  );\n  const totalNulls = useMemo(\n    () =>\n      analysisColumns.reduce(\n        (acc, column) => acc + (typeof column.num_null === \"number\" ? column.num_null : 0),\n        0,\n      ),\n    [analysisColumns],\n  );\n  const nullRate = useMemo(() => {\n    if (\n      typeof recordsMetric !== \"number\" ||\n      recordsMetric <= 0 ||\n      columnCount <= 0\n    ) {\n      return null;\n    }\n    return (totalNulls / (recordsMetric * columnCount)) * 100;\n  }, [columnCount, recordsMetric, totalNulls]);\n  const lowUniquenessColumns = useMemo(() => {\n    if (typeof recordsMetric !== \"number\" || recordsMetric <= 0) {\n      return [];\n    }\n    return analysisColumns\n      .filter(\n        (column) =>\n          typeof column.num_unique === \"number\" &&\n          column.num_unique / recordsMetric < 0.5,\n      )\n      .map((column) => column.column_name);\n  }, [analysisColumns, recordsMetric]);\n  const runDuration = useMemo(() => {\n    if (!selectedExecution) {\n      return \"--\";\n    }\n    return formatDuration(selectedExecution.createdAt, selectedExecution.finishedAt);\n  }, [selectedExecution]);\n  const showSummaryCards = selectedExecution?.status === \"completed\";\n  const hasProgressSnapshot = Boolean(\n    selectedExecution?.progress &&\n      (typeof selectedExecution.progress.done === \"number\" ||\n        typeof selectedExecution.progress.total === \"number\" ||\n        typeof selectedExecution.progress.percent === \"number\" ||\n        typeof selectedExecution.progress.rate === \"number\" ||\n        typeof selectedExecution.progress.eta_sec === \"number\"),\n  ) || Boolean(\n    selectedExecution?.column_progress &&\n      (typeof selectedExecution.column_progress.done === \"number\" ||\n        typeof selectedExecution.column_progress.total === \"number\" ||\n        typeof selectedExecution.column_progress.percent === \"number\"),\n  ) || Boolean(\n    selectedExecution?.batch &&\n      (typeof selectedExecution.batch.idx === \"number\" ||\n        typeof selectedExecution.batch.total === \"number\"),\n  );\n  const selectedStatus = selectedExecution?.status ?? null;\n  const isSelectedExecutionInProgress = selectedStatus\n    ? isExecutionInProgress(selectedStatus)\n    : false;\n  const showProgressPanel = Boolean(selectedExecution) && (\n    selectedStatus === \"completed\" ||\n    isSelectedExecutionInProgress ||\n    hasProgressSnapshot\n  );\n  const progressComplete = selectedExecution?.status === \"completed\";\n  const progressPercent = selectedExecution?.progress?.percent ?? (progressComplete ? 100 : 0);\n  const batchTotal = selectedExecution?.batch?.total ?? null;\n  const batchIdx = selectedExecution?.batch?.idx ?? null;\n  const showBatchProgress = typeof batchTotal === \"number\" && batchTotal > 1;\n  const terminalLines = selectedExecution?.log_lines ?? [];\n  const rawExecution = useMemo(() => {\n    if (!selectedExecution) {\n      return null;\n    }\n    const next = { ...selectedExecution } as Record<string, unknown>;\n    delete next.dataset;\n    delete next.log_lines;\n    return next;\n  }, [selectedExecution]);\n\n  useEffect(() => {\n    if (!terminalRef.current) {\n      return;\n    }\n    shouldStickTerminalToBottomRef.current = true;\n    terminalRef.current.scrollTop = terminalRef.current.scrollHeight;\n  }, [selectedExecution?.id]);\n\n  useEffect(() => {\n    if (!terminalRef.current) {\n      return;\n    }\n    if (!shouldStickTerminalToBottomRef.current) {\n      return;\n    }\n    terminalRef.current.scrollTop = terminalRef.current.scrollHeight;\n  }, [terminalLines.length]);\n\n  return (\n    <div className=\"flex h-full min-h-0\">\n      <ExecutionSidebar\n        executions={executions}\n        selectedExecutionId={selectedExecutionId}\n        onSelectExecution={onSelectExecution}\n      />\n      <section className=\"min-w-0 flex-1 overflow-auto p-4\">\n        {!selectedExecution ? (\n          <div className=\"rounded-xl border border-dashed border-border/60 p-4 text-sm text-muted-foreground\">\n            Select an execution.\n          </div>\n        ) : (\n          <div className=\"space-y-4\">\n            {showProgressPanel && (\n              <div className=\"space-y-3 rounded-2xl border shadow-border border-border/60 bg-card/55 p-3\">\n                <div className=\"flex items-center justify-between\">\n                  <div className=\"flex items-center gap-2\">\n                    <HugeiconsIcon\n                      icon={progressComplete ? CheckmarkCircle02Icon : Flag02Icon}\n                      className={cn(\n                        \"size-4\",\n                        progressComplete\n                          ? \"text-emerald-700 dark:text-emerald-300\"\n                          : \"text-amber-700 dark:text-amber-300\",\n                      )}\n                    />\n                    <p className=\"text-sm font-semibold text-foreground\">\n                      Progress\n                    </p>\n                  </div>\n                  <p className=\"text-xs text-muted-foreground\">{formatPercent(progressPercent)}</p>\n                </div>\n                <Progress value={progressPercent} className=\"h-1\" />\n                <div className=\"grid gap-2 text-xs md:grid-cols-4\">\n                  <p className=\"text-muted-foreground\">\n                    Done: <span className=\"text-foreground\">{selectedExecution.progress?.done ?? \"--\"}</span>\n                  </p>\n                  <p className=\"text-muted-foreground\">\n                    Total: <span className=\"text-foreground\">{selectedExecution.progress?.total ?? \"--\"}</span>\n                  </p>\n                  <p className=\"text-muted-foreground\">\n                    Rate: <span className=\"text-foreground\">{selectedExecution.progress?.rate ?? \"--\"} rec/s</span>\n                  </p>\n                  <p className=\"text-muted-foreground\">\n                    ETA: <span className=\"text-foreground\">{formatEta(selectedExecution.progress?.eta_sec)}</span>\n                  </p>\n                </div>\n                {selectedExecution.current_column && selectedExecution.column_progress && (\n                  <p className=\"text-xs text-muted-foreground\">\n                    Column {selectedExecution.current_column}:{\" \"}\n                    {selectedExecution.column_progress.done ?? \"--\"}/\n                    {selectedExecution.column_progress.total ?? \"--\"} (\n                    {formatPercent(selectedExecution.column_progress.percent)})\n                  </p>\n                )}\n                {showBatchProgress && (\n                  <p className=\"text-xs text-muted-foreground\">\n                    Processed batch: {batchIdx ?? \"--\"}/{batchTotal}\n                  </p>\n                )}\n                {isStale && <Badge variant=\"outline\">Recipe changed since this run</Badge>}\n              </div>\n            )}\n\n            {(selectedExecution.status === \"error\" ||\n              selectedExecution.status === \"cancelled\") && (\n              <div className=\"rounded-xl border border-destructive/40 bg-destructive/5 p-3\">\n                <p className=\"text-sm font-semibold text-destructive\">\n                  {selectedExecution.status === \"cancelled\"\n                    ? \"Execution cancelled\"\n                    : \"Execution failed\"}\n                </p>\n                <p className=\"text-xs text-destructive\">\n                  {selectedExecution.error ?? \"Unknown error.\"}\n                </p>\n              </div>\n            )}\n\n            <Tabs value={detailTab} onValueChange={setDetailTab}>\n              <div className=\"flex items-center justify-between gap-2\">\n                <TabsList className=\"border border-border/60 bg-card/40\">\n                  <TabsTrigger value=\"overview\">Overview</TabsTrigger>\n                  <TabsTrigger value=\"columns\">Columns</TabsTrigger>\n                  <TabsTrigger value=\"data\">Data</TabsTrigger>\n                  <TabsTrigger value=\"raw\">Raw</TabsTrigger>\n                </TabsList>\n                <div className=\"flex items-center gap-2\">\n                  {canPublish && (\n                    <Button\n                      type=\"button\"\n                      size=\"sm\"\n                      variant=\"outline\"\n                      onClick={() => setPublishDialogOpen(true)}\n                    >\n                      <HugeiconsIcon icon={Share08Icon} className=\"mr-2 size-4\" />\n                      Publish to Hugging Face\n                    </Button>\n                  )}\n                  {canCancel && (\n                    <Button\n                      type=\"button\"\n                      size=\"sm\"\n                      variant=\"outline\"\n                      onClick={() => onCancelExecution(selectedExecution.id)}\n                    >\n                      Cancel\n                    </Button>\n                  )}\n                </div>\n              </div>\n              <TabsContent value=\"overview\">\n                <ExecutionOverviewTab\n                  execution={selectedExecution}\n                  showSummaryCards={showSummaryCards}\n                  recordsMetric={recordsMetric}\n                  totalMetric={totalMetric}\n                  runDuration={runDuration}\n                  columnCount={columnCount}\n                  llmColumnCount={llmColumnCount}\n                  nullRate={nullRate}\n                  sideEffects={sideEffects}\n                  lowUniquenessColumns={lowUniquenessColumns}\n                  modelUsageRows={modelUsageRows}\n                  terminalLines={terminalLines}\n                  terminalRef={terminalRef}\n                  canPublish={canPublish}\n                  onOpenPublish={() => setPublishDialogOpen(true)}\n                  onTerminalScroll={(event) => {\n                    const element = event.currentTarget;\n                    const distanceFromBottom =\n                      element.scrollHeight - element.scrollTop - element.clientHeight;\n                    shouldStickTerminalToBottomRef.current =\n                      distanceFromBottom <= TERMINAL_STICKY_BOTTOM_THRESHOLD_PX;\n                  }}\n                />\n              </TabsContent>\n              <TabsContent value=\"columns\">\n                <ExecutionColumnsTab analysisColumns={analysisColumns} />\n              </TabsContent>\n              <TabsContent value=\"data\">\n                <ExecutionDataTab\n                  execution={selectedExecution}\n                  datasetColumnNames={datasetColumnNames}\n                  hiddenDatasetColumns={hiddenDatasetColumns}\n                  canPageDataset={canPageDataset}\n                  currentDatasetPage={currentDatasetPage}\n                  totalPages={totalPages}\n                  tableColumns={tableColumns}\n                  datasetRowsForTable={datasetRowsForTable}\n                  visibleDatasetColumnNames={visibleDatasetColumnNames}\n                  expandedDatasetRows={expandedDatasetRows}\n                  selectedExecutionIdSafe={selectedExecutionIdSafe}\n                  onSetHiddenColumns={(updater) => {\n                    const selectedId = selectedExecution.id;\n                    setHiddenDatasetColumnsByExecution((current) => {\n                      const currentColumns = current[selectedId] ?? [];\n                      return {\n                        ...current,\n                        [selectedId]: updater(currentColumns),\n                      };\n                    });\n                  }}\n                  onPrevPage={() => {\n                    if (selectedExecution.kind === \"preview\") {\n                      const selectedId = selectedExecution.id;\n                      setPreviewDatasetPageByExecution((current) => ({\n                        ...current,\n                        [selectedId]: Math.max(1, currentDatasetPage - 1),\n                      }));\n                      return;\n                    }\n                    onLoadDatasetPage(selectedExecution.id, currentDatasetPage - 1);\n                  }}\n                  onNextPage={() => {\n                    if (selectedExecution.kind === \"preview\") {\n                      const selectedId = selectedExecution.id;\n                      setPreviewDatasetPageByExecution((current) => ({\n                        ...current,\n                        [selectedId]: Math.min(totalPages, currentDatasetPage + 1),\n                      }));\n                      return;\n                    }\n                    onLoadDatasetPage(selectedExecution.id, currentDatasetPage + 1);\n                  }}\n                  onToggleRowExpanded={(rowId) => {\n                    setExpandedDatasetRowsByExecution((current) => {\n                      const rows = current[selectedExecution.id] ?? {};\n                      return {\n                        ...current,\n                        [selectedExecution.id]: {\n                          ...rows,\n                          [rowId]: !rows[rowId],\n                        },\n                      };\n                    });\n                  }}\n                />\n              </TabsContent>\n              <TabsContent value=\"raw\">\n                <ExecutionRawTab rawExecution={rawExecution} />\n              </TabsContent>\n            </Tabs>\n          </div>\n        )}\n      </section>\n      <PublishExecutionDialog\n        open={publishDialogOpen}\n        onOpenChange={setPublishDialogOpen}\n        execution={canPublish ? selectedExecution : null}\n        onPublish={async (payload) => {\n          if (!selectedExecution?.jobId) {\n            throw new Error(\"This run is missing a job id.\");\n          }\n          const response = await publishRecipeJob(selectedExecution.jobId, payload);\n          return { url: response.url };\n        }}\n      />\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/components/executions/publish-execution-dialog.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { useEffect, useMemo, useState, type ReactElement } from \"react\";\nimport { ArrowRight01Icon, CheckmarkCircle02Icon, Copy01Icon, Key01Icon } from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport { Button } from \"@/components/ui/button\";\nimport {\n  Dialog,\n  DialogContent,\n  DialogDescription,\n  DialogFooter,\n  DialogHeader,\n  DialogTitle,\n} from \"@/components/ui/dialog\";\nimport { Input } from \"@/components/ui/input\";\nimport { Switch } from \"@/components/ui/switch\";\nimport { Textarea } from \"@/components/ui/textarea\";\nimport { toastError, toastSuccess } from \"@/shared/toast\";\nimport type { RecipeExecutionRecord } from \"../../execution-types\";\nimport { copyTextToClipboard } from \"../../executions/execution-helpers\";\n\ntype PublishExecutionDialogProps = {\n  open: boolean;\n  onOpenChange: (open: boolean) => void;\n  execution: RecipeExecutionRecord | null;\n  onPublish: (payload: {\n    repo_id: string;\n    description: string;\n    hf_token?: string | null;\n    private: boolean;\n    artifact_path?: string | null;\n  }) => Promise<{ url: string }>;\n};\n\nfunction getExecutionRecordCount(execution: RecipeExecutionRecord | null): number | null {\n  if (!execution) {\n    return null;\n  }\n  if (typeof execution.analysis?.num_records === \"number\") {\n    return execution.analysis.num_records;\n  }\n  if (execution.datasetTotal > 0) {\n    return execution.datasetTotal;\n  }\n  if (execution.rows > 0) {\n    return execution.rows;\n  }\n  return null;\n}\n\nfunction buildDefaultDescription(execution: RecipeExecutionRecord | null): string {\n  if (!execution) {\n    return \"\";\n  }\n  const runName = execution.run_name?.trim() || \"This dataset\";\n  const records = getExecutionRecordCount(execution);\n  const recordPart =\n    typeof records === \"number\" && records > 0\n      ? ` It contains ${records.toLocaleString()} generated records.`\n      : \"\";\n  return `${runName} was generated with Unsloth Recipe Studio.${recordPart}`;\n}\n\nexport function PublishExecutionDialog({\n  open,\n  onOpenChange,\n  execution,\n  onPublish,\n}: PublishExecutionDialogProps): ReactElement {\n  const [repoId, setRepoId] = useState(\"\");\n  const [description, setDescription] = useState(\"\");\n  const [hfToken, setHfToken] = useState(\"\");\n  const [privateRepo, setPrivateRepo] = useState(false);\n  const [publishing, setPublishing] = useState(false);\n  const [publishError, setPublishError] = useState<string | null>(null);\n  const [publishedUrl, setPublishedUrl] = useState<string | null>(null);\n\n  const defaultDescription = useMemo(\n    () => buildDefaultDescription(execution),\n    [execution],\n  );\n  const runLabel = execution?.run_name?.trim() || \"Completed run\";\n  const recordCount = getExecutionRecordCount(execution);\n  const recordLabel =\n    typeof recordCount === \"number\" ? recordCount.toLocaleString() : \"--\";\n\n  useEffect(() => {\n    if (!open) {\n      setPublishing(false);\n      setPublishError(null);\n      setPublishedUrl(null);\n      setRepoId(\"\");\n      setDescription(\"\");\n      setHfToken(\"\");\n      setPrivateRepo(false);\n      return;\n    }\n    setPublishError(null);\n    setPublishedUrl(null);\n    setDescription(buildDefaultDescription(execution));\n  }, [execution, open]);\n\n  const canSubmit =\n    !publishing &&\n    Boolean(execution?.jobId) &&\n    Boolean(execution?.artifact_path) &&\n    repoId.trim().length > 0 &&\n    description.trim().length > 0;\n\n  const handleCopyUrl = async (): Promise<void> => {\n    if (!publishedUrl) {\n      return;\n    }\n    const ok = await copyTextToClipboard(publishedUrl);\n    if (ok) {\n      toastSuccess(\"Dataset link copied\");\n      return;\n    }\n    toastError(\"Copy failed\", \"Could not copy the dataset link.\");\n  };\n\n  const handlePublish = async (): Promise<void> => {\n    if (!execution?.jobId) {\n      setPublishError(\"This run is missing a job id, so it cannot be published.\");\n      return;\n    }\n    setPublishing(true);\n    setPublishError(null);\n    try {\n      const result = await onPublish({\n        repo_id: repoId.trim(),\n        description: description.trim(),\n        hf_token: hfToken.trim() || null,\n        private: privateRepo,\n        artifact_path: execution.artifact_path,\n      });\n      setPublishedUrl(result.url);\n      toastSuccess(\"Dataset published\");\n    } catch (error) {\n      const message =\n        error instanceof Error ? error.message : \"Could not publish this dataset.\";\n      setPublishError(message);\n      toastError(\"Publish failed\", message);\n    } finally {\n      setPublishing(false);\n    }\n  };\n\n  return (\n    <Dialog\n      open={open}\n      onOpenChange={(nextOpen) => {\n        if (publishing) {\n          return;\n        }\n        onOpenChange(nextOpen);\n      }}\n    >\n      <DialogContent\n        className=\"sm:max-w-xl\"\n        overlayClassName=\"bg-black/55\"\n        onInteractOutside={(event) => {\n          if (publishing) {\n            event.preventDefault();\n          }\n        }}\n      >\n        {publishedUrl ? (\n          <>\n            <div className=\"flex flex-col items-center gap-3 py-4\">\n              <div className=\"flex size-12 items-center justify-center rounded-full bg-emerald-500/10\">\n                <HugeiconsIcon\n                  icon={CheckmarkCircle02Icon}\n                  className=\"size-6 text-emerald-600 dark:text-emerald-400\"\n                />\n              </div>\n              <div className=\"space-y-1 text-center\">\n                <DialogTitle>Published</DialogTitle>\n                <DialogDescription>\n                  Your dataset is live on Hugging Face.\n                </DialogDescription>\n              </div>\n            </div>\n            <div className=\"rounded-2xl border border-border/60 bg-card/55 p-3 text-xs\">\n              <p className=\"mb-1 text-muted-foreground\">Dataset URL</p>\n              <p className=\"break-all font-medium text-foreground\">{publishedUrl}</p>\n            </div>\n            <DialogFooter>\n              <Button variant=\"outline\" onClick={handleCopyUrl}>\n                <HugeiconsIcon icon={Copy01Icon} className=\"mr-2 size-4\" />\n                Copy link\n              </Button>\n              <Button asChild={true}>\n                <a href={publishedUrl} target=\"_blank\" rel=\"noreferrer\">\n                  Open repo\n                  <HugeiconsIcon icon={ArrowRight01Icon} className=\"ml-2 size-4\" />\n                </a>\n              </Button>\n              <Button variant=\"ghost\" onClick={() => onOpenChange(false)}>\n                Done\n              </Button>\n            </DialogFooter>\n          </>\n        ) : (\n          <>\n            <DialogHeader>\n              <DialogTitle>Publish to Hugging Face</DialogTitle>\n              <DialogDescription>\n                Create or update a dataset repo from this completed run.\n              </DialogDescription>\n            </DialogHeader>\n\n            <div className=\"space-y-4\">\n              <div className=\"rounded-2xl border border-border/60 bg-card/55 p-3 text-xs\">\n                <p className=\"font-medium text-foreground\">From this run</p>\n                <div className=\"mt-2 grid gap-1.5 text-muted-foreground sm:grid-cols-2\">\n                  <p>\n                    Run: <span className=\"text-foreground\">{runLabel}</span>\n                  </p>\n                  <p>\n                    Records: <span className=\"text-foreground\">{recordLabel}</span>\n                  </p>\n                </div>\n                <p className=\"mt-2 text-muted-foreground\">\n                  We’ll upload the generated dataset, dataset card, images, and any processor\n                  outputs from this execution.\n                </p>\n              </div>\n\n              <div className=\"space-y-1.5\">\n                <label className=\"text-sm font-medium text-foreground\" htmlFor=\"publish-repo-id\">\n                  Repository\n                </label>\n                <Input\n                  id=\"publish-repo-id\"\n                  placeholder=\"your-name/customer-support-synth\"\n                  value={repoId}\n                  onChange={(event) => setRepoId(event.target.value)}\n                  disabled={publishing}\n                />\n                <p className=\"text-xs text-muted-foreground\">\n                  Use the format <span className=\"font-mono\">username-or-org/dataset-name</span>.\n                </p>\n              </div>\n\n              <div className=\"space-y-1.5\">\n                <label className=\"text-sm font-medium text-foreground\" htmlFor=\"publish-description\">\n                  About this dataset\n                </label>\n                <Textarea\n                  id=\"publish-description\"\n                  className=\"corner-squircle\"\n                  value={description}\n                  onChange={(event) => setDescription(event.target.value)}\n                  disabled={publishing}\n                  rows={4}\n                  placeholder={defaultDescription || \"What is this dataset for?\"}\n                />\n                <p className=\"text-xs text-muted-foreground\">\n                  This short summary is used in the dataset card on Hugging Face.\n                </p>\n              </div>\n\n              <div className=\"space-y-1.5\">\n                <div className=\"flex items-center justify-between gap-3\">\n                  <label className=\"text-sm font-medium text-foreground\" htmlFor=\"publish-hf-token\">\n                    HF write token\n                  </label>\n                  <a\n                    href=\"https://huggingface.co/settings/tokens\"\n                    target=\"_blank\"\n                    rel=\"noreferrer\"\n                    className=\"text-xs text-muted-foreground underline underline-offset-3 hover:text-foreground\"\n                  >\n                    Manage tokens\n                  </a>\n                </div>\n                <div className=\"relative\">\n                  <HugeiconsIcon\n                    icon={Key01Icon}\n                    className=\"pointer-events-none absolute top-1/2 left-3 size-4 -translate-y-1/2 text-muted-foreground\"\n                  />\n                  <Input\n                    id=\"publish-hf-token\"\n                    type=\"password\"\n                    autoComplete=\"new-password\"\n                    className=\"pl-9\"\n                    placeholder=\"hf_...\"\n                    value={hfToken}\n                    onChange={(event) => setHfToken(event.target.value)}\n                    disabled={publishing}\n                  />\n                </div>\n                <p className=\"text-xs text-muted-foreground\">\n                  Leave empty if you're already logged in via CLI.\n                </p>\n              </div>\n\n              <div className=\"corner-squircle flex items-start gap-3 rounded-2xl border border-border/60 bg-card/35 p-3\">\n                <Switch\n                  id=\"publish-private\"\n                  size=\"sm\"\n                  checked={privateRepo}\n                  onCheckedChange={setPrivateRepo}\n                  disabled={publishing}\n                />\n                <div className=\"space-y-1\">\n                  <label\n                    htmlFor=\"publish-private\"\n                    className=\"text-sm font-medium text-foreground\"\n                  >\n                    Private dataset\n                  </label>\n                  <p className=\"text-xs text-muted-foreground\">\n                    Only people with access can view or download the repo.\n                  </p>\n                </div>\n              </div>\n\n              {publishError ? (\n                <div className=\"rounded-2xl border border-destructive/40 bg-destructive/5 p-3 text-sm text-destructive\">\n                  {publishError}\n                </div>\n              ) : null}\n            </div>\n\n            <DialogFooter>\n              <Button\n                variant=\"outline\"\n                onClick={() => onOpenChange(false)}\n                disabled={publishing}\n              >\n                Cancel\n              </Button>\n              <Button onClick={() => void handlePublish()} disabled={!canSubmit}>\n                {publishing ? \"Publishing...\" : \"Publish to Hugging Face\"}\n              </Button>\n            </DialogFooter>\n          </>\n        )}\n      </DialogContent>\n    </Dialog>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/components/graph/internals-sync.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { useUpdateNodeInternals } from \"@xyflow/react\";\nimport { useEffect, useMemo, useRef } from \"react\";\n\ntype InternalsSyncProps = {\n  nodeIds: string[];\n};\n\nexport function InternalsSync({ nodeIds }: InternalsSyncProps): null {\n  const updateNodeInternals = useUpdateNodeInternals();\n  const idsKey = useMemo(() => nodeIds.join(\"|\"), [nodeIds]);\n  const nodeIdsRef = useRef(nodeIds);\n  nodeIdsRef.current = nodeIds;\n\n  useEffect(() => {\n    if (!idsKey) {\n      return;\n    }\n    const raf = requestAnimationFrame(() => {\n      updateNodeInternals(nodeIdsRef.current);\n    });\n    return () => cancelAnimationFrame(raf);\n  }, [idsKey, updateNodeInternals]);\n\n  return null;\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/components/inline/inline-category-badges.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Badge } from \"@/components/ui/badge\";\nimport { type ReactElement, useLayoutEffect, useRef, useState } from \"react\";\n\r\ntype InlineCategoryBadgesProps = {\r\n  values: string[];\r\n};\r\n\r\nexport function InlineCategoryBadges({\r\n  values,\r\n}: InlineCategoryBadgesProps): ReactElement {\r\n  const containerRef = useRef<HTMLDivElement>(null);\r\n  const [visibleCount, setVisibleCount] = useState(values.length);\r\n\r\n  useLayoutEffect(() => {\n    const container = containerRef.current;\n    if (!container) return;\n\n    const badges = Array.from(container.children) as HTMLElement[];\n    if (badges.length === 0) {\n      const id = requestAnimationFrame(() => setVisibleCount(0));\n      return () => cancelAnimationFrame(id);\n    }\n\r\n    const containerWidth = container.clientWidth;\r\n    // Reserve space for the \"+N\" badge (~36px)\r\n    const overflowBadgeWidth = 36;\r\n    let count = 0;\r\n    let usedWidth = 0;\r\n\r\n    for (const badge of badges) {\r\n      const badgeWidth = badge.scrollWidth + 4; // 4px for gap\r\n      if (usedWidth + badgeWidth > containerWidth - overflowBadgeWidth && count < badges.length - 1) {\r\n        break;\r\n      }\r\n      if (usedWidth + badgeWidth > containerWidth) {\r\n        break;\r\n      }\r\n      usedWidth += badgeWidth;\r\n      count++;\r\n    }\r\n\r\n    const id = requestAnimationFrame(() => setVisibleCount(count || 1));\n    return () => cancelAnimationFrame(id);\n  }, [values]);\n\r\n  if (values.length === 0) {\r\n    return <p className=\"text-xs text-muted-foreground\">No values</p>;\r\n  }\r\n\r\n  const overflow = values.length - visibleCount;\r\n\r\n  return (\r\n    <div className=\"relative\">\r\n      {/* Hidden measurer */}\r\n      <div\r\n        ref={containerRef}\r\n        className=\"pointer-events-none invisible absolute inset-x-0 top-0 flex flex-nowrap gap-1\"\r\n        aria-hidden\r\n      >\r\n        {values.map((v, i) => (\r\n          <Badge\r\n            key={`m-${v}-${i}`}\r\n            variant=\"secondary\"\r\n            className=\"corner-squircle h-4 shrink-0 px-1.5 text-[10px]\"\r\n          >\r\n            {v}\r\n          </Badge>\r\n        ))}\r\n      </div>\r\n      {/* Visible badges */}\r\n      <div className=\"flex flex-wrap gap-1\">\r\n        {values.slice(0, visibleCount).map((v, i) => (\r\n          <Badge\r\n            key={`${v}-${i}`}\r\n            variant=\"secondary\"\r\n            className=\"corner-squircle h-4 px-1.5 text-[10px]\"\r\n          >\r\n            {v}\r\n          </Badge>\r\n        ))}\r\n        {overflow > 0 && (\r\n          <Badge variant=\"outline\" className=\"corner-squircle h-4 px-1.5 text-[10px]\">\r\n            +{overflow}\r\n          </Badge>\r\n        )}\r\n      </div>\r\n    </div>\r\n  );\r\n}\r\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/components/inline/inline-expression.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Input } from \"@/components/ui/input\";\nimport {\n  Select,\n  SelectContent,\n  SelectItem,\n  SelectTrigger,\n  SelectValue,\n} from \"@/components/ui/select\";\nimport type { ReactElement } from \"react\";\nimport { useRecipeStudioStore } from \"../../stores/recipe-studio\";\nimport type { ExpressionConfig, ExpressionDtype } from \"../../types\";\nimport { findInvalidJinjaReferences } from \"../../utils/refs\";\nimport { getAvailableVariableEntries } from \"../../utils/variables\";\nimport { AvailableReferencesInline } from \"../shared/available-references-inline\";\nimport { InlineField } from \"./inline-field\";\n\ntype InlineExpressionProps = {\n  config: ExpressionConfig;\n  onUpdate: (patch: Partial<ExpressionConfig>) => void;\n};\n\nconst DTYPE_OPTIONS: ExpressionDtype[] = [\"str\", \"int\", \"float\", \"bool\"];\n\nexport function InlineExpression({\n  config,\n  onUpdate,\n}: InlineExpressionProps): ReactElement {\n  const configs = useRecipeStudioStore((state) => state.configs);\n  const vars = getAvailableVariableEntries(configs, config.id);\n  const invalidRefs = findInvalidJinjaReferences(\n    config.expr,\n    vars.map((entry) => entry.name),\n  );\n\n  return (\n    <div className=\"space-y-3\">\n      <div className=\"grid gap-3 sm:grid-cols-[130px_1fr]\">\n        <InlineField label=\"Output type\">\n          <Select\n            value={config.dtype}\n            onValueChange={(value) =>\n              onUpdate({ dtype: value as ExpressionDtype })\n            }\n          >\n            <SelectTrigger className=\"nodrag h-8 w-full text-xs\">\n              <SelectValue placeholder=\"dtype\" />\n            </SelectTrigger>\n            <SelectContent>\n              {DTYPE_OPTIONS.map((dtype) => (\n                <SelectItem key={dtype} value={dtype}>\n                  {dtype}\n                </SelectItem>\n              ))}\n            </SelectContent>\n          </Select>\n        </InlineField>\n        <InlineField label=\"Expression\">\n          <Input\n            className=\"nodrag h-8 w-full text-xs\"\n            aria-invalid={invalidRefs.length > 0}\n            placeholder=\"{{ column_name }}\"\n            value={config.expr}\n            onChange={(event) => onUpdate({ expr: event.target.value })}\n          />\n        </InlineField>\n      </div>\n      <AvailableReferencesInline entries={vars} />\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/components/inline/inline-field.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { cn } from \"@/lib/utils\";\nimport type { ReactElement, ReactNode } from \"react\";\n\ntype InlineFieldProps = {\n  label: string;\n  className?: string;\n  children: ReactNode;\n};\n\nexport function InlineField({\n  label,\n  className,\n  children,\n}: InlineFieldProps): ReactElement {\n  return (\n    <div className={cn(\"grid gap-1.5\", className)}>\n      <p className=\"text-[11px] font-semibold uppercase tracking-wide text-muted-foreground\">\n        {label}\n      </p>\n      {children}\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/components/inline/inline-llm.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport {\n  Combobox,\n  ComboboxContent,\n  ComboboxEmpty,\n  ComboboxInput,\n  ComboboxItem,\n  ComboboxList,\n} from \"@/components/ui/combobox\";\nimport {\n  Select,\n  SelectContent,\n  SelectItem,\n  SelectTrigger,\n  SelectValue,\n} from \"@/components/ui/select\";\nimport { type ReactElement, useMemo, useRef } from \"react\";\nimport { useRecipeStudioStore } from \"../../stores/recipe-studio\";\nimport type { LlmConfig } from \"../../types\";\nimport { InlineField } from \"./inline-field\";\n\ntype InlineLlmProps = {\n  config: LlmConfig;\n  onUpdate: (patch: Partial<LlmConfig>) => void;\n};\n\nconst CODE_LANG_OPTIONS = [\n  \"python\",\n  \"javascript\",\n  \"typescript\",\n  \"java\",\n  \"kotlin\",\n  \"go\",\n  \"rust\",\n  \"ruby\",\n  \"scala\",\n  \"swift\",\n  \"sql:sqlite\",\n  \"sql:postgres\",\n  \"sql:mysql\",\n  \"sql:tsql\",\n  \"sql:bigquery\",\n  \"sql:ansi\",\n] as const;\n\nexport function InlineLlm({ config, onUpdate }: InlineLlmProps): ReactElement {\n  const isCode = config.llm_type === \"code\";\n  const configs = useRecipeStudioStore((state) => state.configs);\n  const modelConfigAliases = useMemo(\n    () =>\n      Object.values(configs)\n        .filter((c) => c.kind === \"model_config\")\n        .map((c) => c.name),\n    [configs],\n  );\n  const toolProfileAliases = useMemo(\n    () =>\n      Object.values(configs)\n        .filter((c) => c.kind === \"tool_config\")\n        .map((c) => c.name),\n    [configs],\n  );\n  const aliasInputRef = useRef(config.model_alias);\n  const lastAliasRef = useRef(config.model_alias);\n  const anchorRef = useRef<HTMLDivElement>(null);\n  const toolAnchorRef = useRef<HTMLDivElement>(null);\n  if (lastAliasRef.current !== config.model_alias) {\n    lastAliasRef.current = config.model_alias;\n    aliasInputRef.current = config.model_alias;\n  }\n\n  return (\n    <div className=\"space-y-3\">\n      <InlineField label=\"Model alias\">\n        <div ref={anchorRef}>\n          <Combobox\n            items={modelConfigAliases}\n            filteredItems={modelConfigAliases}\n            filter={null}\n            value={config.model_alias || null}\n            onValueChange={(value) =>\n              onUpdate({\n                // biome-ignore lint/style/useNamingConvention: api schema\n                model_alias: value ?? \"\",\n              })\n            }\n            onInputValueChange={(value) => {\n              aliasInputRef.current = value;\n            }}\n            itemToStringValue={(value) => value}\n            autoHighlight={true}\n          >\n            <ComboboxInput\n              className=\"nodrag h-8 w-full text-xs\"\n              placeholder=\"Model alias\"\n              onBlur={() => {\n                const next = aliasInputRef.current;\n                if (next !== config.model_alias) {\n                  onUpdate({\n                    // biome-ignore lint/style/useNamingConvention: api schema\n                    model_alias: next,\n                  });\n                }\n              }}\n            />\n            <ComboboxContent anchor={anchorRef}>\n              <ComboboxEmpty>No model configs found</ComboboxEmpty>\n              <ComboboxList>\n                {(alias: string) => (\n                  <ComboboxItem key={alias} value={alias}>\n                    {alias}\n                  </ComboboxItem>\n                )}\n              </ComboboxList>\n            </ComboboxContent>\n          </Combobox>\n        </div>\n      </InlineField>\n      <InlineField label=\"Tool profile\">\n        <div ref={toolAnchorRef}>\n          <Combobox\n            items={toolProfileAliases}\n            filteredItems={toolProfileAliases}\n            filter={null}\n            value={config.tool_alias || null}\n            onValueChange={(value) =>\n              onUpdate({\n                // biome-ignore lint/style/useNamingConvention: api schema\n                tool_alias: value ?? \"\",\n              })\n            }\n            itemToStringValue={(value) => value}\n            autoHighlight={true}\n          >\n            <ComboboxInput\n              className=\"nodrag h-8 w-full text-xs\"\n              placeholder=\"Tool profile\"\n              onBlur={(event) => {\n                const next = event.target.value;\n                if (next !== (config.tool_alias ?? \"\")) {\n                  onUpdate({\n                    // biome-ignore lint/style/useNamingConvention: api schema\n                    tool_alias: next,\n                  });\n                }\n              }}\n            />\n            <ComboboxContent anchor={toolAnchorRef}>\n              <ComboboxEmpty>No tool profiles found</ComboboxEmpty>\n              <ComboboxList>\n                {(alias: string) => (\n                  <ComboboxItem key={alias} value={alias}>\n                    {alias}\n                  </ComboboxItem>\n                )}\n              </ComboboxList>\n            </ComboboxContent>\n          </Combobox>\n        </div>\n      </InlineField>\n      {isCode && (\n        <InlineField label=\"Code language\">\n          <Select\n            value={config.code_lang?.trim() || \"python\"}\n            onValueChange={(value) =>\n              onUpdate({\n                // biome-ignore lint/style/useNamingConvention: api schema\n                code_lang: value,\n              })\n            }\n          >\n            <SelectTrigger className=\"nodrag h-8 w-full text-xs\">\n              <SelectValue placeholder=\"Language\" />\n            </SelectTrigger>\n            <SelectContent>\n              {CODE_LANG_OPTIONS.map((lang) => (\n                <SelectItem key={lang} value={lang}>\n                  {lang}\n                </SelectItem>\n              ))}\n            </SelectContent>\n          </Select>\n        </InlineField>\n      )}\n      <p className=\"text-[11px] text-muted-foreground\">\n        Prompt/system edited on aux nodes.\n      </p>\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/components/inline/inline-model.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Input } from \"@/components/ui/input\";\nimport type { ReactElement } from \"react\";\nimport type { ModelConfig, ModelProviderConfig } from \"../../types\";\nimport { InlineField } from \"./inline-field\";\n\ntype InlineModelPatch = Partial<ModelProviderConfig> | Partial<ModelConfig>;\n\ntype InlineModelProps = {\n  config: ModelProviderConfig | ModelConfig;\n  onUpdate: (patch: InlineModelPatch) => void;\n};\n\nexport function InlineModel(props: InlineModelProps): ReactElement {\n  if (props.config.kind === \"model_provider\") {\n    return (\n      <div className=\"grid gap-3 sm:grid-cols-2\">\n        <InlineField label=\"Endpoint\">\n          <Input\n            className=\"nodrag h-8 w-full text-xs\"\n            placeholder=\"https://api.example.com/v1\"\n            value={props.config.endpoint}\n            onChange={(event) => props.onUpdate({ endpoint: event.target.value })}\n          />\n        </InlineField>\n        <InlineField label=\"API key\">\n          <Input\n            className=\"nodrag h-8 w-full text-xs\"\n            placeholder=\"Optional\"\n            value={props.config.api_key ?? \"\"}\n            onChange={(event) =>\n              props.onUpdate({\n                // biome-ignore lint/style/useNamingConvention: api schema\n                api_key: event.target.value,\n              })\n            }\n          />\n        </InlineField>\n      </div>\n    );\n  }\n\n  return (\n    <div className=\"grid gap-3 sm:grid-cols-2\">\n      <InlineField label=\"Provider\">\n        <Input\n          className=\"nodrag h-8 w-full text-xs\"\n          placeholder=\"provider alias\"\n          value={props.config.provider}\n          onChange={(event) => props.onUpdate({ provider: event.target.value })}\n        />\n      </InlineField>\n      <InlineField label=\"Model\">\n        <Input\n          className=\"nodrag h-8 w-full text-xs\"\n          placeholder=\"gpt-4o-mini\"\n          value={props.config.model}\n          onChange={(event) => props.onUpdate({ model: event.target.value })}\n        />\n      </InlineField>\n      <InlineField label=\"Temperature\" className=\"sm:col-span-2\">\n        <Input\n          className=\"nodrag h-8 w-full text-xs\"\n          type=\"number\"\n          placeholder=\"0.7\"\n          value={props.config.inference_temperature ?? \"\"}\n          onChange={(event) =>\n            props.onUpdate({\n              // biome-ignore lint/style/useNamingConvention: api schema\n              inference_temperature: event.target.value,\n            })\n          }\n        />\n      </InlineField>\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/components/inline/inline-policy.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { NodeConfig, SamplerType } from \"../../types\";\n\nexport type ConfigUiMode = \"inline\" | \"dialog\";\n\nconst INLINE_SAMPLERS = new Set<SamplerType>([\n  \"uniform\",\n  \"gaussian\",\n  \"bernoulli\",\n  \"uuid\",\n]);\n\nexport function getConfigUiMode(\n  config: NodeConfig | null | undefined,\n): ConfigUiMode {\n  if (!config) {\n    return \"dialog\";\n  }\n  if (config.kind === \"sampler\") {\n    return INLINE_SAMPLERS.has(config.sampler_type) ? \"inline\" : \"dialog\";\n  }\n  if (config.kind === \"model_provider\" || config.kind === \"model_config\") {\n    return \"inline\";\n  }\n  if (config.kind === \"tool_config\") {\n    return \"dialog\";\n  }\n  if (config.kind === \"llm\") {\n    if (config.llm_type === \"text\" || config.llm_type === \"code\") {\n      return \"inline\";\n    }\n    return \"dialog\";\n  }\n  if (config.kind === \"seed\") {\n    return \"inline\";\n  }\n  if (config.kind === \"expression\") {\n    return \"inline\";\n  }\n  return \"dialog\";\n}\n\nexport function isInlineConfig(\n  config: NodeConfig | null | undefined,\n): boolean {\n  return getConfigUiMode(config) === \"inline\";\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/components/inline/inline-sampler.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Input } from \"@/components/ui/input\";\nimport {\n  Select,\n  SelectContent,\n  SelectItem,\n  SelectTrigger,\n  SelectValue,\n} from \"@/components/ui/select\";\nimport type { ReactElement } from \"react\";\nimport type { SamplerConfig } from \"../../types\";\nimport { InlineField } from \"./inline-field\";\n\ntype InlineSamplerProps = {\n  config: SamplerConfig;\n  onUpdate: (patch: Partial<SamplerConfig>) => void;\n};\n\ntype ConvertTo = \"int\" | \"float\" | \"str\";\n\nfunction ConvertToField({\n  value,\n  onValueChange,\n}: {\n  value: SamplerConfig[\"convert_to\"];\n  onValueChange: (value: ConvertTo | undefined) => void;\n}): ReactElement {\n  return (\n    <Select\n      value={value ?? \"none\"}\n      onValueChange={(next) =>\n        onValueChange(next === \"none\" ? undefined : (next as ConvertTo))\n      }\n    >\n      <SelectTrigger className=\"nodrag h-8 w-full text-xs\">\n        <SelectValue placeholder=\"Convert\" />\n      </SelectTrigger>\n      <SelectContent>\n        <SelectItem value=\"none\">None</SelectItem>\n        <SelectItem value=\"int\">int</SelectItem>\n        <SelectItem value=\"float\">float</SelectItem>\n        <SelectItem value=\"str\">str</SelectItem>\n      </SelectContent>\n    </Select>\n  );\n}\n\nexport function InlineSampler({\n  config,\n  onUpdate,\n}: InlineSamplerProps): ReactElement | null {\n  if (config.sampler_type === \"uniform\") {\n    return (\n      <div className=\"grid gap-3 sm:grid-cols-3\">\n        <InlineField label=\"Low\">\n          <Input\n            className=\"nodrag h-8 w-full text-xs\"\n            type=\"number\"\n            placeholder=\"0\"\n            value={config.low ?? \"\"}\n            onChange={(event) => onUpdate({ low: event.target.value })}\n          />\n        </InlineField>\n        <InlineField label=\"High\">\n          <Input\n            className=\"nodrag h-8 w-full text-xs\"\n            type=\"number\"\n            placeholder=\"100\"\n            value={config.high ?? \"\"}\n            onChange={(event) => onUpdate({ high: event.target.value })}\n          />\n        </InlineField>\n        <InlineField label=\"Convert to\">\n          <ConvertToField\n            value={config.convert_to}\n            onValueChange={(value) =>\n              onUpdate({\n                // biome-ignore lint/style/useNamingConvention: api schema\n                convert_to: value,\n              })\n            }\n          />\n        </InlineField>\n      </div>\n    );\n  }\n\n  if (config.sampler_type === \"gaussian\") {\n    return (\n      <div className=\"grid gap-3 sm:grid-cols-3\">\n        <InlineField label=\"Mean\">\n          <Input\n            className=\"nodrag h-8 w-full text-xs\"\n            type=\"number\"\n            placeholder=\"0\"\n            value={config.mean ?? \"\"}\n            onChange={(event) => onUpdate({ mean: event.target.value })}\n          />\n        </InlineField>\n        <InlineField label=\"Std dev\">\n          <Input\n            className=\"nodrag h-8 w-full text-xs\"\n            type=\"number\"\n            placeholder=\"1\"\n            value={config.std ?? \"\"}\n            onChange={(event) => onUpdate({ std: event.target.value })}\n          />\n        </InlineField>\n        <InlineField label=\"Convert to\">\n          <ConvertToField\n            value={config.convert_to}\n            onValueChange={(value) =>\n              onUpdate({\n                // biome-ignore lint/style/useNamingConvention: api schema\n                convert_to: value,\n              })\n            }\n          />\n        </InlineField>\n      </div>\n    );\n  }\n\n  if (config.sampler_type === \"bernoulli\") {\n    return (\n      <InlineField label=\"Probability (p)\">\n        <Input\n          className=\"nodrag h-8 w-full text-xs\"\n          type=\"number\"\n          min=\"0\"\n          max=\"1\"\n          step=\"0.01\"\n          placeholder=\"0.5\"\n          value={config.p ?? \"\"}\n          onChange={(event) => onUpdate({ p: event.target.value })}\n        />\n      </InlineField>\n    );\n  }\n\n  if (config.sampler_type === \"uuid\") {\n    return (\n      <InlineField label=\"UUID format\">\n        <Input\n          className=\"nodrag h-8 w-full text-xs\"\n          placeholder=\"uuid4\"\n          value={config.uuid_format ?? \"\"}\n          onChange={(event) =>\n            onUpdate({\n              // biome-ignore lint/style/useNamingConvention: api schema\n              uuid_format: event.target.value,\n            })\n          }\n        />\n      </InlineField>\n    );\n  }\n\n  return null;\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/components/inline/inline-seed.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { DocumentAttachmentIcon, DocumentCodeIcon, Plant01Icon } from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport type { ReactElement } from \"react\";\nimport type { SeedConfig } from \"../../types\";\nimport { HfDatasetCombobox } from \"../shared/hf-dataset-combobox\";\nimport { InlineField } from \"./inline-field\";\n\ntype InlineSeedProps = {\n  config: SeedConfig;\n  onUpdate: (patch: Partial<SeedConfig>) => void;\n};\n\nexport function InlineSeed({ config, onUpdate }: InlineSeedProps): ReactElement {\n  const mode = config.seed_source_type ?? \"hf\";\n\n  if (mode === \"hf\") {\n    return (\n      <div className=\"space-y-2\">\n        <InlineField label=\"Dataset\">\n          <HfDatasetCombobox\n            value={config.hf_repo_id}\n            accessToken={config.hf_token?.trim() || undefined}\n            onValueChange={(next) =>\n              onUpdate({\n                hf_repo_id: next,\n                hf_path: \"\",\n                seed_columns: [],\n                seed_drop_columns: [],\n                seed_preview_rows: [],\n              })\n            }\n            placeholder=\"org/repo\"\n          />\n        </InlineField>\n        <p className=\"text-[11px] text-muted-foreground\">\n          Load columns in dialog.\n        </p>\n      </div>\n    );\n  }\n\n  const isLocal = mode === \"local\";\n  const fileName = isLocal\n    ? config.local_file_name?.trim()\n    : config.unstructured_file_name?.trim();\n\n  return (\n    <div className=\"corner-squircle flex items-center gap-2 rounded-md border border-border/60 bg-muted/30 px-2 py-2\">\n      <div className=\"corner-squircle rounded-md bg-primary/10 p-1.5 text-primary\">\n        <HugeiconsIcon\n          icon={isLocal ? DocumentCodeIcon : DocumentAttachmentIcon}\n          className=\"size-3.5\"\n        />\n      </div>\n      <div className=\"min-w-0\">\n        <p className=\"truncate text-xs font-medium\">\n          {fileName || \"No file selected\"}\n        </p>\n        <p className=\"text-[11px] text-muted-foreground\">\n          {isLocal ? \"Structured file\" : \"Unstructured document\"} · configure in dialog\n        </p>\n      </div>\n      <HugeiconsIcon icon={Plant01Icon} className=\"ml-auto size-3.5 text-muted-foreground/60\" />\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/components/recipe-floating-icon-button-class.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nexport const RECIPE_FLOATING_ICON_BUTTON_CLASS =\n  \"corner-squircle group h-11 w-11 rounded-xl border border-border/60 bg-transparent p-0 text-muted-foreground hover:bg-transparent hover:text-primary hover:border-primary/60\";\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/components/recipe-graph-aux-node.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Button } from \"@/components/ui/button\";\nimport { Input } from \"@/components/ui/input\";\nimport { Textarea } from \"@/components/ui/textarea\";\nimport {\n  Handle,\n  Position,\n  type Node,\n  type NodeProps,\n  useUpdateNodeInternals,\n} from \"@xyflow/react\";\nimport { memo, type ReactElement, useEffect } from \"react\";\nimport { useRecipeStudioStore } from \"../stores/recipe-studio\";\nimport type { LlmConfig, Score, ScoreOption } from \"../types\";\nimport { AUX_HANDLE_CLASS } from \"../utils/handle-layout\";\nimport { HANDLE_IDS } from \"../utils/handles\";\nimport { findInvalidJinjaReferences } from \"../utils/refs\";\nimport { getAvailableVariableEntries } from \"../utils/variables\";\nimport { AvailableReferencesInline } from \"./shared/available-references-inline\";\nimport { BaseNode, BaseNodeContent, BaseNodeHeader, BaseNodeHeaderTitle } from \"./rf-ui/base-node\";\n\ntype PromptField = \"prompt\" | \"system_prompt\";\n\ntype PromptInputNodeData = {\n  kind: \"llm-prompt-input\";\n  llmId: string;\n  field: PromptField;\n  title: string;\n  executionLocked?: boolean;\n};\n\ntype JudgeScoreNodeData = {\n  kind: \"llm-judge-score\";\n  llmId: string;\n  scoreIndex: number;\n  executionLocked?: boolean;\n};\n\nexport type RecipeGraphAuxNodeData = PromptInputNodeData | JudgeScoreNodeData;\nexport type RecipeGraphAuxNodeType = Node<RecipeGraphAuxNodeData, \"aux\">;\n\nfunction updateScoreAt(\n  config: LlmConfig,\n  scoreIndex: number,\n  patch: Partial<Score>,\n): Score[] {\n  const scores = config.scores ?? [];\n  return scores.map((score, index) =>\n    index === scoreIndex ? { ...score, ...patch } : score,\n  );\n}\n\nfunction updateOptionAt(\n  score: Score,\n  optionIndex: number,\n  patch: Partial<ScoreOption>,\n): ScoreOption[] {\n  return score.options.map((option, index) =>\n    index === optionIndex ? { ...option, ...patch } : option,\n  );\n}\n\nfunction AuxVariableBadges({\n  entries,\n}: {\n  entries: ReturnType<typeof getAvailableVariableEntries>;\n}): ReactElement | null {\n  return <AvailableReferencesInline entries={entries} />;\n}\n\nfunction AuxNodeBase({\n  id,\n  data,\n}: NodeProps<RecipeGraphAuxNodeType>): ReactElement | null {\n  const config = useRecipeStudioStore((state) => state.configs[data.llmId]);\n  const configs = useRecipeStudioStore((state) => state.configs);\n  const updateConfig = useRecipeStudioStore((state) => state.updateConfig);\n  const updateNodeInternals = useUpdateNodeInternals();\n\n  useEffect(() => {\n    updateNodeInternals(id);\n  }, [id, updateNodeInternals]);\n\n  if (!(config && config.kind === \"llm\")) {\n    return null;\n  }\n  const executionLocked = Boolean(data.executionLocked);\n\n  const sourceHandles = (\n    <>\n      <Handle\n        id={HANDLE_IDS.llmInputOutLeft}\n        type=\"source\"\n        position={Position.Left}\n        isConnectable={false}\n        isConnectableStart={false}\n        className={AUX_HANDLE_CLASS}\n      />\n      <Handle\n        id={HANDLE_IDS.llmInputOutRight}\n        type=\"source\"\n        position={Position.Right}\n        isConnectable={false}\n        isConnectableStart={false}\n        className={AUX_HANDLE_CLASS}\n      />\n      <Handle\n        id={HANDLE_IDS.llmInputOutTop}\n        type=\"source\"\n        position={Position.Top}\n        isConnectable={false}\n        isConnectableStart={false}\n        className={AUX_HANDLE_CLASS}\n      />\n      <Handle\n        id={HANDLE_IDS.llmInputOutBottom}\n        type=\"source\"\n        position={Position.Bottom}\n        isConnectable={false}\n        isConnectableStart={false}\n        className={AUX_HANDLE_CLASS}\n      />\n    </>\n  );\n\n  if (data.kind === \"llm-prompt-input\") {\n    const value = data.field === \"prompt\" ? config.prompt : config.system_prompt;\n    const variableEntries = getAvailableVariableEntries(configs, data.llmId);\n    const availableRefs = variableEntries.map((entry) => entry.name);\n    const hasInvalidRefs =\n      findInvalidJinjaReferences(value, availableRefs).length > 0;\n    return (\n      <BaseNode className=\"corner-squircle w-full min-w-0 rounded-lg border-border/60 bg-card shadow-sm\">\n        <BaseNodeHeader className=\"border-b border-border/50 px-3 py-2\">\n          <BaseNodeHeaderTitle className=\"text-xs\">{data.title}</BaseNodeHeaderTitle>\n        </BaseNodeHeader>\n        <BaseNodeContent className=\"gap-2 px-3 py-2\">\n          <Textarea\n            className=\"corner-squircle nodrag nowheel max-h-40 min-h-[88px] w-full resize-none overflow-y-auto text-xs\"\n            aria-invalid={hasInvalidRefs}\n            value={value}\n            disabled={executionLocked}\n            onChange={(event) =>\n              updateConfig(data.llmId, {\n                [data.field]: event.target.value,\n              } as Partial<LlmConfig>)\n            }\n          />\n          <AuxVariableBadges entries={variableEntries} />\n        </BaseNodeContent>\n        {sourceHandles}\n      </BaseNode>\n    );\n  }\n\n  const score = config.scores?.[data.scoreIndex];\n  if (!score) {\n    return null;\n  }\n\n  const updateScore = (patch: Partial<Score>): void => {\n    updateConfig(data.llmId, {\n      scores: updateScoreAt(config, data.scoreIndex, patch),\n    });\n  };\n\n  const removeScore = (): void => {\n    const nextScores = (config.scores ?? []).filter(\n      (_score, index) => index !== data.scoreIndex,\n    );\n    updateConfig(data.llmId, { scores: nextScores });\n  };\n\n  const addOption = (): void => {\n    updateScore({\n      options: [...score.options, { value: \"\", description: \"\" }],\n    });\n  };\n\n  const removeOption = (optionIndex: number): void => {\n    updateScore({\n      options: score.options.filter((_option, index) => index !== optionIndex),\n    });\n  };\n\n  const updateOption = (\n    optionIndex: number,\n    patch: Partial<ScoreOption>,\n  ): void => {\n    updateScore({\n      options: updateOptionAt(score, optionIndex, patch),\n    });\n  };\n\n  return (\n    <BaseNode className=\"corner-squircle w-full min-w-0 rounded-lg border-border/60 bg-card shadow-sm\">\n      <BaseNodeHeader className=\"border-b border-border/50 px-3 py-2\">\n        <BaseNodeHeaderTitle className=\"text-xs\">\n          {score.name.trim() || `Scorer ${data.scoreIndex + 1}`}\n        </BaseNodeHeaderTitle>\n        <Button\n          type=\"button\"\n          size=\"xs\"\n          variant=\"ghost\"\n          className=\"nodrag\"\n          disabled={executionLocked}\n          onClick={removeScore}\n        >\n          Remove\n        </Button>\n      </BaseNodeHeader>\n      <BaseNodeContent className=\"gap-2 px-3 py-2\">\n        <Input\n          className=\"nodrag h-7 w-full text-xs\"\n          placeholder=\"Score name\"\n          value={score.name}\n          disabled={executionLocked}\n          onChange={(event) => updateScore({ name: event.target.value })}\n        />\n        <Textarea\n          className=\"corner-squircle nodrag nowheel max-h-32 min-h-[56px] w-full resize-none overflow-y-auto text-xs\"\n          placeholder=\"Score description\"\n          value={score.description}\n          disabled={executionLocked}\n          onChange={(event) => updateScore({ description: event.target.value })}\n        />\n        <div className=\"space-y-1\">\n          {score.options.map((option, optionIndex) => (\n            <div key={`${data.llmId}-score-${data.scoreIndex}-opt-${optionIndex}`} className=\"grid grid-cols-[74px_1fr_auto] gap-1\">\n              <Input\n                className=\"nodrag h-7 text-xs\"\n                placeholder=\"Value\"\n                value={option.value}\n                disabled={executionLocked}\n                onChange={(event) =>\n                  updateOption(optionIndex, { value: event.target.value })\n                }\n              />\n              <Input\n                className=\"nodrag h-7 text-xs\"\n                placeholder=\"Description\"\n                value={option.description}\n                disabled={executionLocked}\n                onChange={(event) =>\n                  updateOption(optionIndex, {\n                    description: event.target.value,\n                  })\n                }\n              />\n              <Button\n                type=\"button\"\n                size=\"xs\"\n                variant=\"ghost\"\n                className=\"nodrag\"\n                disabled={executionLocked}\n                onClick={() => removeOption(optionIndex)}\n              >\n                x\n              </Button>\n            </div>\n          ))}\n          <Button\n            type=\"button\"\n            size=\"xs\"\n            variant=\"outline\"\n            className=\"nodrag mt-1\"\n            disabled={executionLocked}\n            onClick={addOption}\n          >\n            Add option\n          </Button>\n        </div>\n      </BaseNodeContent>\n      {sourceHandles}\n    </BaseNode>\n  );\n}\n\nexport const RecipeGraphAuxNode = memo(AuxNodeBase);\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/components/recipe-graph-node.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { MarkdownPreview } from \"@/components/markdown/markdown-preview\";\nimport { Badge } from \"@/components/ui/badge\";\nimport { Button } from \"@/components/ui/button\";\nimport { cn } from \"@/lib/utils\";\nimport {\n  BalanceScaleIcon,\n  Clock01Icon,\n  CodeIcon,\n  CodeSimpleIcon,\n  DiceFaces03Icon,\n  EqualSignIcon,\n  FingerPrintIcon,\n  FunctionIcon,\n  Parabola02Icon,\n  PencilEdit02Icon,\n  Plant01Icon,\n  Plug01Icon,\n  Shield02Icon,\n  Tag01Icon,\n  TagsIcon,\n  UserAccountIcon,\n} from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport {\n  type NodeProps,\n  NodeResizer,\n  Position,\n  useUpdateNodeInternals,\n} from \"@xyflow/react\";\nimport { type ReactElement, memo, useEffect } from \"react\";\nimport {\n  MAX_NODE_WIDTH,\n  MAX_NOTE_NODE_WIDTH,\n  MIN_NODE_WIDTH,\n} from \"../constants\";\nimport { useNodeConnectionStatus } from \"../hooks/use-node-connection-status\";\nimport { useRecipeStudioStore } from \"../stores/recipe-studio\";\nimport type {\n  LlmType,\n  NodeConfig,\n  RecipeNode as RecipeGraphNodeType,\n  SamplerType,\n} from \"../types\";\nimport { NODE_HANDLE_CLASS } from \"../utils/handle-layout\";\nimport { HANDLE_IDS } from \"../utils/handles\";\nimport {\n  RECIPE_STUDIO_NODE_TONES,\n  RECIPE_STUDIO_USER_NODE_TONE,\n} from \"../utils/ui-tones\";\nimport { InlineCategoryBadges } from \"./inline/inline-category-badges\";\nimport { InlineExpression } from \"./inline/inline-expression\";\nimport { InlineLlm } from \"./inline/inline-llm\";\nimport { InlineModel } from \"./inline/inline-model\";\nimport { isInlineConfig } from \"./inline/inline-policy\";\nimport { InlineSampler } from \"./inline/inline-sampler\";\nimport { InlineSeed } from \"./inline/inline-seed\";\nimport {\n  BaseNode,\n  BaseNodeContent,\n  BaseNodeHeader,\n  BaseNodeHeaderTitle,\n} from \"./rf-ui/base-node\";\nimport { LabeledHandle } from \"./rf-ui/labeled-handle\";\n\ntype IconType = typeof CodeIcon;\n\nfunction hexToRgb(hex: string): { r: number; g: number; b: number } | null {\n  const normalized = hex.trim().replace(\"#\", \"\");\n  if (!/^[0-9a-fA-F]{6}$/.test(normalized)) {\n    return null;\n  }\n  const int = Number.parseInt(normalized, 16);\n  return {\n    r: (int >> 16) & 255,\n    g: (int >> 8) & 255,\n    b: int & 255,\n  };\n}\n\nfunction parseNoteOpacity(value: string | undefined): number {\n  const parsed = Number.parseInt(value ?? \"\", 10);\n  if (!Number.isFinite(parsed)) {\n    return 0.35;\n  }\n  return Math.max(0.05, Math.min(1, parsed / 100));\n}\n\nconst NODE_META = {\n  sampler: {\n    tone: RECIPE_STUDIO_NODE_TONES.sampler,\n  },\n  llm: {\n    tone: RECIPE_STUDIO_NODE_TONES.llm,\n  },\n  validator: {\n    tone: RECIPE_STUDIO_NODE_TONES.validator,\n  },\n  expression: {\n    tone: RECIPE_STUDIO_NODE_TONES.expression,\n  },\n  note: {\n    tone: RECIPE_STUDIO_NODE_TONES.note,\n  },\n  seed: {\n    tone: RECIPE_STUDIO_NODE_TONES.seed,\n  },\n  model_provider: {\n    tone: RECIPE_STUDIO_NODE_TONES.model_provider,\n  },\n  model_config: {\n    tone: RECIPE_STUDIO_NODE_TONES.model_config,\n  },\n  tool_config: {\n    tone: RECIPE_STUDIO_NODE_TONES.tool_config,\n  },\n} as const;\nconst SAMPLER_ICONS: Record<SamplerType, IconType> = {\n  category: Tag01Icon,\n  subcategory: TagsIcon,\n  uniform: EqualSignIcon,\n  gaussian: Parabola02Icon,\n  bernoulli: EqualSignIcon,\n  datetime: Clock01Icon,\n  timedelta: Clock01Icon,\n  uuid: FingerPrintIcon,\n  person: UserAccountIcon,\n  person_from_faker: UserAccountIcon,\n};\n\nconst LLM_ICONS: Record<LlmType, IconType> = {\n  text: PencilEdit02Icon,\n  structured: CodeIcon,\n  code: CodeSimpleIcon,\n  judge: BalanceScaleIcon,\n};\n\nfunction resolveNodeIcon(\n  kind: RecipeGraphNodeType[\"data\"][\"kind\"],\n  blockType: RecipeGraphNodeType[\"data\"][\"blockType\"],\n): IconType {\n  if (kind === \"sampler\" && blockType in SAMPLER_ICONS) {\n    return SAMPLER_ICONS[blockType as SamplerType];\n  }\n  if (kind === \"llm\" && blockType in LLM_ICONS) {\n    return LLM_ICONS[blockType as LlmType];\n  }\n  if (kind === \"validator\") {\n    return Shield02Icon;\n  }\n  if (kind === \"expression\") {\n    return FunctionIcon;\n  }\n  if (kind === \"note\") {\n    return PencilEdit02Icon;\n  }\n  if (kind === \"model_provider\") {\n    return Shield02Icon;\n  }\n  if (kind === \"model_config\") {\n    return Plant01Icon;\n  }\n  if (kind === \"tool_config\") {\n    return Plug01Icon;\n  }\n  if (kind === \"seed\") {\n    return Plant01Icon;\n  }\n  return DiceFaces03Icon;\n}\n\nfunction getConfigSummary(config: NodeConfig | undefined): string {\n  if (!config) {\n    return \"Open settings\";\n  }\n\n  if (config.kind === \"sampler\") {\n    if (config.sampler_type === \"category\") {\n      const count = config.values?.length ?? 0;\n      return `${count} options`;\n    }\n    if (config.sampler_type === \"subcategory\") {\n      if (config.subcategory_parent?.trim()) {\n        return `Based on ${config.subcategory_parent}`;\n      }\n      return \"Choose the main field\";\n    }\n    if (config.sampler_type === \"datetime\") {\n      const start = config.datetime_start?.trim() || \"?\";\n      const end = config.datetime_end?.trim() || \"?\";\n      return `${start} -> ${end}`;\n    }\n    if (config.sampler_type === \"timedelta\") {\n      if (config.reference_column_name?.trim()) {\n        return `From ${config.reference_column_name}`;\n      }\n      return \"Choose a date field\";\n    }\n    if (\n      config.sampler_type === \"person\" ||\n      config.sampler_type === \"person_from_faker\"\n    ) {\n      const locale = config.person_locale?.trim() || \"any locale\";\n      const city = config.person_city?.trim();\n      if (city) {\n        return `${locale} · ${city}`;\n      }\n      return locale;\n    }\n    return \"Open settings\";\n  }\n\n  if (config.kind === \"llm\") {\n    if (config.llm_type === \"structured\") {\n      return \"Set the response format in settings\";\n    }\n    if (config.llm_type === \"judge\") {\n      const scoreCount = config.scores?.length ?? 0;\n      return `${scoreCount} criteria`;\n    }\n    if (config.tool_alias?.trim()) {\n      return `Tools: ${config.tool_alias.trim()}`;\n    }\n    return \"Add your prompt in settings\";\n  }\n\n  if (config.kind === \"tool_config\") {\n    const providerCount = config.mcp_providers.length;\n    const allowCount =\n      config.allow_tools?.filter((value) => value.trim()).length ?? 0;\n    const providerLabel =\n      providerCount === 1 ? \"1 server\" : `${providerCount} servers`;\n    if (allowCount === 0) {\n      return `${providerLabel} · all tools allowed`;\n    }\n    return `${providerLabel} · ${allowCount} selected tools`;\n  }\n\n  if (config.kind === \"validator\") {\n    const target = config.target_columns[0]?.trim();\n    if (target) {\n      return `Checks ${target}`;\n    }\n    return \"Choose code to check\";\n  }\n\n  if (config.kind === \"seed\") {\n    const seedSourceType = config.seed_source_type ?? \"hf\";\n    if (seedSourceType === \"hf\" && config.hf_repo_id.trim()) {\n      return config.hf_repo_id.trim();\n    }\n    if (seedSourceType === \"local\" && config.local_file_name?.trim()) {\n      return config.local_file_name.trim();\n    }\n    if (\n      seedSourceType === \"unstructured\" &&\n      config.unstructured_file_name?.trim()\n    ) {\n      return config.unstructured_file_name.trim();\n    }\n    if (config.hf_path.trim()) {\n      return config.hf_path.trim();\n    }\n    if (seedSourceType === \"hf\") {\n      return \"Choose a dataset\";\n    }\n    if (seedSourceType === \"local\") {\n      return \"Upload a table file\";\n    }\n    return \"Upload a document\";\n  }\n\n  if (config.kind === \"markdown_note\") {\n    if (config.markdown.trim()) {\n      return \"Note preview\";\n    }\n    return \"Add note text\";\n  }\n\n  return \"Open settings\";\n}\n\nfunction renderNodeBody(\n  config: NodeConfig | undefined,\n  summary: string,\n  updateConfig: (id: string, patch: Partial<NodeConfig>) => void,\n): ReactElement {\n  if (config?.kind === \"markdown_note\") {\n    return <MarkdownPreview markdown={config.markdown} />;\n  }\n\n  if (config && isInlineConfig(config)) {\n    const onUpdate = (patch: Partial<NodeConfig>) =>\n      updateConfig(config.id, patch);\n\n    if (config.kind === \"sampler\") {\n      return <InlineSampler config={config} onUpdate={onUpdate} />;\n    }\n    if (config.kind === \"model_provider\" || config.kind === \"model_config\") {\n      return <InlineModel config={config} onUpdate={onUpdate} />;\n    }\n    if (config.kind === \"llm\") {\n      return <InlineLlm config={config} onUpdate={onUpdate} />;\n    }\n    if (config.kind === \"expression\") {\n      return <InlineExpression config={config} onUpdate={onUpdate} />;\n    }\n    if (config.kind === \"seed\") {\n      return <InlineSeed config={config} onUpdate={onUpdate} />;\n    }\n  }\n\n  if (config?.kind === \"sampler\" && config.sampler_type === \"category\") {\n    return <InlineCategoryBadges values={config.values ?? []} />;\n  }\n\n  if (config?.kind === \"tool_config\") {\n    const providerNames = config.mcp_providers\n      .map((provider) => provider.name.trim())\n      .filter(Boolean);\n    return (\n      <div className=\"space-y-2\">\n        <p className=\"text-xs text-muted-foreground\">{summary}</p>\n        {providerNames.length > 0 && (\n          <div className=\"flex flex-wrap gap-1.5\">\n            {providerNames.map((providerName) => (\n              <Badge\n                key={providerName}\n                variant=\"secondary\"\n                className=\"corner-squircle font-mono text-[11px]\"\n              >\n                {providerName}\n              </Badge>\n            ))}\n          </div>\n        )}\n      </div>\n    );\n  }\n\n  return <p className=\"text-xs text-muted-foreground\">{summary}</p>;\n}\n\nfunction RecipeGraphNodeBase({\n  id,\n  data,\n  selected,\n}: NodeProps<RecipeGraphNodeType>): ReactElement {\n  const meta = NODE_META[data.kind];\n  const icon = resolveNodeIcon(data.kind, data.blockType);\n  const layoutDirection = data.layoutDirection ?? \"LR\";\n  const config = useRecipeStudioStore((state) => state.configs[id]);\n  const openConfig = useRecipeStudioStore((state) => state.openConfig);\n  const updateConfig = useRecipeStudioStore((state) => state.updateConfig);\n  const llmAuxVisible = useRecipeStudioStore(\n    (state) => state.llmAuxVisibility[id] ?? false,\n  );\n  const setLlmAuxVisibility = useRecipeStudioStore(\n    (state) => state.setLlmAuxVisibility,\n  );\n  const updateNodeInternals = useUpdateNodeInternals();\n  const executionLocked = Boolean(data.executionLocked);\n  const runtimeState = data.runtimeState ?? \"idle\";\n  const connectionStatus = useNodeConnectionStatus(id, config);\n\n  useEffect(() => {\n    updateNodeInternals(id);\n  }, [id, layoutDirection, config, updateNodeInternals]);\n\n  if (config?.kind === \"markdown_note\") {\n    const rgb = hexToRgb(config.note_color ?? \"#FDE68A\");\n    const alpha = parseNoteOpacity(config.note_opacity);\n    const noteStyle = rgb\n      ? {\n          backgroundColor: `rgba(${rgb.r}, ${rgb.g}, ${rgb.b}, ${alpha})`,\n          borderColor: `rgba(${rgb.r}, ${rgb.g}, ${rgb.b}, ${Math.min(1, Math.max(alpha + 0.15, 0.3))})`,\n        }\n      : undefined;\n\n    return (\n      <BaseNode\n        className=\"corner-squircle relative w-full min-w-0 overflow-visible rounded-lg border-border/60 shadow-sm\"\n        style={noteStyle}\n      >\n        <NodeResizer\n          isVisible={selected}\n          minWidth={MIN_NODE_WIDTH}\n          minHeight={80}\n          maxWidth={MAX_NOTE_NODE_WIDTH}\n          maxHeight={520}\n          color=\"var(--primary)\"\n          lineClassName=\"!border-transparent !shadow-none\"\n          lineStyle={{ opacity: 0 }}\n          handleClassName=\"!h-3 !w-3 !border-transparent !bg-transparent\"\n          handleStyle={{ opacity: 0 }}\n        />\n        <BaseNodeContent className=\"px-3 py-2\">\n          <MarkdownPreview markdown={config.markdown} plain={true} />\n        </BaseNodeContent>\n      </BaseNode>\n    );\n  }\n\n  const showDataHandles =\n    data.kind === \"llm\" ||\n    data.kind === \"validator\" ||\n    data.kind === \"expression\" ||\n    data.kind === \"sampler\" ||\n    data.kind === \"seed\";\n  const showSemanticIn =\n    data.kind === \"model_config\" || data.kind === \"validator\";\n  const showSemanticOut =\n    data.kind === \"model_config\" ||\n    data.kind === \"model_provider\" ||\n    data.kind === \"tool_config\" ||\n    data.kind === \"validator\";\n  const summary = getConfigSummary(config);\n  const nodeBody = renderNodeBody(config, summary, updateConfig);\n  const canShowLlmAux =\n    config?.kind === \"llm\" &&\n    (Boolean(config.prompt.trim()) ||\n      Boolean(config.system_prompt.trim()) ||\n      Boolean((config.scores?.length ?? 0) > 0));\n  const iconTone =\n    config?.kind === \"sampler\" &&\n    (config.sampler_type === \"person\" ||\n      config.sampler_type === \"person_from_faker\")\n      ? RECIPE_STUDIO_USER_NODE_TONE\n      : meta.tone;\n  const runtimeNodeTone =\n    runtimeState === \"running\"\n      ? \"border-primary/70 ring-2 ring-primary/20 shadow-md\"\n      : runtimeState === \"done\"\n        ? \"border-emerald-500/60 ring-1 ring-emerald-500/20\"\n        : \"\";\n  const hasConnectionIssue =\n    connectionStatus.isDisconnected ||\n    connectionStatus.missingDataInput;\n\n  return (\n    <BaseNode\n      className={cn(\n        \"corner-squircle relative w-full min-w-0 overflow-visible rounded-lg border-border/60 shadow-sm\",\n        runtimeNodeTone,\n        hasConnectionIssue &&\n          runtimeState === \"idle\" &&\n          \"opacity-80 border-dashed border-amber-400/70\",\n      )}\n    >\n      {runtimeState === \"running\" && config?.kind === \"llm\" && (\n        <div className=\"pointer-events-none absolute -top-7 right-2 z-20\">\n          <span\n            className=\"block size-6 animate-spin rounded-full border-[3px] border-primary/90 border-t-transparent bg-background\"\n            aria-label=\"Running\"\n          />\n        </div>\n      )}\n      <NodeResizer\n        isVisible={selected}\n        minWidth={MIN_NODE_WIDTH}\n        minHeight={120}\n        maxWidth={MAX_NODE_WIDTH}\n        maxHeight={520}\n        color=\"var(--primary)\"\n        lineClassName=\"!border-transparent !shadow-none\"\n        lineStyle={{ opacity: 0 }}\n        handleClassName=\"!h-3 !w-3 !border-transparent !bg-transparent\"\n        handleStyle={{ opacity: 0 }}\n      />\n      <BaseNodeHeader className=\"border-b border-border/50 px-3 py-2\">\n        <div className=\"flex min-w-0 items-center gap-2\">\n          <div\n            className={cn(\n              \"corner-squircle flex size-7 items-center justify-center rounded-md border\",\n              iconTone,\n            )}\n          >\n            <HugeiconsIcon icon={icon} className=\"size-3.5\" />\n          </div>\n          <div className=\"min-w-0\">\n            <BaseNodeHeaderTitle className=\"truncate text-sm\">\n              {data.name}\n            </BaseNodeHeaderTitle>\n            <p className=\"truncate text-[11px] text-muted-foreground\">\n              {data.subtype} · {data.title}\n            </p>\n          </div>\n        </div>\n        <div className=\"flex items-center gap-1\">\n          {canShowLlmAux && (\n            <Button\n              type=\"button\"\n              size=\"xs\"\n              variant=\"ghost\"\n              className=\"nodrag\"\n              disabled={executionLocked}\n              onClick={(event) => {\n                event.preventDefault();\n                event.stopPropagation();\n                setLlmAuxVisibility(id, !llmAuxVisible);\n              }}\n            >\n              {llmAuxVisible ? \"Hide inputs\" : \"Show inputs\"}\n            </Button>\n          )}\n          <Button\n            type=\"button\"\n            size=\"xs\"\n            variant=\"ghost\"\n            className=\"nodrag\"\n            disabled={executionLocked}\n            onClick={(event) => {\n              event.preventDefault();\n              event.stopPropagation();\n              openConfig(id);\n            }}\n          >\n            Configure\n          </Button>\n        </div>\n      </BaseNodeHeader>\n\n      <BaseNodeContent\n        className={cn(\n          \"gap-2 px-3 py-2\",\n          executionLocked && \"pointer-events-none opacity-85\",\n        )}\n      >\n        {nodeBody}\n      </BaseNodeContent>\n\n      {showDataHandles && (\n        <>\n          <LabeledHandle\n            id={HANDLE_IDS.dataIn}\n            title=\"Data input\"\n            type=\"target\"\n            position={Position.Left}\n            className=\"absolute inset-0 pointer-events-none\"\n            labelClassName=\"sr-only\"\n            handleClassName={NODE_HANDLE_CLASS}\n          />\n          <LabeledHandle\n            id={HANDLE_IDS.dataOutLeft}\n            title=\"Data output\"\n            type=\"source\"\n            position={Position.Left}\n            className=\"absolute inset-0 pointer-events-none opacity-0\"\n            labelClassName=\"sr-only\"\n            handleClassName={NODE_HANDLE_CLASS}\n          />\n          <LabeledHandle\n            id={HANDLE_IDS.dataInTop}\n            title=\"Data input\"\n            type=\"target\"\n            position={Position.Top}\n            className=\"absolute inset-0 pointer-events-none\"\n            labelClassName=\"sr-only\"\n            handleClassName={NODE_HANDLE_CLASS}\n          />\n          <LabeledHandle\n            id={HANDLE_IDS.dataOutTop}\n            title=\"Data output\"\n            type=\"source\"\n            position={Position.Top}\n            className=\"absolute inset-0 pointer-events-none opacity-0\"\n            labelClassName=\"sr-only\"\n            handleClassName={NODE_HANDLE_CLASS}\n          />\n          <LabeledHandle\n            id={HANDLE_IDS.dataOut}\n            title=\"Data output\"\n            type=\"source\"\n            position={Position.Right}\n            className=\"absolute inset-0 pointer-events-none\"\n            labelClassName=\"sr-only\"\n            handleClassName={NODE_HANDLE_CLASS}\n          />\n          <LabeledHandle\n            id={HANDLE_IDS.dataInRight}\n            title=\"Data input\"\n            type=\"target\"\n            position={Position.Right}\n            className=\"absolute inset-0 pointer-events-none opacity-0\"\n            labelClassName=\"sr-only\"\n            handleClassName={NODE_HANDLE_CLASS}\n          />\n          <LabeledHandle\n            id={HANDLE_IDS.dataOutBottom}\n            title=\"Data output\"\n            type=\"source\"\n            position={Position.Bottom}\n            className=\"absolute inset-0 pointer-events-none\"\n            labelClassName=\"sr-only\"\n            handleClassName={NODE_HANDLE_CLASS}\n          />\n          <LabeledHandle\n            id={HANDLE_IDS.dataInBottom}\n            title=\"Data input\"\n            type=\"target\"\n            position={Position.Bottom}\n            className=\"absolute inset-0 pointer-events-none opacity-0\"\n            labelClassName=\"sr-only\"\n            handleClassName={NODE_HANDLE_CLASS}\n          />\n        </>\n      )}\n\n      {showSemanticIn && (\n        <>\n          <LabeledHandle\n            id={HANDLE_IDS.semanticIn}\n            title=\"Semantic input\"\n            type=\"target\"\n            position={Position.Left}\n            className=\"absolute inset-0 pointer-events-none\"\n            labelClassName=\"sr-only\"\n            handleClassName={NODE_HANDLE_CLASS}\n          />\n          <LabeledHandle\n            id={HANDLE_IDS.semanticInTop}\n            title=\"Semantic input\"\n            type=\"target\"\n            position={Position.Top}\n            className=\"absolute inset-0 pointer-events-none\"\n            labelClassName=\"sr-only\"\n            handleClassName={NODE_HANDLE_CLASS}\n          />\n        </>\n      )}\n\n      {showSemanticOut && (\n        <>\n          <LabeledHandle\n            id={HANDLE_IDS.semanticOut}\n            title=\"Semantic output\"\n            type=\"source\"\n            position={Position.Right}\n            className=\"absolute inset-0 pointer-events-none\"\n            labelClassName=\"sr-only\"\n            handleClassName={NODE_HANDLE_CLASS}\n          />\n          <LabeledHandle\n            id={HANDLE_IDS.semanticOutBottom}\n            title=\"Semantic output\"\n            type=\"source\"\n            position={Position.Bottom}\n            className=\"absolute inset-0 pointer-events-none\"\n            labelClassName=\"sr-only\"\n            handleClassName={NODE_HANDLE_CLASS}\n          />\n        </>\n      )}\n    </BaseNode>\n  );\n}\n\nexport const RecipeNode = memo(RecipeGraphNodeBase);\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/components/recipe-graph-semantic-edge.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { BaseEdge, type EdgeProps, getSmoothStepPath } from \"@xyflow/react\";\nimport { memo, type ReactElement } from \"react\";\n\nexport const RecipeGraphSemanticEdge = memo(function RecipeGraphSemanticEdge({\n  id,\n  sourceX,\n  sourceY,\n  targetX,\n  targetY,\n  sourcePosition,\n  targetPosition,\n  style,\n  markerEnd,\n  selected,\n  data,\n}: EdgeProps): ReactElement {\n  const isActive = Boolean((data as { active?: boolean } | undefined)?.active);\n  const [path] = getSmoothStepPath({\n    sourceX,\n    sourceY,\n    sourcePosition,\n    targetX,\n    targetY,\n    targetPosition,\n    borderRadius: 0,\n    offset: 16,\n  });\n\n  return (\n    <BaseEdge\n      id={id}\n      path={path}\n      markerEnd={markerEnd}\n      style={{\n        strokeDasharray: isActive ? \"8 6\" : selected ? \"7 5\" : \"6 5\",\n        strokeWidth: isActive ? 2.4 : selected ? 2.3 : 1.8,\n        stroke: isActive || selected ? \"var(--primary)\" : \"var(--muted-foreground)\",\n        opacity: isActive ? 1 : selected ? 0.95 : 0.62,\n        ...style,\n      }}\n    />\n  );\n});\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/components/recipe-studio-header.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Badge } from \"@/components/ui/badge\";\nimport { Button } from \"@/components/ui/button\";\nimport { Input } from \"@/components/ui/input\";\nimport {\n  Popover,\n  PopoverContent,\n  PopoverTrigger,\n} from \"@/components/ui/popover\";\nimport { Tabs, TabsList, TabsTrigger } from \"@/components/ui/tabs\";\nimport {\n  Alert02Icon,\n  AlertDiamondIcon,\n  CookBookIcon,\n  FloppyDiskIcon,\n} from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport { type KeyboardEvent, type ReactElement, useState } from \"react\";\nimport type { RecipeStudioView } from \"../execution-types\";\nimport type { GraphWarning } from \"../utils/graph-warnings\";\nimport {\n  RECIPE_STUDIO_WARNING_BADGE_TONE,\n  RECIPE_STUDIO_WARNING_ICON_TONE,\n} from \"../utils/ui-tones\";\n\ntype StatusTone = \"success\" | \"error\";\n\ntype RecipeStudioHeaderProps = {\n  activeView: RecipeStudioView;\n  saveLoading: boolean;\n  saveTone: StatusTone;\n  savedAtLabel: string;\n  workflowName: string;\n  warnings?: GraphWarning[];\n  onWorkflowNameChange: (value: string) => void;\n  onViewChange: (view: RecipeStudioView) => void;\n  onSaveRecipe: () => void;\n};\n\nconst STATUS_MESSAGE_CLASS: Record<StatusTone, string> = {\n  success: \"Saved\",\n  error: \"Needs saving\",\n};\n\nexport function RecipeStudioHeader({\n  activeView,\n  saveLoading,\n  saveTone,\n  savedAtLabel,\n  workflowName,\n  warnings = [],\n  onWorkflowNameChange,\n  onViewChange,\n  onSaveRecipe,\n}: RecipeStudioHeaderProps): ReactElement {\n  const [editingWorkflowName, setEditingWorkflowName] = useState(false);\n\n  function handleViewValueChange(value: string): void {\n    if (value === \"editor\" || value === \"executions\") {\n      onViewChange(value);\n    }\n  }\n\n  function closeWorkflowNameEditor(): void {\n    if (workflowName.trim().length === 0) {\n      onWorkflowNameChange(\"Untitled recipe\");\n    }\n    setEditingWorkflowName(false);\n  }\n\n  function handleWorkflowNameKeyDown(\n    event: KeyboardEvent<HTMLInputElement>,\n  ): void {\n    if (event.key === \"Enter\") {\n      closeWorkflowNameEditor();\n      return;\n    }\n    if (event.key === \"Escape\") {\n      setEditingWorkflowName(false);\n    }\n  }\n\n  return (\n    <div className=\"grid grid-cols-[minmax(0,1fr)_auto_minmax(0,1fr)] items-center gap-4 border-b px-4 py-3\">\n      <div className=\"flex min-w-0 items-center gap-3\">\n        <div\n          className=\"flex size-8 shrink-0 items-center justify-center rounded-lg corner-squircle border border-border/70 bg-muted/20\"\n          aria-hidden={true}\n        >\n          <HugeiconsIcon\n            icon={CookBookIcon}\n            className=\"size-4 text-muted-foreground\"\n          />\n        </div>\n        <div className=\"flex min-w-0 items-center gap-2\">\n          {editingWorkflowName ? (\n            <Input\n              value={workflowName}\n              onChange={(event) => onWorkflowNameChange(event.target.value)}\n              onBlur={closeWorkflowNameEditor}\n              onKeyDown={handleWorkflowNameKeyDown}\n              autoFocus={true}\n              className=\"h-7 w-full max-w-[min(22rem,50vw)]\"\n              aria-label=\"Recipe name\"\n            />\n          ) : (\n            <button\n              type=\"button\"\n              onClick={() => setEditingWorkflowName(true)}\n              className=\"max-w-[min(22rem,50vw)] truncate text-sm font-semibold text-foreground hover:text-primary\"\n              title={workflowName}\n              aria-label={`Edit recipe name: ${workflowName}`}\n            >\n              {workflowName}\n            </button>\n          )}\n          <Badge variant=\"secondary\" className=\"h-6 shrink-0 text-[10px]\">\n            {STATUS_MESSAGE_CLASS[saveTone]}\n          </Badge>\n          <span\n            className=\"hidden max-w-[12rem] truncate text-xs text-muted-foreground sm:inline\"\n            title={savedAtLabel}\n          >\n            {savedAtLabel}\n          </span>\n        </div>\n      </div>\n      <div className=\"justify-self-center\">\n        <Tabs value={activeView} onValueChange={handleViewValueChange}>\n          <TabsList>\n            <TabsTrigger value=\"editor\">Editor</TabsTrigger>\n            <TabsTrigger value=\"executions\">Runs</TabsTrigger>\n          </TabsList>\n        </Tabs>\n      </div>\n      <div className=\"flex items-center justify-self-end gap-2\">\n        {warnings.length > 0 && (\n          <Popover>\n            <PopoverTrigger asChild={true}>\n              <button\n                type=\"button\"\n                className={`inline-flex h-6 shrink-0 items-center gap-1 rounded-md border px-2 text-[10px] font-medium ${RECIPE_STUDIO_WARNING_BADGE_TONE}`}\n              >\n                <HugeiconsIcon icon={Alert02Icon} className=\"size-3\" />\n                {warnings.length}\n              </button>\n            </PopoverTrigger>\n            <PopoverContent align=\"end\" className=\"w-80 p-0\">\n              <div className=\"border-b px-3 py-2\">\n                <p className=\"text-xs font-semibold text-foreground\">\n                  Graph warnings ({warnings.length})\n                </p>\n              </div>\n              <ul className=\"max-h-60 overflow-y-auto py-1\">\n                {warnings.map((w) => (\n                  <li\n                    key={`${w.nodeId ?? \"global\"}-${w.message}`}\n                    className=\"flex items-start gap-2 px-3 py-1.5\"\n                  >\n                    <HugeiconsIcon\n                      icon={\n                        w.severity === \"error\" ? AlertDiamondIcon : Alert02Icon\n                      }\n                      className={`mt-0.5 size-3 shrink-0 ${w.severity === \"error\" ? \"text-destructive\" : RECIPE_STUDIO_WARNING_ICON_TONE}`}\n                    />\n                    <span className=\"text-xs text-muted-foreground\">\n                      {(w.nodeName || w.nodeId) && (\n                        <span className=\"font-medium text-foreground\">\n                          {w.nodeName || w.nodeId}:{\" \"}\n                        </span>\n                      )}\n                      {w.message}\n                    </span>\n                  </li>\n                ))}\n              </ul>\n            </PopoverContent>\n          </Popover>\n        )}\n        <Button\n          type=\"button\"\n          size=\"sm\"\n          variant=\"outline\"\n          onClick={onSaveRecipe}\n          disabled={saveLoading}\n        >\n          <HugeiconsIcon icon={FloppyDiskIcon} className=\"size-3.5\" />\n          {saveLoading ? \"Saving...\" : \"Save\"}\n        </Button>\n      </div>\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/components/rf-ui/base-handle.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { ComponentProps, ReactElement } from \"react\";\nimport { Handle, type HandleProps } from \"@xyflow/react\";\n\nimport { cn } from \"@/lib/utils\";\n\nexport type BaseHandleProps = HandleProps;\n\nexport function BaseHandle({\n  className,\n  children,\n  ...props\n}: ComponentProps<typeof Handle>): ReactElement {\n  return (\n    <Handle\n      {...props}\n      className={cn(\n        \"h-[12px] w-[12px] rounded-full border border-border/80 bg-muted shadow-[0_0_0_1px_hsl(var(--background))] transition-all hover:scale-110 hover:border-primary/70 hover:bg-primary/20\",\n        className,\n      )}\n    >\n      {children}\n    </Handle>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/components/rf-ui/base-node.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { ComponentProps, ReactElement } from \"react\";\n\nimport { cn } from \"@/lib/utils\";\n\nexport function BaseNode({\n  className,\n  ...props\n}: ComponentProps<\"div\">): ReactElement {\n  return (\n    <div\n      className={cn(\n        \"bg-card text-card-foreground relative rounded-md border transition-[border-color,box-shadow] duration-150\",\n        \"hover:border-primary/40 hover:ring-1 hover:ring-primary/20 hover:shadow-sm\",\n        \"[.react-flow\\\\_\\\\_node.selected_&]:border-primary/45\",\n        \"[.react-flow\\\\_\\\\_node.selected_&]:ring-1 [.react-flow\\\\_\\\\_node.selected_&]:ring-primary/25\",\n        \"[.react-flow\\\\_\\\\_node.selected_&]:shadow-md\",\n        className,\n      )}\n      tabIndex={0}\n      {...props}\n    />\n  );\n}\n\nexport function BaseNodeHeader({\n  className,\n  ...props\n}: ComponentProps<\"header\">): ReactElement {\n  return (\n    <header\n      {...props}\n      className={cn(\n        \"mx-0 my-0 -mb-1 flex flex-row items-center justify-between gap-2 px-3 py-2\",\n        className,\n      )}\n    />\n  );\n}\n\nexport function BaseNodeHeaderTitle({\n  className,\n  ...props\n}: ComponentProps<\"h3\">): ReactElement {\n  return (\n    <h3\n      data-slot=\"base-node-title\"\n      className={cn(\"user-select-none flex-1 font-semibold\", className)}\n      {...props}\n    />\n  );\n}\n\nexport function BaseNodeContent({\n  className,\n  ...props\n}: ComponentProps<\"div\">): ReactElement {\n  return (\n    <div\n      data-slot=\"base-node-content\"\n      className={cn(\"flex flex-col gap-y-2 p-3\", className)}\n      {...props}\n    />\n  );\n}\n\nexport function BaseNodeFooter({\n  className,\n  ...props\n}: ComponentProps<\"div\">): ReactElement {\n  return (\n    <div\n      data-slot=\"base-node-footer\"\n      className={cn(\n        \"flex flex-col items-center gap-y-2 border-t px-3 pt-2 pb-3\",\n        className,\n      )}\n      {...props}\n    />\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/components/rf-ui/data-edge.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { ReactElement } from \"react\";\nimport {\n  BaseEdge,\n  getBezierPath,\n  getSmoothStepPath,\n  getStraightPath,\n  Position,\n  type Edge,\n  type EdgeProps,\n} from \"@xyflow/react\";\n\nexport type DataEdge = Edge<{\n  path?: \"auto\" | \"bezier\" | \"smoothstep\" | \"step\" | \"straight\";\n  active?: boolean;\n}>;\n\nexport function DataEdge({\n  data = { path: \"auto\" },\n  id,\n  markerEnd,\n  selected,\n  sourcePosition,\n  sourceX,\n  sourceY,\n  style,\n  targetPosition,\n  targetX,\n  targetY,\n}: EdgeProps<DataEdge>): ReactElement {\n  const resolvedPathType = resolvePathType({\n    type: data.path ?? \"auto\",\n  });\n  const isActive = Boolean(data.active);\n  const [edgePath] = getPath({\n    type: resolvedPathType,\n    sourceX,\n    sourceY,\n    sourcePosition,\n    targetX,\n    targetY,\n    targetPosition,\n  });\n\n  const edgeStyle = {\n    stroke: isActive || selected ? \"var(--primary)\" : \"var(--muted-foreground)\",\n    strokeWidth: isActive ? 2.6 : selected ? 2.6 : 2.1,\n    opacity: isActive ? 1 : selected ? 0.96 : 0.7,\n    strokeDasharray: isActive ? \"8 6\" : undefined,\n    ...style,\n  };\n\n  return (\n    <BaseEdge\n      id={id}\n      path={edgePath}\n      markerEnd={markerEnd}\n      style={edgeStyle}\n    />\n  );\n}\n\nfunction getPath({\n  type,\n  sourceX,\n  sourceY,\n  targetX,\n  targetY,\n  sourcePosition,\n  targetPosition,\n}: {\n  type: \"bezier\" | \"smoothstep\" | \"step\" | \"straight\";\n  sourceX: number;\n  sourceY: number;\n  targetX: number;\n  targetY: number;\n  sourcePosition: Position;\n  targetPosition: Position;\n}): [string, number, number, ...number[]] {\n  if (type === \"bezier\") {\n    return getBezierPath({\n      sourceX,\n      sourceY,\n      targetX,\n      targetY,\n      sourcePosition,\n      targetPosition,\n    });\n  }\n  if (type === \"smoothstep\") {\n    return getSmoothStepPath({\n      sourceX,\n      sourceY,\n      targetX,\n      targetY,\n      sourcePosition,\n      targetPosition,\n    });\n  }\n  if (type === \"step\") {\n    return getSmoothStepPath({\n      sourceX,\n      sourceY,\n      targetX,\n      targetY,\n      sourcePosition,\n      targetPosition,\n      borderRadius: 0,\n    });\n  }\n  return getStraightPath({\n    sourceX,\n    sourceY,\n    targetX,\n    targetY,\n  });\n}\n\nfunction resolvePathType({\n  type,\n}: {\n  type: \"auto\" | \"bezier\" | \"smoothstep\" | \"step\" | \"straight\";\n}): \"bezier\" | \"smoothstep\" | \"step\" | \"straight\" {\n  if (type !== \"auto\") {\n    return type;\n  }\n  return \"smoothstep\";\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/components/rf-ui/labeled-handle.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { type ComponentProps, type ReactElement } from \"react\";\nimport { type HandleProps } from \"@xyflow/react\";\n\nimport { cn } from \"@/lib/utils\";\nimport { BaseHandle } from \"./base-handle\";\n\nconst flexDirections = {\n  top: \"flex-col\",\n  right: \"flex-row-reverse justify-end\",\n  bottom: \"flex-col-reverse justify-end\",\n  left: \"flex-row\",\n};\n\nexport function LabeledHandle({\n  className,\n  labelClassName,\n  handleClassName,\n  title,\n  position,\n  ...props\n}: HandleProps &\n  ComponentProps<\"div\"> & {\n    title: string;\n    handleClassName?: string;\n    labelClassName?: string;\n  }): ReactElement {\n  const { ref, ...handleProps } = props;\n\n  return (\n    <div\n      title={title}\n      className={cn(\n        \"relative flex items-center\",\n        flexDirections[position],\n        className,\n      )}\n      ref={ref}\n    >\n      <BaseHandle\n        position={position}\n        className={handleClassName}\n        {...handleProps}\n      />\n      <label className={cn(\"text-foreground px-3\", labelClassName)}>\n        {title}\n      </label>\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/components/runtime/execution-progress-island.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport {\n  ArrowDown01Icon,\n  ArrowUp01Icon,\n  CheckmarkCircle02Icon,\n  Flag02Icon,\n} from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport type { ReactElement } from \"react\";\nimport { Button } from \"@/components/ui/button\";\nimport { Progress } from \"@/components/ui/progress\";\nimport { Spinner } from \"@/components/ui/spinner\";\nimport { cn } from \"@/lib/utils\";\nimport type { RecipeExecutionRecord } from \"../../execution-types\";\nimport { isExecutionInProgress } from \"../../executions/execution-helpers\";\nimport {\n  formatMetricValue,\n  formatPercent,\n} from \"../executions/executions-view-helpers\";\n\ntype ExecutionProgressIslandProps = {\n  execution: RecipeExecutionRecord;\n  currentColumnIcon: typeof Flag02Icon;\n  minimized: boolean;\n  onMinimizedChange: (value: boolean) => void;\n  onViewExecutions: () => void;\n};\n\nfunction formatEta(value: number | null | undefined): string {\n  const metric = formatMetricValue(value);\n  if (metric === \"--\") {\n    return \"--\";\n  }\n  return `${metric}s`;\n}\n\nfunction statusLabel(input: {\n  complete: boolean;\n  inProgress: boolean;\n}): string {\n  if (input.complete) {\n    return \"Run completed\";\n  }\n  if (input.inProgress) {\n    return \"Run in progress\";\n  }\n  return \"Run status\";\n}\n\nexport function ExecutionProgressIsland({\n  execution,\n  currentColumnIcon,\n  minimized,\n  onMinimizedChange,\n  onViewExecutions,\n}: ExecutionProgressIslandProps): ReactElement {\n  const complete = execution.status === \"completed\";\n  const inProgress = isExecutionInProgress(execution.status);\n  const progressPercent = execution.progress?.percent ?? (complete ? 100 : 0);\n  const hasProgressSignal = Boolean(\n    execution.progress &&\n      (typeof execution.progress.done === \"number\" ||\n        typeof execution.progress.total === \"number\" ||\n        typeof execution.progress.percent === \"number\" ||\n        typeof execution.progress.rate === \"number\" ||\n        typeof execution.progress.eta_sec === \"number\"),\n  );\n  const showLoadingSpinner = inProgress && !hasProgressSignal;\n  const batchTotal = execution.batch?.total ?? null;\n  const showBatch = typeof batchTotal === \"number\" && batchTotal > 1;\n\n  return (\n    <div\n      className={cn(\n        \"w-[clamp(15rem,26vw,20rem)] max-w-[calc(100vw-1rem)] rounded-b-xl border-x border-b bg-card/96 shadow-sm backdrop-blur-sm transition-all\",\n        minimized ? \"min-h-[3rem]\" : \"min-h-[8.5rem]\",\n      )}\n      aria-live=\"polite\"\n    >\n      <div className=\"flex items-center justify-between gap-2 px-3 py-2\">\n        <div className=\"flex min-w-0 items-center gap-2\">\n          <HugeiconsIcon\n            icon={complete ? CheckmarkCircle02Icon : Flag02Icon}\n            className={cn(\n              \"size-3.5\",\n              complete\n                ? \"text-emerald-700 dark:text-emerald-300\"\n                : \"text-amber-700 dark:text-amber-300\",\n            )}\n          />\n          <p className=\"truncate text-xs font-medium text-foreground\">\n            {statusLabel({ complete, inProgress })}\n          </p>\n        </div>\n        <div className=\"flex items-center gap-2\">\n          {showLoadingSpinner && (\n            <Spinner className=\"size-3.5 text-muted-foreground\" />\n          )}\n          <span className=\"shrink-0 text-[11px] text-muted-foreground\">\n            {formatPercent(progressPercent)}\n          </span>\n          <button\n            type=\"button\"\n            onClick={() => onMinimizedChange(!minimized)}\n            className=\"inline-flex size-8 shrink-0 items-center justify-center rounded border border-border/70 text-muted-foreground transition hover:bg-muted/50\"\n            aria-label={minimized ? \"Expand progress\" : \"Minimize progress\"}\n            title={minimized ? \"Expand\" : \"Minimize\"}\n          >\n            <HugeiconsIcon\n              icon={minimized ? ArrowDown01Icon : ArrowUp01Icon}\n              className=\"size-3.5\"\n            />\n          </button>\n        </div>\n      </div>\n\n      <div className=\"px-3\">\n        <Progress value={progressPercent} className=\"h-1\" />\n      </div>\n\n      {!minimized && (\n        <>\n          <div className=\"grid grid-cols-2 gap-2 px-3 pt-2 text-[11px] text-muted-foreground sm:grid-cols-4\">\n            <p className=\"truncate\" title={`Done: ${formatMetricValue(execution.progress?.done)}`}>\n              Done: {formatMetricValue(execution.progress?.done)}\n            </p>\n            <p className=\"truncate\" title={`Total: ${formatMetricValue(execution.progress?.total)}`}>\n              Total: {formatMetricValue(execution.progress?.total)}\n            </p>\n            <p className=\"truncate\" title={`Rate: ${formatMetricValue(execution.progress?.rate)}`}>\n              Rate: {formatMetricValue(execution.progress?.rate)}\n            </p>\n            <p className=\"truncate\" title={`ETA: ${formatEta(execution.progress?.eta_sec)}`}>\n              ETA: {formatEta(execution.progress?.eta_sec)}\n            </p>\n          </div>\n          <div className=\"mt-1 flex items-center gap-1.5 px-3 text-[11px] text-muted-foreground\">\n            <HugeiconsIcon\n              icon={currentColumnIcon}\n              className=\"size-3.5 shrink-0\"\n            />\n            <p\n              className=\"truncate\"\n              title={execution.current_column ?? \"--\"}\n            >\n              Column: {execution.current_column ?? \"--\"}\n            </p>\n          </div>\n          {showBatch && (\n            <div\n              className=\"mt-1 truncate px-3 text-[11px] text-muted-foreground\"\n              title={`Batch: ${execution.batch?.idx ?? \"--\"}/${execution.batch?.total ?? \"--\"}`}\n            >\n              Batch: {execution.batch?.idx ?? \"--\"}/{execution.batch?.total ?? \"--\"}\n            </div>\n          )}\n          <div className=\"px-3 pb-2 pt-2\">\n            <Button\n              type=\"button\"\n              variant=\"outline\"\n              size=\"sm\"\n              className=\"h-7 w-full text-[11px]\"\n              onClick={onViewExecutions}\n            >\n              View run details\n            </Button>\n          </div>\n        </>\n      )}\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/components/shared/available-references-inline.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Badge } from \"@/components/ui/badge\";\nimport { type ReactElement, useLayoutEffect, useRef, useState } from \"react\";\nimport type { AvailableVariableEntry } from \"../../utils/variables\";\n\ntype AvailableReferencesInlineProps = {\n  entries: AvailableVariableEntry[];\n};\n\nconst MAX_ROWS = 2;\n\nexport function AvailableReferencesInline({\n  entries,\n}: AvailableReferencesInlineProps): ReactElement | null {\n  const [expanded, setExpanded] = useState(false);\n  const [collapsedCount, setCollapsedCount] = useState(entries.length);\n  const wrapperRef = useRef<HTMLDivElement | null>(null);\n  const measureRefs = useRef<Array<HTMLSpanElement | null>>([]);\n\n  useLayoutEffect(() => {\n    if (expanded) {\n      return;\n    }\n    const wrapper = wrapperRef.current;\n    const items = measureRefs.current.filter(\n      (node): node is HTMLSpanElement => Boolean(node),\n    );\n    if (!(wrapper && items.length > 0)) {\n      setCollapsedCount(entries.length);\n      return;\n    }\n\n    const compute = () => {\n      const rowTops: number[] = [];\n      let cutoff = items.length;\n      for (let i = 0; i < items.length; i += 1) {\n        const top = items[i].offsetTop;\n        if (!rowTops.some((value) => Math.abs(value - top) <= 1)) {\n          rowTops.push(top);\n        }\n        if (rowTops.length > MAX_ROWS) {\n          cutoff = i;\n          break;\n        }\n      }\n      if (cutoff < items.length) {\n        cutoff = Math.max(0, cutoff - 1);\n      }\n      setCollapsedCount(cutoff);\n    };\n\n    compute();\n    const observer = new ResizeObserver(compute);\n    observer.observe(wrapper);\n    return () => observer.disconnect();\n  }, [entries.length, expanded]);\n\n  if (entries.length === 0) {\n    return null;\n  }\n\n  const shown = expanded ? entries : entries.slice(0, collapsedCount);\n  const hiddenCount = Math.max(0, entries.length - shown.length);\n\n  return (\n    <div className=\"space-y-1\">\n      <p className=\"text-[10px] font-medium text-muted-foreground\">\n        Available references\n      </p>\n      <div ref={wrapperRef} className=\"relative\">\n        {!expanded && (\n          <div className=\"invisible pointer-events-none absolute inset-0 -z-10\">\n            <div className=\"flex flex-wrap gap-1\">\n              {entries.map((entry, index) => (\n                <Badge\n                  // biome-ignore lint/suspicious/noArrayIndexKey: static measurement mirror\n                  key={`${entry.source}:${entry.name}:${index}`}\n                  ref={(node) => {\n                    measureRefs.current[index] = node;\n                  }}\n                  variant=\"secondary\"\n                  className={\n                    entry.source === \"seed\"\n                      ? \"corner-squircle h-4 border-blue-500/25 bg-blue-500/10 px-1.5 font-mono text-[10px] text-blue-700 dark:text-blue-300\"\n                      : \"corner-squircle h-4 px-1.5 font-mono text-[10px]\"\n                  }\n                >\n                  {entry.name}\n                </Badge>\n              ))}\n            </div>\n          </div>\n        )}\n        <div className=\"flex flex-wrap gap-1\">\n          {shown.map((entry) => (\n            <Badge\n              key={`${entry.source}:${entry.name}`}\n              variant=\"secondary\"\n              className={\n                entry.source === \"seed\"\n                  ? \"corner-squircle h-4 border-blue-500/25 bg-blue-500/10 px-1.5 font-mono text-[10px] text-blue-700 dark:text-blue-300\"\n                  : \"corner-squircle h-4 px-1.5 font-mono text-[10px]\"\n              }\n            >\n              {entry.name}\n            </Badge>\n          ))}\n          {!expanded && hiddenCount > 0 && (\n            <button\n              type=\"button\"\n              className=\"corner-squircle h-4 px-1.5 text-[10px] text-muted-foreground hover:text-foreground\"\n              onClick={() => setExpanded(true)}\n            >\n              +{hiddenCount} more\n            </button>\n          )}\n          {expanded && collapsedCount < entries.length && (\n            <button\n              type=\"button\"\n              className=\"corner-squircle h-4 px-1.5 text-[10px] text-muted-foreground hover:text-foreground\"\n              onClick={() => setExpanded(false)}\n            >\n              Show less\n            </button>\n          )}\n        </div>\n      </div>\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/components/shared/hf-dataset-combobox.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport {\n  Combobox,\n  ComboboxContent,\n  ComboboxEmpty,\n  ComboboxInput,\n  ComboboxItem,\n  ComboboxList,\n} from \"@/components/ui/combobox\";\nimport { Spinner } from \"@/components/ui/spinner\";\nimport { useDebouncedValue, useHfDatasetSearch } from \"@/hooks\";\nimport { type ReactElement, useEffect, useMemo, useRef, useState } from \"react\";\n\ntype HfDatasetComboboxProps = {\n  value: string;\n  onValueChange: (value: string) => void;\n  accessToken?: string;\n  inputId?: string;\n  placeholder?: string;\n  className?: string;\n};\n\nexport function HfDatasetCombobox({\n  value,\n  onValueChange,\n  accessToken,\n  inputId,\n  placeholder = \"Search datasets...\",\n  className,\n}: HfDatasetComboboxProps): ReactElement {\n  const [inputValue, setInputValue] = useState(value);\n  const selectingRef = useRef(false);\n  const anchorRef = useRef<HTMLDivElement>(null);\n  const debouncedQuery = useDebouncedValue(inputValue);\n\n  useEffect(() => {\n    setInputValue(value);\n  }, [value]);\n\n  const { results, isLoading, error } = useHfDatasetSearch(debouncedQuery, {\n    accessToken,\n  });\n\n  const items = useMemo(() => {\n    const ids = results.map((item) => item.id);\n    const selected = value.trim();\n    if (selected && !ids.includes(selected)) {\n      ids.push(selected);\n    }\n    return ids;\n  }, [results, value]);\n\n  return (\n    <div\n      ref={anchorRef}\n      className={className}\n      onKeyDown={(event) => {\n        if (event.key !== \"Enter\") return;\n        if (!(event.target instanceof HTMLInputElement)) return;\n        event.preventDefault();\n        if (items.length > 0) {\n          onValueChange(items[0]);\n          return;\n        }\n        const typed = event.target.value.trim();\n        if (typed) {\n          onValueChange(typed);\n        }\n      }}\n    >\n      <Combobox\n        items={items}\n        filteredItems={items}\n        filter={null}\n        value={value.trim() ? value : null}\n        onValueChange={(next) => onValueChange(next ?? \"\")}\n        onInputValueChange={(next) => {\n          if (selectingRef.current) {\n            selectingRef.current = false;\n            return;\n          }\n          setInputValue(next);\n        }}\n        itemToStringValue={(item) => item}\n        autoHighlight={true}\n      >\n        <ComboboxInput\n          id={inputId}\n          className=\"nodrag w-full\"\n          placeholder={placeholder}\n        />\n        <ComboboxContent anchor={anchorRef}>\n          {isLoading ? (\n            <div className=\"flex items-center gap-2 px-2 py-3 text-xs text-muted-foreground\">\n              <Spinner className=\"size-3.5\" />\n              Searching...\n            </div>\n          ) : (\n            <ComboboxEmpty>No datasets found</ComboboxEmpty>\n          )}\n          <ComboboxList>\n            {(id: string) => (\n              <ComboboxItem\n                key={id}\n                value={id}\n                onPointerDown={() => {\n                  selectingRef.current = true;\n                }}\n              >\n                {id}\n              </ComboboxItem>\n            )}\n          </ComboboxList>\n        </ComboboxContent>\n      </Combobox>\n      {error && (\n        <p className=\"mt-1 text-xs text-destructive\">\n          {error}\n        </p>\n      )}\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/constants.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nexport const DEFAULT_NODE_WIDTH = 400;\nexport const DEFAULT_NODE_HEIGHT = 120;\nexport const MIN_NODE_WIDTH = 260;\nexport const MAX_NODE_WIDTH = 900;\nexport const MAX_NOTE_NODE_WIDTH = 600;\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/data/executions-db.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport Dexie, { type EntityTable } from \"dexie\";\nimport type { RecipeExecutionRecord } from \"../execution-types\";\n\nconst db = new Dexie(\"unsloth-data-recipe-executions\") as Dexie & {\n  executions: EntityTable<RecipeExecutionRecord, \"id\">;\n};\n\ndb.version(1).stores({\n  executions: \"id, recipeId, kind, status, createdAt\",\n});\n\ndb.version(2).stores({\n  executions: \"id, recipeId, kind, status, createdAt, finishedAt, jobId\",\n});\n\nexport async function listRecipeExecutions(\n  recipeId: string,\n): Promise<RecipeExecutionRecord[]> {\n  const executions = await db.executions.where(\"recipeId\").equals(recipeId).toArray();\n  return executions.sort((a, b) => b.createdAt - a.createdAt);\n}\n\nexport async function saveRecipeExecution(\n  execution: RecipeExecutionRecord,\n): Promise<void> {\n  await db.executions.put(execution);\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/dialogs/config-dialog.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Button } from \"@/components/ui/button\";\nimport { Dialog, DialogContent, DialogFooter } from \"@/components/ui/dialog\";\nimport { Switch } from \"@/components/ui/switch\";\nimport type { ReactElement } from \"react\";\nimport { getBlockDefinitionForConfig } from \"../blocks/definitions\";\nimport { renderBlockDialog } from \"../blocks/registry\";\nimport type { NodeConfig, SamplerConfig } from \"../types\";\nimport { DialogShell } from \"./shared/dialog-shell\";\nimport { ValidationBanner } from \"./shared/validation-banner\";\n\ntype ConfigDialogProps = {\n  open: boolean;\n  onOpenChange: (open: boolean) => void;\n  config: NodeConfig | null;\n  categoryOptions: SamplerConfig[];\n  modelConfigAliases: string[];\n  modelProviderOptions: string[];\n  toolProfileAliases: string[];\n  datetimeOptions: string[];\n  onUpdate: (id: string, patch: Partial<NodeConfig>) => void;\n  container?: HTMLDivElement | null;\n  readOnly?: boolean;\n};\n\nexport function ConfigDialog({\n  open,\n  onOpenChange,\n  config,\n  categoryOptions,\n  modelConfigAliases,\n  modelProviderOptions,\n  toolProfileAliases,\n  datetimeOptions,\n  onUpdate,\n  container,\n  readOnly = false,\n}: ConfigDialogProps): ReactElement {\n  const blockDefinition = getBlockDefinitionForConfig(config);\n  const showDropToggle =\n    config?.kind === \"sampler\" ||\n    config?.kind === \"llm\" ||\n    config?.kind === \"validator\" ||\n    config?.kind === \"expression\" ||\n    (config?.kind === \"seed\" &&\n      (config.seed_source_type ?? \"hf\") === \"unstructured\");\n\n  return (\n    <Dialog open={open} onOpenChange={onOpenChange}>\n      <DialogContent\n        container={container}\n        position=\"absolute\"\n        overlayPosition=\"absolute\"\n        overlayClassName=\"bg-transparent\"\n        className=\"corner-squircle max-h-[650px] overflow-y-auto overflow-x-hidden sm:max-w-2xl shadow-border\"\n      >\n        <DialogShell\n          title={blockDefinition ? blockDefinition.title : undefined}\n          description={\n            blockDefinition\n              ? blockDefinition.description\n              : \"Choose a step to edit.\"\n          }\n        />\n        {!config && (\n          <div className=\"text-sm text-muted-foreground\">\n            Select a step to edit.\n          </div>\n        )}\n        {config && (\n          <div className=\"min-w-0 space-y-4\">\n            {readOnly && (\n              <div className=\"rounded-lg border border-amber-500/30 bg-amber-500/10 px-3 py-2 text-xs text-amber-700 dark:text-amber-300\">\n                This recipe is locked while a run is in progress.\n              </div>\n            )}\n            <ValidationBanner config={config} />\n            <div\n              className={readOnly ? \"pointer-events-none min-w-0 opacity-75\" : \"min-w-0\"}\n            >\n              {showDropToggle && (\n                <div className=\"mb-2 flex items-center corner-squircle justify-between gap-3 rounded-2xl border border-border/60 px-3 pt-2 pb-4\">\n                  <div className=\"min-w-0\">\n                    <p className=\"text-sm font-semibold\">Keep out of final dataset</p>\n                    <p className=\"break-words text-xs text-muted-foreground\">\n                      Use this step while generating, but leave it out of exported rows.\n                    </p>\n                  </div>\n                  <Switch\n                    checked={config.drop ?? false}\n                    disabled={readOnly}\n                    onCheckedChange={(value) => onUpdate(config.id, { drop: value })}\n                  />\n                </div>\n              )}\n              {renderBlockDialog(\n                config,\n                open,\n                categoryOptions,\n                modelConfigAliases,\n                modelProviderOptions,\n                toolProfileAliases,\n                datetimeOptions,\n                onUpdate,\n              )}\n            </div>\n          </div>\n        )}\n        <DialogFooter>\n          <Button\n            type=\"button\"\n            variant=\"outline\"\n            onClick={() => onOpenChange(false)}\n          >\n            Done\n          </Button>\n        </DialogFooter>\n      </DialogContent>\n    </Dialog>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/dialogs/expression/expression-dialog.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport {\n  Select,\n  SelectContent,\n  SelectItem,\n  SelectTrigger,\n  SelectValue,\n} from \"@/components/ui/select\";\nimport { Textarea } from \"@/components/ui/textarea\";\nimport type { ReactElement } from \"react\";\nimport { useMemo } from \"react\";\nimport { useRecipeStudioStore } from \"../../stores/recipe-studio\";\nimport type { ExpressionConfig, ExpressionDtype } from \"../../types\";\nimport { findInvalidJinjaReferences } from \"../../utils/refs\";\nimport { getAvailableVariables } from \"../../utils/variables\";\nimport { AvailableVariables } from \"../shared/available-variables\";\nimport { FieldLabel } from \"../shared/field-label\";\nimport { NameField } from \"../shared/name-field\";\n\nconst DTYPE_OPTIONS: ExpressionDtype[] = [\"str\", \"int\", \"float\", \"bool\"];\n\ntype ExpressionDialogProps = {\n  config: ExpressionConfig;\n  onUpdate: (patch: Partial<ExpressionConfig>) => void;\n};\n\nexport function ExpressionDialog({\n  config,\n  onUpdate,\n}: ExpressionDialogProps): ReactElement {\n  const configs = useRecipeStudioStore((state) => state.configs);\n  const dtypeId = `${config.id}-dtype`;\n  const exprId = `${config.id}-expr`;\n  const validReferences = useMemo(\n    () => getAvailableVariables(configs, config.id),\n    [configs, config.id],\n  );\n  const invalidExprRefs = useMemo(\n    () => findInvalidJinjaReferences(config.expr, validReferences),\n    [config.expr, validReferences],\n  );\n  const invalidExprText = invalidExprRefs\n    .slice(0, 3)\n    .map((ref) => `{{ ${ref} }}`)\n    .join(\", \");\n  const updateField = <K extends keyof ExpressionConfig>(\n    key: K,\n    value: ExpressionConfig[K],\n  ) => {\n    onUpdate({ [key]: value } as Partial<ExpressionConfig>);\n  };\n  return (\n    <div className=\"space-y-4\">\n      <AvailableVariables configId={config.id} />\n      <NameField\n        value={config.name}\n        onChange={(value) => onUpdate({ name: value })}\n      />\n      <div className=\"grid gap-2\">\n        <FieldLabel\n          label=\"Output type\"\n          htmlFor={dtypeId}\n          hint=\"Choose how this formula should be stored in the final dataset.\"\n        />\n        <Select\n          value={config.dtype}\n          onValueChange={(value) =>\n            updateField(\"dtype\", value as ExpressionDtype)\n          }\n        >\n          <SelectTrigger className=\"nodrag w-full\" id={dtypeId}>\n            <SelectValue placeholder=\"Select type\" />\n          </SelectTrigger>\n          <SelectContent>\n            {DTYPE_OPTIONS.map((dtype) => (\n              <SelectItem key={dtype} value={dtype}>\n                {dtype}\n              </SelectItem>\n            ))}\n          </SelectContent>\n        </Select>\n      </div>\n      <div className=\"grid gap-2\">\n        <FieldLabel\n          label=\"Formula\"\n          htmlFor={exprId}\n          hint=\"Build this field from other fields.\"\n        />\n        <Textarea\n          id={exprId}\n          className=\"corner-squircle nodrag\"\n          aria-invalid={invalidExprRefs.length > 0}\n          placeholder=\"{{ category_1 }} - {{ subcategory_1 }}\"\n          value={config.expr}\n          onChange={(event) => updateField(\"expr\", event.target.value)}\n        />\n        {invalidExprRefs.length > 0 && (\n          <p className=\"text-xs text-destructive\">\n            Unknown field: {invalidExprText}\n            {invalidExprRefs.length > 3\n              ? ` +${invalidExprRefs.length - 3} more`\n              : \"\"}\n          </p>\n        )}\n        <p className=\"text-xs text-muted-foreground\">\n          Insert other fields like {\"{{ field_name }}\"}.\n        </p>\n      </div>\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/dialogs/import-dialog.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Button } from \"@/components/ui/button\";\nimport {\n  Dialog,\n  DialogContent,\n  DialogFooter,\n  DialogHeader,\n  DialogTitle,\n} from \"@/components/ui/dialog\";\nimport { Textarea } from \"@/components/ui/textarea\";\nimport { type ReactElement, useState } from \"react\";\nimport { FieldLabel } from \"./shared/field-label\";\n\ntype ImportDialogProps = {\n  open: boolean;\n  onOpenChange: (open: boolean) => void;\n  onImport: (value: string) => string | null;\n  container?: HTMLDivElement | null;\n};\n\nexport function ImportDialog({\n  open,\n  onOpenChange,\n  onImport,\n  container,\n}: ImportDialogProps): ReactElement {\n  const [value, setValue] = useState(\"\");\n  const [error, setError] = useState<string | null>(null);\n  const payloadId = \"recipe-import-payload\";\n  const handleOpenChange = (nextOpen: boolean) => {\n    if (!nextOpen) {\n      setValue(\"\");\n      setError(null);\n    }\n    onOpenChange(nextOpen);\n  };\n\n  const handleImport = () => {\n    const message = onImport(value);\n    if (message) {\n      setError(message);\n      return;\n    }\n    handleOpenChange(false);\n  };\n\n  return (\n    <Dialog open={open} onOpenChange={handleOpenChange}>\n      <DialogContent\n        container={container}\n        position=\"absolute\"\n        overlayPosition=\"absolute\"\n        overlayClassName=\"bg-transparent\"\n        className=\"corner-squircle max-h-[650px] overflow-auto sm:max-w-2xl shadow-border\"\n      >\n        <DialogHeader>\n          <DialogTitle>Import recipe</DialogTitle>\n        </DialogHeader>\n        <div className=\"grid gap-2\">\n          <FieldLabel\n            label=\"Recipe JSON\"\n            htmlFor={payloadId}\n            hint=\"Paste JSON exported from Recipe Studio.\"\n          />\n          <Textarea\n            id={payloadId}\n            className=\"corner-squircle nodrag min-h-[220px] max-h-[450px]\"\n            placeholder='{\"recipe\": { \"columns\": [] }}'\n            value={value}\n            onChange={(event) => setValue(event.target.value)}\n          />\n          {error && (\n            <p className=\"text-xs text-rose-600\" role=\"alert\">\n              {error}\n            </p>\n          )}\n        </div>\n        <DialogFooter>\n          <Button type=\"button\" variant=\"outline\" onClick={handleImport}>\n            Import recipe\n          </Button>\n        </DialogFooter>\n      </DialogContent>\n    </Dialog>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/dialogs/llm/general-tab.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport {\n  Collapsible,\n  CollapsibleContent,\n  CollapsibleTrigger,\n} from \"@/components/ui/collapsible\";\nimport {\n  Combobox,\n  ComboboxContent,\n  ComboboxEmpty,\n  ComboboxInput,\n  ComboboxItem,\n  ComboboxList,\n} from \"@/components/ui/combobox\";\nimport {\n  Select,\n  SelectContent,\n  SelectItem,\n  SelectTrigger,\n  SelectValue,\n} from \"@/components/ui/select\";\nimport { Switch } from \"@/components/ui/switch\";\nimport { Textarea } from \"@/components/ui/textarea\";\nimport { ArrowRight01Icon } from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport { type ReactElement, type RefObject, useMemo, useRef } from \"react\";\nimport { useRecipeStudioStore } from \"../../stores/recipe-studio\";\nimport type { LlmConfig } from \"../../types\";\nimport { isLikelyImageValue } from \"../../utils/image-preview\";\nimport { findInvalidJinjaReferences } from \"../../utils/refs\";\nimport { getAvailableVariables } from \"../../utils/variables\";\nimport { CollapsibleSectionTriggerButton } from \"../shared/collapsible-section-trigger\";\nimport { AvailableVariables } from \"../shared/available-variables\";\nimport { FieldLabel } from \"../shared/field-label\";\nimport { NameField } from \"../shared/name-field\";\n\nconst CODE_LANG_OPTIONS = [\n  \"python\",\n  \"javascript\",\n  \"typescript\",\n  \"java\",\n  \"kotlin\",\n  \"go\",\n  \"rust\",\n  \"ruby\",\n  \"scala\",\n  \"swift\",\n  \"sql:sqlite\",\n  \"sql:postgres\",\n  \"sql:mysql\",\n  \"sql:tsql\",\n  \"sql:bigquery\",\n  \"sql:ansi\",\n];\n\nconst TRACE_MODE_OPTIONS = [\"none\", \"last_message\", \"all_messages\"] as const;\n\nfunction normalizeTraceMode(value: string): LlmConfig[\"with_trace\"] {\n  if (value === \"last_message\" || value === \"all_messages\") {\n    return value;\n  }\n  return \"none\";\n}\n\ntype LlmGeneralTabProps = {\n  config: LlmConfig;\n  modelConfigAliases: string[];\n  modelProviderOptions: string[];\n  toolProfileAliases: string[];\n  modelAliasAnchorRef: RefObject<HTMLDivElement | null>;\n  onUpdate: (patch: Partial<LlmConfig>) => void;\n};\n\nexport function LlmGeneralTab({\n  config,\n  modelConfigAliases,\n  modelProviderOptions,\n  toolProfileAliases,\n  modelAliasAnchorRef,\n  onUpdate,\n}: LlmGeneralTabProps): ReactElement {\n  const configs = useRecipeStudioStore((state) => state.configs);\n  const modelAliasId = `${config.id}-model-alias`;\n  const toolAliasId = `${config.id}-tool-alias`;\n  const codeLangId = `${config.id}-code-lang`;\n  const promptId = `${config.id}-prompt`;\n  const outputFormatId = `${config.id}-output-format`;\n  const systemPromptId = `${config.id}-system-prompt`;\n  const hasModelConfigs = modelConfigAliases.length > 0;\n  const hasModelProviders = modelProviderOptions.length > 0;\n  const hasToolProfiles = toolProfileAliases.length > 0;\n  const validReferences = useMemo(\n    () => getAvailableVariables(configs, config.id),\n    [configs, config.id],\n  );\n  const invalidPromptRefs = useMemo(\n    () => findInvalidJinjaReferences(config.prompt, validReferences),\n    [config.prompt, validReferences],\n  );\n  const invalidSystemRefs = useMemo(\n    () => findInvalidJinjaReferences(config.system_prompt, validReferences),\n    [config.system_prompt, validReferences],\n  );\n  const invalidPromptText = invalidPromptRefs\n    .slice(0, 3)\n    .map((ref) => `{{ ${ref} }}`)\n    .join(\", \");\n  const invalidSystemText = invalidSystemRefs\n    .slice(0, 3)\n    .map((ref) => `{{ ${ref} }}`)\n    .join(\", \");\n  const seedConfig = useMemo(\n    () => Object.values(configs).find((item) => item.kind === \"seed\"),\n    [configs],\n  );\n  const hasHfSeed = Boolean(\n    seedConfig && (seedConfig.seed_source_type ?? \"hf\") === \"hf\",\n  );\n  const seedColumns = useMemo(\n    () => seedConfig?.seed_columns ?? [],\n    [seedConfig],\n  );\n  const seedPreviewRows = useMemo(\n    () => seedConfig?.seed_preview_rows ?? [],\n    [seedConfig],\n  );\n  const imageColumnOptions = useMemo(() => {\n    if (seedColumns.length === 0) {\n      return [];\n    }\n    const detected = seedColumns.filter((columnName) => {\n      const lower = columnName.toLowerCase();\n      if (\n        lower.includes(\"image\") ||\n        lower.includes(\"img\") ||\n        lower.includes(\"photo\") ||\n        lower.includes(\"picture\") ||\n        lower.includes(\"base64\") ||\n        lower.includes(\"url\")\n      ) {\n        return true;\n      }\n      return seedPreviewRows.some((row) => isLikelyImageValue(row[columnName]));\n    });\n    return detected.length > 0 ? detected : seedColumns;\n  }, [seedColumns, seedPreviewRows]);\n  const imageContext = config.image_context ?? {\n    enabled: false,\n    // biome-ignore lint/style/useNamingConvention: api schema\n    column_name: \"\",\n  };\n  const imageContextToggleId = `${config.id}-image-context-enabled`;\n  const imageContextColumnId = `${config.id}-image-context-column`;\n  const imageContextColumnOptions = useMemo(() => {\n    const preferred =\n      imageColumnOptions.length > 0 ? imageColumnOptions : seedColumns;\n    const deduped = Array.from(\n      new Set(preferred.map((value) => value.trim()).filter(Boolean)),\n    );\n    const selected = imageContext.column_name.trim();\n    if (selected && !deduped.includes(selected)) {\n      deduped.unshift(selected);\n    }\n    return deduped;\n  }, [imageColumnOptions, imageContext.column_name, seedColumns]);\n  const traceModeId = `${config.id}-trace-mode`;\n  const reasoningToggleId = `${config.id}-reasoning-content`;\n  const advancedOpen = config.advancedOpen === true;\n  const toolAliasAnchorRef = useRef<HTMLDivElement>(null);\n  const needsSetupHelp = !hasModelConfigs || !hasModelProviders;\n  const needsModelChoice = !config.model_alias?.trim();\n\n  return (\n    <div className=\"space-y-4\">\n      <NameField\n        value={config.name}\n        onChange={(value) => onUpdate({ name: value })}\n      />\n      {needsSetupHelp ? (\n        <div className=\"rounded-2xl border border-border/60 bg-muted/10 px-4 py-3 text-xs text-muted-foreground\">\n          <p className=\"text-sm font-semibold text-foreground\">\n            Set up the model once, then come back here\n          </p>\n          <div className=\"mt-2 space-y-1.5\">\n            {!hasModelProviders && (\n              <p className=\"flex items-start gap-2\">\n                <HugeiconsIcon\n                  icon={ArrowRight01Icon}\n                  className=\"mt-0.5 size-3.5 shrink-0 text-primary\"\n                />\n                <span>Add a Provider connection step in AI generation → Setup.</span>\n              </p>\n            )}\n            {!hasModelConfigs && (\n              <p className=\"flex items-start gap-2\">\n                <HugeiconsIcon\n                  icon={ArrowRight01Icon}\n                  className=\"mt-0.5 size-3.5 shrink-0 text-primary\"\n                />\n                <span>Add a Model preset step, connect it, then choose it below.</span>\n              </p>\n            )}\n          </div>\n        </div>\n      ) : needsModelChoice ? (\n        <div className=\"rounded-2xl border border-border/60 bg-muted/10 px-4 py-3 text-xs text-muted-foreground\">\n          <p className=\"text-sm font-semibold text-foreground\">\n            Start by choosing a model preset\n          </p>\n          <p className=\"mt-1\">\n            Once that is in place, write the prompt and add optional tool access\n            if this step needs tools.\n          </p>\n        </div>\n      ) : null}\n      <div className=\"grid gap-2\">\n        <FieldLabel\n          label=\"Model preset\"\n          htmlFor={modelAliasId}\n          hint=\"Choose the reusable model setup for this step.\"\n        />\n        <div ref={modelAliasAnchorRef}>\n          <Combobox\n            items={modelConfigAliases}\n            filteredItems={modelConfigAliases}\n            filter={null}\n            value={config.model_alias || null}\n            onValueChange={(value) => onUpdate({ model_alias: value ?? \"\" })}\n            itemToStringValue={(value) => value}\n            autoHighlight={true}\n          >\n            <ComboboxInput\n              id={modelAliasId}\n              className=\"nodrag w-full\"\n              placeholder=\"Choose a model preset\"\n              onBlur={(event) => {\n                const inputValue = event.target.value;\n                if (inputValue !== config.model_alias) {\n                  onUpdate({ model_alias: inputValue });\n                }\n              }}\n            />\n            <ComboboxContent anchor={modelAliasAnchorRef}>\n              <ComboboxEmpty>No model configs found</ComboboxEmpty>\n              <ComboboxList>\n                {(alias: string) => (\n                  <ComboboxItem key={alias} value={alias}>\n                    {alias}\n                  </ComboboxItem>\n                )}\n              </ComboboxList>\n            </ComboboxContent>\n          </Combobox>\n        </div>\n      </div>\n      {!hasToolProfiles && (\n        <p className=\"text-xs text-muted-foreground\">\n          Need tools for this step? Add a Tool access step in AI generation →\n          Setup.\n        </p>\n      )}\n      {(hasToolProfiles || Boolean(config.tool_alias?.trim())) && (\n        <div className=\"grid gap-2\">\n          <FieldLabel\n            label=\"Tool access (optional)\"\n            htmlFor={toolAliasId}\n            hint=\"Choose saved tool access for this step. Leave empty if this step should not use tools.\"\n          />\n          <div ref={toolAliasAnchorRef}>\n            <Combobox\n              items={toolProfileAliases}\n              filteredItems={toolProfileAliases}\n              filter={null}\n              value={config.tool_alias || null}\n              onValueChange={(value) => onUpdate({ tool_alias: value ?? \"\" })}\n              itemToStringValue={(value) => value}\n              autoHighlight={true}\n            >\n              <ComboboxInput\n                id={toolAliasId}\n                className=\"nodrag w-full\"\n                placeholder=\"Choose tool access\"\n                onBlur={(event) => {\n                  const inputValue = event.target.value;\n                  if (inputValue !== (config.tool_alias ?? \"\")) {\n                    onUpdate({ tool_alias: inputValue });\n                  }\n                }}\n              />\n              <ComboboxContent anchor={toolAliasAnchorRef}>\n                <ComboboxEmpty>No tool access found</ComboboxEmpty>\n                <ComboboxList>\n                  {(alias: string) => (\n                    <ComboboxItem key={alias} value={alias}>\n                      {alias}\n                    </ComboboxItem>\n                  )}\n                </ComboboxList>\n              </ComboboxContent>\n            </Combobox>\n          </div>\n        </div>\n      )}\n      {config.llm_type === \"code\" && (\n        <div className=\"grid gap-2\">\n          <FieldLabel\n            label=\"Code language\"\n            htmlFor={codeLangId}\n            hint=\"Choose the language this AI step should generate.\"\n          />\n          <Select\n            value={config.code_lang ?? \"python\"}\n            onValueChange={(value) => onUpdate({ code_lang: value })}\n          >\n            <SelectTrigger className=\"nodrag w-full\" id={codeLangId}>\n              <SelectValue placeholder=\"Select language\" />\n            </SelectTrigger>\n            <SelectContent>\n              {CODE_LANG_OPTIONS.map((lang) => (\n                <SelectItem key={lang} value={lang}>\n                  {lang}\n                </SelectItem>\n              ))}\n            </SelectContent>\n          </Select>\n        </div>\n      )}\n      <div className=\"grid gap-2\">\n        <FieldLabel\n          label=\"Prompt\"\n          htmlFor={promptId}\n          hint=\"Write the prompt for this step. Insert other fields with {{ field_name }}.\"\n        />\n        <Textarea\n          id={promptId}\n          className=\"corner-squircle nodrag max-h-[450px] overflow-auto\"\n          aria-invalid={invalidPromptRefs.length > 0}\n          value={config.prompt}\n          onChange={(event) => onUpdate({ prompt: event.target.value })}\n        />\n        {invalidPromptRefs.length > 0 && (\n          <p className=\"text-xs text-destructive\">\n            Unknown field: {invalidPromptText}\n            {invalidPromptRefs.length > 3\n              ? ` +${invalidPromptRefs.length - 3} more`\n              : \"\"}\n          </p>\n        )}\n      </div>\n      <AvailableVariables configId={config.id} />\n      {hasHfSeed && (\n        <div className=\"space-y-2\">\n          <div className=\"flex items-center justify-between gap-3\">\n            <FieldLabel\n              label=\"Use image context\"\n              htmlFor={imageContextToggleId}\n              hint=\"Attach one image field from your source data to this AI step.\"\n            />\n            <Switch\n              id={imageContextToggleId}\n              checked={imageContext.enabled}\n              onCheckedChange={(checked) => {\n                onUpdate({\n                  image_context: {\n                    ...imageContext,\n                    enabled: checked,\n                    // biome-ignore lint/style/useNamingConvention: api schema\n                    column_name:\n                      checked && !imageContext.column_name\n                        ? (imageContextColumnOptions[0] ?? \"\")\n                        : imageContext.column_name,\n                  },\n                });\n              }}\n            />\n          </div>\n          {imageContext.enabled && (\n            <div className=\"grid gap-2\">\n              <FieldLabel\n                label=\"Image field\"\n                htmlFor={imageContextColumnId}\n                hint=\"Choose the source-data field that contains the image.\"\n              />\n              <Select\n                value={imageContext.column_name || undefined}\n                onValueChange={(value) =>\n                  onUpdate({\n                    image_context: {\n                      ...imageContext,\n                      // biome-ignore lint/style/useNamingConvention: api schema\n                      column_name: value,\n                    },\n                  })\n                }\n              >\n                <SelectTrigger\n                  className=\"nodrag w-full\"\n                  id={imageContextColumnId}\n                >\n                  <SelectValue placeholder=\"Select image column\" />\n                </SelectTrigger>\n                <SelectContent>\n                  {imageContextColumnOptions.map((columnName) => (\n                    <SelectItem key={columnName} value={columnName}>\n                      {columnName}\n                    </SelectItem>\n                  ))}\n                </SelectContent>\n              </Select>\n            </div>\n          )}\n        </div>\n      )}\n      {config.llm_type === \"structured\" && (\n        <div className=\"grid gap-2\">\n          <FieldLabel\n            label=\"Response format\"\n            htmlFor={outputFormatId}\n            hint=\"Describe the JSON shape you want back.\"\n          />\n          <Textarea\n            id={outputFormatId}\n            className=\"corner-squircle nodrag\"\n            value={config.output_format ?? \"\"}\n            onChange={(event) =>\n              onUpdate({ output_format: event.target.value })\n            }\n          />\n        </div>\n      )}\n      <Collapsible\n        open={advancedOpen}\n        onOpenChange={(open) => onUpdate({ advancedOpen: open })}\n      >\n        <CollapsibleTrigger asChild={true}>\n          <CollapsibleSectionTriggerButton\n            label=\"Trace and extra controls\"\n            open={advancedOpen}\n          />\n        </CollapsibleTrigger>\n        <CollapsibleContent className=\"mt-3 space-y-4\">\n          <div className=\"grid gap-2\">\n            <FieldLabel\n              label=\"Instructions (optional)\"\n              htmlFor={systemPromptId}\n              hint=\"Add extra guidance that should apply before the prompt.\"\n            />\n            <Textarea\n              id={systemPromptId}\n              className=\"corner-squircle nodrag max-h-[450px] overflow-auto\"\n              aria-invalid={invalidSystemRefs.length > 0}\n              value={config.system_prompt}\n              onChange={(event) =>\n                onUpdate({ system_prompt: event.target.value })\n              }\n            />\n            {invalidSystemRefs.length > 0 && (\n              <p className=\"text-xs text-destructive\">\n                Unknown field: {invalidSystemText}\n                {invalidSystemRefs.length > 3\n                  ? ` +${invalidSystemRefs.length - 3} more`\n                  : \"\"}\n              </p>\n            )}\n          </div>\n          <div className=\"grid gap-2\">\n            <FieldLabel\n              label=\"Save trace details\"\n              htmlFor={traceModeId}\n              hint=\"Adds a trace field you can inspect later.\"\n            />\n            <Select\n              value={config.with_trace ?? \"none\"}\n              onValueChange={(value) =>\n                onUpdate({\n                  // biome-ignore lint/style/useNamingConvention: api schema\n                  with_trace: normalizeTraceMode(value),\n                })\n              }\n            >\n              <SelectTrigger className=\"nodrag w-full\" id={traceModeId}>\n                <SelectValue placeholder=\"Select trace mode\" />\n              </SelectTrigger>\n              <SelectContent>\n                {TRACE_MODE_OPTIONS.map((traceMode) => (\n                  <SelectItem key={traceMode} value={traceMode}>\n                    {traceMode}\n                  </SelectItem>\n                ))}\n              </SelectContent>\n            </Select>\n          </div>\n          <div className=\"flex items-center justify-between gap-3\">\n            <FieldLabel\n              label=\"Save reasoning text\"\n              htmlFor={reasoningToggleId}\n              hint=\"Adds a reasoning field when the model returns one.\"\n            />\n            <Switch\n              id={reasoningToggleId}\n              checked={config.extract_reasoning_content === true}\n              onCheckedChange={(checked) =>\n                onUpdate({\n                  // biome-ignore lint/style/useNamingConvention: api schema\n                  extract_reasoning_content: checked,\n                })\n              }\n            />\n          </div>\n        </CollapsibleContent>\n      </Collapsible>\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/dialogs/llm/llm-dialog.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { type ReactElement, useRef } from \"react\";\nimport type { LlmConfig } from \"../../types\";\nimport { LlmGeneralTab } from \"./general-tab\";\nimport { LlmScoresTab } from \"./scores-tab\";\nimport {\n  Tabs,\n  TabsContent,\n  TabsList,\n  TabsTrigger,\n} from \"@/components/ui/tabs\";\n\ntype LlmDialogProps = {\n  config: LlmConfig;\n  modelConfigAliases: string[];\n  modelProviderOptions: string[];\n  toolProfileAliases: string[];\n  onUpdate: (patch: Partial<LlmConfig>) => void;\n};\n\nexport function LlmDialog({\n  config,\n  modelConfigAliases,\n  modelProviderOptions,\n  toolProfileAliases,\n  onUpdate,\n}: LlmDialogProps): ReactElement {\n  const modelAliasAnchorRef = useRef<HTMLDivElement>(null);\n\n  if (config.llm_type !== \"judge\") {\n    return (\n      <LlmGeneralTab\n        config={config}\n        modelConfigAliases={modelConfigAliases}\n        modelProviderOptions={modelProviderOptions}\n        toolProfileAliases={toolProfileAliases}\n        modelAliasAnchorRef={modelAliasAnchorRef}\n        onUpdate={onUpdate}\n      />\n    );\n  }\n\n  return (\n    <Tabs defaultValue=\"general\" className=\"w-full\">\n      <TabsList className=\"w-full\">\n        <TabsTrigger value=\"general\">General</TabsTrigger>\n        {config.llm_type === \"judge\" && <TabsTrigger value=\"scores\">Scores</TabsTrigger>}\n      </TabsList>\n      <TabsContent value=\"general\" className=\"pt-3\">\n        <LlmGeneralTab\n          config={config}\n          modelConfigAliases={modelConfigAliases}\n          modelProviderOptions={modelProviderOptions}\n          toolProfileAliases={toolProfileAliases}\n          modelAliasAnchorRef={modelAliasAnchorRef}\n          onUpdate={onUpdate}\n        />\n      </TabsContent>\n      {config.llm_type === \"judge\" && (\n        <TabsContent value=\"scores\" className=\"pt-3\">\n          <LlmScoresTab config={config} onUpdate={onUpdate} />\n        </TabsContent>\n      )}\n    </Tabs>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/dialogs/llm/scores-tab.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Button } from \"@/components/ui/button\";\nimport {\n  Empty,\n  EmptyContent,\n  EmptyDescription,\n  EmptyHeader,\n  EmptyTitle,\n} from \"@/components/ui/empty\";\nimport { Input } from \"@/components/ui/input\";\nimport { Textarea } from \"@/components/ui/textarea\";\nimport { type ReactElement } from \"react\";\nimport type { LlmConfig, Score } from \"../../types\";\nimport { FieldLabel } from \"../shared/field-label\";\n\ntype LlmScoresTabProps = {\n  config: LlmConfig;\n  onUpdate: (patch: Partial<LlmConfig>) => void;\n};\n\nexport function LlmScoresTab({\n  config,\n  onUpdate,\n}: LlmScoresTabProps): ReactElement {\n  const scores = config.scores ?? [];\n\n  function updateScores(nextScores: Score[]): void {\n    onUpdate({ scores: nextScores });\n  }\n\n  function removeScore(index: number): void {\n    updateScores(scores.filter((_, currentIndex) => currentIndex !== index));\n  }\n\n  function addScore(): void {\n    updateScores([\n      ...scores,\n      {\n        name: \"\",\n        description: \"\",\n        options: [\n          { value: \"1\", description: \"\" },\n          { value: \"5\", description: \"\" },\n        ],\n      },\n    ]);\n  }\n\n  function updateScore(index: number, patch: Partial<Score>): void {\n    updateScores(\n      scores.map((score, currentIndex) =>\n        currentIndex === index ? { ...score, ...patch } : score,\n      ),\n    );\n  }\n\n  function addOption(scoreIndex: number): void {\n    const score = scores[scoreIndex];\n    if (!score) {\n      return;\n    }\n    updateScore(scoreIndex, {\n      options: [...(score.options ?? []), { value: \"\", description: \"\" }],\n    });\n  }\n\n  function removeOption(scoreIndex: number, optionIndex: number): void {\n    const score = scores[scoreIndex];\n    if (!score) {\n      return;\n    }\n    updateScore(scoreIndex, {\n      options: (score.options ?? []).filter(\n        (_option, currentIndex) => currentIndex !== optionIndex,\n      ),\n    });\n  }\n\n  function updateOption(\n    scoreIndex: number,\n    optionIndex: number,\n    patch: { value?: string; description?: string },\n  ): void {\n    const score = scores[scoreIndex];\n    if (!score) {\n      return;\n    }\n    updateScore(scoreIndex, {\n      options: (score.options ?? []).map((option, currentIndex) =>\n        currentIndex === optionIndex ? { ...option, ...patch } : option,\n      ),\n    });\n  }\n\n  return (\n    <div className=\"space-y-3\">\n      <div className=\"flex items-center justify-between\">\n        <FieldLabel\n          label=\"Scorers\"\n          hint=\"Rubrics used by LLM Judge to score each generated row.\"\n        />\n        {scores.length > 0 && (\n          <Button type=\"button\" size=\"xs\" variant=\"outline\" onClick={addScore}>\n            Add scorer\n          </Button>\n        )}\n      </div>\n      {scores.length === 0 && (\n        <Empty className=\"rounded-xl border border-dashed border-border/70 p-5\">\n          <EmptyHeader>\n            <EmptyTitle className=\"text-sm\">No scorers yet</EmptyTitle>\n            <EmptyDescription className=\"text-xs\">\n              Add a scorer rubric before running judge generation.\n            </EmptyDescription>\n          </EmptyHeader>\n          <EmptyContent className=\"max-w-none\">\n            <Button type=\"button\" size=\"sm\" onClick={addScore}>\n              Add first scorer\n            </Button>\n          </EmptyContent>\n        </Empty>\n      )}\n      {scores.map((score, index) => (\n        <div\n          key={`${config.id}-score-${index}`}\n          className=\"space-y-2 rounded-xl corner-squircle border border-border/60 px-3 py-2\"\n        >\n          <div className=\"flex items-center justify-between gap-2\">\n            <p className=\"text-xs font-semibold text-foreground\">\n              {score.name.trim() || `Scorer ${index + 1}`}\n            </p>\n            <Button\n              type=\"button\"\n              size=\"xs\"\n              variant=\"ghost\"\n              onClick={() => removeScore(index)}\n            >\n              Remove\n            </Button>\n          </div>\n          <Input\n            className=\"nodrag h-8 text-xs\"\n            placeholder=\"Score name\"\n            value={score.name}\n            onChange={(event) =>\n              updateScore(index, { name: event.target.value })\n            }\n          />\n          <Textarea\n            className=\"corner-squircle nodrag min-h-[56px] text-xs\"\n            placeholder=\"Score description\"\n            value={score.description}\n            onChange={(event) =>\n              updateScore(index, { description: event.target.value })\n            }\n          />\n          <div className=\"space-y-1\">\n            {(score.options ?? []).map((option, optionIndex) => (\n              <div\n                key={`${config.id}-score-${index}-option-${optionIndex}`}\n                className=\"grid grid-cols-[74px_1fr_auto] gap-1\"\n              >\n                <Input\n                  className=\"nodrag h-7 text-xs\"\n                  placeholder=\"Value\"\n                  value={option.value}\n                  onChange={(event) =>\n                    updateOption(index, optionIndex, {\n                      value: event.target.value,\n                    })\n                  }\n                />\n                <Input\n                  className=\"nodrag h-7 text-xs\"\n                  placeholder=\"Description\"\n                  value={option.description}\n                  onChange={(event) =>\n                    updateOption(index, optionIndex, {\n                      description: event.target.value,\n                    })\n                  }\n                />\n                <Button\n                  type=\"button\"\n                  size=\"xs\"\n                  variant=\"ghost\"\n                  onClick={() => removeOption(index, optionIndex)}\n                >\n                  x\n                </Button>\n              </div>\n            ))}\n            <Button\n              type=\"button\"\n              size=\"xs\"\n              variant=\"outline\"\n              onClick={() => addOption(index)}\n            >\n              Add option\n            </Button>\n          </div>\n        </div>\n      ))}\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/dialogs/markdown-note/markdown-note-dialog.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Slider } from \"@/components/ui/slider\";\nimport { Textarea } from \"@/components/ui/textarea\";\nimport type { ReactElement } from \"react\";\nimport type { MarkdownNoteConfig } from \"../../types\";\nimport { FieldLabel } from \"../shared/field-label\";\nimport { NameField } from \"../shared/name-field\";\n\ntype MarkdownNoteDialogProps = {\n  config: MarkdownNoteConfig;\n  onUpdate: (patch: Partial<MarkdownNoteConfig>) => void;\n};\n\nexport function MarkdownNoteDialog({\n  config,\n  onUpdate,\n}: MarkdownNoteDialogProps): ReactElement {\n  const markdownId = `${config.id}-markdown`;\n  const colorId = `${config.id}-note-color`;\n  const opacity =\n    Number.parseInt(config.note_opacity ?? \"35\", 10) > 0\n      ? Math.max(0, Math.min(100, Number.parseInt(config.note_opacity ?? \"35\", 10)))\n      : 35;\n\n  return (\n    <div className=\"space-y-4\">\n      <NameField value={config.name} onChange={(value) => onUpdate({ name: value })} />\n      <div className=\"grid gap-3\">\n        <FieldLabel\n          label=\"Note style\"\n          htmlFor={colorId}\n          hint=\"Pick a color and opacity for this note block.\"\n        />\n        <div className=\"flex items-center gap-3\">\n          <input\n            id={colorId}\n            type=\"color\"\n            className=\"nodrag h-9 w-14 cursor-pointer rounded-md border border-border/60 bg-transparent p-1\"\n            value={config.note_color ?? \"#FDE68A\"}\n            onChange={(event) => onUpdate({ note_color: event.target.value })}\n          />\n          <div className=\"flex-1 space-y-1\">\n            <div className=\"flex items-center justify-between\">\n              <span className=\"text-xs text-muted-foreground\">Opacity</span>\n              <span className=\"text-xs tabular-nums text-muted-foreground\">{opacity}%</span>\n            </div>\n            <Slider\n              min={5}\n              max={100}\n              step={1}\n              value={[opacity]}\n              onValueChange={([value]) =>\n                onUpdate({ note_opacity: String(Math.round(value)) })\n              }\n            />\n          </div>\n        </div>\n      </div>\n      <div className=\"grid gap-2\">\n        <FieldLabel\n          label=\"Markdown\"\n          htmlFor={markdownId}\n          hint=\"UI-only note. Not sent to backend payload recipe.\"\n        />\n        <Textarea\n          id={markdownId}\n          className=\"corner-squircle nodrag min-h-[180px]\"\n          placeholder=\"## Note\"\n          value={config.markdown}\n          onChange={(event) => onUpdate({ markdown: event.target.value })}\n        />\n      </div>\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/dialogs/models/model-config-dialog.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport {\n  Collapsible,\n  CollapsibleContent,\n  CollapsibleTrigger,\n} from \"@/components/ui/collapsible\";\nimport { Checkbox } from \"@/components/ui/checkbox\";\nimport {\n  Combobox,\n  ComboboxContent,\n  ComboboxEmpty,\n  ComboboxInput,\n  ComboboxItem,\n  ComboboxList,\n} from \"@/components/ui/combobox\";\nimport { Input } from \"@/components/ui/input\";\nimport { Textarea } from \"@/components/ui/textarea\";\nimport { type ReactElement, useRef, useState } from \"react\";\nimport type { ModelConfig } from \"../../types\";\nimport { CollapsibleSectionTriggerButton } from \"../shared/collapsible-section-trigger\";\nimport { FieldLabel } from \"../shared/field-label\";\nimport { NameField } from \"../shared/name-field\";\n\ntype ModelConfigDialogProps = {\n  config: ModelConfig;\n  providerOptions: string[];\n  onUpdate: (patch: Partial<ModelConfig>) => void;\n};\n\nexport function ModelConfigDialog({\n  config,\n  providerOptions,\n  onUpdate,\n}: ModelConfigDialogProps): ReactElement {\n  const [optionalOpen, setOptionalOpen] = useState(false);\n  const modelId = `${config.id}-model`;\n  const providerId = `${config.id}-provider`;\n  const tempId = `${config.id}-temperature`;\n  const topPId = `${config.id}-top-p`;\n  const maxTokensId = `${config.id}-max-tokens`;\n  const timeoutId = `${config.id}-timeout`;\n  const extraBodyId = `${config.id}-inference-extra-body`;\n  const providerAnchorRef = useRef<HTMLDivElement>(null);\n  const providerInputRef = useRef(config.provider);\n  const lastProviderRef = useRef(config.provider);\n  if (lastProviderRef.current !== config.provider) {\n    lastProviderRef.current = config.provider;\n    providerInputRef.current = config.provider;\n  }\n  const updateField = <K extends keyof ModelConfig>(\n    key: K,\n    value: ModelConfig[K],\n  ) => {\n    onUpdate({ [key]: value } as Partial<ModelConfig>);\n  };\n\n  return (\n    <div className=\"space-y-4\">\n      <NameField\n        label=\"Model preset name\"\n        value={config.name}\n        onChange={(value) => onUpdate({ name: value })}\n      />\n      <div className=\"rounded-2xl border border-border/60 bg-muted/10 px-4 py-3\">\n        <p className=\"text-sm font-semibold text-foreground\">\n          Set up one reusable model choice for your AI steps\n        </p>\n        <p className=\"mt-1 text-xs text-muted-foreground\">\n          Choose the provider connection, enter the exact model ID, then save any\n          generation defaults you want to reuse.\n        </p>\n      </div>\n      <div className=\"grid gap-2\">\n        <FieldLabel\n          label=\"Provider connection\"\n          htmlFor={providerId}\n          hint=\"Choose where this model should run.\"\n        />\n        <div ref={providerAnchorRef}>\n          <Combobox\n            items={providerOptions}\n            filteredItems={providerOptions}\n            filter={null}\n            value={config.provider || null}\n            onValueChange={(value) => updateField(\"provider\", value ?? \"\")}\n            onInputValueChange={(value) => {\n              providerInputRef.current = value;\n            }}\n            itemToStringValue={(value) => value}\n            autoHighlight={true}\n          >\n            <ComboboxInput\n              id={providerId}\n              className=\"nodrag w-full\"\n              placeholder=\"Choose a provider connection\"\n              onBlur={() => {\n                const next = providerInputRef.current;\n                if (next !== config.provider) {\n                  updateField(\"provider\", next);\n                }\n              }}\n            />\n            <ComboboxContent anchor={providerAnchorRef}>\n              <ComboboxEmpty>No providers found</ComboboxEmpty>\n              <ComboboxList>\n                {(provider: string) => (\n                  <ComboboxItem key={provider} value={provider}>\n                    {provider}\n                  </ComboboxItem>\n                )}\n              </ComboboxList>\n            </ComboboxContent>\n          </Combobox>\n        </div>\n        <p className=\"text-xs text-muted-foreground\">\n          {providerOptions.length === 0\n            ? \"Add a Provider connection step first, then come back here.\"\n            : \"Matching blocks are linked automatically on the canvas.\"}\n        </p>\n      </div>\n      <div className=\"grid gap-2\">\n        <FieldLabel\n          label=\"Model ID\"\n          htmlFor={modelId}\n          hint=\"The exact model name sent to the connection.\"\n        />\n        <Input\n          id={modelId}\n          className=\"nodrag\"\n          placeholder=\"gpt-4o-mini\"\n          value={config.model}\n          onChange={(event) => updateField(\"model\", event.target.value)}\n        />\n      </div>\n      <div className=\"grid gap-3\">\n        <div className=\"space-y-1\">\n          <p className=\"text-sm font-semibold text-foreground\">\n            Default generation settings\n          </p>\n          <p className=\"text-xs text-muted-foreground\">\n            These defaults are reused anywhere you choose this model preset.\n          </p>\n        </div>\n        <div className=\"grid gap-3 sm:grid-cols-2\">\n          <div className=\"grid gap-2\">\n            <FieldLabel\n              label=\"Temperature\"\n              htmlFor={tempId}\n              hint=\"Higher values make responses more varied.\"\n            />\n            <Input\n              id={tempId}\n              className=\"nodrag\"\n              value={config.inference_temperature ?? \"\"}\n              onChange={(event) =>\n                updateField(\"inference_temperature\", event.target.value)\n              }\n            />\n          </div>\n          <div className=\"grid gap-2\">\n            <FieldLabel\n              label=\"Top-p\"\n              htmlFor={topPId}\n              hint=\"Use this to limit how broad token selection can be.\"\n            />\n            <Input\n              id={topPId}\n              className=\"nodrag\"\n              value={config.inference_top_p ?? \"\"}\n              onChange={(event) =>\n                updateField(\"inference_top_p\", event.target.value)\n              }\n            />\n          </div>\n          <div className=\"grid gap-2\">\n            <FieldLabel\n              label=\"Max tokens\"\n              htmlFor={maxTokensId}\n              hint=\"Maximum length of the model response.\"\n            />\n            <Input\n              id={maxTokensId}\n              className=\"nodrag\"\n              value={config.inference_max_tokens ?? \"\"}\n              onChange={(event) =>\n                updateField(\"inference_max_tokens\", event.target.value)\n              }\n            />\n          </div>\n          <div className=\"grid gap-2\">\n            <FieldLabel\n              label=\"Timeout (seconds)\"\n              htmlFor={timeoutId}\n              hint=\"How long to wait before a request is treated as failed.\"\n            />\n            <Input\n              id={timeoutId}\n              className=\"nodrag\"\n              value={config.inference_timeout ?? \"\"}\n              onChange={(event) =>\n                updateField(\"inference_timeout\", event.target.value)\n              }\n            />\n          </div>\n        </div>\n      </div>\n      <Collapsible open={optionalOpen} onOpenChange={setOptionalOpen}>\n        <CollapsibleTrigger asChild={true}>\n          <CollapsibleSectionTriggerButton\n            label=\"Advanced request fields\"\n            open={optionalOpen}\n          />\n        </CollapsibleTrigger>\n        <CollapsibleContent className=\"mt-3 space-y-4\">\n          <div className=\"grid gap-2\">\n            <FieldLabel\n              label=\"Advanced request fields (JSON)\"\n              htmlFor={extraBodyId}\n              hint=\"Extra request fields to send with every call.\"\n            />\n            <Textarea\n              id={extraBodyId}\n              className=\"corner-squircle nodrag\"\n              placeholder='{\"top_k\": 20, \"min_p\": 0.0}'\n              value={config.inference_extra_body ?? \"\"}\n              onChange={(event) =>\n                updateField(\"inference_extra_body\", event.target.value)\n              }\n            />\n          </div>\n          <label className=\"flex items-center gap-2 text-xs font-semibold uppercase text-muted-foreground\">\n            <Checkbox\n              checked={config.skip_health_check ?? false}\n              onCheckedChange={(value) =>\n                updateField(\"skip_health_check\", Boolean(value))\n              }\n            />\n            Skip connection check\n          </label>\n        </CollapsibleContent>\n      </Collapsible>\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/dialogs/models/model-provider-dialog.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport {\n  Collapsible,\n  CollapsibleContent,\n  CollapsibleTrigger,\n} from \"@/components/ui/collapsible\";\nimport { Input } from \"@/components/ui/input\";\nimport { Textarea } from \"@/components/ui/textarea\";\nimport { type ReactElement, useState } from \"react\";\nimport type { ModelProviderConfig } from \"../../types\";\nimport { CollapsibleSectionTriggerButton } from \"../shared/collapsible-section-trigger\";\nimport { FieldLabel } from \"../shared/field-label\";\nimport { NameField } from \"../shared/name-field\";\n\ntype ModelProviderDialogProps = {\n  config: ModelProviderConfig;\n  onUpdate: (patch: Partial<ModelProviderConfig>) => void;\n};\n\nexport function ModelProviderDialog({\n  config,\n  onUpdate,\n}: ModelProviderDialogProps): ReactElement {\n  const [optionalOpen, setOptionalOpen] = useState(false);\n  const endpointId = `${config.id}-endpoint`;\n  const apiKeyEnvId = `${config.id}-api-key-env`;\n  const apiKeyId = `${config.id}-api-key`;\n  const extraHeadersId = `${config.id}-extra-headers`;\n  const extraBodyId = `${config.id}-extra-body`;\n  const updateField = <K extends keyof ModelProviderConfig>(\n    key: K,\n    value: ModelProviderConfig[K],\n  ) => {\n    onUpdate({ [key]: value } as Partial<ModelProviderConfig>);\n  };\n\n  return (\n    <div className=\"space-y-4\">\n      <NameField\n        label=\"Connection name\"\n        value={config.name}\n        onChange={(value) => onUpdate({ name: value })}\n      />\n      <div className=\"rounded-2xl border border-border/60 bg-muted/10 px-4 py-3\">\n        <p className=\"text-sm font-semibold text-foreground\">\n          Start with the endpoint you want this model to use\n        </p>\n        <p className=\"mt-1 text-xs text-muted-foreground\">\n          Most connections only need an endpoint. Add an API key if that\n          service requires one.\n        </p>\n      </div>\n      <div className=\"grid gap-2\">\n        <FieldLabel\n          label=\"Endpoint\"\n          htmlFor={endpointId}\n          hint=\"Base URL for the model service or gateway.\"\n        />\n        <Input\n          id={endpointId}\n          className=\"nodrag\"\n          placeholder=\"https://...\"\n          value={config.endpoint}\n          onChange={(event) => updateField(\"endpoint\", event.target.value)}\n        />\n      </div>\n      <div className=\"grid gap-2\">\n        <FieldLabel\n          label=\"API key (optional)\"\n          htmlFor={apiKeyId}\n          hint=\"Paste a key here, or use an environment variable below.\"\n        />\n        <Input\n          id={apiKeyId}\n          className=\"nodrag\"\n          value={config.api_key ?? \"\"}\n          onChange={(event) => updateField(\"api_key\", event.target.value)}\n        />\n      </div>\n      <Collapsible open={optionalOpen} onOpenChange={setOptionalOpen}>\n        <CollapsibleTrigger asChild={true}>\n          <CollapsibleSectionTriggerButton\n            label=\"Advanced request overrides\"\n            open={optionalOpen}\n          />\n        </CollapsibleTrigger>\n        <CollapsibleContent className=\"mt-3 space-y-4\">\n          <div className=\"grid gap-2\">\n            <FieldLabel\n              label=\"API key environment variable\"\n              htmlFor={apiKeyEnvId}\n              hint=\"Name of the environment variable that stores the key.\"\n            />\n            <Input\n              id={apiKeyEnvId}\n              className=\"nodrag\"\n              placeholder=\"OPENAI_API_KEY\"\n              value={config.api_key_env ?? \"\"}\n              onChange={(event) => updateField(\"api_key_env\", event.target.value)}\n            />\n          </div>\n          <div className=\"grid gap-2\">\n            <FieldLabel\n              label=\"Extra headers (JSON)\"\n              htmlFor={extraHeadersId}\n              hint=\"Optional headers to send with every request.\"\n            />\n            <Textarea\n              id={extraHeadersId}\n              className=\"corner-squircle nodrag\"\n              placeholder='{\"X-Header\": \"value\"}'\n              value={config.extra_headers ?? \"\"}\n              onChange={(event) => updateField(\"extra_headers\", event.target.value)}\n            />\n          </div>\n          <div className=\"grid gap-2\">\n            <FieldLabel\n              label=\"Extra body (JSON)\"\n              htmlFor={extraBodyId}\n              hint=\"Optional request fields to send every time.\"\n            />\n            <Textarea\n              id={extraBodyId}\n              className=\"corner-squircle nodrag\"\n              placeholder='{\"key\": \"value\"}'\n              value={config.extra_body ?? \"\"}\n              onChange={(event) => updateField(\"extra_body\", event.target.value)}\n            />\n          </div>\n        </CollapsibleContent>\n      </Collapsible>\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/dialogs/preview-dialog.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Badge } from \"@/components/ui/badge\";\nimport { Button } from \"@/components/ui/button\";\nimport {\n  Collapsible,\n  CollapsibleContent,\n  CollapsibleTrigger,\n} from \"@/components/ui/collapsible\";\nimport {\n  Dialog,\n  DialogContent,\n  DialogFooter,\n  DialogHeader,\n  DialogTitle,\n} from \"@/components/ui/dialog\";\nimport { Input } from \"@/components/ui/input\";\nimport { Switch } from \"@/components/ui/switch\";\nimport { cn } from \"@/lib/utils\";\nimport {\n  AlertCircleIcon,\n  ArrowDown01Icon,\n  CheckmarkCircle02Icon,\n  CookBookIcon,\n  TestTube01Icon,\n} from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport { type ReactElement, type ReactNode, useState } from \"react\";\nimport type { RecipeExecutionKind } from \"../execution-types\";\nimport type { RecipeRunSettings } from \"../stores/recipe-executions\";\nimport { FieldLabel } from \"./shared/field-label\";\n\ntype RunDialogProps = {\n  open: boolean;\n  onOpenChange: (open: boolean) => void;\n  kind: RecipeExecutionKind;\n  onKindChange: (kind: RecipeExecutionKind) => void;\n  rows: number;\n  fullRunName: string;\n  onFullRunNameChange: (name: string) => void;\n  onRowsChange: (rows: number) => void;\n  settings: RecipeRunSettings;\n  onSettingsChange: (patch: Partial<RecipeRunSettings>) => void;\n  loading: boolean;\n  validateLoading: boolean;\n  validateResult: {\n    valid: boolean;\n    errors: string[];\n    rawDetail: string | null;\n  } | null;\n  errors: string[];\n  onRun: () => void;\n  onValidate: () => void;\n  container?: HTMLDivElement | null;\n};\n\ntype ValidationResult = RunDialogProps[\"validateResult\"];\n\nconst MAX_RECORDS = 200_000;\nconst MAX_WORKERS = 2_048;\nconst MAX_SHUTDOWN_WINDOW = 10_000;\nconst MAX_RETRY_STEPS = 100;\n\nfunction clampInt(value: number, min: number, max: number): number {\n  if (!Number.isFinite(value)) {\n    return min;\n  }\n  const next = Math.floor(value);\n  if (next < min) {\n    return min;\n  }\n  if (next > max) {\n    return max;\n  }\n  return next;\n}\n\nfunction clampFloat(value: number, min: number, max: number): number {\n  if (!Number.isFinite(value)) {\n    return min;\n  }\n  if (value < min) {\n    return min;\n  }\n  if (value > max) {\n    return max;\n  }\n  return value;\n}\n\nfunction commitInt(\n  raw: string,\n  current: number,\n  min: number,\n  max: number,\n  apply: (value: number) => void,\n  setDraft: (value: string) => void,\n): void {\n  const trimmed = raw.trim();\n  if (!trimmed) {\n    setDraft(String(current));\n    return;\n  }\n  const parsed = Number(trimmed);\n  if (!Number.isFinite(parsed)) {\n    setDraft(String(current));\n    return;\n  }\n  const next = clampInt(parsed, min, max);\n  apply(next);\n  setDraft(String(next));\n}\n\nfunction commitFloat(\n  raw: string,\n  current: number,\n  min: number,\n  max: number,\n  apply: (value: number) => void,\n  setDraft: (value: string) => void,\n): void {\n  const trimmed = raw.trim();\n  if (!trimmed) {\n    setDraft(String(current));\n    return;\n  }\n  const parsed = Number(trimmed);\n  if (!Number.isFinite(parsed)) {\n    setDraft(String(current));\n    return;\n  }\n  const next = clampFloat(parsed, min, max);\n  apply(next);\n  setDraft(String(next));\n}\n\ntype DraftInputFieldProps = {\n  id: string;\n  label: string;\n  hint: string;\n  inputMode: \"numeric\" | \"decimal\";\n  value: string;\n  onChange: (value: string) => void;\n  onBlur: () => void;\n  placeholder?: string;\n};\n\nfunction DraftInputField({\n  id,\n  label,\n  hint,\n  inputMode,\n  value,\n  onChange,\n  onBlur,\n  placeholder,\n}: DraftInputFieldProps): ReactElement {\n  return (\n    <div className=\"grid gap-2\">\n      <FieldLabel label={label} htmlFor={id} hint={hint} />\n      <Input\n        id={id}\n        type=\"text\"\n        inputMode={inputMode}\n        value={value}\n        onChange={(event) => onChange(event.target.value)}\n        onBlur={onBlur}\n        placeholder={placeholder}\n      />\n    </div>\n  );\n}\n\nfunction AdvancedSettingsSection({\n  title,\n  description,\n  children,\n}: {\n  title: string;\n  description: string;\n  children: ReactNode;\n}): ReactElement {\n  return (\n    <div className=\"space-y-3 rounded-2xl border border-border/70 bg-card/60 p-4\">\n      <div className=\"space-y-0.5\">\n        <p className=\"text-sm font-semibold text-foreground\">{title}</p>\n        <p className=\"text-xs text-muted-foreground\">{description}</p>\n      </div>\n      {children}\n    </div>\n  );\n}\n\nfunction ValidationResultPanel({\n  validateResult,\n}: {\n  validateResult: ValidationResult;\n}): ReactElement | null {\n  if (!validateResult) {\n    return null;\n  }\n\n  return (\n    <div\n      className={cn(\n        \"space-y-3 rounded-2xl border p-4 shadow-border backdrop-blur-sm\",\n        validateResult.valid\n          ? \"border-emerald-300/70 bg-emerald-50/80 dark:border-emerald-900/60 dark:bg-emerald-950/30\"\n          : \"border-destructive/30 bg-destructive/5\",\n      )}\n    >\n      <div className=\"flex items-start gap-3\">\n        <div\n          className={cn(\n            \"mt-0.5 flex size-8 shrink-0 items-center justify-center rounded-full border\",\n            validateResult.valid\n              ? \"border-emerald-300/70 bg-emerald-500/10 text-emerald-700 dark:border-emerald-900/60 dark:text-emerald-300\"\n              : \"border-destructive/30 bg-destructive/10 text-destructive\",\n          )}\n        >\n          <HugeiconsIcon\n            icon={\n              validateResult.valid ? CheckmarkCircle02Icon : AlertCircleIcon\n            }\n            className=\"size-4\"\n          />\n        </div>\n        <div className=\"min-w-0 flex-1 space-y-1\">\n          <p\n            className={cn(\n              \"text-sm font-semibold\",\n              validateResult.valid\n                ? \"text-emerald-700 dark:text-emerald-300\"\n                : \"text-destructive\",\n            )}\n          >\n            {validateResult.valid ? \"Ready to run\" : \"Fix these issues first\"}\n          </p>\n          <p className=\"text-xs text-muted-foreground\">\n            {validateResult.valid\n              ? \"Everything checks out. Start the run when you're ready.\"\n              : \"Update the recipe, then check it again.\"}\n          </p>\n        </div>\n      </div>\n      {!validateResult.valid && validateResult.errors.length > 0 && (\n        <div className=\"space-y-1\">\n          {validateResult.errors.map((error) => (\n            <p key={error} className=\"break-words text-xs text-destructive\">\n              {error}\n            </p>\n          ))}\n        </div>\n      )}\n      {!validateResult.valid && validateResult.rawDetail && (\n        <p className=\"break-words text-xs text-destructive\">\n          {validateResult.rawDetail}\n        </p>\n      )}\n    </div>\n  );\n}\n\ntype RunDialogBodyProps = Omit<\n  RunDialogProps,\n  \"open\" | \"onOpenChange\" | \"container\"\n> & {\n  onClose: () => void;\n};\n\nfunction RunDialogBody({\n  kind,\n  onKindChange,\n  rows,\n  fullRunName,\n  onFullRunNameChange,\n  onRowsChange,\n  settings,\n  onSettingsChange,\n  loading,\n  validateLoading,\n  validateResult,\n  errors,\n  onRun,\n  onValidate,\n  onClose,\n}: RunDialogBodyProps): ReactElement {\n  const [advancedOpen, setAdvancedOpen] = useState(false);\n  const kindLabel = kind === \"preview\" ? \"Test run\" : \"Full run\";\n  const normalizedFullRunName = fullRunName.trim();\n  const isFullRunNameMissing =\n    kind === \"full\" && normalizedFullRunName.length === 0;\n  const rowHint =\n    kind === \"preview\"\n      ? \"How many sample rows to generate for a quick check.\"\n      : \"How many rows to generate in total.\";\n\n  const [rowsDraft, setRowsDraft] = useState(String(rows));\n  const [batchSizeDraft, setBatchSizeDraft] = useState(\n    String(settings.batchSize),\n  );\n  const [llmParallelDraft, setLlmParallelDraft] = useState(\n    settings.llmParallelRequests === null\n      ? \"\"\n      : String(settings.llmParallelRequests),\n  );\n  const [workersDraft, setWorkersDraft] = useState(\n    String(settings.nonInferenceWorkers),\n  );\n  const [windowDraft, setWindowDraft] = useState(\n    String(settings.shutdownErrorWindow),\n  );\n  const [restartsDraft, setRestartsDraft] = useState(\n    String(settings.maxConversationRestarts),\n  );\n  const [correctionsDraft, setCorrectionsDraft] = useState(\n    String(settings.maxConversationCorrectionSteps),\n  );\n  const [shutdownRateDraft, setShutdownRateDraft] = useState(\n    String(settings.shutdownErrorRate),\n  );\n\n  return (\n    <>\n      <DialogHeader className=\"space-y-2\">\n        <DialogTitle>{kindLabel}</DialogTitle>\n        <p className=\"text-sm text-muted-foreground\">\n          Choose a quick test or a full run. Advanced settings are optional.\n        </p>\n      </DialogHeader>\n\n      <div className=\"grid gap-2\">\n        <FieldLabel\n          label=\"Run type\"\n          hint=\"Start with a quick check or generate the full dataset.\"\n        />\n        <div className=\"grid grid-cols-2 gap-2\">\n          <Button\n            type=\"button\"\n            variant={kind === \"preview\" ? \"default\" : \"outline\"}\n            className=\"corner-squircle min-h-10 justify-center whitespace-normal px-3 text-center\"\n            aria-pressed={kind === \"preview\"}\n            onClick={() => onKindChange(\"preview\")}\n          >\n            Test run\n          </Button>\n          <Button\n            type=\"button\"\n            variant={kind === \"full\" ? \"default\" : \"outline\"}\n            className=\"corner-squircle min-h-10 justify-center whitespace-normal px-3 text-center\"\n            aria-pressed={kind === \"full\"}\n            onClick={() => onKindChange(\"full\")}\n          >\n            Full run\n          </Button>\n        </div>\n      </div>\n\n      {kind === \"full\" && (\n        <div className=\"grid gap-2\">\n          <FieldLabel\n            label=\"Run name\"\n            htmlFor=\"run-name\"\n            hint=\"Name shown in your run history.\"\n          />\n          <Input\n            id=\"run-name\"\n            type=\"text\"\n            value={fullRunName}\n            onChange={(event) => onFullRunNameChange(event.target.value)}\n            placeholder=\"Sprint dataset v2\"\n            aria-invalid={isFullRunNameMissing}\n          />\n          {isFullRunNameMissing ? (\n            <p className=\"text-xs text-destructive\">\n              Give this full run a name before you start.\n            </p>\n          ) : null}\n        </div>\n      )}\n\n      <div className=\"grid gap-2\">\n        <FieldLabel label=\"Records\" htmlFor=\"run-rows\" hint={rowHint} />\n        <Input\n          id=\"run-rows\"\n          type=\"text\"\n          inputMode=\"numeric\"\n          value={rowsDraft}\n          onChange={(event) => setRowsDraft(event.target.value)}\n          onBlur={() =>\n            commitInt(\n              rowsDraft,\n              rows,\n              1,\n              MAX_RECORDS,\n              onRowsChange,\n              setRowsDraft,\n            )\n          }\n        />\n      </div>\n\n      <Collapsible open={advancedOpen} onOpenChange={setAdvancedOpen}>\n        <CollapsibleTrigger asChild={true}>\n          <button\n            type=\"button\"\n            className=\"flex items-center gap-2 text-xs font-semibold uppercase tracking-wide text-muted-foreground hover:text-foreground\"\n          >\n            <HugeiconsIcon\n              icon={ArrowDown01Icon}\n              className={cn(\n                \"size-3.5 transition-transform\",\n                advancedOpen && \"rotate-180\",\n              )}\n            />\n            {advancedOpen\n              ? \"Hide advanced run settings\"\n              : \"Show advanced run settings\"}\n          </button>\n        </CollapsibleTrigger>\n        <CollapsibleContent className=\"mt-3 space-y-4\">\n          {kind === \"full\" && (\n            <AdvancedSettingsSection\n              title=\"Batching\"\n              description=\"Use batches when you want to split a larger run into smaller pieces.\"\n            >\n              <div className=\"flex items-center justify-between gap-3 text-sm\">\n                <div className=\"space-y-0.5\">\n                  <span className=\"font-medium\">Enable batching</span>\n                  <p className=\"text-xs text-muted-foreground\">\n                    Split a larger run into smaller chunks.\n                  </p>\n                </div>\n                <Switch\n                  checked={settings.batchEnabled}\n                  onCheckedChange={(checked) =>\n                    onSettingsChange({ batchEnabled: Boolean(checked) })\n                  }\n                />\n              </div>\n              {rows >= 1000 && !settings.batchEnabled ? (\n                <p className=\"text-xs text-muted-foreground\">\n                  Larger runs are usually easier to manage in batches.\n                </p>\n              ) : null}\n            </AdvancedSettingsSection>\n          )}\n          <AdvancedSettingsSection\n            title=\"Throughput\"\n            description=\"Control how much work runs at the same time.\"\n          >\n            <div className=\"grid gap-4 md:grid-cols-2\">\n              <DraftInputField\n                id=\"run-llm-parallel\"\n                label=\"AI requests at once\"\n                hint=\"Leave empty to use each saved model's own setting.\"\n                inputMode=\"numeric\"\n                value={llmParallelDraft}\n                onChange={setLlmParallelDraft}\n                onBlur={() => {\n                  const trimmed = llmParallelDraft.trim();\n                  if (!trimmed) {\n                    onSettingsChange({ llmParallelRequests: null });\n                    setLlmParallelDraft(\"\");\n                    return;\n                  }\n                  const parsed = Number(trimmed);\n                  if (!Number.isFinite(parsed)) {\n                    setLlmParallelDraft(\n                      settings.llmParallelRequests === null\n                        ? \"\"\n                        : String(settings.llmParallelRequests),\n                    );\n                    return;\n                  }\n                  const next = clampInt(parsed, 1, MAX_WORKERS);\n                  onSettingsChange({ llmParallelRequests: next });\n                  setLlmParallelDraft(String(next));\n                }}\n                placeholder=\"Use saved model setting\"\n              />\n              <DraftInputField\n                id=\"run-non-inference-workers\"\n                label=\"CPU workers\"\n                hint=\"Used for steps like source data, generated fields, and formulas.\"\n                inputMode=\"numeric\"\n                value={workersDraft}\n                onChange={setWorkersDraft}\n                onBlur={() =>\n                  commitInt(\n                    workersDraft,\n                    settings.nonInferenceWorkers,\n                    1,\n                    MAX_WORKERS,\n                    (value) => onSettingsChange({ nonInferenceWorkers: value }),\n                    setWorkersDraft,\n                  )\n                }\n              />\n              {kind === \"full\" && settings.batchEnabled && (\n                <>\n                  <DraftInputField\n                    id=\"run-batch-size\"\n                    label=\"Batch size\"\n                    hint=\"How many rows to generate in each batch.\"\n                    inputMode=\"numeric\"\n                    value={batchSizeDraft}\n                    onChange={setBatchSizeDraft}\n                    onBlur={() =>\n                      commitInt(\n                        batchSizeDraft,\n                        settings.batchSize,\n                        1,\n                        MAX_RECORDS,\n                        (value) => onSettingsChange({ batchSize: value }),\n                        setBatchSizeDraft,\n                      )\n                    }\n                  />\n                  <div className=\"flex items-center justify-between gap-3 rounded-xl border border-border/60 bg-background/60 px-3 py-2 text-sm text-foreground\">\n                    <div className=\"space-y-0.5\">\n                      <p className=\"font-medium\">Merge batches into one file</p>\n                      <p className=\"text-xs text-muted-foreground\">\n                        Combine every batch output into one final file.\n                      </p>\n                    </div>\n                    <Switch\n                      checked={settings.mergeBatches}\n                      onCheckedChange={(checked) =>\n                        onSettingsChange({ mergeBatches: Boolean(checked) })\n                      }\n                    />\n                  </div>\n                </>\n              )}\n            </div>\n          </AdvancedSettingsSection>\n          <AdvancedSettingsSection\n            title=\"Retries and recovery\"\n            description=\"Choose how hard the run should try before it gives up.\"\n          >\n            <div className=\"grid gap-4 md:grid-cols-2\">\n              <DraftInputField\n                id=\"run-shutdown-window\"\n                label=\"Failure check window\"\n                hint=\"How many recent attempts to inspect before stopping early.\"\n                inputMode=\"numeric\"\n                value={windowDraft}\n                onChange={setWindowDraft}\n                onBlur={() =>\n                  commitInt(\n                    windowDraft,\n                    settings.shutdownErrorWindow,\n                    1,\n                    MAX_SHUTDOWN_WINDOW,\n                    (value) => onSettingsChange({ shutdownErrorWindow: value }),\n                    setWindowDraft,\n                  )\n                }\n              />\n              <DraftInputField\n                id=\"run-shutdown-rate\"\n                label=\"Stop after too many failures\"\n                hint=\"Example: 0.5 stops when about half of recent attempts fail.\"\n                inputMode=\"decimal\"\n                value={shutdownRateDraft}\n                onChange={setShutdownRateDraft}\n                onBlur={() =>\n                  commitFloat(\n                    shutdownRateDraft,\n                    settings.shutdownErrorRate,\n                    0,\n                    1,\n                    (value) => onSettingsChange({ shutdownErrorRate: value }),\n                    setShutdownRateDraft,\n                  )\n                }\n              />\n              <DraftInputField\n                id=\"run-max-restarts\"\n                label=\"Full retries\"\n                hint=\"How many times to retry when a model answer fails checks.\"\n                inputMode=\"numeric\"\n                value={restartsDraft}\n                onChange={setRestartsDraft}\n                onBlur={() =>\n                  commitInt(\n                    restartsDraft,\n                    settings.maxConversationRestarts,\n                    0,\n                    MAX_RETRY_STEPS,\n                    (value) =>\n                      onSettingsChange({ maxConversationRestarts: value }),\n                    setRestartsDraft,\n                  )\n                }\n              />\n              <DraftInputField\n                id=\"run-correction-steps\"\n                label=\"Correction attempts\"\n                hint=\"How many follow-up fixes to try before starting over.\"\n                inputMode=\"numeric\"\n                value={correctionsDraft}\n                onChange={setCorrectionsDraft}\n                onBlur={() =>\n                  commitInt(\n                    correctionsDraft,\n                    settings.maxConversationCorrectionSteps,\n                    0,\n                    MAX_RETRY_STEPS,\n                    (value) =>\n                      onSettingsChange({ maxConversationCorrectionSteps: value }),\n                    setCorrectionsDraft,\n                  )\n                }\n              />\n              <div className=\"flex items-center justify-between gap-3 rounded-xl border border-border/60 bg-background/60 px-3 py-2 text-sm text-foreground md:col-span-2\">\n                <div className=\"space-y-0.5\">\n                  <p className=\"font-medium\">Keep running through failures</p>\n                  <p className=\"text-xs text-muted-foreground\">\n                    Useful for longer runs when you want as many rows as possible.\n                  </p>\n                </div>\n                <Switch\n                  checked={settings.disableEarlyShutdown}\n                  onCheckedChange={(checked) =>\n                    onSettingsChange({\n                      disableEarlyShutdown: Boolean(checked),\n                    })\n                  }\n                />\n              </div>\n            </div>\n          </AdvancedSettingsSection>\n        </CollapsibleContent>\n      </Collapsible>\n\n      {errors.length > 0 && (\n        <div className=\"max-h-44 space-y-2 overflow-y-auto rounded-2xl border border-destructive/30 bg-destructive/5 p-4 shadow-border\">\n          <div className=\"flex items-center gap-2\">\n            <HugeiconsIcon\n              icon={AlertCircleIcon}\n              className=\"size-4 text-destructive\"\n            />\n            <Badge\n              variant=\"outline\"\n              className=\"rounded-full text-[10px] text-destructive\"\n            >\n              Before you run\n            </Badge>\n          </div>\n          {errors.map((error) => (\n            <p key={error} className=\"break-words text-xs text-destructive\">\n              {error}\n            </p>\n          ))}\n        </div>\n      )}\n\n      <ValidationResultPanel validateResult={validateResult} />\n\n      <DialogFooter>\n        <Button\n          type=\"button\"\n          variant=\"outline\"\n          onClick={onClose}\n          disabled={loading}\n          className=\"corner-squircle border-border/70 bg-card/70\"\n        >\n          Cancel\n        </Button>\n        <Button\n          type=\"button\"\n          variant=\"outline\"\n          onClick={onValidate}\n          disabled={loading || validateLoading}\n          className=\"corner-squircle border-border/70 bg-card/70\"\n        >\n          <HugeiconsIcon icon={TestTube01Icon} className=\"size-3.5\" />\n          {validateLoading ? \"Checking...\" : \"Check recipe\"}\n        </Button>\n        <Button\n          type=\"button\"\n          onClick={onRun}\n          disabled={loading || isFullRunNameMissing}\n          className=\"corner-squircle\"\n        >\n          <HugeiconsIcon icon={CookBookIcon} className=\"size-3.5\" />\n          {loading ? \"Starting...\" : `Start ${kindLabel.toLowerCase()}`}\n        </Button>\n      </DialogFooter>\n    </>\n  );\n}\n\nexport function RunDialog({\n  open,\n  onOpenChange,\n  container,\n  ...contentProps\n}: RunDialogProps): ReactElement {\n  const draftKey = [open ? \"open\" : \"closed\", contentProps.kind].join(\"|\");\n\n  return (\n    <Dialog open={open} onOpenChange={onOpenChange}>\n      <DialogContent\n        container={container}\n        position=\"absolute\"\n        overlayPosition=\"absolute\"\n        overlayClassName=\"bg-transparent\"\n        className=\"corner-squircle max-h-[650px] overflow-y-auto overflow-x-hidden border-border/70 bg-background/95 sm:max-w-2xl shadow-border backdrop-blur-xl\"\n      >\n        <RunDialogBody\n          key={draftKey}\n          {...contentProps}\n          onClose={() => onOpenChange(false)}\n        />\n      </DialogContent>\n    </Dialog>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/dialogs/processors-dialog.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Button } from \"@/components/ui/button\";\nimport { Dialog, DialogContent, DialogFooter, DialogTitle } from \"@/components/ui/dialog\";\nimport { Input } from \"@/components/ui/input\";\nimport { Switch } from \"@/components/ui/switch\";\nimport { Textarea } from \"@/components/ui/textarea\";\nimport { VisuallyHidden } from \"radix-ui\";\nimport { type ReactElement, useMemo } from \"react\";\nimport type { RecipeProcessorConfig } from \"../types\";\nimport { buildDefaultSchemaTransform } from \"../utils/processors\";\nimport { AvailableVariables } from \"./shared/available-variables\";\nimport { FieldLabel } from \"./shared/field-label\";\ntype ProcessorsDialogProps = {\n  open: boolean;\n  onOpenChange: (open: boolean) => void;\n  processors: RecipeProcessorConfig[];\n  onProcessorsChange: (processors: RecipeProcessorConfig[]) => void;\n  container?: HTMLDivElement | null;\n};\n\nexport function ProcessorsDialog({\n  open,\n  onOpenChange,\n  processors,\n  onProcessorsChange,\n  container,\n}: ProcessorsDialogProps): ReactElement {\n  const schemaIndex = useMemo(\n    () =>\n      processors.findIndex(\n        (processor) => processor.processor_type === \"schema_transform\",\n      ),\n    [processors],\n  );\n  const schemaProcessor = schemaIndex >= 0 ? processors[schemaIndex] : null;\n  const nameId = schemaProcessor ? `${schemaProcessor.id}-name` : \"schema-transform-name\";\n  const templateId = schemaProcessor\n    ? `${schemaProcessor.id}-template`\n    : \"schema-transform-template\";\n\n  const setSchemaEnabled = (enabled: boolean) => {\n    if (enabled) {\n      if (schemaProcessor) {\n        return;\n      }\n      onProcessorsChange([...processors, buildDefaultSchemaTransform()]);\n      return;\n    }\n    onProcessorsChange(\n      processors.filter(\n        (processor) => processor.processor_type !== \"schema_transform\",\n      ),\n    );\n  };\n\n  const updateSchema = (patch: Partial<RecipeProcessorConfig>) => {\n    if (!schemaProcessor) {\n      return;\n    }\n    const next = [...processors];\n    next[schemaIndex] = { ...schemaProcessor, ...patch };\n    onProcessorsChange(next);\n  };\n\n  return (\n    <Dialog open={open} onOpenChange={onOpenChange}>\n      <DialogContent\n        container={container}\n        position=\"absolute\"\n        overlayPosition=\"absolute\"\n        overlayClassName=\"bg-transparent\"\n        className=\"corner-squircle max-h-[650px] overflow-auto sm:max-w-2xl shadow-border\"\n      >\n        <VisuallyHidden.Root>\n          <DialogTitle>Processors</DialogTitle>\n        </VisuallyHidden.Root>\n        <div className=\"space-y-4\">\n          <div className=\"flex items-center justify-between gap-3 corner-squircle rounded-2xl border border-border/60 px-3 py-2\">\n            <div>\n              <p className=\"text-sm font-semibold\">Schema transform</p>\n              <p className=\"text-xs text-muted-foreground\">\n                Transform final rows to target schema (post-batch).\n              </p>\n            </div>\n            <Switch\n              checked={Boolean(schemaProcessor)}\n              onCheckedChange={setSchemaEnabled}\n            />\n          </div>\n\n          {schemaProcessor && (\n            <div className=\"space-y-3\">\n              <AvailableVariables configId=\"\" />\n              <div className=\"grid gap-2\">\n                <FieldLabel\n                  label=\"Name\"\n                  htmlFor={nameId}\n                  hint=\"Processor name shown in graph and payload.\"\n                />\n                <Input\n                  id={nameId}\n                  className=\"nodrag\"\n                  value={schemaProcessor.name}\n                  onChange={(event) => updateSchema({ name: event.target.value })}\n                />\n              </div>\n              <div className=\"grid gap-2\">\n                <FieldLabel\n                  label=\"Template (JSON)\"\n                  htmlFor={templateId}\n                  hint=\"Target output schema template using Jinja references.\"\n                />\n                <Textarea\n                  id={templateId}\n                  className=\"corner-squircle nodrag min-h-[220px]\"\n                  value={schemaProcessor.template}\n                  onChange={(event) =>\n                    updateSchema({ template: event.target.value })\n                  }\n                />\n                <p className=\"text-xs text-muted-foreground\">\n                  Use Jinja refs like {\"{{ customer_review }}\"} in values.\n                </p>\n              </div>\n            </div>\n          )}\n        </div>\n        <DialogFooter>\n          <Button\n            type=\"button\"\n            variant=\"outline\"\n            onClick={() => onOpenChange(false)}\n          >\n            Done\n          </Button>\n        </DialogFooter>\n      </DialogContent>\n    </Dialog>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/dialogs/samplers/bernoulli-dialog.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Input } from \"@/components/ui/input\";\nimport type { ReactElement } from \"react\";\nimport type { SamplerConfig } from \"../../types\";\nimport { NameField } from \"../shared/name-field\";\nimport { FieldLabel } from \"../shared/field-label\";\n\ntype BernoulliDialogProps = {\n  config: SamplerConfig;\n  onUpdate: (patch: Partial<SamplerConfig>) => void;\n};\n\nexport function BernoulliDialog({\n  config,\n  onUpdate,\n}: BernoulliDialogProps): ReactElement {\n  const pId = `${config.id}-bernoulli-p`;\n  return (\n    <div className=\"space-y-4\">\n      <NameField\n        value={config.name}\n        onChange={(value) => onUpdate({ name: value })}\n      />\n      <div className=\"grid gap-2\">\n        <FieldLabel\n          label=\"Probability (p)\"\n          htmlFor={pId}\n          hint=\"Success probability in [0, 1].\"\n        />\n        <Input\n          id={pId}\n          type=\"number\"\n          min=\"0\"\n          max=\"1\"\n          step=\"0.01\"\n          className=\"nodrag\"\n          value={config.p ?? \"\"}\n          onChange={(event) => onUpdate({ p: event.target.value })}\n        />\n      </div>\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/dialogs/samplers/category-dialog.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Button } from \"@/components/ui/button\";\nimport {\n  Collapsible,\n  CollapsibleContent,\n  CollapsibleTrigger,\n} from \"@/components/ui/collapsible\";\nimport { Input } from \"@/components/ui/input\";\nimport { type ReactElement, useState } from \"react\";\nimport type { SamplerConfig } from \"../../types\";\nimport { ChipInput } from \"../../components/chip-input\";\nimport { CollapsibleSectionTriggerButton } from \"../shared/collapsible-section-trigger\";\nimport { FieldLabel } from \"../shared/field-label\";\nimport { NameField } from \"../shared/name-field\";\n\ntype CategoryDialogProps = {\n  config: SamplerConfig;\n  onUpdate: (patch: Partial<SamplerConfig>) => void;\n};\n\nfunction addChipWithWeight(\n  values: string[] | undefined,\n  weights: Array<number | null> | undefined,\n  value: string,\n): { values: string[]; weights: Array<number | null> } {\n  return {\n    values: [...(values ?? []), value],\n    weights: [...(weights ?? []), null],\n  };\n}\n\nfunction removeChipWithWeight(\n  values: string[] | undefined,\n  weights: Array<number | null> | undefined,\n  index: number,\n): { values: string[]; weights: Array<number | null> } {\n  const nextValues = [...(values ?? [])];\n  const nextWeights = [...(weights ?? [])];\n  nextValues.splice(index, 1);\n  nextWeights.splice(index, 1);\n  return { values: nextValues, weights: nextWeights };\n}\n\nexport function CategoryDialog({\n  config,\n  onUpdate,\n}: CategoryDialogProps): ReactElement {\n  const [conditionDraft, setConditionDraft] = useState(\"\");\n  const advancedOpen = config.advancedOpen === true;\n  const conditionInputId = `${config.id}-conditional-rule`;\n  const conditional = config.conditional_params ?? {};\n  const conditionalCount = Object.keys(conditional).length;\n\n  const handleAddCondition = () => {\n    const condition = conditionDraft.trim();\n    if (!condition || conditional[condition]) {\n      return;\n    }\n    onUpdate({\n      // biome-ignore lint/style/useNamingConvention: api schema\n      conditional_params: {\n        ...conditional,\n        [condition]: {\n          // biome-ignore lint/style/useNamingConvention: api schema\n          sampler_type: \"category\",\n          values: [],\n          weights: [],\n        },\n      },\n    });\n    setConditionDraft(\"\");\n  };\n\n  const removeCondition = (condition: string) => {\n    const next = { ...conditional };\n    delete next[condition];\n    onUpdate({\n      // biome-ignore lint/style/useNamingConvention: api schema\n      conditional_params: Object.keys(next).length > 0 ? next : undefined,\n    });\n  };\n\n  return (\n    <div className=\"space-y-4\">\n      <NameField\n        value={config.name}\n        onChange={(value) => onUpdate({ name: value })}\n      />\n      <div className=\"space-y-3\">\n        <div className=\"grid gap-2\">\n          <FieldLabel\n            label=\"Values\"\n            hint=\"Define allowed categorical values for this column.\"\n          />\n          <ChipInput\n            values={config.values ?? []}\n            onAdd={(value) => {\n              const { values, weights } = addChipWithWeight(\n                config.values,\n                config.weights,\n                value,\n              );\n              onUpdate({ values, weights });\n            }}\n            onRemove={(index) => {\n              const { values, weights } = removeChipWithWeight(\n                config.values,\n                config.weights,\n                index,\n              );\n              onUpdate({ values, weights });\n            }}\n            placeholder=\"Type a value and press Enter\"\n          />\n        </div>\n      </div>\n      <Collapsible\n        open={advancedOpen}\n        onOpenChange={(open) => onUpdate({ advancedOpen: open })}\n      >\n        <CollapsibleTrigger asChild={true}>\n          <CollapsibleSectionTriggerButton\n            label=\"Advanced list settings\"\n            open={advancedOpen}\n          />\n        </CollapsibleTrigger>\n        <CollapsibleContent className=\"mt-2 space-y-3\">\n            <div className=\"grid gap-2\">\n              <FieldLabel\n                label=\"Weights (optional)\"\n                hint=\"Set selection probability per value.\"\n              />\n              {(config.values ?? []).length === 0 ? (\n                <p className=\"text-xs text-muted-foreground\">\n                  Add values first, then set optional weights.\n                </p>\n              ) : (\n                <div className=\"grid gap-2 sm:grid-cols-2 lg:grid-cols-3\">\n                  {(config.values ?? []).map((value, index) => (\n                    <div key={`${value}-weight`} className=\"space-y-1\">\n                      <p\n                        className=\"truncate text-xs text-muted-foreground\"\n                        title={value}\n                      >\n                        {value}\n                      </p>\n                      <Input\n                        type=\"number\"\n                        className=\"nodrag w-full\"\n                        placeholder=\"Weight\"\n                        value={config.weights?.[index] ?? \"\"}\n                        onChange={(event) => {\n                          const weights = [...(config.weights ?? [])];\n                          weights[index] = event.target.value\n                            ? Number(event.target.value)\n                            : null;\n                          onUpdate({ weights });\n                        }}\n                      />\n                    </div>\n                  ))}\n                </div>\n              )}\n            </div>\n            <div className=\"flex items-center justify-between gap-2\">\n              <FieldLabel\n                label=\"Conditional params (category)\"\n                hint=\"Override category values/weights when condition matches.\"\n              />\n              <span className=\"text-xs text-muted-foreground\">\n                {conditionalCount} rules\n              </span>\n            </div>\n            <div className=\"flex gap-2\">\n              <Input\n                id={conditionInputId}\n                className=\"nodrag\"\n                placeholder=\"Condition (e.g., {{ region }} == 'US')\"\n                value={conditionDraft}\n                onChange={(event) => setConditionDraft(event.target.value)}\n                onKeyDown={(event) => {\n                  if (event.key === \"Enter\") {\n                    event.preventDefault();\n                    handleAddCondition();\n                  }\n                }}\n              />\n              <Button type=\"button\" size=\"sm\" onClick={handleAddCondition}>\n                Add rule\n              </Button>\n            </div>\n            {Object.entries(conditional).map(([condition, params]) => (\n              <div\n                key={condition}\n                className=\"space-y-3 rounded-2xl border border-border/60 p-3\"\n              >\n                <div className=\"flex items-center justify-between gap-2\">\n                  <p className=\"text-xs font-semibold text-foreground\">{condition}</p>\n                  <Button\n                    type=\"button\"\n                    size=\"xs\"\n                    variant=\"ghost\"\n                    onClick={() => removeCondition(condition)}\n                  >\n                    Remove\n                  </Button>\n                </div>\n                <ChipInput\n                  values={params.values ?? []}\n                  onAdd={(value) => {\n                    const { values, weights } = addChipWithWeight(\n                      params.values,\n                      params.weights,\n                      value,\n                    );\n                    onUpdate({\n                      // biome-ignore lint/style/useNamingConvention: api schema\n                      conditional_params: {\n                        ...conditional,\n                        [condition]: { ...params, values, weights },\n                      },\n                    });\n                  }}\n                  onRemove={(index) => {\n                    const { values, weights } = removeChipWithWeight(\n                      params.values,\n                      params.weights,\n                      index,\n                    );\n                    onUpdate({\n                      // biome-ignore lint/style/useNamingConvention: api schema\n                      conditional_params: {\n                        ...conditional,\n                        [condition]: { ...params, values, weights },\n                      },\n                    });\n                  }}\n                  placeholder=\"Type a conditional value and press Enter\"\n                />\n                <div className=\"grid gap-2\">\n                  <p className=\"text-xs font-semibold uppercase text-muted-foreground\">\n                    Rule weights (optional)\n                  </p>\n                  <div className=\"grid gap-2 sm:grid-cols-2 lg:grid-cols-3\">\n                    {(params.values ?? []).map((value, index) => (\n                      <div\n                        key={`${condition}-${value}-${index}-weight`}\n                        className=\"space-y-1\"\n                      >\n                        <p\n                          className=\"truncate text-xs text-muted-foreground\"\n                          title={value}\n                        >\n                          {value}\n                        </p>\n                        <Input\n                          type=\"number\"\n                          className=\"nodrag\"\n                          placeholder=\"Weight\"\n                          value={params.weights?.[index] ?? \"\"}\n                          onChange={(event) => {\n                            const weights = [\n                              ...(params.weights ??\n                                Array.from(\n                                  { length: (params.values ?? []).length },\n                                  () => null,\n                                )),\n                            ];\n                            weights[index] = event.target.value\n                              ? Number(event.target.value)\n                              : null;\n                            onUpdate({\n                              // biome-ignore lint/style/useNamingConvention: api schema\n                              conditional_params: {\n                                ...conditional,\n                                [condition]: { ...params, weights },\n                              },\n                            });\n                          }}\n                        />\n                      </div>\n                    ))}\n                  </div>\n                </div>\n              </div>\n            ))}\n        </CollapsibleContent>\n      </Collapsible>\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/dialogs/samplers/datetime-dialog.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Input } from \"@/components/ui/input\";\nimport {\n  Select,\n  SelectContent,\n  SelectItem,\n  SelectTrigger,\n  SelectValue,\n} from \"@/components/ui/select\";\nimport type { ReactElement } from \"react\";\nimport type { SamplerConfig } from \"../../types\";\nimport { NameField } from \"../shared/name-field\";\nimport { FieldLabel } from \"../shared/field-label\";\n\nconst DATETIME_UNITS = [\n  \"second\",\n  \"minute\",\n  \"hour\",\n  \"day\",\n  \"week\",\n  \"month\",\n  \"year\",\n];\n\ntype DatetimeDialogProps = {\n  config: SamplerConfig;\n  onUpdate: (patch: Partial<SamplerConfig>) => void;\n};\n\nexport function DatetimeDialog({\n  config,\n  onUpdate,\n}: DatetimeDialogProps): ReactElement {\n  const startId = `${config.id}-datetime-start`;\n  const endId = `${config.id}-datetime-end`;\n  const unitId = `${config.id}-datetime-unit`;\n  const updateField = <K extends keyof SamplerConfig>(\n    key: K,\n    value: SamplerConfig[K],\n  ) => {\n    onUpdate({ [key]: value } as Partial<SamplerConfig>);\n  };\n  return (\n    <div className=\"space-y-4\">\n      <NameField\n        value={config.name}\n        onChange={(value) => onUpdate({ name: value })}\n      />\n      <div className=\"grid gap-3\">\n        <div className=\"grid gap-2 sm:grid-cols-2\">\n          <div className=\"grid gap-2\">\n            <FieldLabel\n              label=\"Start\"\n              htmlFor={startId}\n              hint=\"Earliest datetime allowed.\"\n            />\n            <Input\n              id={startId}\n              type=\"datetime-local\"\n              className=\"nodrag\"\n              value={config.datetime_start ?? \"\"}\n              onChange={(event) =>\n                updateField(\"datetime_start\", event.target.value)\n              }\n            />\n          </div>\n          <div className=\"grid gap-2\">\n            <FieldLabel\n              label=\"End\"\n              htmlFor={endId}\n              hint=\"Latest datetime allowed.\"\n            />\n            <Input\n              id={endId}\n              type=\"datetime-local\"\n              className=\"nodrag\"\n              value={config.datetime_end ?? \"\"}\n              onChange={(event) =>\n                updateField(\"datetime_end\", event.target.value)\n              }\n            />\n          </div>\n        </div>\n        <div className=\"grid gap-2\">\n          <FieldLabel\n            label=\"Unit\"\n            htmlFor={unitId}\n            hint=\"Sampling granularity for generated timestamps.\"\n          />\n          <Select\n            value={config.datetime_unit ?? \"\"}\n            onValueChange={(value) => updateField(\"datetime_unit\", value)}\n          >\n            <SelectTrigger className=\"nodrag w-full\" id={unitId}>\n              <SelectValue placeholder=\"Select unit\" />\n            </SelectTrigger>\n            <SelectContent>\n              {DATETIME_UNITS.map((unit) => (\n                <SelectItem key={unit} value={unit}>\n                  {unit}\n                </SelectItem>\n              ))}\n            </SelectContent>\n          </Select>\n        </div>\n      </div>\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/dialogs/samplers/gaussian-dialog.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Input } from \"@/components/ui/input\";\nimport {\n  Select,\n  SelectContent,\n  SelectItem,\n  SelectTrigger,\n  SelectValue,\n} from \"@/components/ui/select\";\nimport type { ReactElement } from \"react\";\nimport type { SamplerConfig } from \"../../types\";\nimport { NameField } from \"../shared/name-field\";\nimport { FieldLabel } from \"../shared/field-label\";\n\ntype GaussianDialogProps = {\n  config: SamplerConfig;\n  onUpdate: (patch: Partial<SamplerConfig>) => void;\n};\n\nexport function GaussianDialog({\n  config,\n  onUpdate,\n}: GaussianDialogProps): ReactElement {\n  const meanId = `${config.id}-gaussian-mean`;\n  const stdId = `${config.id}-gaussian-std`;\n  const convertId = `${config.id}-gaussian-convert`;\n  return (\n    <div className=\"space-y-4\">\n      <NameField\n        value={config.name}\n        onChange={(value) => onUpdate({ name: value })}\n      />\n      <div className=\"grid gap-3 sm:grid-cols-2\">\n        <div className=\"grid gap-2\">\n          <FieldLabel\n            label=\"Mean\"\n            htmlFor={meanId}\n            hint=\"Center of the normal distribution.\"\n          />\n          <Input\n            id={meanId}\n            type=\"number\"\n            className=\"nodrag\"\n            value={config.mean ?? \"\"}\n            onChange={(event) => onUpdate({ mean: event.target.value })}\n          />\n        </div>\n        <div className=\"grid gap-2\">\n          <FieldLabel\n            label=\"Std\"\n            htmlFor={stdId}\n            hint=\"Standard deviation. must be > 0.\"\n          />\n          <Input\n            id={stdId}\n            type=\"number\"\n            className=\"nodrag\"\n            value={config.std ?? \"\"}\n            onChange={(event) => onUpdate({ std: event.target.value })}\n          />\n        </div>\n      </div>\n      <div className=\"grid gap-2\">\n        <FieldLabel\n          label=\"Convert to\"\n          htmlFor={convertId}\n          hint=\"Optionally cast sampled values before output.\"\n        />\n        <Select\n          value={config.convert_to ?? \"none\"}\n          onValueChange={(value) =>\n            onUpdate({\n              // biome-ignore lint/style/useNamingConvention: api schema\n              convert_to: value === \"none\" ? undefined : (value as \"int\" | \"float\" | \"str\"),\n            })\n          }\n        >\n          <SelectTrigger className=\"nodrag w-full\" id={convertId}>\n            <SelectValue placeholder=\"No conversion\" />\n          </SelectTrigger>\n          <SelectContent>\n            <SelectItem value=\"none\">None</SelectItem>\n            <SelectItem value=\"int\">int</SelectItem>\n            <SelectItem value=\"float\">float</SelectItem>\n            <SelectItem value=\"str\">str</SelectItem>\n          </SelectContent>\n        </Select>\n      </div>\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/dialogs/samplers/person-dialog.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Input } from \"@/components/ui/input\";\nimport {\n  Select,\n  SelectContent,\n  SelectItem,\n  SelectTrigger,\n  SelectValue,\n} from \"@/components/ui/select\";\nimport { type ReactElement, useEffect } from \"react\";\nimport type { SamplerConfig } from \"../../types\";\nimport { NameField } from \"../shared/name-field\";\nimport { FieldLabel } from \"../shared/field-label\";\n\ntype PersonDialogProps = {\n  config: SamplerConfig;\n  onUpdate: (patch: Partial<SamplerConfig>) => void;\n};\n\nexport function PersonDialog({\n  config,\n  onUpdate,\n}: PersonDialogProps): ReactElement {\n  const localeId = `${config.id}-person-locale`;\n  const sexId = `${config.id}-person-sex`;\n  const ageRangeId = `${config.id}-person-age-range`;\n  const cityId = `${config.id}-person-city`;\n\n  const updateField = <K extends keyof SamplerConfig>(\n    key: K,\n    value: SamplerConfig[K],\n  ) => {\n    onUpdate({ [key]: value } as Partial<SamplerConfig>);\n  };\n\n  useEffect(() => {\n    if (config.sampler_type !== \"person_from_faker\") {\n      onUpdate({\n        sampler_type: \"person_from_faker\",\n        person_with_synthetic_personas: undefined,\n      });\n    }\n  }, [config.sampler_type, onUpdate]);\n\n  return (\n    <div className=\"space-y-4\">\n      <NameField\n        value={config.name}\n        onChange={(value) => onUpdate({ name: value })}\n      />\n      <div className=\"grid gap-3\">\n        <div className=\"rounded-2xl border border-border/60 px-3 py-2\">\n          <p className=\"text-xs font-semibold uppercase text-muted-foreground\">\n            Source\n          </p>\n          <p className=\"text-sm text-foreground\">Faker</p>\n        </div>\n        <div className=\"grid gap-3 sm:grid-cols-2\">\n          <div className=\"grid gap-2\">\n            <FieldLabel\n              label=\"Locale\"\n              htmlFor={localeId}\n              hint=\"Faker locale e.g. en_US.\"\n            />\n            <Input\n              id={localeId}\n              className=\"nodrag\"\n              value={config.person_locale ?? \"\"}\n              onChange={(event) =>\n                updateField(\"person_locale\", event.target.value)\n              }\n            />\n          </div>\n          <div className=\"grid gap-2\">\n            <FieldLabel\n              label=\"Sex\"\n              htmlFor={sexId}\n              hint=\"Optional demographic filter.\"\n            />\n            <Select\n              value={config.person_sex?.trim() ? config.person_sex : \"any\"}\n              onValueChange={(value) =>\n                updateField(\"person_sex\", value === \"any\" ? \"\" : value)\n              }\n            >\n              <SelectTrigger className=\"nodrag w-full\" id={sexId}>\n                <SelectValue placeholder=\"Any\" />\n              </SelectTrigger>\n              <SelectContent>\n                <SelectItem value=\"any\">Any</SelectItem>\n                <SelectItem value=\"Male\">Male</SelectItem>\n                <SelectItem value=\"Female\">Female</SelectItem>\n              </SelectContent>\n            </Select>\n          </div>\n          <div className=\"grid gap-2\">\n            <FieldLabel\n              label=\"Age range\"\n              htmlFor={ageRangeId}\n              hint=\"Range format: min-max, e.g. 18-70.\"\n            />\n            <Input\n              id={ageRangeId}\n              className=\"nodrag\"\n              value={config.person_age_range ?? \"\"}\n              onChange={(event) =>\n                updateField(\"person_age_range\", event.target.value)\n              }\n              placeholder=\"18-70\"\n            />\n          </div>\n          <div className=\"grid gap-2\">\n            <FieldLabel\n              label=\"City\"\n              htmlFor={cityId}\n              hint=\"Optional city bias for faker generation.\"\n            />\n            <Input\n              id={cityId}\n              className=\"nodrag\"\n              value={config.person_city ?? \"\"}\n              onChange={(event) =>\n                updateField(\"person_city\", event.target.value)\n              }\n            />\n          </div>\n        </div>\n      </div>\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/dialogs/samplers/subcategory-dialog.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport {\n  Select,\n  SelectContent,\n  SelectItem,\n  SelectTrigger,\n  SelectValue,\n} from \"@/components/ui/select\";\nimport { type ReactElement, useCallback, useEffect, useMemo } from \"react\";\nimport type { SamplerConfig } from \"../../types\";\nimport { ChipInput } from \"../../components/chip-input\";\nimport { NameField } from \"../shared/name-field\";\nimport { FieldLabel } from \"../shared/field-label\";\n\ntype SubcategoryDialogProps = {\n  config: SamplerConfig;\n  categoryOptions: SamplerConfig[];\n  onUpdate: (patch: Partial<SamplerConfig>) => void;\n};\n\nexport function SubcategoryDialog({\n  config,\n  categoryOptions,\n  onUpdate,\n}: SubcategoryDialogProps): ReactElement {\n  const parentSelectId = `${config.id}-parent-category`;\n  const updateField = useCallback(\n    <K extends keyof SamplerConfig>(key: K, value: SamplerConfig[K]) => {\n      onUpdate({ [key]: value } as Partial<SamplerConfig>);\n    },\n    [onUpdate],\n  );\n  const parent = useMemo(\n    () =>\n      categoryOptions.find(\n        (option) => option.name === config.subcategory_parent,\n      ) ?? null,\n    [categoryOptions, config.subcategory_parent],\n  );\n  const categoryValues = parent?.values ?? [];\n  const mapping = config.subcategory_mapping ?? {};\n\n  const ensureMapping = useCallback(\n    (nextParent?: SamplerConfig | null) => {\n      const values = nextParent?.values ?? [];\n      const nextMapping: Record<string, string[]> = {};\n      for (const value of values) {\n        nextMapping[value] = config.subcategory_mapping?.[value] ?? [];\n      }\n      const currentKeys = Object.keys(config.subcategory_mapping ?? {});\n      const nextKeys = Object.keys(nextMapping);\n      const changed =\n        currentKeys.length !== nextKeys.length ||\n        currentKeys.some((key) => !nextKeys.includes(key));\n      if (changed) {\n        updateField(\"subcategory_mapping\", nextMapping);\n      }\n    },\n    [config.subcategory_mapping, updateField],\n  );\n\n  useEffect(() => {\n    if (parent) {\n      ensureMapping(parent);\n    }\n  }, [ensureMapping, parent]);\n\n  return (\n    <div className=\"space-y-4\">\n      <NameField\n        value={config.name}\n        onChange={(value) => onUpdate({ name: value })}\n      />\n      <div className=\"space-y-3\">\n        <div className=\"grid gap-2\">\n          <FieldLabel\n            label=\"Parent category column\"\n            htmlFor={parentSelectId}\n            hint=\"Category column this block maps from.\"\n          />\n          <Select\n            value={config.subcategory_parent ?? \"\"}\n            onValueChange={(value) => {\n              const nextParent =\n                categoryOptions.find((option) => option.name === value) ?? null;\n              updateField(\"subcategory_parent\", value);\n              ensureMapping(nextParent);\n            }}\n          >\n            <SelectTrigger className=\"nodrag w-full\" id={parentSelectId}>\n              <SelectValue placeholder=\"Select category column\" />\n            </SelectTrigger>\n            <SelectContent>\n              {categoryOptions.map((option) => (\n                <SelectItem key={option.id} value={option.name}>\n                  {option.name}\n                </SelectItem>\n              ))}\n            </SelectContent>\n          </Select>\n          <p className=\"text-xs text-muted-foreground\">\n            Map each parent category value to its subcategory options below.\n          </p>\n        </div>\n        {categoryValues.length > 0 && (\n          <div className=\"grid gap-4\">\n            {categoryValues.map((value) => (\n              <div key={value}>\n                <div className=\"mb-2 flex items-center justify-between gap-2\">\n                  <p className=\"text-sm font-semibold text-foreground\">\n                    {value}\n                  </p>\n                  <span className=\"text-xs text-muted-foreground\">\n                    {mapping[value]?.length ?? 0} subvalues\n                  </span>\n                </div>\n                <ChipInput\n                  values={mapping[value] ?? []}\n                  onAdd={(item) => {\n                    const next = { ...mapping };\n                    const list = next[value] ? [...next[value]] : [];\n                    list.push(item);\n                    next[value] = list;\n                    updateField(\"subcategory_mapping\", next);\n                  }}\n                  onRemove={(index) => {\n                    const next = { ...mapping };\n                    const list = [...(next[value] ?? [])];\n                    list.splice(index, 1);\n                    next[value] = list;\n                    updateField(\"subcategory_mapping\", next);\n                  }}\n                  placeholder=\"Type subcategory and press Enter\"\n                />\n                {(mapping[value] ?? []).length === 0 && (\n                  <p className=\"mt-2 text-xs text-rose-500\">\n                    Add at least 1 subcategory.\n                  </p>\n                )}\n              </div>\n            ))}\n          </div>\n        )}\n      </div>\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/dialogs/samplers/timedelta-dialog.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Input } from \"@/components/ui/input\";\nimport {\n  Select,\n  SelectContent,\n  SelectItem,\n  SelectTrigger,\n  SelectValue,\n} from \"@/components/ui/select\";\nimport type { ReactElement } from \"react\";\nimport type { SamplerConfig } from \"../../types\";\nimport { NameField } from \"../shared/name-field\";\nimport { FieldLabel } from \"../shared/field-label\";\n\nconst TIMEDELTA_UNITS: Array<\"D\" | \"h\" | \"m\" | \"s\"> = [\"D\", \"h\", \"m\", \"s\"];\nconst NONE_VALUE = \"__none\";\n\ntype TimedeltaDialogProps = {\n  config: SamplerConfig;\n  datetimeOptions: string[];\n  onUpdate: (patch: Partial<SamplerConfig>) => void;\n};\n\nexport function TimedeltaDialog({\n  config,\n  datetimeOptions,\n  onUpdate,\n}: TimedeltaDialogProps): ReactElement {\n  const dtMinId = `${config.id}-timedelta-min`;\n  const dtMaxId = `${config.id}-timedelta-max`;\n  const unitId = `${config.id}-timedelta-unit`;\n  const referenceId = `${config.id}-timedelta-reference`;\n  const updateField = <K extends keyof SamplerConfig>(\n    key: K,\n    value: SamplerConfig[K],\n  ) => {\n    onUpdate({ [key]: value } as Partial<SamplerConfig>);\n  };\n  return (\n    <div className=\"space-y-4\">\n      <NameField\n        value={config.name}\n        onChange={(value) => onUpdate({ name: value })}\n      />\n      <div className=\"grid gap-3 sm:grid-cols-2\">\n        <div className=\"grid gap-2\">\n          <FieldLabel\n            label=\"dt_min\"\n            htmlFor={dtMinId}\n            hint=\"Minimum offset from reference datetime.\"\n          />\n          <Input\n            id={dtMinId}\n            type=\"number\"\n            className=\"nodrag\"\n            value={config.dt_min ?? \"\"}\n            onChange={(event) => updateField(\"dt_min\", event.target.value)}\n          />\n        </div>\n        <div className=\"grid gap-2\">\n          <FieldLabel\n            label=\"dt_max\"\n            htmlFor={dtMaxId}\n            hint=\"Maximum offset from reference datetime.\"\n          />\n          <Input\n            id={dtMaxId}\n            type=\"number\"\n            className=\"nodrag\"\n            value={config.dt_max ?? \"\"}\n            onChange={(event) => updateField(\"dt_max\", event.target.value)}\n          />\n        </div>\n      </div>\n      <div className=\"grid gap-2\">\n        <FieldLabel\n          label=\"Unit\"\n          htmlFor={unitId}\n          hint=\"Offset unit. D/h/m/s.\"\n        />\n        <Select\n          value={config.timedelta_unit ?? \"D\"}\n          onValueChange={(value) =>\n            updateField(\"timedelta_unit\", value as \"D\" | \"h\" | \"m\" | \"s\")\n          }\n        >\n          <SelectTrigger className=\"nodrag w-full\" id={unitId}>\n            <SelectValue placeholder=\"Select unit\" />\n          </SelectTrigger>\n          <SelectContent>\n            {TIMEDELTA_UNITS.map((unit) => (\n              <SelectItem key={unit} value={unit}>\n                {unit}\n              </SelectItem>\n            ))}\n          </SelectContent>\n        </Select>\n      </div>\n      <div className=\"grid gap-2\">\n        <FieldLabel\n          label=\"Reference datetime column\"\n          htmlFor={referenceId}\n          hint=\"Datetime column used as anchor before offset.\"\n        />\n        <Select\n          value={config.reference_column_name?.trim() || NONE_VALUE}\n          onValueChange={(value) =>\n            updateField(\"reference_column_name\", value === NONE_VALUE ? \"\" : value)\n          }\n        >\n          <SelectTrigger className=\"nodrag w-full\" id={referenceId}>\n            <SelectValue placeholder=\"Select datetime column\" />\n          </SelectTrigger>\n          <SelectContent>\n            <SelectItem value={NONE_VALUE}>None</SelectItem>\n            {datetimeOptions.map((name) => (\n              <SelectItem key={name} value={name}>\n                {name}\n              </SelectItem>\n            ))}\n          </SelectContent>\n        </Select>\n      </div>\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/dialogs/samplers/uniform-dialog.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Input } from \"@/components/ui/input\";\nimport {\n  Select,\n  SelectContent,\n  SelectItem,\n  SelectTrigger,\n  SelectValue,\n} from \"@/components/ui/select\";\nimport type { ReactElement } from \"react\";\nimport type { SamplerConfig } from \"../../types\";\nimport { NameField } from \"../shared/name-field\";\nimport { FieldLabel } from \"../shared/field-label\";\n\ntype UniformDialogProps = {\n  config: SamplerConfig;\n  onUpdate: (patch: Partial<SamplerConfig>) => void;\n};\n\nexport function UniformDialog({\n  config,\n  onUpdate,\n}: UniformDialogProps): ReactElement {\n  const lowId = `${config.id}-uniform-low`;\n  const highId = `${config.id}-uniform-high`;\n  const convertId = `${config.id}-uniform-convert`;\n  return (\n    <div className=\"space-y-4\">\n      <NameField\n        value={config.name}\n        onChange={(value) => onUpdate({ name: value })}\n      />\n      <div className=\"grid gap-3 sm:grid-cols-2\">\n        <div className=\"grid gap-2\">\n          <FieldLabel\n            label=\"Low\"\n            htmlFor={lowId}\n            hint=\"Minimum sampled value.\"\n          />\n          <Input\n            id={lowId}\n            type=\"number\"\n            className=\"nodrag\"\n            value={config.low ?? \"\"}\n            onChange={(event) => onUpdate({ low: event.target.value })}\n          />\n        </div>\n        <div className=\"grid gap-2\">\n          <FieldLabel\n            label=\"High\"\n            htmlFor={highId}\n            hint=\"Maximum sampled value.\"\n          />\n          <Input\n            id={highId}\n            type=\"number\"\n            className=\"nodrag\"\n            value={config.high ?? \"\"}\n            onChange={(event) => onUpdate({ high: event.target.value })}\n          />\n        </div>\n      </div>\n      <div className=\"grid gap-2\">\n        <FieldLabel\n          label=\"Convert to\"\n          htmlFor={convertId}\n          hint=\"Optionally cast sampled values before output.\"\n        />\n        <Select\n          value={config.convert_to ?? \"none\"}\n          onValueChange={(value) =>\n            onUpdate({\n              // biome-ignore lint/style/useNamingConvention: api schema\n              convert_to: value === \"none\" ? undefined : (value as \"int\" | \"float\" | \"str\"),\n            })\n          }\n        >\n          <SelectTrigger className=\"nodrag w-full\" id={convertId}>\n            <SelectValue placeholder=\"No conversion\" />\n          </SelectTrigger>\n          <SelectContent>\n            <SelectItem value=\"none\">None</SelectItem>\n            <SelectItem value=\"int\">int</SelectItem>\n            <SelectItem value=\"float\">float</SelectItem>\n            <SelectItem value=\"str\">str</SelectItem>\n          </SelectContent>\n        </Select>\n      </div>\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/dialogs/samplers/uuid-dialog.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Input } from \"@/components/ui/input\";\nimport type { ReactElement } from \"react\";\nimport type { SamplerConfig } from \"../../types\";\nimport { NameField } from \"../shared/name-field\";\nimport { FieldLabel } from \"../shared/field-label\";\n\ntype UuidDialogProps = {\n  config: SamplerConfig;\n  onUpdate: (patch: Partial<SamplerConfig>) => void;\n};\n\nexport function UuidDialog({\n  config,\n  onUpdate,\n}: UuidDialogProps): ReactElement {\n  const uuidId = `${config.id}-uuid-format`;\n  const updateField = <K extends keyof SamplerConfig>(\n    key: K,\n    value: SamplerConfig[K],\n  ) => {\n    onUpdate({ [key]: value } as Partial<SamplerConfig>);\n  };\n  return (\n    <div className=\"space-y-4\">\n      <NameField\n        value={config.name}\n        onChange={(value) => onUpdate({ name: value })}\n      />\n      <div className=\"grid gap-2\">\n        <FieldLabel\n          label=\"UUID format (optional)\"\n          htmlFor={uuidId}\n          hint=\"Optional formatter e.g. prefix:, short, uppercase.\"\n        />\n        <Input\n          id={uuidId}\n          className=\"nodrag\"\n          value={config.uuid_format ?? \"\"}\n          onChange={(event) => updateField(\"uuid_format\", event.target.value)}\n        />\n      </div>\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/dialogs/seed/seed-dialog.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Button } from \"@/components/ui/button\";\nimport { Checkbox } from \"@/components/ui/checkbox\";\nimport {\n  Collapsible,\n  CollapsibleContent,\n  CollapsibleTrigger,\n} from \"@/components/ui/collapsible\";\nimport {\n  Empty,\n  EmptyContent,\n  EmptyDescription,\n  EmptyHeader,\n  EmptyTitle,\n} from \"@/components/ui/empty\";\nimport { Input } from \"@/components/ui/input\";\nimport {\n  Select,\n  SelectContent,\n  SelectItem,\n  SelectTrigger,\n  SelectValue,\n} from \"@/components/ui/select\";\nimport {\n  Table,\n  TableBody,\n  TableCell,\n  TableHead,\n  TableHeader,\n  TableRow,\n} from \"@/components/ui/table\";\nimport {\n  Tabs,\n  TabsContent,\n  TabsList,\n  TabsTrigger,\n} from \"@/components/ui/tabs\";\nimport mammoth from \"mammoth\";\nimport { type ReactElement, useCallback, useEffect, useMemo, useRef, useState } from \"react\";\nimport { extractText, getDocumentProxy } from \"unpdf\";\nimport { cn } from \"@/lib/utils\";\nimport { inspectSeedDataset, inspectSeedUpload } from \"../../api\";\nimport { resolveImagePreview } from \"../../utils/image-preview\";\nimport type {\n  SeedConfig,\n  SeedSamplingStrategy,\n  SeedSelectionType,\n} from \"../../types\";\nimport { CollapsibleSectionTriggerButton } from \"../shared/collapsible-section-trigger\";\nimport { HfDatasetCombobox } from \"../../components/shared/hf-dataset-combobox\";\nimport { FieldLabel } from \"../shared/field-label\";\n\nconst SAMPLING_OPTIONS: Array<{ value: SeedSamplingStrategy; label: string }> = [\n  { value: \"ordered\", label: \"Ordered\" },\n  { value: \"shuffle\", label: \"Shuffle\" },\n];\n\nconst SELECTION_OPTIONS: Array<{ value: SeedSelectionType; label: string }> = [\n  { value: \"none\", label: \"None\" },\n  { value: \"index_range\", label: \"Index range\" },\n  { value: \"partition_block\", label: \"Partition block\" },\n];\n\nconst LOCAL_ACCEPT = \".csv,.json,.jsonl\";\nconst UNSTRUCTURED_ACCEPT = \".txt,.pdf,.docx\";\nconst MAX_UPLOAD_BYTES = 50 * 1024 * 1024;\nconst DEFAULT_CHUNK_SIZE = 1200;\nconst DEFAULT_CHUNK_OVERLAP = 200;\nconst MAX_CHUNK_SIZE = 20000;\nconst PREVIEW_TRUNCATE_AT = 320;\n\ntype SeedDialogProps = {\n  config: SeedConfig;\n  onUpdate: (patch: Partial<SeedConfig>) => void;\n  open: boolean;\n};\n\nfunction getErrorMessage(error: unknown, fallback: string): string {\n  if (error instanceof Error && error.message) {\n    return error.message;\n  }\n  return fallback;\n}\n\nfunction stringifyCell(value: unknown): string {\n  if (value === null || value === undefined) return \"\";\n  if (typeof value === \"string\") return value;\n  if (typeof value === \"number\" || typeof value === \"boolean\") return String(value);\n  try {\n    return JSON.stringify(value);\n  } catch {\n    return String(value);\n  }\n}\n\nfunction isExpandablePreviewValue(value: string): boolean {\n  return value.length > PREVIEW_TRUNCATE_AT;\n}\n\nfunction truncatePreviewValue(value: string): string {\n  if (!isExpandablePreviewValue(value)) {\n    return value;\n  }\n  return `${value.slice(0, PREVIEW_TRUNCATE_AT)}…`;\n}\n\nfunction getPreviewEmptyStateCopy(mode: SeedConfig[\"seed_source_type\"]): {\n  title: string;\n  description: string;\n} {\n  if (mode === \"local\") {\n    return {\n      title: \"No local preview yet\",\n      description: \"Choose a CSV/JSON/JSONL file, then click Load to fetch 10 rows.\",\n    };\n  }\n  if (mode === \"unstructured\") {\n    return {\n      title: \"No chunk preview yet\",\n      description:\n        \"Choose a TXT/PDF/DOCX file, then click Load to extract + preview chunk_text rows.\",\n    };\n  }\n  return {\n    title: \"No dataset preview yet\",\n    description: \"Pick a Hugging Face dataset and click Load to fetch 10 sample rows.\",\n  };\n}\n\nfunction parseChunkNumber(\n  value: string | undefined,\n  fallback: number,\n  min: number,\n  max: number,\n): number {\n  const raw = value?.trim();\n  if (!raw) return fallback;\n  const parsed = Number(raw);\n  if (!Number.isFinite(parsed)) return fallback;\n  const int = Math.floor(parsed);\n  if (int < min) return min;\n  if (int > max) return max;\n  return int;\n}\n\nfunction resolveChunking(config: SeedConfig): {\n  chunkSize: number;\n  chunkOverlap: number;\n} {\n  const chunkSize = parseChunkNumber(\n    config.unstructured_chunk_size,\n    DEFAULT_CHUNK_SIZE,\n    1,\n    MAX_CHUNK_SIZE,\n  );\n  const chunkOverlap = parseChunkNumber(\n    config.unstructured_chunk_overlap,\n    DEFAULT_CHUNK_OVERLAP,\n    0,\n    Math.max(0, chunkSize - 1),\n  );\n  return { chunkSize, chunkOverlap };\n}\n\nasync function fileToBase64Payload(file: File): Promise<string> {\n  return new Promise((resolve, reject) => {\n    const reader = new FileReader();\n    reader.onload = () => {\n      const value = String(reader.result ?? \"\");\n      const parts = value.split(\",\");\n      resolve(parts.length > 1 ? parts[1] : value);\n    };\n    reader.onerror = () => reject(new Error(\"Failed to read file\"));\n    reader.readAsDataURL(file);\n  });\n}\n\nasync function extractUnstructuredText(file: File): Promise<string> {\n  const lower = file.name.toLowerCase();\n  if (lower.endsWith(\".txt\")) {\n    return file.text();\n  }\n  if (lower.endsWith(\".pdf\")) {\n    const buffer = new Uint8Array(await file.arrayBuffer());\n    const pdf = await getDocumentProxy(buffer);\n    const { text } = await extractText(pdf, { mergePages: true });\n    return text;\n  }\n  if (lower.endsWith(\".docx\")) {\n    const arrayBuffer = await file.arrayBuffer();\n    const { value } = await mammoth.extractRawText({ arrayBuffer });\n    return value;\n  }\n  throw new Error(\"Unsupported unstructured file type\");\n}\n\nasync function toUnstructuredUploadFile(file: File): Promise<File> {\n  const lower = file.name.toLowerCase();\n  if (lower.endsWith(\".txt\") || lower.endsWith(\".md\")) {\n    return file;\n  }\n\n  const text = (await extractUnstructuredText(file)).trim();\n  if (!text) {\n    throw new Error(\"No text found in file.\");\n  }\n  const normalized = text.replace(/\\r\\n/g, \"\\n\").replace(/\\r/g, \"\\n\");\n  const stem = file.name.replace(/\\.(pdf|docx)$/i, \"\") || \"unstructured_seed\";\n  return new File([normalized], `${stem}.txt`, {\n    type: \"text/plain\",\n  });\n}\n\nexport function SeedDialog({ config, onUpdate, open }: SeedDialogProps): ReactElement {\n  const [inspectError, setInspectError] = useState<string | null>(null);\n  const [isInspecting, setIsInspecting] = useState(false);\n  const advancedOpen = config.advancedOpen === true;\n  const [previewRows, setPreviewRows] = useState<Record<string, unknown>[]>([]);\n  const [expandedPreviewRows, setExpandedPreviewRows] = useState<Record<number, boolean>>({});\n  const [localFile, setLocalFile] = useState<File | null>(null);\n  const [unstructuredFile, setUnstructuredFile] = useState<File | null>(null);\n\n  const mode = config.seed_source_type ?? \"hf\";\n  const previewEmpty = getPreviewEmptyStateCopy(mode);\n\n  useEffect(() => {\n    setInspectError(null);\n    setLocalFile(null);\n    setUnstructuredFile(null);\n  }, [mode]);\n\n  useEffect(() => {\n    setPreviewRows(config.seed_preview_rows ?? []);\n    setExpandedPreviewRows({});\n  }, [config.seed_preview_rows]);\n\n  const samplingId = `${config.id}-sampling`;\n  const selectionId = `${config.id}-selection`;\n  const tokenId = `${config.id}-hf-token`;\n  const datasetId = `${config.id}-hf-dataset`;\n  const chunkSizeId = `${config.id}-chunk-size`;\n  const chunkOverlapId = `${config.id}-chunk-overlap`;\n  const [lastLoadedKey, setLastLoadedKey] = useState<string | null>(null);\n  const wasOpenRef = useRef(open);\n\n  const getCurrentLoadKey = useCallback((): string | null => {\n    if (mode === \"hf\") {\n      const dataset = config.hf_repo_id.trim();\n      if (!dataset) return null;\n      const token = config.hf_token?.trim() ?? \"\";\n      return `hf:${dataset}|${token}`;\n    }\n    if (mode === \"local\") {\n      if (!localFile) return null;\n      return `local:${localFile.name}|${localFile.size}|${localFile.lastModified}`;\n    }\n    if (!unstructuredFile) return null;\n    const { chunkSize, chunkOverlap } = resolveChunking(config);\n    return `unstructured:${unstructuredFile.name}|${unstructuredFile.size}|${unstructuredFile.lastModified}|${chunkSize}|${chunkOverlap}`;\n  }, [\n    config,\n    localFile,\n    mode,\n    unstructuredFile,\n  ]);\n\n  const loadSeedMetadata = useCallback(async (opts?: { silent?: boolean }): Promise<boolean> => {\n    const loadKey = getCurrentLoadKey();\n    if (!opts?.silent) {\n      setInspectError(null);\n    }\n    setIsInspecting(true);\n    try {\n      if (mode === \"hf\") {\n        const datasetName = config.hf_repo_id.trim();\n        if (!datasetName) {\n          throw new Error(\"Dataset repo is required.\");\n        }\n        const response = await inspectSeedDataset({\n          dataset_name: datasetName,\n          hf_token: config.hf_token?.trim() || undefined,\n          split: config.hf_split?.trim() || undefined,\n          subset: config.hf_subset?.trim() || undefined,\n          preview_size: 10,\n        });\n        onUpdate({\n          hf_path: response.resolved_path,\n          seed_columns: response.columns,\n          seed_drop_columns: (config.seed_drop_columns ?? []).filter((name) =>\n            response.columns.includes(name),\n          ),\n          seed_preview_rows: response.preview_rows ?? [],\n          hf_split: response.split ?? \"\",\n          hf_subset: response.subset ?? \"\",\n          local_file_name: \"\",\n          unstructured_file_name: \"\",\n        });\n        setPreviewRows(response.preview_rows ?? []);\n        setLastLoadedKey(loadKey);\n        return true;\n      }\n\n      if (mode === \"local\") {\n        if (!localFile) {\n          throw new Error(\"Select a local CSV/JSON/JSONL file first.\");\n        }\n        if (localFile.size > MAX_UPLOAD_BYTES) {\n          throw new Error(\"File too large (max 50MB).\");\n        }\n        const payload = await fileToBase64Payload(localFile);\n        const response = await inspectSeedUpload({\n          filename: localFile.name,\n          content_base64: payload,\n          preview_size: 10,\n        });\n        onUpdate({\n          hf_path: response.resolved_path,\n          seed_columns: response.columns,\n          seed_drop_columns: (config.seed_drop_columns ?? []).filter((name) =>\n            response.columns.includes(name),\n          ),\n          seed_preview_rows: response.preview_rows ?? [],\n          hf_repo_id: \"\",\n          hf_subset: \"\",\n          hf_split: \"\",\n          local_file_name: localFile.name,\n          unstructured_file_name: \"\",\n        });\n        setPreviewRows(response.preview_rows ?? []);\n        setLastLoadedKey(loadKey);\n        return true;\n      }\n\n      if (!unstructuredFile) {\n        throw new Error(\"Select a PDF/DOCX/TXT file first.\");\n      }\n      if (unstructuredFile.size > MAX_UPLOAD_BYTES) {\n        throw new Error(\"File too large (max 50MB).\");\n      }\n\n      const { chunkSize, chunkOverlap } = resolveChunking(config);\n      const uploadFile = await toUnstructuredUploadFile(unstructuredFile);\n      if (uploadFile.size > MAX_UPLOAD_BYTES) {\n        throw new Error(\"Processed text is too large (max 50MB).\");\n      }\n      const payload = await fileToBase64Payload(uploadFile);\n      const response = await inspectSeedUpload({\n        filename: uploadFile.name,\n        content_base64: payload,\n        preview_size: 10,\n        seed_source_type: \"unstructured\",\n        unstructured_chunk_size: chunkSize,\n        unstructured_chunk_overlap: chunkOverlap,\n      });\n      onUpdate({\n        hf_path: response.resolved_path,\n        seed_columns: response.columns,\n        seed_drop_columns: (config.seed_drop_columns ?? []).filter((name) =>\n          response.columns.includes(name),\n        ),\n        seed_preview_rows: response.preview_rows ?? [],\n        hf_repo_id: \"\",\n        hf_subset: \"\",\n        hf_split: \"\",\n        local_file_name: \"\",\n        unstructured_file_name: unstructuredFile.name,\n      });\n      setPreviewRows(response.preview_rows ?? []);\n      setLastLoadedKey(loadKey);\n      return true;\n    } catch (error) {\n      if (!opts?.silent) {\n        setInspectError(getErrorMessage(error, \"Failed to load seed metadata.\"));\n      }\n      setPreviewRows([]);\n      return false;\n    } finally {\n      setIsInspecting(false);\n    }\n  }, [\n    config,\n    getCurrentLoadKey,\n    localFile,\n    mode,\n    onUpdate,\n    unstructuredFile,\n  ]);\n\n  useEffect(() => {\n    const wasOpen = wasOpenRef.current;\n    wasOpenRef.current = open;\n    if (!wasOpen || open || isInspecting) {\n      return;\n    }\n    const key = getCurrentLoadKey();\n    if (!key || key === lastLoadedKey) {\n      return;\n    }\n    void loadSeedMetadata({ silent: true });\n  }, [getCurrentLoadKey, isInspecting, lastLoadedKey, loadSeedMetadata, open]);\n\n  const previewColumns = useMemo(() => {\n    const loadedColumns = config.seed_columns ?? [];\n    if (loadedColumns.length > 0) return loadedColumns;\n    if (previewRows[0]) return Object.keys(previewRows[0]);\n    return [];\n  }, [config.seed_columns, previewRows]);\n  const selectedSeedDropColumns = useMemo(\n    () => (config.seed_drop_columns ?? []).filter((name) => name.trim().length > 0),\n    [config.seed_drop_columns],\n  );\n  const selectedSeedDropSet = useMemo(\n    () => new Set(selectedSeedDropColumns),\n    [selectedSeedDropColumns],\n  );\n  const rowHasExpandableText = useCallback(\n    (row: Record<string, unknown>): boolean =>\n      previewColumns.some((columnName) => {\n        if (resolveImagePreview(row[columnName])) {\n          return false;\n        }\n        return isExpandablePreviewValue(stringifyCell(row[columnName]));\n      }),\n    [previewColumns],\n  );\n\n  return (\n    <Tabs defaultValue=\"config\" className=\"w-full min-w-0\">\n      <TabsList className=\"w-full\">\n        <TabsTrigger value=\"config\">Config</TabsTrigger>\n        <TabsTrigger value=\"preview\">Preview</TabsTrigger>\n      </TabsList>\n\n      <TabsContent value=\"config\" className=\"min-w-0 pt-3\">\n        <div className=\"space-y-4\">\n          {mode === \"hf\" && (\n            <>\n              <div className=\"grid gap-2\">\n                <FieldLabel\n                  label=\"Dataset\"\n                  htmlFor={datasetId}\n                  hint=\"Hugging Face dataset repo id (org/repo).\"\n                />\n                <div className=\"flex items-center gap-2\">\n                  <HfDatasetCombobox\n                    inputId={datasetId}\n                    className=\"flex-1\"\n                    value={config.hf_repo_id}\n                    accessToken={config.hf_token?.trim() || undefined}\n                    placeholder=\"org/repo\"\n                    onValueChange={(nextValue) =>\n                      onUpdate({\n                        hf_repo_id: nextValue,\n                        hf_subset: \"\",\n                        hf_split: \"\",\n                        hf_path: \"\",\n                        seed_columns: [],\n                        seed_drop_columns: [],\n                        seed_preview_rows: [],\n                      })\n                    }\n                  />\n                  <Button\n                    type=\"button\"\n                    variant=\"outline\"\n                    className=\"nodrag shrink-0\"\n                    onClick={() => void loadSeedMetadata()}\n                    disabled={isInspecting || !config.hf_repo_id.trim()}\n                  >\n                    {isInspecting ? \"Loading...\" : \"Load\"}\n                  </Button>\n                </div>\n              </div>\n\n              <div className=\"grid gap-2\">\n                <FieldLabel\n                  label=\"HF token (optional)\"\n                  htmlFor={tokenId}\n                  hint=\"Only needed for private/gated datasets.\"\n                />\n                <Input\n                  id={tokenId}\n                  className=\"nodrag\"\n                  placeholder=\"hf_...\"\n                  value={config.hf_token ?? \"\"}\n                  onChange={(event) => onUpdate({ hf_token: event.target.value })}\n                />\n              </div>\n\n            </>\n          )}\n\n          {mode === \"local\" && (\n            <div className=\"grid gap-2\">\n              <FieldLabel\n                label=\"Structured file\"\n                hint=\"Upload CSV, JSON, or JSONL seed file.\"\n              />\n              <div className=\"flex items-center gap-2\">\n                <Input\n                  className=\"nodrag flex-1\"\n                  type=\"file\"\n                  accept={LOCAL_ACCEPT}\n                  onChange={(event) => {\n                    const file = event.target.files?.[0] ?? null;\n                    setLocalFile(file);\n                    onUpdate({\n                      hf_path: \"\",\n                      seed_columns: [],\n                      seed_drop_columns: [],\n                      seed_preview_rows: [],\n                      local_file_name: file?.name ?? \"\",\n                    });\n                  }}\n                />\n                <Button\n                  type=\"button\"\n                  variant=\"outline\"\n                  className=\"nodrag shrink-0\"\n                  onClick={() => void loadSeedMetadata()}\n                  disabled={isInspecting || !localFile}\n                >\n                  {isInspecting ? \"Loading...\" : \"Load\"}\n                </Button>\n              </div>\n              <p className=\"text-xs text-muted-foreground\">\n                Upload-only. Max 50MB.\n              </p>\n              {(localFile?.name || config.local_file_name?.trim()) && (\n                <p className=\"text-xs text-muted-foreground\">\n                  Selected: {localFile?.name ?? config.local_file_name?.trim()}\n                </p>\n              )}\n            </div>\n          )}\n\n          {mode === \"unstructured\" && (\n            <div className=\"grid gap-2\">\n              <FieldLabel\n                label=\"Unstructured file\"\n                hint=\"Upload PDF, DOCX, or TXT. We chunk text into seed rows.\"\n              />\n              <div className=\"flex items-center gap-2\">\n                <Input\n                  className=\"nodrag flex-1\"\n                  type=\"file\"\n                  accept={UNSTRUCTURED_ACCEPT}\n                  onChange={(event) => {\n                    const file = event.target.files?.[0] ?? null;\n                    setUnstructuredFile(file);\n                    onUpdate({\n                      hf_path: \"\",\n                      seed_columns: [],\n                      seed_drop_columns: [],\n                      seed_preview_rows: [],\n                      unstructured_file_name: file?.name ?? \"\",\n                    });\n                  }}\n                />\n                <Button\n                  type=\"button\"\n                  variant=\"outline\"\n                  className=\"nodrag shrink-0\"\n                  onClick={() => void loadSeedMetadata()}\n                  disabled={isInspecting || !unstructuredFile}\n                >\n                  {isInspecting ? \"Loading...\" : \"Load\"}\n                </Button>\n              </div>\n              <p className=\"text-xs text-muted-foreground\">\n                File is converted to text, then chunked server-side into chunk_text rows. Max 50MB.\n              </p>\n              {(unstructuredFile?.name ||\n                config.unstructured_file_name?.trim()) && (\n                <p className=\"text-xs text-muted-foreground\">\n                  Selected:{\" \"}\n                  {unstructuredFile?.name ?? config.unstructured_file_name?.trim()}\n                </p>\n              )}\n            </div>\n          )}\n\n          {inspectError && <p className=\"text-xs text-red-600\">{inspectError}</p>}\n\n          {mode !== \"unstructured\" && (\n            <div className=\"space-y-2 rounded-xl corner-squircle border border-border/60 p-3\">\n              <FieldLabel\n                label=\"Drop specific seed columns\"\n                hint=\"Dropped columns stay usable in prompts/expressions but are omitted from final dataset.\"\n              />\n              {previewColumns.length === 0 ? (\n                <p className=\"text-xs text-muted-foreground\">\n                  Load columns to select which seed fields to drop.\n                </p>\n              ) : (\n                <div className=\"grid gap-2 sm:grid-cols-2\">\n                  {previewColumns.map((columnName) => {\n                    const checked = selectedSeedDropSet.has(columnName);\n                    return (\n                      <label\n                        key={columnName}\n                        className=\"flex cursor-pointer items-center gap-2 rounded-md border border-border/60 px-2 py-1.5 text-xs\"\n                      >\n                        <Checkbox\n                          checked={checked}\n                          onCheckedChange={(value) => {\n                            const isChecked = value === true;\n                            const next = isChecked\n                              ? Array.from(new Set([...selectedSeedDropColumns, columnName]))\n                              : selectedSeedDropColumns.filter((name) => name !== columnName);\n                            onUpdate({ seed_drop_columns: next });\n                          }}\n                        />\n                        <span className=\"truncate\">{columnName}</span>\n                      </label>\n                    );\n                  })}\n                </div>\n              )}\n            </div>\n          )}\n\n          <Collapsible\n            open={advancedOpen}\n            onOpenChange={(openState) => onUpdate({ advancedOpen: openState })}\n          >\n            <CollapsibleTrigger asChild={true}>\n              <CollapsibleSectionTriggerButton\n                label=\"Advanced source options\"\n                open={advancedOpen}\n              />\n            </CollapsibleTrigger>\n            <CollapsibleContent className=\"mt-2 space-y-3\">\n              <div className=\"grid gap-2\">\n                <FieldLabel\n                  label=\"Sampling strategy\"\n                  htmlFor={samplingId}\n                  hint=\"Ordered keeps row order. Shuffle randomizes sampled rows.\"\n                />\n                <Select\n                  value={config.sampling_strategy}\n                  onValueChange={(value) =>\n                    onUpdate({ sampling_strategy: value as SeedSamplingStrategy })\n                  }\n                >\n                  <SelectTrigger className=\"nodrag w-full\" id={samplingId}>\n                    <SelectValue placeholder=\"Select sampling\" />\n                  </SelectTrigger>\n                  <SelectContent>\n                    {SAMPLING_OPTIONS.map((option) => (\n                      <SelectItem key={option.value} value={option.value}>\n                        {option.label}\n                      </SelectItem>\n                    ))}\n                  </SelectContent>\n                </Select>\n              </div>\n\n              <div className=\"grid gap-2\">\n                <FieldLabel\n                  label=\"Selection strategy\"\n                  htmlFor={selectionId}\n                  hint=\"Select all, a row range, or partition block.\"\n                />\n                <Select\n                  value={config.selection_type}\n                  onValueChange={(value) =>\n                    onUpdate({ selection_type: value as SeedSelectionType })\n                  }\n                >\n                  <SelectTrigger className=\"nodrag w-full\" id={selectionId}>\n                    <SelectValue placeholder=\"Select selection\" />\n                  </SelectTrigger>\n                  <SelectContent>\n                    {SELECTION_OPTIONS.map((option) => (\n                      <SelectItem key={option.value} value={option.value}>\n                        {option.label}\n                      </SelectItem>\n                    ))}\n                  </SelectContent>\n                </Select>\n              </div>\n\n              {mode === \"unstructured\" && (\n                <div className=\"grid grid-cols-2 gap-3\">\n                  <div className=\"grid gap-2\">\n                    <FieldLabel\n                      label=\"Chunk size\"\n                      htmlFor={chunkSizeId}\n                      hint=\"Characters per chunk.\"\n                    />\n                    <Input\n                      id={chunkSizeId}\n                      className=\"nodrag\"\n                      inputMode=\"numeric\"\n                      value={config.unstructured_chunk_size ?? String(DEFAULT_CHUNK_SIZE)}\n                      onChange={(event) =>\n                        onUpdate({ unstructured_chunk_size: event.target.value })\n                      }\n                    />\n                  </div>\n                  <div className=\"grid gap-2\">\n                    <FieldLabel\n                      label=\"Chunk overlap\"\n                      htmlFor={chunkOverlapId}\n                      hint=\"Shared chars between adjacent chunks.\"\n                    />\n                    <Input\n                      id={chunkOverlapId}\n                      className=\"nodrag\"\n                      inputMode=\"numeric\"\n                      value={\n                        config.unstructured_chunk_overlap ??\n                        String(DEFAULT_CHUNK_OVERLAP)\n                      }\n                      onChange={(event) =>\n                        onUpdate({ unstructured_chunk_overlap: event.target.value })\n                      }\n                    />\n                  </div>\n                </div>\n              )}\n\n              {config.selection_type === \"index_range\" && (\n                <div className=\"grid grid-cols-2 gap-3\">\n                  <div className=\"grid gap-2\">\n                    <FieldLabel label=\"Start\" hint=\"Inclusive start row index for index_range.\" />\n                    <Input\n                      className=\"nodrag\"\n                      inputMode=\"numeric\"\n                      value={config.selection_start ?? \"\"}\n                      onChange={(event) => onUpdate({ selection_start: event.target.value })}\n                    />\n                  </div>\n                  <div className=\"grid gap-2\">\n                    <FieldLabel label=\"End\" hint=\"Inclusive end row index for index_range.\" />\n                    <Input\n                      className=\"nodrag\"\n                      inputMode=\"numeric\"\n                      value={config.selection_end ?? \"\"}\n                      onChange={(event) => onUpdate({ selection_end: event.target.value })}\n                    />\n                  </div>\n                </div>\n              )}\n\n              {config.selection_type === \"partition_block\" && (\n                <div className=\"grid grid-cols-2 gap-3\">\n                  <div className=\"grid gap-2\">\n                    <FieldLabel label=\"Index\" hint=\"Partition index to load.\" />\n                    <Input\n                      className=\"nodrag\"\n                      inputMode=\"numeric\"\n                      value={config.selection_index ?? \"\"}\n                      onChange={(event) => onUpdate({ selection_index: event.target.value })}\n                    />\n                  </div>\n                  <div className=\"grid gap-2\">\n                    <FieldLabel label=\"Partitions\" hint=\"Total number of partitions.\" />\n                    <Input\n                      className=\"nodrag\"\n                      inputMode=\"numeric\"\n                      value={config.selection_num_partitions ?? \"\"}\n                      onChange={(event) =>\n                        onUpdate({ selection_num_partitions: event.target.value })\n                      }\n                    />\n                  </div>\n                </div>\n              )}\n            </CollapsibleContent>\n          </Collapsible>\n        </div>\n      </TabsContent>\n\n      <TabsContent value=\"preview\" className=\"min-w-0 pt-3\">\n        <div className=\"space-y-4\">\n          {previewRows.length === 0 ? (\n            <div className=\"flex w-full items-center justify-center\">\n              <Empty className=\"max-w-lg\">\n                <EmptyHeader>\n                  <EmptyTitle>{previewEmpty.title}</EmptyTitle>\n                  <EmptyDescription>\n                    {previewEmpty.description}\n                  </EmptyDescription>\n                </EmptyHeader>\n                <EmptyContent className=\"text-xs text-muted-foreground\">\n                  Preview appears here after loading source metadata.\n                </EmptyContent>\n              </Empty>\n            </div>\n          ) : (\n            <div className=\"space-y-2\">\n              <div className=\"text-xs text-muted-foreground\">\n                Loaded columns: {previewColumns.join(\", \") || \"None\"}\n              </div>\n              <div className=\"max-h-[360px] overflow-y-auto overflow-x-hidden rounded-xl corner-squircle border border-border/60\">\n                <Table className=\"corner-squircle min-w-max\">\n                  <TableHeader>\n                    <TableRow>\n                      {previewColumns.map((col) => (\n                        <TableHead key={col} className=\"whitespace-nowrap\">\n                          {col}\n                        </TableHead>\n                      ))}\n                    </TableRow>\n                  </TableHeader>\n                  <TableBody>\n                    {previewRows.map((row, rowIdx) => (\n                      <TableRow\n                        key={`row-${rowIdx}`}\n                        className={cn(\n                          rowHasExpandableText(row) && \"cursor-pointer hover:bg-primary/[0.06]\",\n                          expandedPreviewRows[rowIdx] && \"bg-primary/[0.05]\",\n                        )}\n                        onClick={() => {\n                          const canExpand = rowHasExpandableText(row);\n                          if (!canExpand) {\n                            return;\n                          }\n                          setExpandedPreviewRows((current) => ({\n                            ...current,\n                            [rowIdx]: !current[rowIdx],\n                          }));\n                        }}\n                      >\n                        {previewColumns.map((col) => (\n                          <TableCell\n                            key={`${rowIdx}-${col}`}\n                            className=\"max-w-[260px] whitespace-pre-wrap break-words text-xs\"\n                          >\n                            {(() => {\n                              const imagePreview = resolveImagePreview(row[col]);\n                              if (imagePreview?.kind === \"ready\") {\n                                return (\n                                  <img\n                                    src={imagePreview.src}\n                                    alt={`${col} preview`}\n                                    loading=\"lazy\"\n                                    className=\"h-20 w-auto max-w-[220px] rounded-md border border-border/60 bg-muted/20 object-contain\"\n                                  />\n                                );\n                              }\n                              if (imagePreview?.kind === \"too_large\") {\n                                return \"Image too large to preview\";\n                              }\n                              const value = stringifyCell(row[col]);\n                              const rowHasExpandableCell = rowHasExpandableText(row);\n                              const rowExpanded = Boolean(expandedPreviewRows[rowIdx]);\n                              return rowHasExpandableCell && !rowExpanded\n                                ? truncatePreviewValue(value)\n                                : value;\n                            })()}\n                          </TableCell>\n                        ))}\n                      </TableRow>\n                    ))}\n                  </TableBody>\n                </Table>\n              </div>\n            </div>\n          )}\n        </div>\n      </TabsContent>\n    </Tabs>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/dialogs/shared/available-variables.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Badge } from \"@/components/ui/badge\";\nimport { ArrowDown01Icon } from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport { type ReactElement, useMemo, useState } from \"react\";\nimport { useRecipeStudioStore } from \"../../stores/recipe-studio\";\nimport { getAvailableVariableEntries } from \"../../utils/variables\";\nimport { RECIPE_STUDIO_REFERENCE_BADGE_TONES } from \"../../utils/ui-tones\";\n\ntype AvailableVariablesProps = {\n  configId: string;\n};\n\nconst USER_EXPANDED_FIELDS = [\n  \"first_name\",\n  \"last_name\",\n  \"sex\",\n  \"city\",\n  \"state\",\n  \"age\",\n] as const;\nexport function AvailableVariables({\n  configId,\n}: AvailableVariablesProps): ReactElement | null {\n  const [showUserFields, setShowUserFields] = useState(false);\n  const configs = useRecipeStudioStore((state) => state.configs);\n  const vars = getAvailableVariableEntries(configs, configId);\n  const variableNames = useMemo(() => new Set(vars.map((entry) => entry.name)), [vars]);\n  const hasUserRoot = variableNames.has(\"user\");\n  const userFieldEntries = useMemo(\n    () =>\n      USER_EXPANDED_FIELDS.map((field) => ({\n        source: \"column\" as const,\n        name: `user.${field}`,\n      })).filter((entry) => !variableNames.has(entry.name)),\n    [variableNames],\n  );\n\n  if (vars.length === 0) return null;\n\n  return (\n    <div className=\"corner-squircle rounded-2xl border border-border/60 px-3 py-2\">\n      <p className=\"mb-2 text-xs font-semibold uppercase text-muted-foreground\">\n        Available references\n      </p>\n      <div className=\"flex flex-wrap gap-1.5\">\n        {vars.map((v) => {\n          const className =\n            v.name === \"user\" || v.name.startsWith(\"user.\")\n              ? RECIPE_STUDIO_REFERENCE_BADGE_TONES.user\n              : v.source === \"seed\"\n                ? RECIPE_STUDIO_REFERENCE_BADGE_TONES.seed\n                : RECIPE_STUDIO_REFERENCE_BADGE_TONES.default;\n          if (v.name !== \"user\") {\n            return (\n              <Badge\n                key={`${v.source}:${v.name}`}\n                variant=\"secondary\"\n                className={className}\n              >\n                {`{{ ${v.name} }}`}\n              </Badge>\n            );\n          }\n          return (\n            <button\n              key={`${v.source}:${v.name}`}\n              type=\"button\"\n              onClick={() => setShowUserFields((prev) => !prev)}\n              className=\"cursor-pointer\"\n              aria-expanded={showUserFields}\n              aria-label={showUserFields ? \"Hide user fields\" : \"Show user fields\"}\n            >\n              <Badge variant=\"secondary\" className={className}>\n                <span>{`{{ ${v.name} }}`}</span>\n                <HugeiconsIcon\n                  icon={ArrowDown01Icon}\n                  className={`size-3 transition-transform ${showUserFields ? \"rotate-180\" : \"\"}`}\n                />\n              </Badge>\n            </button>\n          );\n        })}\n        {hasUserRoot && showUserFields &&\n          userFieldEntries.map((entry) => (\n            <Badge\n              key={`user-expanded:${entry.name}`}\n              variant=\"secondary\"\n              className={RECIPE_STUDIO_REFERENCE_BADGE_TONES.user}\n            >\n              {`{{ ${entry.name} }}`}\n            </Badge>\n          ))}\n      </div>\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/dialogs/shared/collapsible-section-trigger.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { cn } from \"@/lib/utils\";\nimport { ArrowDown01Icon } from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport {\n  forwardRef,\n  type ButtonHTMLAttributes,\n  type ReactElement,\n} from \"react\";\n\ntype CollapsibleSectionTriggerProps = {\n  label: string;\n  open: boolean;\n  summary?: string;\n} & ButtonHTMLAttributes<HTMLButtonElement>;\n\nexport const CollapsibleSectionTriggerButton = forwardRef<\n  HTMLButtonElement,\n  CollapsibleSectionTriggerProps\n>(function CollapsibleSectionTriggerButton(\n  {\n    label,\n    open,\n    summary,\n    className,\n    type = \"button\",\n    ...props\n  }: CollapsibleSectionTriggerProps,\n  ref,\n): ReactElement {\n  return (\n    <button\n      ref={ref}\n      type={type}\n      className={cn(\n        \"flex w-full items-center justify-between gap-3 text-left text-xs text-muted-foreground transition hover:text-foreground\",\n        className,\n      )}\n      {...props}\n    >\n      <span className=\"flex min-w-0 items-center gap-2\">\n        <HugeiconsIcon\n          icon={ArrowDown01Icon}\n          className={cn(\n            \"size-3.5 shrink-0 transition-transform\",\n            open && \"rotate-180\",\n          )}\n        />\n        <span className=\"font-semibold uppercase\">{label}</span>\n      </span>\n      <span className=\"shrink-0\">{summary ?? (open ? \"Hide\" : \"Show\")}</span>\n    </button>\n  );\n});\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/dialogs/shared/dialog-shell.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport {\n  DialogDescription,\n  DialogHeader,\n  DialogTitle,\n} from \"@/components/ui/dialog\";\nimport type { ReactElement } from \"react\";\n\ntype DialogShellProps = {\n  title?: string;\n  description?: string;\n};\n\nexport function DialogShell({\n  title = \"Edit step\",\n  description = \"Update this step before you run the recipe.\",\n}: DialogShellProps): ReactElement {\n  return (\n    <DialogHeader>\n      <DialogTitle>{title}</DialogTitle>\n      <DialogDescription>{description}</DialogDescription>\n    </DialogHeader>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/dialogs/shared/field-label.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Tooltip, TooltipContent, TooltipTrigger } from \"@/components/ui/tooltip\";\nimport { InformationCircleIcon } from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport type { ReactElement } from \"react\";\n\ntype FieldLabelProps = {\n  label: string;\n  htmlFor?: string;\n  hint?: string;\n};\n\nexport function FieldLabel({\n  label,\n  htmlFor,\n  hint,\n}: FieldLabelProps): ReactElement {\n  return (\n    <div className=\"flex min-w-0 items-start gap-1.5 text-xs font-semibold uppercase text-muted-foreground\">\n      {htmlFor ? (\n        <label className=\"min-w-0 cursor-pointer\" htmlFor={htmlFor}>\n          <span className=\"break-words\">{label}</span>\n        </label>\n      ) : (\n        <span className=\"min-w-0 break-words\">{label}</span>\n      )}\n      {hint && (\n        <Tooltip>\n          <TooltipTrigger asChild={true}>\n            <button\n              type=\"button\"\n              className=\"inline-flex size-6 shrink-0 items-center justify-center rounded-full text-muted-foreground/80 transition hover:text-foreground\"\n              aria-label={`More info: ${label}`}\n              title={`More info about ${label}`}\n            >\n              <HugeiconsIcon icon={InformationCircleIcon} className=\"size-4\" />\n            </button>\n          </TooltipTrigger>\n          <TooltipContent className=\"max-w-64 break-words text-xs leading-relaxed\">\n            {hint}\n          </TooltipContent>\n        </Tooltip>\n      )}\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/dialogs/shared/name-field.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Input } from \"@/components/ui/input\";\nimport { type ReactElement, useId } from \"react\";\nimport { FieldLabel } from \"./field-label\";\n\ntype NameFieldProps = {\n  id?: string;\n  value: string;\n  onChange: (value: string) => void;\n  label?: string;\n  hint?: string;\n};\n\nexport function NameField({\n  id,\n  value,\n  onChange,\n  label,\n  hint,\n}: NameFieldProps): ReactElement {\n  const fallbackId = useId();\n  const inputId = id ?? fallbackId;\n  return (\n    <div className=\"grid gap-2\">\n      <FieldLabel\n        label={label ?? \"Field name\"}\n        htmlFor={inputId}\n        hint={\n          hint ??\n          \"This name is used in prompts and in the final dataset.\"\n        }\n      />\n      <Input\n        id={inputId}\n        className=\"nodrag\"\n        value={value}\n        onChange={(event) => onChange(event.target.value)}\n      />\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/dialogs/shared/validation-banner.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { ReactElement } from \"react\";\nimport type { NodeConfig } from \"../../types\";\nimport { getConfigErrors } from \"../../utils\";\n\nexport function ValidationBanner({\n  config,\n}: {\n  config: NodeConfig | null;\n}): ReactElement | null {\n  const errors = getConfigErrors(config);\n  if (errors.length === 0) {\n    return null;\n  }\n  return (\n    <p className=\"text-xs text-amber-600\">\n      <span className=\"font-semibold\">Needs attention: </span>\n      {errors.join(\". \")}.\n    </p>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/dialogs/tool-profile/helpers.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { LlmMcpProviderConfig } from \"../../types\";\n\nexport function createMcpProviderId(prefix: string, index: number): string {\n  return `${prefix}-mcp-${Date.now()}-${index + 1}`;\n}\n\nexport function addUnique(items: string[], value: string): string[] {\n  const trimmed = value.trim();\n  if (!trimmed || items.includes(trimmed)) {\n    return items;\n  }\n  return [...items, trimmed];\n}\n\nexport function collectToolSuggestions(\n  providerNames: string[],\n  toolsByProvider: Record<string, string[]>,\n): string[] {\n  return Array.from(\n    new Set(\n      providerNames.flatMap(\n        (providerName) => toolsByProvider[providerName.trim()] ?? [],\n      ),\n    ),\n  );\n}\n\nexport function isProviderReadyForToolFetch(\n  provider: LlmMcpProviderConfig,\n): boolean {\n  const hasName = provider.name.trim().length > 0;\n  if (!hasName) {\n    return false;\n  }\n  if (provider.provider_type === \"stdio\") {\n    return (provider.command?.trim().length ?? 0) > 0;\n  }\n  return (provider.endpoint?.trim().length ?? 0) > 0;\n}\n\nexport function toApiProvider(\n  provider: LlmMcpProviderConfig,\n): Record<string, unknown> {\n  if (provider.provider_type === \"stdio\") {\n    const env = Object.fromEntries(\n      (provider.env ?? [])\n        .map((item) => [item.key.trim(), item.value.trim()] as const)\n        .filter(([key, value]) => key && value),\n    );\n    return {\n      // biome-ignore lint/style/useNamingConvention: api schema\n      provider_type: \"stdio\",\n      name: provider.name.trim(),\n      command: provider.command?.trim() ?? \"\",\n      args: (provider.args ?? []).map((value) => value.trim()).filter(Boolean),\n      env,\n    };\n  }\n  return {\n    // biome-ignore lint/style/useNamingConvention: api schema\n    provider_type: \"streamable_http\",\n    name: provider.name.trim(),\n    endpoint: provider.endpoint?.trim() ?? \"\",\n    // biome-ignore lint/style/useNamingConvention: api schema\n    api_key: provider.api_key?.trim() || undefined,\n    // biome-ignore lint/style/useNamingConvention: api schema\n    api_key_env: provider.api_key_env?.trim() || undefined,\n  };\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/dialogs/tool-profile/tool-profile-dialog.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Badge } from \"@/components/ui/badge\";\nimport { Button } from \"@/components/ui/button\";\nimport {\n  Collapsible,\n  CollapsibleContent,\n  CollapsibleTrigger,\n} from \"@/components/ui/collapsible\";\nimport { Input } from \"@/components/ui/input\";\nimport { Tabs, TabsContent, TabsList, TabsTrigger } from \"@/components/ui/tabs\";\nimport { toastError } from \"@/shared/toast\";\nimport {\n  ArrowRight01Icon,\n  Delete02Icon,\n  PlusSignIcon,\n} from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport { type ReactElement, useEffect, useMemo, useRef, useState } from \"react\";\nimport { listMcpTools } from \"../../api\";\nimport { ChipInput } from \"../../components/chip-input\";\nimport { CollapsibleSectionTriggerButton } from \"../shared/collapsible-section-trigger\";\nimport type { LlmMcpProviderConfig, McpEnvVar, ToolProfileConfig } from \"../../types\";\nimport { FieldLabel } from \"../shared/field-label\";\nimport { NameField } from \"../shared/name-field\";\nimport {\n  addUnique,\n  collectToolSuggestions,\n  createMcpProviderId,\n  isProviderReadyForToolFetch,\n  toApiProvider,\n} from \"./helpers\";\n\ntype ToolProfileDialogProps = {\n  config: ToolProfileConfig;\n  onUpdate: (patch: Partial<ToolProfileConfig>) => void;\n};\n\nfunction EmptyState({\n  title,\n  description,\n}: {\n  title: string;\n  description: string;\n}): ReactElement {\n  return (\n    <div className=\"rounded-2xl border border-dashed border-border/70 bg-muted/15 px-4 py-5 text-sm\">\n      <p className=\"font-semibold text-foreground\">{title}</p>\n      <p className=\"mt-1 text-xs text-muted-foreground\">{description}</p>\n    </div>\n  );\n}\n\nfunction isProviderConfigured(provider: LlmMcpProviderConfig): boolean {\n  const hasName = provider.name.trim().length > 0;\n  if (!hasName) {\n    return false;\n  }\n  if (provider.provider_type === \"stdio\") {\n    return (provider.command?.trim().length ?? 0) > 0;\n  }\n  return (provider.endpoint?.trim().length ?? 0) > 0;\n}\n\nfunction McpServerCard({\n  provider,\n  index,\n  toolsCount,\n  error,\n  open,\n  onOpenChange,\n  onUpdateProviderAt,\n  onRemoveProvider,\n  onAddProviderArg,\n  onUpdateProviderArg,\n  onRemoveProviderArg,\n  onAddProviderEnv,\n  onUpdateProviderEnv,\n  onRemoveProviderEnv,\n}: {\n  provider: LlmMcpProviderConfig;\n  index: number;\n  toolsCount?: number;\n  error?: string;\n  open: boolean;\n  onOpenChange: (open: boolean) => void;\n  onUpdateProviderAt: (\n    index: number,\n    patch: Partial<LlmMcpProviderConfig>,\n  ) => void;\n  onRemoveProvider: (index: number) => void;\n  onAddProviderArg: (index: number) => void;\n  onUpdateProviderArg: (index: number, argIndex: number, value: string) => void;\n  onRemoveProviderArg: (index: number, argIndex: number) => void;\n  onAddProviderEnv: (index: number) => void;\n  onUpdateProviderEnv: (\n    index: number,\n    envIndex: number,\n    patch: Partial<McpEnvVar>,\n  ) => void;\n  onRemoveProviderEnv: (index: number, envIndex: number) => void;\n}): ReactElement {\n  const args = provider.args && provider.args.length > 0 ? provider.args : [\"\"];\n  const envVars =\n    provider.env && provider.env.length > 0\n      ? provider.env\n      : [{ key: \"\", value: \"\" }];\n  const summaryTitle = provider.name.trim() || `Tool server ${index + 1}`;\n  const transportLabel =\n    provider.provider_type === \"stdio\" ? \"Local command\" : \"HTTP\";\n  const toolsLabel = typeof toolsCount === \"number\" ? `${toolsCount} tools` : null;\n  const description =\n    provider.provider_type === \"stdio\"\n      ? \"Runs a local tool server.\"\n      : \"Calls a remote tool server.\";\n\n  return (\n    <Collapsible open={open} onOpenChange={onOpenChange}>\n      <div className=\"rounded-2xl border border-border/60 bg-background/80\">\n        <div className=\"flex items-start gap-2 px-4 py-4\">\n          <CollapsibleTrigger asChild={true}>\n            <button\n              type=\"button\"\n              className=\"flex min-w-0 flex-1 items-start gap-3 text-left\"\n            >\n              <HugeiconsIcon\n                icon={ArrowRight01Icon}\n                className={`mt-0.5 size-4 shrink-0 text-muted-foreground transition-transform ${\n                  open ? \"rotate-90\" : \"\"\n                }`}\n              />\n              <div className=\"min-w-0\">\n                <div className=\"flex flex-wrap items-center gap-2\">\n                  <p className=\"truncate text-sm font-semibold text-foreground\">\n                    {summaryTitle}\n                  </p>\n                  <Badge variant=\"outline\" className=\"rounded-full text-[10px] uppercase\">\n                    {transportLabel}\n                  </Badge>\n                  {toolsLabel ? (\n                    <Badge variant=\"secondary\" className=\"rounded-full text-[10px]\">\n                      {toolsLabel}\n                    </Badge>\n                  ) : null}\n                </div>\n                <p className=\"mt-1 text-xs text-muted-foreground\">{description}</p>\n              </div>\n            </button>\n          </CollapsibleTrigger>\n          <Button\n            type=\"button\"\n            size=\"icon-sm\"\n            variant=\"ghost\"\n            onClick={() => onRemoveProvider(index)}\n          >\n            <HugeiconsIcon icon={Delete02Icon} className=\"size-4\" />\n          </Button>\n        </div>\n\n        <CollapsibleContent className=\"space-y-4 border-t border-border/50 px-4 pt-4 pb-4\">\n          {error && (\n            <div className=\"rounded-xl border border-destructive/30 bg-destructive/5 px-3 py-2 text-xs text-destructive\">\n              {error}\n            </div>\n          )}\n\n          <div className=\"grid gap-2\">\n            <FieldLabel label=\"Server name\" hint=\"Name shown in this tool access setup.\" />\n            <Input\n              className=\"nodrag\"\n              value={provider.name}\n              placeholder=\"context7\"\n              onChange={(event) =>\n                onUpdateProviderAt(index, { name: event.target.value })\n              }\n            />\n          </div>\n\n          <Tabs\n            value={provider.provider_type}\n            onValueChange={(value) =>\n              onUpdateProviderAt(index, {\n                // biome-ignore lint/style/useNamingConvention: ui schema\n                provider_type: value === \"stdio\" ? \"stdio\" : \"streamable_http\",\n              })\n            }\n          >\n              <TabsList className=\"w-full\">\n                <TabsTrigger value=\"stdio\">Local command</TabsTrigger>\n                <TabsTrigger value=\"streamable_http\">HTTP endpoint</TabsTrigger>\n              </TabsList>\n          </Tabs>\n\n          {provider.provider_type === \"stdio\" ? (\n            <div className=\"space-y-4\">\n              <div className=\"grid gap-2\">\n                <FieldLabel label=\"Command\" hint=\"Command used to start the tool server.\" />\n                <Input\n                  className=\"nodrag\"\n                  value={provider.command ?? \"\"}\n                  placeholder=\"npx\"\n                  onChange={(event) =>\n                    onUpdateProviderAt(index, { command: event.target.value })\n                  }\n                />\n              </div>\n\n              <div className=\"space-y-2\">\n                <div className=\"flex items-center justify-between gap-3\">\n                  <FieldLabel label=\"Arguments\" hint=\"Optional command arguments.\" />\n                  <Button\n                    type=\"button\"\n                    size=\"xs\"\n                    variant=\"outline\"\n                    onClick={() => onAddProviderArg(index)}\n                  >\n                    <HugeiconsIcon icon={PlusSignIcon} className=\"size-3.5\" />\n                    Add arg\n                  </Button>\n                </div>\n                {args.map((arg, argIndex) => (\n                  <div key={`${provider.id}-arg-${argIndex}`} className=\"flex gap-2\">\n                    <Input\n                      className=\"nodrag\"\n                      value={arg}\n                      placeholder={argIndex === 0 ? \"-y\" : \"argument\"}\n                      onChange={(event) =>\n                        onUpdateProviderArg(index, argIndex, event.target.value)\n                      }\n                    />\n                    <Button\n                      type=\"button\"\n                      size=\"icon-sm\"\n                      variant=\"ghost\"\n                      onClick={() => onRemoveProviderArg(index, argIndex)}\n                    >\n                      <HugeiconsIcon icon={Delete02Icon} className=\"size-4\" />\n                    </Button>\n                  </div>\n                ))}\n              </div>\n\n              <div className=\"space-y-2\">\n                <div className=\"flex items-center justify-between gap-3\">\n                  <FieldLabel label=\"Environment variables\" hint=\"Optional values passed to the tool server.\" />\n                  <Button\n                    type=\"button\"\n                    size=\"xs\"\n                    variant=\"outline\"\n                    onClick={() => onAddProviderEnv(index)}\n                  >\n                    <HugeiconsIcon icon={PlusSignIcon} className=\"size-3.5\" />\n                    Add env\n                  </Button>\n                </div>\n                {envVars.map((item, envIndex) => (\n                  <div\n                    key={`${provider.id}-env-${envIndex}`}\n                    className=\"grid grid-cols-[1fr_1fr_auto] gap-2\"\n                  >\n                    <Input\n                      className=\"nodrag\"\n                      value={item.key}\n                      placeholder=\"KEY\"\n                      onChange={(event) =>\n                        onUpdateProviderEnv(index, envIndex, {\n                          key: event.target.value,\n                        })\n                      }\n                    />\n                    <Input\n                      className=\"nodrag\"\n                      value={item.value}\n                      placeholder=\"value\"\n                      onChange={(event) =>\n                        onUpdateProviderEnv(index, envIndex, {\n                          value: event.target.value,\n                        })\n                      }\n                    />\n                    <Button\n                      type=\"button\"\n                      size=\"icon-sm\"\n                      variant=\"ghost\"\n                      onClick={() => onRemoveProviderEnv(index, envIndex)}\n                    >\n                      <HugeiconsIcon icon={Delete02Icon} className=\"size-4\" />\n                    </Button>\n                  </div>\n                ))}\n              </div>\n            </div>\n          ) : (\n            <div className=\"space-y-4\">\n              <div className=\"grid gap-2\">\n                <FieldLabel label=\"Endpoint\" hint=\"URL for the tool server.\" />\n                <Input\n                  className=\"nodrag\"\n                  value={provider.endpoint ?? \"\"}\n                  placeholder=\"https://example.com/mcp\"\n                  onChange={(event) =>\n                    onUpdateProviderAt(index, { endpoint: event.target.value })\n                  }\n                />\n              </div>\n              <div className=\"grid gap-2 sm:grid-cols-2\">\n                <div className=\"grid gap-2\">\n                  <FieldLabel\n                    label=\"API key environment variable\"\n                    hint=\"Optional environment variable that stores the API key.\"\n                  />\n                  <Input\n                    className=\"nodrag\"\n                    value={provider.api_key_env ?? \"\"}\n                    placeholder=\"TOOL_SERVER_API_KEY\"\n                    onChange={(event) =>\n                      onUpdateProviderAt(index, {\n                        // biome-ignore lint/style/useNamingConvention: api schema\n                        api_key_env: event.target.value,\n                      })\n                    }\n                  />\n                </div>\n                <div className=\"grid gap-2\">\n                  <FieldLabel\n                    label=\"API key\"\n                    hint=\"Optional API key.\"\n                  />\n                  <Input\n                    className=\"nodrag\"\n                    value={provider.api_key ?? \"\"}\n                    placeholder=\"token\"\n                    onChange={(event) =>\n                      onUpdateProviderAt(index, {\n                        // biome-ignore lint/style/useNamingConvention: api schema\n                        api_key: event.target.value,\n                      })\n                    }\n                  />\n                </div>\n              </div>\n            </div>\n          )}\n        </CollapsibleContent>\n      </div>\n    </Collapsible>\n  );\n}\n\nexport function ToolProfileDialog({\n  config,\n  onUpdate,\n}: ToolProfileDialogProps): ReactElement {\n  const providers = config.mcp_providers;\n  const [activeTab, setActiveTab] = useState<\"profile\" | \"servers\">(\n    providers.length > 0 ? \"profile\" : \"servers\",\n  );\n  const [advancedOpen, setAdvancedOpen] = useState(false);\n  const [loadingTools, setLoadingTools] = useState(false);\n  const [toolsByProvider, setToolsByProvider] = useState<Record<string, string[]>>(\n    config.fetched_tools_by_provider ?? {},\n  );\n  const [providerErrors, setProviderErrors] = useState<Record<string, string>>({});\n  const [duplicateTools, setDuplicateTools] = useState<Record<string, string[]>>({});\n  const [openProviders, setOpenProviders] = useState<Record<string, boolean>>({});\n  const previousProviderSignatureRef = useRef<string | null>(null);\n\n  const providerSignature = useMemo(\n    () =>\n      JSON.stringify(\n        providers.map((provider) => ({\n          name: provider.name,\n          // biome-ignore lint/style/useNamingConvention: ui schema\n          provider_type: provider.provider_type,\n          command: provider.command,\n          args: provider.args,\n          env: provider.env,\n          endpoint: provider.endpoint,\n          // biome-ignore lint/style/useNamingConvention: api schema\n          api_key: provider.api_key,\n          // biome-ignore lint/style/useNamingConvention: api schema\n          api_key_env: provider.api_key_env,\n        })),\n      ),\n    [providers],\n  );\n\n  useEffect(() => {\n    const previousSignature = previousProviderSignatureRef.current;\n    previousProviderSignatureRef.current = providerSignature;\n    if (previousSignature === null) {\n      setToolsByProvider(config.fetched_tools_by_provider ?? {});\n      return;\n    }\n    if (previousSignature === providerSignature) {\n      return;\n    }\n    setToolsByProvider({});\n    setProviderErrors({});\n    setDuplicateTools({});\n    if (Object.keys(config.fetched_tools_by_provider ?? {}).length > 0) {\n      onUpdate({\n        // biome-ignore lint/style/useNamingConvention: ui schema\n        fetched_tools_by_provider: {},\n      });\n    }\n  }, [config.fetched_tools_by_provider, onUpdate, providerSignature]);\n\n  useEffect(() => {\n    const tools = config.fetched_tools_by_provider ?? {};\n    setToolsByProvider(tools);\n  }, [config.fetched_tools_by_provider]);\n\n  useEffect(() => {\n    setOpenProviders((current) => {\n      const next: Record<string, boolean> = {};\n      for (const provider of providers) {\n        next[provider.id] =\n          current[provider.id] ?? !isProviderConfigured(provider);\n      }\n      return next;\n    });\n  }, [providers]);\n\n  function updateProviders(nextProviders: LlmMcpProviderConfig[]): void {\n    onUpdate({\n      // biome-ignore lint/style/useNamingConvention: ui schema\n      mcp_providers: nextProviders,\n    });\n  }\n\n  function updateProviderAt(\n    index: number,\n    patch: Partial<LlmMcpProviderConfig>,\n  ): void {\n    updateProviders(\n      providers.map((provider, currentIndex) =>\n        currentIndex === index ? { ...provider, ...patch } : provider,\n      ),\n    );\n  }\n\n  function mutateProviderAt(\n    index: number,\n    mapProvider: (provider: LlmMcpProviderConfig) => Partial<LlmMcpProviderConfig>,\n  ): void {\n    const provider = providers[index];\n    if (!provider) {\n      return;\n    }\n    updateProviderAt(index, mapProvider(provider));\n  }\n\n  function removeProvider(index: number): void {\n    updateProviders(providers.filter((_, currentIndex) => currentIndex !== index));\n  }\n\n  function addProvider(): void {\n    updateProviders([\n      ...providers,\n      {\n        id: createMcpProviderId(config.id, providers.length),\n        name: \"\",\n        // biome-ignore lint/style/useNamingConvention: ui schema\n        provider_type: \"stdio\",\n        command: \"\",\n        args: [],\n        env: [],\n        endpoint: \"\",\n        // biome-ignore lint/style/useNamingConvention: api schema\n        api_key: \"\",\n        // biome-ignore lint/style/useNamingConvention: api schema\n        api_key_env: \"\",\n      },\n    ]);\n  }\n\n  function addProviderArg(providerIndex: number): void {\n    mutateProviderAt(providerIndex, (provider) => ({\n      args: [...(provider.args ?? []), \"\"],\n    }));\n  }\n\n  function updateProviderArg(\n    providerIndex: number,\n    argIndex: number,\n    value: string,\n  ): void {\n    mutateProviderAt(providerIndex, (provider) => {\n      const nextArgs =\n        provider.args && provider.args.length > 0 ? [...provider.args] : [\"\"];\n      nextArgs[argIndex] = value;\n      return { args: nextArgs };\n    });\n  }\n\n  function removeProviderArg(providerIndex: number, argIndex: number): void {\n    mutateProviderAt(providerIndex, (provider) => ({\n      args: (provider.args ?? []).filter((_, currentIndex) => currentIndex !== argIndex),\n    }));\n  }\n\n  function addProviderEnv(providerIndex: number): void {\n    mutateProviderAt(providerIndex, (provider) => ({\n      env: [...(provider.env ?? []), { key: \"\", value: \"\" }],\n    }));\n  }\n\n  function updateProviderEnv(\n    providerIndex: number,\n    envIndex: number,\n    patch: Partial<McpEnvVar>,\n  ): void {\n    mutateProviderAt(providerIndex, (provider) => ({\n      env: (\n        provider.env && provider.env.length > 0\n          ? provider.env\n          : [{ key: \"\", value: \"\" }]\n      ).map((item, currentIndex) =>\n        currentIndex === envIndex ? { ...item, ...patch } : item,\n      ),\n    }));\n  }\n\n  function removeProviderEnv(providerIndex: number, envIndex: number): void {\n    mutateProviderAt(providerIndex, (provider) => ({\n      env: (provider.env ?? []).filter((_, currentIndex) => currentIndex !== envIndex),\n    }));\n  }\n\n  async function loadTools(): Promise<void> {\n    const readyProviders = providers.filter(isProviderReadyForToolFetch);\n    if (readyProviders.length === 0) {\n      toastError(\n        \"No tool servers are ready\",\n        \"Add a server name plus a command or endpoint first.\",\n      );\n      return;\n    }\n\n    setLoadingTools(true);\n    try {\n      const timeoutRaw = config.timeout_sec?.trim();\n      const timeoutSec =\n        timeoutRaw && Number.isFinite(Number(timeoutRaw))\n          ? Number(timeoutRaw)\n          : 15;\n      const response = await listMcpTools({\n        // biome-ignore lint/style/useNamingConvention: api schema\n        mcp_providers: readyProviders.map(toApiProvider),\n        // biome-ignore lint/style/useNamingConvention: api schema\n        timeout_sec: timeoutSec,\n      });\n      const nextToolsByProvider = Object.fromEntries(\n        response.providers\n          .filter((provider) => provider.name.trim())\n          .map((provider) => [provider.name.trim(), provider.tools]),\n      );\n      setToolsByProvider(nextToolsByProvider);\n      onUpdate({\n        // biome-ignore lint/style/useNamingConvention: ui schema\n        fetched_tools_by_provider: nextToolsByProvider,\n      });\n      setProviderErrors(\n        Object.fromEntries(\n          response.providers\n            .filter((provider) => provider.name.trim() && provider.error)\n            .map((provider) => [provider.name.trim(), provider.error ?? \"Failed to load tools.\"]),\n        ),\n      );\n      setDuplicateTools(response.duplicate_tools ?? {});\n    } catch (error) {\n      toastError(\n        \"Couldn't load tools\",\n        error instanceof Error ? error.message : \"We couldn't load the tools for these servers.\",\n      );\n    } finally {\n      setLoadingTools(false);\n    }\n  }\n\n  const providerNames = useMemo(\n    () =>\n      Array.from(\n        new Set(providers.map((provider) => provider.name.trim()).filter(Boolean)),\n      ),\n    [providers],\n  );\n  const availableTools = useMemo(\n    () => collectToolSuggestions(providerNames, toolsByProvider),\n    [providerNames, toolsByProvider],\n  );\n  const hasProviders = providers.length > 0;\n\n  useEffect(() => {\n    if (!hasProviders && activeTab === \"profile\") {\n      setActiveTab(\"servers\");\n    }\n  }, [activeTab, hasProviders]);\n\n  return (\n    <Tabs\n      value={activeTab}\n      onValueChange={(value) =>\n        setActiveTab(value === \"servers\" ? \"servers\" : \"profile\")\n      }\n      className=\"w-full\"\n    >\n      <TabsList className=\"w-full\">\n        <TabsTrigger value=\"servers\">1. Add servers</TabsTrigger>\n        <TabsTrigger value=\"profile\">2. Choose tools</TabsTrigger>\n      </TabsList>\n\n      <TabsContent value=\"profile\" className=\"space-y-4 pt-3\">\n        <NameField\n          label=\"Tool access name\"\n          value={config.name}\n          onChange={(value) => onUpdate({ name: value })}\n        />\n\n        {!hasProviders ? (\n          <div className=\"space-y-3\">\n            <EmptyState\n              title=\"Add a server to start choosing tools\"\n              description=\"Set up a server first, then come back here to choose which tools this step can use.\"\n            />\n            <Button\n              type=\"button\"\n              variant=\"outline\"\n              onClick={() => setActiveTab(\"servers\")}\n            >\n              Add servers first\n            </Button>\n          </div>\n        ) : (\n          <>\n            <div className=\"rounded-2xl border border-border/60 bg-muted/10 px-4 py-3\">\n              <p className=\"text-sm font-semibold text-foreground\">\n                Pick which tools this setup may use\n              </p>\n              <p className=\"mt-1 text-xs text-muted-foreground\">\n                1. Load tool names from your servers. 2. Leave the list empty to\n                allow all tools, or add only the ones this step should use.\n              </p>\n            </div>\n            <div className=\"space-y-3 rounded-2xl border border-border/60 bg-muted/10 p-4\">\n              <div className=\"flex items-start justify-between gap-3\">\n                <div>\n                  <p className=\"text-sm font-semibold text-foreground\">\n                    Available tools\n                  </p>\n                  <p className=\"text-xs text-muted-foreground\">\n                    Load tool names so you can pick from a list instead of guessing.\n                  </p>\n                </div>\n                <Button\n                  type=\"button\"\n                  size=\"xs\"\n                  variant=\"outline\"\n                  disabled={loadingTools}\n                  onClick={() => {\n                    void loadTools();\n                  }}\n                >\n                  {loadingTools ? \"Loading...\" : \"Load tools\"}\n                </Button>\n              </div>\n\n              {Object.keys(toolsByProvider).length === 0 &&\n                Object.keys(providerErrors).length === 0 && (\n                  <p className=\"text-xs text-muted-foreground\">\n                    Load tools to browse what's available.\n                  </p>\n                )}\n\n              {Object.entries(toolsByProvider).map(([providerName, toolNames]) => (\n                <div key={providerName} className=\"space-y-2\">\n                  <div className=\"flex items-center gap-2\">\n                    <p className=\"text-xs font-semibold uppercase text-muted-foreground\">\n                      {providerName}\n                    </p>\n                    <Badge variant=\"outline\" className=\"rounded-full text-[10px]\">\n                      {toolNames.length}\n                    </Badge>\n                  </div>\n                  <div className=\"flex flex-wrap gap-2\">\n                    {toolNames.map((toolName) => (\n                      <Badge key={`${providerName}-${toolName}`} variant=\"secondary\">\n                        {toolName}\n                      </Badge>\n                    ))}\n                  </div>\n                </div>\n              ))}\n\n              {Object.entries(duplicateTools).length > 0 && (\n                <div className=\"rounded-xl border border-amber-500/30 bg-amber-500/5 px-3 py-2 text-xs text-amber-700 dark:text-amber-300\">\n                  Some tool names appear on more than one server:\n                  {\" \"}\n                  {Object.entries(duplicateTools)\n                    .map(([toolName, providerList]) => `${toolName} (${providerList.join(\", \")})`)\n                    .join(\"; \")}\n                </div>\n              )}\n            </div>\n\n            <div className=\"grid gap-2\">\n              <FieldLabel\n                label=\"Tools this setup may use\"\n                hint=\"Leave this empty to allow every tool from these servers.\"\n              />\n              <ChipInput\n                values={config.allow_tools ?? []}\n                suggestions={availableTools}\n                onAdd={(value) =>\n                  onUpdate({\n                    // biome-ignore lint/style/useNamingConvention: api schema\n                    allow_tools: addUnique(config.allow_tools ?? [], value),\n                  })\n                }\n                onRemove={(toolIndex) =>\n                  onUpdate({\n                    // biome-ignore lint/style/useNamingConvention: api schema\n                    allow_tools: (config.allow_tools ?? []).filter(\n                      (_, currentIndex) => currentIndex !== toolIndex,\n                    ),\n                  })\n                }\n                placeholder=\"Type tool name and press Enter\"\n              />\n            </div>\n\n            <Collapsible open={advancedOpen} onOpenChange={setAdvancedOpen}>\n              <CollapsibleTrigger asChild={true}>\n                <CollapsibleSectionTriggerButton\n                  label=\"Tool-call limits\"\n                  open={advancedOpen}\n                />\n              </CollapsibleTrigger>\n              <CollapsibleContent className=\"mt-3\">\n                <div className=\"grid gap-3 sm:grid-cols-2\">\n                  <div className=\"grid gap-2\">\n                    <FieldLabel\n                      label=\"Max tool-use turns\"\n                      hint=\"How many back-and-forth tool calls an AI step can make.\"\n                    />\n                    <Input\n                      className=\"nodrag\"\n                      value={config.max_tool_call_turns ?? \"\"}\n                      onChange={(event) =>\n                        onUpdate({\n                          // biome-ignore lint/style/useNamingConvention: api schema\n                          max_tool_call_turns: event.target.value,\n                        })\n                      }\n                    />\n                  </div>\n                  <div className=\"grid gap-2\">\n                    <FieldLabel\n                      label=\"Timeout (seconds)\"\n                      hint=\"How long to wait when loading or calling tools.\"\n                    />\n                    <Input\n                      className=\"nodrag\"\n                      value={config.timeout_sec ?? \"\"}\n                      onChange={(event) =>\n                        onUpdate({\n                          // biome-ignore lint/style/useNamingConvention: api schema\n                          timeout_sec: event.target.value,\n                        })\n                      }\n                    />\n                  </div>\n                </div>\n              </CollapsibleContent>\n            </Collapsible>\n          </>\n        )}\n      </TabsContent>\n\n      <TabsContent value=\"servers\" className=\"space-y-4 pt-3\">\n        <div className=\"rounded-2xl border border-border/60 bg-muted/10 px-4 py-3\">\n          <p className=\"text-sm font-semibold text-foreground\">\n            Add one or more tool servers\n          </p>\n          <p className=\"mt-1 text-xs text-muted-foreground\">\n            After your servers are ready, switch to Choose tools to load names\n            and decide which ones this setup should allow.\n          </p>\n        </div>\n        <div className=\"flex items-center justify-between gap-3\">\n          <FieldLabel\n            label=\"Tool servers\"\n            hint=\"These servers belong to this tool access setup and can be reused by linked AI steps.\"\n          />\n          <Button type=\"button\" size=\"xs\" variant=\"outline\" onClick={addProvider}>\n            <HugeiconsIcon icon={PlusSignIcon} className=\"size-3.5\" />\n            Add server\n          </Button>\n        </div>\n\n        {!hasProviders ? (\n          <EmptyState\n            title=\"No tool servers yet\"\n            description=\"Add one or more servers here, then go back to Access to load and choose tools.\"\n          />\n        ) : (\n          <div className=\"space-y-3\">\n            {providers.map((provider, index) => (\n              <McpServerCard\n                key={provider.id}\n                provider={provider}\n                index={index}\n                toolsCount={\n                  provider.name.trim()\n                    ? (toolsByProvider[provider.name.trim()] ?? []).length\n                    : undefined\n                }\n                error={provider.name.trim() ? providerErrors[provider.name.trim()] : undefined}\n                open={openProviders[provider.id] ?? !isProviderConfigured(provider)}\n                onOpenChange={(open) =>\n                  setOpenProviders((current) => ({\n                    ...current,\n                    [provider.id]: open,\n                  }))\n                }\n                onUpdateProviderAt={updateProviderAt}\n                onRemoveProvider={removeProvider}\n                onAddProviderArg={addProviderArg}\n                onUpdateProviderArg={updateProviderArg}\n                onRemoveProviderArg={removeProviderArg}\n                onAddProviderEnv={addProviderEnv}\n                onUpdateProviderEnv={updateProviderEnv}\n                onRemoveProviderEnv={removeProviderEnv}\n              />\n            ))}\n          </div>\n        )}\n      </TabsContent>\n    </Tabs>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/dialogs/validators/validator-dialog.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport {\n  Collapsible,\n  CollapsibleContent,\n  CollapsibleTrigger,\n} from \"@/components/ui/collapsible\";\nimport {\n  Combobox,\n  ComboboxContent,\n  ComboboxEmpty,\n  ComboboxInput,\n  ComboboxItem,\n  ComboboxList,\n} from \"@/components/ui/combobox\";\nimport { Input } from \"@/components/ui/input\";\nimport {\n  Select,\n  SelectContent,\n  SelectItem,\n  SelectTrigger,\n  SelectValue,\n} from \"@/components/ui/select\";\nimport { type ReactElement, useMemo, useRef } from \"react\";\nimport { useRecipeStudioStore } from \"../../stores/recipe-studio\";\nimport type { ValidatorConfig } from \"../../types\";\nimport {\n  isValidatorCodeLang,\n  VALIDATOR_OXC_CODE_LANGS,\n  VALIDATOR_SQL_CODE_LANGS,\n} from \"../../utils/validators/code-lang\";\nimport {\n  OXC_CODE_SHAPES,\n  normalizeOxcCodeShape,\n} from \"../../utils/validators/oxc-code-shape\";\nimport {\n  OXC_VALIDATION_MODES,\n  normalizeOxcValidationMode,\n} from \"../../utils/validators/oxc-mode\";\nimport { CollapsibleSectionTriggerButton } from \"../shared/collapsible-section-trigger\";\nimport { FieldLabel } from \"../shared/field-label\";\nimport { NameField } from \"../shared/name-field\";\n\ntype ValidatorDialogProps = {\n  config: ValidatorConfig;\n  onUpdate: (patch: Partial<ValidatorConfig>) => void;\n};\n\nconst NONE_VALUE = \"__none__\";\n\nexport function ValidatorDialog({\n  config,\n  onUpdate,\n}: ValidatorDialogProps): ReactElement {\n  const configs = useRecipeStudioStore((state) => state.configs);\n  const targetColumnId = `${config.id}-target-column`;\n  const oxcModeId = `${config.id}-oxc-mode`;\n  const oxcCodeShapeId = `${config.id}-oxc-code-shape`;\n  const batchSizeId = `${config.id}-batch-size`;\n  const oxcModeAnchorRef = useRef<HTMLDivElement>(null);\n  const oxcCodeShapeAnchorRef = useRef<HTMLDivElement>(null);\n  const advancedOpen = config.advancedOpen === true;\n  const selectedOxcMode = normalizeOxcValidationMode(config.oxc_validation_mode);\n  const selectedOxcCodeShape = normalizeOxcCodeShape(config.oxc_code_shape);\n  const codeOptions = useMemo(\n    () =>\n      Object.values(configs)\n        .flatMap((item) => {\n          if (!(item.kind === \"llm\" && item.llm_type === \"code\")) {\n            return [];\n          }\n          if (config.validator_type === \"oxc\") {\n            const lang = item.code_lang?.trim() ?? \"\";\n            if (!VALIDATOR_OXC_CODE_LANGS.includes(lang as typeof config.code_lang)) {\n              return [];\n            }\n          } else {\n            const lang = item.code_lang?.trim() ?? \"\";\n            if (\n              !(\n                lang === \"python\" ||\n                VALIDATOR_SQL_CODE_LANGS.includes(lang as typeof config.code_lang)\n              )\n            ) {\n              return [];\n            }\n          }\n          return [\n            {\n              name: item.name,\n              codeLang: item.code_lang?.trim() ?? \"\",\n            },\n          ];\n        })\n        .filter((item) => item.name.trim())\n        .sort((a, b) => a.name.localeCompare(b.name)),\n    [configs],\n  );\n  const currentTarget = config.target_columns[0] ?? \"\";\n\n  return (\n    <div className=\"space-y-4\">\n      <NameField\n        label=\"Check name\"\n        hint=\"Name used for this check in the canvas and run results.\"\n        value={config.name}\n        onChange={(value) => onUpdate({ name: value })}\n      />\n      <div className=\"grid gap-2\">\n        <FieldLabel\n          label=\"Code to check\"\n          htmlFor={targetColumnId}\n          hint=\"Choose the AI code step this check should review.\"\n        />\n        <Select\n          value={currentTarget || NONE_VALUE}\n          onValueChange={(value) => {\n            if (value === NONE_VALUE) {\n              onUpdate({\n                // biome-ignore lint/style/useNamingConvention: api schema\n                target_columns: [],\n              });\n              return;\n            }\n            const targetConfig = codeOptions.find((item) => item.name === value);\n            const nextCodeLang = targetConfig?.codeLang?.trim();\n            onUpdate({\n              // biome-ignore lint/style/useNamingConvention: api schema\n              target_columns: [value],\n              // biome-ignore lint/style/useNamingConvention: api schema\n              code_lang:\n                nextCodeLang && isValidatorCodeLang(nextCodeLang)\n                  ? nextCodeLang\n                  : config.code_lang,\n            });\n          }}\n        >\n          <SelectTrigger className=\"nodrag w-full\" id={targetColumnId}>\n            <SelectValue placeholder=\"Select code column\" />\n          </SelectTrigger>\n          <SelectContent>\n            <SelectItem value={NONE_VALUE}>None</SelectItem>\n            {codeOptions.map((item) => (\n              <SelectItem key={item.name} value={item.name}>\n                {item.name}\n              </SelectItem>\n            ))}\n          </SelectContent>\n        </Select>\n        {codeOptions.length === 0 && (\n              <p className=\"text-xs text-muted-foreground\">\n                {config.validator_type === \"oxc\"\n                  ? \"Add an AI code step that generates JavaScript or TypeScript first.\"\n                  : \"Add an AI code step first.\"}\n              </p>\n        )}\n      </div>\n      {config.validator_type === \"oxc\" && (\n        <div className=\"grid gap-3\">\n          <div className=\"grid gap-2\">\n            <FieldLabel\n              label=\"Check mode\"\n              htmlFor={oxcModeId}\n              hint=\"Choose whether to check syntax, lint rules, or both.\"\n            />\n            <div ref={oxcModeAnchorRef}>\n              <Combobox\n                items={OXC_VALIDATION_MODES}\n                filteredItems={OXC_VALIDATION_MODES}\n                filter={null}\n                value={selectedOxcMode}\n                onValueChange={(value) =>\n                  onUpdate({\n                    oxc_validation_mode: normalizeOxcValidationMode(value),\n                  })\n                }\n                itemToStringValue={(value) => value}\n                autoHighlight={true}\n              >\n                <ComboboxInput\n                  id={oxcModeId}\n                  className=\"nodrag w-full\"\n                  placeholder=\"Select validation mode\"\n                  readOnly={true}\n                />\n                <ComboboxContent anchor={oxcModeAnchorRef}>\n                  <ComboboxEmpty>No modes available</ComboboxEmpty>\n                  <ComboboxList>\n                    {(mode: string) => (\n                      <ComboboxItem key={mode} value={mode}>\n                        {mode}\n                      </ComboboxItem>\n                    )}\n                  </ComboboxList>\n                </ComboboxContent>\n              </Combobox>\n            </div>\n          </div>\n          <div className=\"grid gap-2\">\n            <FieldLabel\n              label=\"Code shape\"\n              htmlFor={oxcCodeShapeId}\n              hint=\"Choose whether the code should be treated like a full file or a smaller snippet.\"\n            />\n            <div ref={oxcCodeShapeAnchorRef}>\n              <Combobox\n                items={OXC_CODE_SHAPES}\n                filteredItems={OXC_CODE_SHAPES}\n                filter={null}\n                value={selectedOxcCodeShape}\n                onValueChange={(value) =>\n                  onUpdate({\n                    oxc_code_shape: normalizeOxcCodeShape(value),\n                  })\n                }\n                itemToStringValue={(value) => value}\n                autoHighlight={true}\n              >\n                <ComboboxInput\n                  id={oxcCodeShapeId}\n                  className=\"nodrag w-full\"\n                  placeholder=\"Select code shape\"\n                  readOnly={true}\n                />\n                <ComboboxContent anchor={oxcCodeShapeAnchorRef}>\n                  <ComboboxEmpty>No code-shape options</ComboboxEmpty>\n                  <ComboboxList>\n                    {(shape: string) => (\n                      <ComboboxItem key={shape} value={shape}>\n                        {shape}\n                      </ComboboxItem>\n                    )}\n                  </ComboboxList>\n                </ComboboxContent>\n              </Combobox>\n            </div>\n          </div>\n        </div>\n      )}\n      <Collapsible\n        open={advancedOpen}\n        onOpenChange={(open) => onUpdate({ advancedOpen: open })}\n      >\n        <CollapsibleTrigger asChild={true}>\n          <CollapsibleSectionTriggerButton\n            label=\"Advanced check settings\"\n            open={advancedOpen}\n          />\n        </CollapsibleTrigger>\n        <CollapsibleContent className=\"mt-3\">\n          <div className=\"grid gap-2\">\n            <FieldLabel\n              label=\"Batch size\"\n              htmlFor={batchSizeId}\n              hint=\"How many records to check at a time.\"\n            />\n            <Input\n              id={batchSizeId}\n              className=\"nodrag\"\n              value={config.batch_size}\n              onChange={(event) => onUpdate({ batch_size: event.target.value })}\n            />\n          </div>\n        </CollapsibleContent>\n      </Collapsible>\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/execution-types.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nexport type RecipeStudioView = \"editor\" | \"executions\";\n\nexport type RecipeExecutionKind = \"preview\" | \"full\";\n\nexport type RecipeExecutionStatus =\n  | \"pending\"\n  | \"running\"\n  | \"active\"\n  | \"cancelling\"\n  | \"cancelled\"\n  | \"completed\"\n  | \"error\";\n\nexport type RecipeExecutionProgress = {\n  done?: number | null;\n  total?: number | null;\n  percent?: number | null;\n  // biome-ignore lint/style/useNamingConvention: backend schema\n  eta_sec?: number | null;\n  rate?: number | null;\n  ok?: number | null;\n  failed?: number | null;\n};\n\nexport type RecipeExecutionBatch = {\n  idx?: number | null;\n  total?: number | null;\n};\n\nexport type RecipeExecutionAnalysis = {\n  num_records?: number;\n  target_num_records?: number;\n  // biome-ignore lint/style/useNamingConvention: backend schema\n  column_statistics?: Record<string, unknown>[];\n  // biome-ignore lint/style/useNamingConvention: backend schema\n  side_effect_column_names?: string[] | null;\n  // biome-ignore lint/style/useNamingConvention: backend schema\n  column_profiles?: Record<string, unknown>[] | null;\n} & Record<string, unknown>;\n\nexport type RecipeExecutionRecord = {\n  id: string;\n  recipeId: string;\n  // biome-ignore lint/style/useNamingConvention: backend schema\n  jobId: string | null;\n  kind: RecipeExecutionKind;\n  // ui-only display label for full runs\n  // biome-ignore lint/style/useNamingConvention: ui schema\n  run_name: string | null;\n  status: RecipeExecutionStatus;\n  rows: number;\n  createdAt: number;\n  finishedAt: number | null;\n  recipeSignature: string;\n  stage: string | null;\n  // biome-ignore lint/style/useNamingConvention: backend schema\n  current_column: string | null;\n  // biome-ignore lint/style/useNamingConvention: backend schema\n  completed_columns: string[];\n  progress: RecipeExecutionProgress | null;\n  // biome-ignore lint/style/useNamingConvention: backend schema\n  column_progress: RecipeExecutionProgress | null;\n  batch: RecipeExecutionBatch | null;\n  // biome-ignore lint/style/useNamingConvention: backend schema\n  model_usage: Record<string, unknown> | null;\n  // biome-ignore lint/style/useNamingConvention: backend schema\n  lastEventId: number | null;\n  // biome-ignore lint/style/useNamingConvention: backend schema\n  artifact_path: string | null;\n  // biome-ignore lint/style/useNamingConvention: backend schema\n  log_lines: string[];\n  dataset: Record<string, unknown>[];\n  datasetTotal: number;\n  datasetPage: number;\n  datasetPageSize: number;\n  analysis: RecipeExecutionAnalysis | null;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  processor_artifacts: Record<string, unknown> | null;\n  error: string | null;\n};\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/executions/execution-helpers.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type {\n  RecipeExecutionAnalysis,\n  RecipeExecutionRecord,\n  RecipeExecutionStatus,\n} from \"../execution-types\";\nimport type { RecipePayload } from \"../utils/payload/types\";\n\nexport const DATASET_PAGE_SIZE = 20;\n\nexport function buildSignature(name: string, payload: RecipePayload): string {\n  return JSON.stringify({ name, payload });\n}\n\nexport function formatSavedLabel(savedAt: number | null): string {\n  if (!savedAt) {\n    return \"Not saved yet\";\n  }\n  const time = new Date(savedAt).toLocaleTimeString([], {\n    hour: \"numeric\",\n    minute: \"2-digit\",\n  });\n  return `Saved ${time}`;\n}\n\nexport function toErrorMessage(error: unknown, fallback: string): string {\n  if (error instanceof Error) {\n    return error.message;\n  }\n  return fallback;\n}\n\nexport function normalizeDatasetRows(value: unknown): Record<string, unknown>[] {\n  if (!Array.isArray(value)) {\n    return [];\n  }\n  return value.filter(\n    (row): row is Record<string, unknown> =>\n      typeof row === \"object\" && row !== null && !Array.isArray(row),\n  );\n}\n\nexport function normalizeObject(value: unknown): Record<string, unknown> | null {\n  if (typeof value !== \"object\" || value === null || Array.isArray(value)) {\n    return null;\n  }\n  return value as Record<string, unknown>;\n}\n\nexport function normalizeAnalysis(value: unknown): RecipeExecutionAnalysis | null {\n  const normalized = normalizeObject(value);\n  if (!normalized) {\n    return null;\n  }\n  return normalized as RecipeExecutionAnalysis;\n}\n\nexport function mapJobStatus(status: string): RecipeExecutionStatus {\n  if (status === \"active\") {\n    return \"active\";\n  }\n  if (status === \"pending\") {\n    return \"pending\";\n  }\n  if (status === \"cancelling\") {\n    return \"cancelling\";\n  }\n  if (status === \"cancelled\") {\n    return \"cancelled\";\n  }\n  if (status === \"completed\") {\n    return \"completed\";\n  }\n  if (status === \"error\") {\n    return \"error\";\n  }\n  return \"running\";\n}\n\nexport function isExecutionInProgress(status: RecipeExecutionStatus): boolean {\n  return (\n    status === \"running\" ||\n    status === \"active\" ||\n    status === \"pending\" ||\n    status === \"cancelling\"\n  );\n}\n\nexport function executionLabel(kind: \"preview\" | \"full\"): string {\n  return kind === \"preview\" ? \"Preview\" : \"Full run\";\n}\n\nexport function normalizeRunName(value: unknown): string | null {\n  if (typeof value !== \"string\") {\n    return null;\n  }\n  const trimmed = value.trim();\n  return trimmed.length > 0 ? trimmed : null;\n}\n\nfunction executionSortWeight(status: RecipeExecutionStatus): number {\n  if (isExecutionInProgress(status)) {\n    return 0;\n  }\n  if (status === \"error\" || status === \"cancelled\") {\n    return 2;\n  }\n  return 1;\n}\n\nexport function sortExecutions(records: RecipeExecutionRecord[]): RecipeExecutionRecord[] {\n  const next = [...records];\n  next.sort((a, b) => {\n    const statusDelta = executionSortWeight(a.status) - executionSortWeight(b.status);\n    if (statusDelta !== 0) {\n      return statusDelta;\n    }\n    return b.createdAt - a.createdAt;\n  });\n  return next;\n}\n\nexport function withExecutionDefaults(\n  record: RecipeExecutionRecord,\n): RecipeExecutionRecord {\n  const dataset = Array.isArray(record.dataset) ? record.dataset : [];\n  const logLines = Array.isArray(record.log_lines)\n    ? record.log_lines.filter((line): line is string => typeof line === \"string\")\n    : [];\n  const datasetPageSize =\n    typeof record.datasetPageSize === \"number\" && record.datasetPageSize > 0\n      ? record.datasetPageSize\n      : DATASET_PAGE_SIZE;\n  const datasetPage =\n    typeof record.datasetPage === \"number\" && record.datasetPage > 0\n      ? record.datasetPage\n      : 1;\n  const datasetTotal =\n    typeof record.datasetTotal === \"number\" && record.datasetTotal >= 0\n      ? record.datasetTotal\n      : dataset.length;\n\n  return {\n    ...record,\n    run_name: normalizeRunName(record.run_name),\n    dataset,\n    log_lines: logLines,\n    datasetTotal,\n    datasetPage,\n    datasetPageSize,\n    completed_columns: Array.isArray(record.completed_columns)\n      ? record.completed_columns.filter(\n          (value): value is string => typeof value === \"string\" && value.trim().length > 0,\n        )\n      : [],\n    column_progress: record.column_progress ?? null,\n    batch: record.batch ?? null,\n  };\n}\n\nexport function delay(ms: number): Promise<void> {\n  return new Promise((resolve) => {\n    window.setTimeout(resolve, ms);\n  });\n}\n\nexport async function copyTextToClipboard(text: string): Promise<boolean> {\n  try {\n    if (navigator.clipboard?.writeText) {\n      await navigator.clipboard.writeText(text);\n      return true;\n    }\n  } catch {\n    // fallthrough to legacy path\n  }\n\n  try {\n    const textarea = document.createElement(\"textarea\");\n    textarea.value = text;\n    textarea.setAttribute(\"readonly\", \"\");\n    textarea.style.position = \"fixed\";\n    textarea.style.top = \"0\";\n    textarea.style.left = \"-9999px\";\n    document.body.appendChild(textarea);\n    textarea.select();\n    const ok = document.execCommand(\"copy\");\n    document.body.removeChild(textarea);\n    return ok;\n  } catch {\n    return false;\n  }\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/executions/hydration.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { listRecipeExecutions } from \"../data/executions-db\";\nimport type { RecipeExecutionRecord } from \"../execution-types\";\nimport {\n  isExecutionInProgress,\n  sortExecutions,\n  withExecutionDefaults,\n} from \"./execution-helpers\";\n\nexport async function loadSortedRecipeExecutions(\n  recipeId: string,\n): Promise<RecipeExecutionRecord[]> {\n  const records = await listRecipeExecutions(recipeId);\n  return sortExecutions(records.map(withExecutionDefaults));\n}\n\nexport function findResumableExecution(\n  records: RecipeExecutionRecord[],\n): RecipeExecutionRecord | null {\n  return (\n    records.find((record) => record.jobId && isExecutionInProgress(record.status)) ?? null\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/executions/run-settings.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { RecipeExecutionKind } from \"../execution-types\";\nimport type { RecipeRunSettings } from \"../stores/recipe-executions\";\nimport type { RecipePayload } from \"../utils/payload/types\";\n\nfunction toPositiveInt(\n  value: number,\n  fallback: number,\n  min = 1,\n  max = Number.MAX_SAFE_INTEGER,\n): number {\n  if (!Number.isFinite(value)) {\n    return fallback;\n  }\n  const next = Math.floor(value);\n  if (next < min) {\n    return min;\n  }\n  if (next > max) {\n    return max;\n  }\n  return next;\n}\n\nfunction toNonNegativeInt(\n  value: number,\n  fallback: number,\n  max = Number.MAX_SAFE_INTEGER,\n): number {\n  if (!Number.isFinite(value)) {\n    return fallback;\n  }\n  const next = Math.floor(value);\n  if (next < 0) {\n    return 0;\n  }\n  if (next > max) {\n    return max;\n  }\n  return next;\n}\n\nfunction toRatio(value: number, fallback: number): number {\n  if (!Number.isFinite(value)) {\n    return fallback;\n  }\n  if (value < 0) {\n    return 0;\n  }\n  if (value > 1) {\n    return 1;\n  }\n  return value;\n}\n\nexport function sanitizeExecutionRows(\n  rows: number,\n  kind: RecipeExecutionKind,\n): number {\n  return toPositiveInt(rows, kind === \"preview\" ? 5 : 1000);\n}\n\nexport function normalizeRunSettings(settings: RecipeRunSettings): RecipeRunSettings {\n  return {\n    batchSize: toPositiveInt(settings.batchSize, 1000, 1, 200_000),\n    batchEnabled: Boolean(settings.batchEnabled),\n    mergeBatches: Boolean(settings.mergeBatches),\n    llmParallelRequests:\n      typeof settings.llmParallelRequests === \"number\"\n        ? toPositiveInt(settings.llmParallelRequests, 4, 1, 2048)\n        : null,\n    nonInferenceWorkers: toPositiveInt(\n      settings.nonInferenceWorkers,\n      4,\n      1,\n      2048,\n    ),\n    maxConversationRestarts: toNonNegativeInt(\n      settings.maxConversationRestarts,\n      5,\n      100,\n    ),\n    maxConversationCorrectionSteps: toNonNegativeInt(\n      settings.maxConversationCorrectionSteps,\n      0,\n      100,\n    ),\n    disableEarlyShutdown: Boolean(settings.disableEarlyShutdown),\n    shutdownErrorRate: toRatio(settings.shutdownErrorRate, 0.5),\n    shutdownErrorWindow: toPositiveInt(settings.shutdownErrorWindow, 10, 1, 10_000),\n  };\n}\n\nfunction buildRunConfigPayload(\n  settings: RecipeRunSettings,\n  rows: number,\n  kind: RecipeExecutionKind,\n): Record<string, unknown> {\n  const useBatching = kind === \"full\" && settings.batchEnabled;\n  return {\n    // biome-ignore lint/style/useNamingConvention: backend schema\n    buffer_size: useBatching ? settings.batchSize : toPositiveInt(rows, 1000, 1, 200_000),\n    // biome-ignore lint/style/useNamingConvention: backend schema\n    non_inference_max_parallel_workers: settings.nonInferenceWorkers,\n    // biome-ignore lint/style/useNamingConvention: backend schema\n    max_conversation_restarts: settings.maxConversationRestarts,\n    // biome-ignore lint/style/useNamingConvention: backend schema\n    max_conversation_correction_steps: settings.maxConversationCorrectionSteps,\n    // biome-ignore lint/style/useNamingConvention: backend schema\n    disable_early_shutdown: settings.disableEarlyShutdown,\n    // biome-ignore lint/style/useNamingConvention: backend schema\n    shutdown_error_rate: settings.shutdownErrorRate,\n    // biome-ignore lint/style/useNamingConvention: backend schema\n    shutdown_error_window: settings.shutdownErrorWindow,\n  };\n}\n\nfunction applyGlobalParallelismOverride(\n  payload: RecipePayload,\n  llmParallelRequests: number | null,\n): RecipePayload {\n  if (typeof llmParallelRequests !== \"number\") {\n    return payload;\n  }\n\n  const modelConfigs = payload.recipe.model_configs.map((modelConfig) => {\n    const nextModelConfig = { ...modelConfig };\n    const inferenceRaw = modelConfig.inference_parameters;\n    const inference =\n      inferenceRaw &&\n      typeof inferenceRaw === \"object\" &&\n      !Array.isArray(inferenceRaw)\n        ? { ...(inferenceRaw as Record<string, unknown>) }\n        : {};\n    // biome-ignore lint/style/useNamingConvention: backend schema\n    inference.max_parallel_requests = llmParallelRequests;\n    // biome-ignore lint/style/useNamingConvention: backend schema\n    nextModelConfig.inference_parameters = inference;\n    return nextModelConfig;\n  });\n\n  return {\n    ...payload,\n    recipe: {\n      ...payload.recipe,\n      // biome-ignore lint/style/useNamingConvention: backend schema\n      model_configs: modelConfigs,\n    },\n  };\n}\n\nexport function buildExecutionPayload(input: {\n  payload: RecipePayload;\n  kind: RecipeExecutionKind;\n  rows: number;\n  settings: RecipeRunSettings;\n  runName?: string | null;\n}): RecipePayload {\n  const normalizedSettings = normalizeRunSettings(input.settings);\n  const payloadWithParallelism = applyGlobalParallelismOverride(\n    input.payload,\n    normalizedSettings.llmParallelRequests,\n  );\n  return {\n    ...payloadWithParallelism,\n    run: {\n      ...payloadWithParallelism.run,\n      rows: input.rows,\n      // biome-ignore lint/style/useNamingConvention: backend schema\n      execution_type: input.kind,\n      // biome-ignore lint/style/useNamingConvention: backend schema\n      run_config: buildRunConfigPayload(normalizedSettings, input.rows, input.kind),\n      // biome-ignore lint/style/useNamingConvention: backend schema\n      merge_batches:\n        input.kind === \"full\" &&\n        normalizedSettings.batchEnabled &&\n        normalizedSettings.mergeBatches,\n      // biome-ignore lint/style/useNamingConvention: backend schema\n      run_name: input.kind === \"full\" ? (input.runName ?? null) : null,\n    },\n  };\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/executions/runtime.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { JobEvent, JobStatusResponse } from \"../api\";\nimport type {\n  RecipeExecutionBatch,\n  RecipeExecutionKind,\n  RecipeExecutionRecord,\n} from \"../execution-types\";\nimport {\n  DATASET_PAGE_SIZE,\n  mapJobStatus,\n  normalizeObject,\n} from \"./execution-helpers\";\n\nconst MAX_LOG_LINES = 1500;\n\nfunction formatEventTime(ts: unknown): string {\n  if (typeof ts !== \"number\" || !Number.isFinite(ts)) {\n    return new Date().toLocaleTimeString();\n  }\n  const ms = ts > 10_000_000_000 ? ts : ts * 1000;\n  return new Date(ms).toLocaleTimeString();\n}\n\nexport function appendExecutionLogLine(lines: string[], nextLine: string): string[] {\n  const next = [...lines, nextLine];\n  if (next.length <= MAX_LOG_LINES) {\n    return next;\n  }\n  return next.slice(next.length - MAX_LOG_LINES);\n}\n\nexport function toExecutionLogLine(event: JobEvent): string | null {\n  const eventType =\n    typeof event.payload.type === \"string\" ? event.payload.type : event.event;\n  const ts = formatEventTime(event.payload.ts);\n\n  if (eventType === \"log\") {\n    const message =\n      typeof event.payload.message === \"string\" ? event.payload.message.trim() : \"\";\n    if (!message) {\n      return null;\n    }\n    const level =\n      typeof event.payload.level === \"string\" && event.payload.level.length > 0\n        ? event.payload.level.toUpperCase()\n        : \"INFO\";\n    return `[${ts}] [${level}] ${message}`;\n  }\n\n  if (eventType === \"job.started\") {\n    return `[${ts}] [INFO] Job started`;\n  }\n  if (eventType === \"job.completed\") {\n    return `[${ts}] [INFO] Job completed`;\n  }\n  if (eventType === \"job.cancelling\") {\n    return `[${ts}] [WARN] Cancellation requested`;\n  }\n  if (eventType === \"job.cancelled\") {\n    return `[${ts}] [WARN] Job cancelled`;\n  }\n  if (eventType === \"job.error\") {\n    const error =\n      typeof event.payload.error === \"string\" && event.payload.error.length > 0\n        ? event.payload.error\n        : \"Job failed\";\n    return `[${ts}] [ERROR] ${error}`;\n  }\n\n  return null;\n}\n\nexport function applyExecutionStatusSnapshot(\n  execution: RecipeExecutionRecord,\n  status: JobStatusResponse,\n): RecipeExecutionRecord {\n  const mappedStatus = mapJobStatus(status.status);\n  const batchRaw = normalizeObject(status.batch);\n  const batch: RecipeExecutionBatch | null = batchRaw\n    ? {\n        idx: typeof batchRaw.idx === \"number\" ? batchRaw.idx : null,\n        total: typeof batchRaw.total === \"number\" ? batchRaw.total : null,\n      }\n    : null;\n  return {\n    ...execution,\n    status: mappedStatus,\n    rows: status.rows ?? execution.rows,\n    stage: status.stage ?? execution.stage,\n    current_column: status.current_column ?? null,\n    completed_columns: Array.isArray(status.completed_columns)\n      ? status.completed_columns.filter(\n          (value): value is string => typeof value === \"string\" && value.trim().length > 0,\n        )\n      : execution.completed_columns,\n    progress: (normalizeObject(status.progress) as RecipeExecutionRecord[\"progress\"]) ?? null,\n    column_progress:\n      (normalizeObject(status.column_progress) as RecipeExecutionRecord[\"column_progress\"]) ??\n      null,\n    batch,\n    model_usage: normalizeObject(status.model_usage),\n    artifact_path: status.artifact_path ?? execution.artifact_path,\n    error: status.error ?? null,\n    finishedAt:\n      mappedStatus === \"completed\" ||\n      mappedStatus === \"error\" ||\n      mappedStatus === \"cancelled\"\n        ? Date.now()\n        : null,\n  };\n}\n\nexport function createBaseExecutionRecord(input: {\n  recipeId: string;\n  kind: RecipeExecutionKind;\n  rows: number;\n  currentSignature: string;\n  runName?: string | null;\n}): RecipeExecutionRecord {\n  const createdAt = Date.now();\n  return {\n    id: crypto.randomUUID(),\n    recipeId: input.recipeId,\n    jobId: null,\n    kind: input.kind,\n    run_name: input.runName ?? null,\n    status: \"pending\",\n    rows: input.rows,\n    createdAt,\n    finishedAt: null,\n    recipeSignature: input.currentSignature,\n    stage: \"pending\",\n    current_column: null,\n    completed_columns: [],\n    progress: null,\n    column_progress: null,\n    batch: null,\n    model_usage: null,\n    lastEventId: null,\n    artifact_path: null,\n    log_lines: [],\n    dataset: [],\n    datasetTotal: 0,\n    datasetPage: 1,\n    datasetPageSize: DATASET_PAGE_SIZE,\n    analysis: null,\n    processor_artifacts: null,\n    error: null,\n  };\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/executions/tracker.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { toastError, toastSuccess } from \"@/shared/toast\";\nimport {\n  getRecipeJobAnalysis,\n  getRecipeJobDataset,\n  getRecipeJobStatus,\n  streamRecipeJobEvents,\n} from \"../api\";\nimport type {\n  RecipeExecutionKind,\n  RecipeExecutionProgress,\n  RecipeExecutionRecord,\n  RecipeExecutionStatus,\n} from \"../execution-types\";\nimport {\n  DATASET_PAGE_SIZE,\n  delay,\n  mapJobStatus,\n  normalizeAnalysis,\n  normalizeDatasetRows,\n  toErrorMessage,\n} from \"./execution-helpers\";\nimport {\n  appendExecutionLogLine,\n  applyExecutionStatusSnapshot,\n  toExecutionLogLine,\n} from \"./runtime\";\n\ntype TrackRecipeExecutionParams = {\n  label: string;\n  kind: RecipeExecutionKind;\n  rows: number;\n  jobId: string;\n  initialExecution: RecipeExecutionRecord;\n  notify: boolean;\n  onUpsert: (record: RecipeExecutionRecord) => void;\n  onSetPreviewErrors: (errors: string[]) => void;\n  onPreviewSuccess?: () => void;\n};\n\nfunction isTerminalStatus(status: RecipeExecutionStatus): boolean {\n  return status === \"completed\" || status === \"error\" || status === \"cancelled\";\n}\n\nfunction normalizeCompletedProgress(input: {\n  latestExecution: RecipeExecutionRecord;\n  rows: number;\n}): {\n  progress: RecipeExecutionProgress;\n  columnProgress: RecipeExecutionProgress | null;\n} {\n  const { latestExecution, rows } = input;\n  const progressTotal =\n    typeof latestExecution.progress?.total === \"number\" && latestExecution.progress.total > 0\n      ? latestExecution.progress.total\n      : latestExecution.rows > 0\n        ? latestExecution.rows\n        : rows;\n\n  const progress: RecipeExecutionProgress = {\n    ...(latestExecution.progress ?? {}),\n    done: progressTotal,\n    total: progressTotal,\n    percent: 100,\n    eta_sec: 0,\n  };\n\n  const columnProgress =\n    latestExecution.column_progress &&\n    typeof latestExecution.column_progress.total === \"number\" &&\n    latestExecution.column_progress.total > 0\n      ? {\n          ...latestExecution.column_progress,\n          done: latestExecution.column_progress.total,\n          percent: 100,\n          eta_sec: 0,\n        }\n      : latestExecution.column_progress;\n\n  return { progress, columnProgress };\n}\n\nexport async function trackRecipeExecution({\n  label,\n  kind,\n  rows,\n  jobId,\n  initialExecution,\n  notify,\n  onUpsert,\n  onSetPreviewErrors,\n  onPreviewSuccess,\n}: TrackRecipeExecutionParams): Promise<boolean> {\n  let done = false;\n  let lastStatus: RecipeExecutionStatus = initialExecution.status;\n  let completedEventPayload: Record<string, unknown> | null = null;\n  let latestExecution: RecipeExecutionRecord = initialExecution;\n\n  const eventsAbortController = new AbortController();\n  void streamRecipeJobEvents({\n    jobId,\n    signal: eventsAbortController.signal,\n    lastEventId: latestExecution.lastEventId,\n    onEvent: (event) => {\n      let changed = false;\n\n      if (typeof event.id === \"number\") {\n        latestExecution = {\n          ...latestExecution,\n          lastEventId: event.id,\n        };\n        changed = true;\n      }\n\n      const logLine = toExecutionLogLine(event);\n      if (logLine) {\n        latestExecution = {\n          ...latestExecution,\n          log_lines: appendExecutionLogLine(latestExecution.log_lines, logLine),\n        };\n        changed = true;\n      }\n\n      const eventType =\n        typeof event.payload.type === \"string\" ? event.payload.type : event.event;\n\n      if (eventType === \"job.started\") {\n        latestExecution = {\n          ...latestExecution,\n          status: \"active\",\n        };\n        onUpsert(latestExecution);\n        return;\n      }\n\n      if (eventType === \"job.completed\") {\n        lastStatus = \"completed\";\n        completedEventPayload = event.payload;\n        done = true;\n        latestExecution = {\n          ...latestExecution,\n          status: \"completed\",\n          finishedAt: Date.now(),\n          artifact_path:\n            typeof event.payload.artifact_path === \"string\"\n              ? event.payload.artifact_path\n              : latestExecution.artifact_path,\n          error: null,\n        };\n        onUpsert(latestExecution);\n        return;\n      }\n\n      if (eventType === \"job.error\") {\n        lastStatus = \"error\";\n        done = true;\n        latestExecution = {\n          ...latestExecution,\n          status: \"error\",\n          finishedAt: Date.now(),\n          error:\n            typeof event.payload.error === \"string\"\n              ? event.payload.error\n              : latestExecution.error ?? `${label} failed.`,\n        };\n        onUpsert(latestExecution);\n        return;\n      }\n\n      if (eventType === \"job.cancelling\") {\n        latestExecution = {\n          ...latestExecution,\n          status: \"cancelling\",\n        };\n        onUpsert(latestExecution);\n        return;\n      }\n\n      if (changed) {\n        onUpsert(latestExecution);\n      }\n    },\n  }).catch(() => {\n    // polling is fallback source of truth\n  });\n\n  try {\n    while (!done) {\n      const status = await getRecipeJobStatus(jobId);\n      const mappedStatus = mapJobStatus(status.status);\n      lastStatus = mappedStatus;\n      latestExecution = applyExecutionStatusSnapshot(latestExecution, status);\n      onUpsert(latestExecution);\n\n      done = isTerminalStatus(mappedStatus);\n      if (!done) {\n        await delay(1200);\n      }\n    }\n  } catch (error) {\n    const message = toErrorMessage(error, `${label} failed.`);\n    latestExecution = {\n      ...latestExecution,\n      status: \"error\",\n      error: message,\n      finishedAt: Date.now(),\n    };\n    onUpsert(latestExecution);\n    if (notify) {\n      toastError(`${label} failed`, message);\n    }\n    return false;\n  } finally {\n    eventsAbortController.abort();\n  }\n\n  if (lastStatus === \"completed\") {\n    for (let attempt = 0; attempt < 3; attempt += 1) {\n      try {\n        const finalStatus = await getRecipeJobStatus(jobId);\n        latestExecution = applyExecutionStatusSnapshot(latestExecution, finalStatus);\n      } catch {\n        break;\n      }\n      if (attempt < 2) {\n        await delay(250);\n      }\n    }\n\n    const eventAnalysis = completedEventPayload\n      ? completedEventPayload[\"analysis\"]\n      : null;\n    const eventDataset = completedEventPayload\n      ? completedEventPayload[\"dataset\"]\n      : null;\n    const eventProcessorArtifacts =\n      completedEventPayload &&\n      typeof completedEventPayload[\"processor_artifacts\"] === \"object\" &&\n      completedEventPayload[\"processor_artifacts\"] !== null\n        ? (completedEventPayload[\"processor_artifacts\"] as Record<string, unknown>)\n        : null;\n    const shouldFetchPreviewDataset = kind === \"preview\" && !Array.isArray(eventDataset);\n    const shouldFetchAnalysis =\n      !completedEventPayload ||\n      typeof eventAnalysis !== \"object\" ||\n      eventAnalysis === null ||\n      kind === \"full\";\n\n    const [analysisResult, datasetResult] = await Promise.allSettled([\n      shouldFetchAnalysis\n        ? getRecipeJobAnalysis(jobId)\n        : Promise.resolve(eventAnalysis),\n      shouldFetchPreviewDataset || kind === \"full\"\n        ? getRecipeJobDataset(jobId, { limit: DATASET_PAGE_SIZE, offset: 0 })\n        : Promise.resolve({ dataset: eventDataset ?? [], total: rows }),\n    ]);\n\n    const analysis =\n      analysisResult.status === \"fulfilled\"\n        ? normalizeAnalysis(analysisResult.value)\n        : latestExecution.analysis;\n    const datasetResponse =\n      datasetResult.status === \"fulfilled\"\n        ? datasetResult.value\n        : null;\n    const dataset = datasetResponse\n      ? normalizeDatasetRows(datasetResponse.dataset)\n      : latestExecution.dataset;\n    const datasetTotal =\n      datasetResponse && typeof datasetResponse.total === \"number\"\n        ? datasetResponse.total\n        : latestExecution.datasetTotal;\n    const completedProgress = normalizeCompletedProgress({ latestExecution, rows });\n\n    latestExecution = {\n      ...latestExecution,\n      status: \"completed\",\n      progress: completedProgress.progress,\n      column_progress: completedProgress.columnProgress,\n      analysis,\n      dataset,\n      datasetTotal,\n      datasetPage: 1,\n      datasetPageSize: DATASET_PAGE_SIZE,\n      error: null,\n      processor_artifacts: eventProcessorArtifacts ?? latestExecution.processor_artifacts,\n      finishedAt: latestExecution.finishedAt ?? Date.now(),\n    };\n    onUpsert(latestExecution);\n\n    if (notify) {\n      if (kind === \"preview\") {\n        onSetPreviewErrors([]);\n        onPreviewSuccess?.();\n        toastSuccess(`Preview generated (${rows} rows).`);\n      } else {\n        toastSuccess(\"Full run completed.\");\n      }\n    }\n    return true;\n  }\n\n  if (lastStatus === \"cancelled\") {\n    latestExecution = {\n      ...latestExecution,\n      status: \"cancelled\",\n      error: latestExecution.error ?? \"Run cancelled.\",\n      finishedAt: latestExecution.finishedAt ?? Date.now(),\n    };\n    onUpsert(latestExecution);\n    if (notify) {\n      toastError(`${label} cancelled`, \"The execution was cancelled.\");\n    }\n    return false;\n  }\n\n  latestExecution = {\n    ...latestExecution,\n    status: \"error\",\n    error: latestExecution.error ?? `${label} failed.`,\n    finishedAt: latestExecution.finishedAt ?? Date.now(),\n  };\n  onUpsert(latestExecution);\n  if (notify) {\n    toastError(`${label} failed`, latestExecution.error ?? \"Execution failed.\");\n  }\n  return false;\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/hooks/use-node-connection-status.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { useMemo } from \"react\";\nimport { useRecipeStudioStore } from \"../stores/recipe-studio\";\nimport { INFRA_NODE_KINDS, type NodeConfig } from \"../types\";\n\ntype ConnectionStatus = {\n  /** True when the node has zero edges at all. */\n  isDisconnected: boolean;\n  /** True when an LLM node has no incoming data edge (only infra). */\n  missingDataInput: boolean;\n};\n\nexport function useNodeConnectionStatus(\n  nodeId: string,\n  config: NodeConfig | undefined,\n): ConnectionStatus {\n  const edges = useRecipeStudioStore((state) => state.edges);\n  const configs = useRecipeStudioStore((state) => state.configs);\n\n  return useMemo(() => {\n    const empty: ConnectionStatus = {\n      isDisconnected: false,\n      missingDataInput: false,\n    };\n\n    if (!config || config.kind === \"markdown_note\") {\n      return empty;\n    }\n\n    const nodeEdges = edges.filter(\n      (e) => e.source === nodeId || e.target === nodeId,\n    );\n    const isDisconnected = nodeEdges.length === 0;\n\n    let missingDataInput = false;\n    if (config.kind === \"llm\" && !isDisconnected) {\n      const hasDataEdge = nodeEdges.some((e) => {\n        const otherId = e.source === nodeId ? e.target : e.source;\n        const otherConfig = configs[otherId];\n        return otherConfig && !INFRA_NODE_KINDS.has(otherConfig.kind);\n      });\n      missingDataInput = !hasDataEdge;\n    }\n\n    return {\n      isDisconnected,\n      missingDataInput,\n    };\n  }, [nodeId, config, edges, configs]);\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/hooks/use-recipe-editor-graph.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type {\n  Edge,\n  EdgeChange,\n  Node,\n  NodeChange,\n  ReactFlowInstance,\n  XYPosition,\n} from \"@xyflow/react\";\nimport {\n  type DragEvent as ReactDragEvent,\n  type RefObject,\n  useCallback,\n  useMemo,\n} from \"react\";\nimport { RECIPE_BLOCK_DND_MIME, type RecipeBlockDragPayload } from \"../components/block-sheet\";\nimport type { SeedBlockType } from \"../blocks/registry\";\nimport type {\n  LlmType,\n  NodeConfig,\n  RecipeNode as RecipeBuilderNode,\n  RecipeNodeData,\n  SamplerType,\n} from \"../types\";\nimport { applyAuxNodeChanges, filterEdgeChangesByIds, filterNodeChangesByIds } from \"../utils/reactflow-changes\";\nimport type { RecipeGraphAuxNodeData } from \"../components/recipe-graph-aux-node\";\n\nconst SUPPORTED_DRAG_KINDS: RecipeBlockDragPayload[\"kind\"][] = [\n  \"sampler\",\n  \"seed\",\n  \"llm\",\n  \"validator\",\n  \"expression\",\n  \"note\",\n];\n\nfunction parseRecipeBlockDragPayload(raw: string): RecipeBlockDragPayload | null {\n  try {\n    const parsed = JSON.parse(raw) as {\n      kind?: RecipeBlockDragPayload[\"kind\"];\n      type?: RecipeBlockDragPayload[\"type\"];\n    };\n    if (!parsed.kind || !parsed.type || !SUPPORTED_DRAG_KINDS.includes(parsed.kind)) {\n      return null;\n    }\n    return {\n      kind: parsed.kind,\n      type: parsed.type,\n    };\n  } catch {\n    return null;\n  }\n}\n\ntype UseRecipeEditorGraphArgs = {\n  nodes: RecipeBuilderNode[];\n  edges: Edge[];\n  configs: Record<string, NodeConfig>;\n  reactFlowInstance: ReactFlowInstance<Node<RecipeNodeData | RecipeGraphAuxNodeData>, Edge> | null;\n  flowContainerRef: RefObject<HTMLDivElement | null>;\n  selectConfig: (id: string) => void;\n  openConfig: (id: string) => void;\n  onNodesChange: (changes: NodeChange<RecipeBuilderNode>[]) => void;\n  onEdgesChange: (changes: EdgeChange<Edge>[]) => void;\n  setAuxNodePosition: (id: string, position: XYPosition) => void;\n  addSamplerNode: (type: SamplerType, position?: XYPosition, openDialog?: boolean) => void;\n  addSeedNode: (type: SeedBlockType, position?: XYPosition, openDialog?: boolean) => void;\n  addLlmNode: (type: LlmType, position?: XYPosition, openDialog?: boolean) => void;\n  addModelProviderNode: (position?: XYPosition, openDialog?: boolean) => void;\n  addModelConfigNode: (position?: XYPosition, openDialog?: boolean) => void;\n  addToolProfileNode: (position?: XYPosition, openDialog?: boolean) => void;\n  addExpressionNode: (position?: XYPosition, openDialog?: boolean) => void;\n  addValidatorNode: (\n    type: \"validator_python\" | \"validator_sql\" | \"validator_oxc\",\n    position?: XYPosition,\n    openDialog?: boolean,\n  ) => void;\n  addMarkdownNoteNode: (position?: XYPosition, openDialog?: boolean) => void;\n};\n\ntype UseRecipeEditorGraphResult = {\n  handleNodeClick: (_: unknown, node: Node<RecipeNodeData | RecipeGraphAuxNodeData>) => void;\n  handleNodeDoubleClick: (_: unknown, node: Node<RecipeNodeData | RecipeGraphAuxNodeData>) => void;\n  handleNodesChange: (\n    changes: NodeChange<Node<RecipeNodeData | RecipeGraphAuxNodeData>>[],\n  ) => void;\n  handleEdgesChange: (changes: EdgeChange<Edge>[]) => void;\n  handleDragOver: (event: ReactDragEvent<HTMLDivElement>) => void;\n  handleDrop: (event: ReactDragEvent<HTMLDivElement>) => void;\n  handleAddSamplerFromSheet: (type: SamplerType) => void;\n  handleAddSeedFromSheet: (type: SeedBlockType) => void;\n  handleAddLlmFromSheet: (type: LlmType) => void;\n  handleAddModelProviderFromSheet: () => void;\n  handleAddModelConfigFromSheet: () => void;\n  handleAddToolProfileFromSheet: () => void;\n  handleAddExpressionFromSheet: () => void;\n  handleAddValidatorFromSheet: (\n    type: \"validator_python\" | \"validator_sql\" | \"validator_oxc\",\n  ) => void;\n  handleAddMarkdownNoteFromSheet: () => void;\n};\n\nexport function useRecipeEditorGraph({\n  nodes,\n  edges,\n  configs,\n  reactFlowInstance,\n  flowContainerRef,\n  selectConfig,\n  openConfig,\n  onNodesChange,\n  onEdgesChange,\n  setAuxNodePosition,\n  addSamplerNode,\n  addSeedNode,\n  addLlmNode,\n  addModelProviderNode,\n  addModelConfigNode,\n  addToolProfileNode,\n  addExpressionNode,\n  addValidatorNode,\n  addMarkdownNoteNode,\n}: UseRecipeEditorGraphArgs): UseRecipeEditorGraphResult {\n  const baseNodeIds = useMemo(() => new Set(nodes.map((node) => node.id)), [nodes]);\n  const baseEdgeIds = useMemo(() => new Set(edges.map((edge) => edge.id)), [edges]);\n\n  const handleNodeClick = useCallback(\n    (_: unknown, node: Node<RecipeNodeData | RecipeGraphAuxNodeData>) => {\n      if (node.type !== \"builder\") {\n        return;\n      }\n      selectConfig(node.id);\n    },\n    [selectConfig],\n  );\n\n  const handleNodeDoubleClick = useCallback(\n    (_: unknown, node: Node<RecipeNodeData | RecipeGraphAuxNodeData>) => {\n      if (node.type !== \"builder\") {\n        return;\n      }\n      const nodeConfig = configs[node.id];\n      if (nodeConfig?.kind === \"markdown_note\") {\n        openConfig(node.id);\n      }\n    },\n    [configs, openConfig],\n  );\n\n  const handleNodesChange = useCallback(\n    (changes: NodeChange<Node<RecipeNodeData | RecipeGraphAuxNodeData>>[]) => {\n      applyAuxNodeChanges(changes, { setAuxNodePosition });\n      const next = filterNodeChangesByIds(\n        changes as NodeChange<RecipeBuilderNode>[],\n        baseNodeIds,\n      );\n      if (next.length) {\n        onNodesChange(next);\n      }\n    },\n    [baseNodeIds, onNodesChange, setAuxNodePosition],\n  );\n\n  const handleEdgesChange = useCallback(\n    (changes: EdgeChange<Edge>[]) => {\n      const next = filterEdgeChangesByIds(changes, baseEdgeIds);\n      if (next.length) {\n        onEdgesChange(next);\n      }\n    },\n    [baseEdgeIds, onEdgesChange],\n  );\n\n  const handleDragOver = useCallback((event: ReactDragEvent<HTMLDivElement>) => {\n    if (\n      !event.dataTransfer.types.includes(RECIPE_BLOCK_DND_MIME) &&\n      !event.dataTransfer.types.includes(\"text/plain\")\n    ) {\n      return;\n    }\n    event.preventDefault();\n    event.dataTransfer.dropEffect = \"copy\";\n  }, []);\n\n  const handleDrop = useCallback(\n    (event: ReactDragEvent<HTMLDivElement>) => {\n      if (!reactFlowInstance) {\n        return;\n      }\n      const raw =\n        event.dataTransfer.getData(RECIPE_BLOCK_DND_MIME) ||\n        event.dataTransfer.getData(\"text/plain\");\n      if (!raw) {\n        return;\n      }\n      const payload = parseRecipeBlockDragPayload(raw);\n      if (!payload) {\n        return;\n      }\n      event.preventDefault();\n      const position = reactFlowInstance.screenToFlowPosition({\n        x: event.clientX,\n        y: event.clientY,\n      });\n\n      if (payload.kind === \"sampler\") {\n        addSamplerNode(payload.type as SamplerType, position, false);\n        return;\n      }\n      if (payload.kind === \"seed\") {\n        addSeedNode(payload.type as SeedBlockType, position, false);\n        return;\n      }\n      if (payload.kind === \"expression\") {\n        addExpressionNode(position, false);\n        return;\n      }\n      if (payload.kind === \"validator\") {\n        addValidatorNode(\n          payload.type as \"validator_python\" | \"validator_sql\" | \"validator_oxc\",\n          position,\n          false,\n        );\n        return;\n      }\n      if (payload.kind === \"note\") {\n        addMarkdownNoteNode(position, false);\n        return;\n      }\n      if (payload.type === \"model_provider\") {\n        addModelProviderNode(position, false);\n        return;\n      }\n      if (payload.type === \"model_config\") {\n        addModelConfigNode(position, false);\n        return;\n      }\n      if (payload.type === \"tool_config\") {\n        addToolProfileNode(position, false);\n        return;\n      }\n      addLlmNode(payload.type as LlmType, position, false);\n    },\n    [\n      addExpressionNode,\n      addLlmNode,\n      addMarkdownNoteNode,\n      addModelConfigNode,\n      addModelProviderNode,\n      addToolProfileNode,\n      addSamplerNode,\n      addSeedNode,\n      addValidatorNode,\n      reactFlowInstance,\n    ],\n  );\n\n  const getViewportCenterPosition = useCallback(() => {\n    if (!reactFlowInstance || !flowContainerRef.current) {\n      return undefined;\n    }\n    const rect = flowContainerRef.current.getBoundingClientRect();\n    return reactFlowInstance.screenToFlowPosition({\n      x: rect.left + rect.width / 2,\n      y: rect.top + rect.height / 2,\n    });\n  }, [flowContainerRef, reactFlowInstance]);\n\n  const handleAddSamplerFromSheet = useCallback(\n    (type: SamplerType) => {\n      addSamplerNode(type, getViewportCenterPosition());\n    },\n    [addSamplerNode, getViewportCenterPosition],\n  );\n\n  const handleAddSeedFromSheet = useCallback(\n    (type: SeedBlockType) => {\n      addSeedNode(type, getViewportCenterPosition());\n    },\n    [addSeedNode, getViewportCenterPosition],\n  );\n\n  const handleAddLlmFromSheet = useCallback(\n    (type: LlmType) => {\n      addLlmNode(type, getViewportCenterPosition());\n    },\n    [addLlmNode, getViewportCenterPosition],\n  );\n\n  const handleAddModelProviderFromSheet = useCallback(() => {\n    addModelProviderNode(getViewportCenterPosition());\n  }, [addModelProviderNode, getViewportCenterPosition]);\n\n  const handleAddModelConfigFromSheet = useCallback(() => {\n    addModelConfigNode(getViewportCenterPosition());\n  }, [addModelConfigNode, getViewportCenterPosition]);\n\n  const handleAddExpressionFromSheet = useCallback(() => {\n    addExpressionNode(getViewportCenterPosition());\n  }, [addExpressionNode, getViewportCenterPosition]);\n\n  const handleAddToolProfileFromSheet = useCallback(() => {\n    addToolProfileNode(getViewportCenterPosition());\n  }, [addToolProfileNode, getViewportCenterPosition]);\n\n  const handleAddValidatorFromSheet = useCallback(\n    (type: \"validator_python\" | \"validator_sql\" | \"validator_oxc\") => {\n      addValidatorNode(type, getViewportCenterPosition());\n    },\n    [addValidatorNode, getViewportCenterPosition],\n  );\n\n  const handleAddMarkdownNoteFromSheet = useCallback(() => {\n    addMarkdownNoteNode(getViewportCenterPosition());\n  }, [addMarkdownNoteNode, getViewportCenterPosition]);\n\n  return {\n    handleNodeClick,\n    handleNodeDoubleClick,\n    handleNodesChange,\n    handleEdgesChange,\n    handleDragOver,\n    handleDrop,\n    handleAddSamplerFromSheet,\n    handleAddSeedFromSheet,\n    handleAddLlmFromSheet,\n    handleAddModelProviderFromSheet,\n    handleAddModelConfigFromSheet,\n    handleAddToolProfileFromSheet,\n    handleAddExpressionFromSheet,\n    handleAddValidatorFromSheet,\n    handleAddMarkdownNoteFromSheet,\n  };\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/hooks/use-recipe-executions.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { useCallback, useEffect, useState } from \"react\";\nimport { useShallow } from \"zustand/react/shallow\";\nimport { toastError } from \"@/shared/toast\";\nimport {\n  cancelRecipeJob,\n  createRecipeJob,\n  getRecipeJobDataset,\n  validateRecipe,\n} from \"../api\";\nimport { saveRecipeExecution } from \"../data/executions-db\";\nimport type {\n  RecipeExecutionKind,\n  RecipeExecutionRecord,\n} from \"../execution-types\";\nimport {\n  DATASET_PAGE_SIZE,\n  executionLabel,\n  normalizeRunName,\n  normalizeDatasetRows,\n  toErrorMessage,\n  withExecutionDefaults,\n} from \"../executions/execution-helpers\";\nimport {\n  findResumableExecution,\n  loadSortedRecipeExecutions,\n} from \"../executions/hydration\";\nimport { createBaseExecutionRecord } from \"../executions/runtime\";\nimport {\n  buildExecutionPayload,\n  sanitizeExecutionRows,\n} from \"../executions/run-settings\";\nimport { trackRecipeExecution } from \"../executions/tracker\";\nimport {\n  type RecipeRunSettings,\n  useRecipeExecutionsStore,\n} from \"../stores/recipe-executions\";\nimport type { RecipePayload, RecipePayloadResult } from \"../utils/payload/types\";\n\ntype UseRecipeExecutionsParams = {\n  recipeId: string;\n  currentSignature: string;\n  payloadResult: RecipePayloadResult;\n  onExecutionStart?: () => void;\n  onPreviewSuccess?: () => void;\n};\n\ntype UseRecipeExecutionsResult = {\n  runDialogOpen: boolean;\n  runDialogKind: RecipeExecutionKind;\n  setRunDialogKind: (kind: RecipeExecutionKind) => void;\n  setRunDialogOpen: (open: boolean) => void;\n  previewRows: number;\n  fullRows: number;\n  fullRunName: string;\n  setPreviewRows: (rows: number) => void;\n  setFullRows: (rows: number) => void;\n  setFullRunName: (name: string) => void;\n  runErrors: string[];\n  runSettings: RecipeRunSettings;\n  setRunSettings: (patch: Partial<RecipeRunSettings>) => void;\n  previewLoading: boolean;\n  fullLoading: boolean;\n  executions: RecipeExecutionRecord[];\n  selectedExecutionId: string | null;\n  setSelectedExecutionId: (id: string) => void;\n  openRunDialog: (kind: RecipeExecutionKind) => void;\n  runFromDialog: () => Promise<boolean>;\n  validateFromDialog: () => Promise<boolean>;\n  validateLoading: boolean;\n  validateResult: {\n    valid: boolean;\n    errors: string[];\n    rawDetail: string | null;\n  } | null;\n  runPreview: () => Promise<boolean>;\n  runFull: () => Promise<boolean>;\n  cancelExecution: (id: string) => Promise<void>;\n  loadExecutionDatasetPage: (id: string, page: number) => Promise<void>;\n};\n\nfunction formatValidationMessages(input: {\n  errors: Array<{ message: string; path?: string | null; code?: string | null }>;\n}): string[] {\n  return input.errors.map((item) => {\n    const path = item.path?.trim();\n    const code = item.code?.trim();\n    const prefix = [\n      code ? code.toUpperCase() : null,\n      path ? `column ${path}` : null,\n    ]\n      .filter(Boolean)\n      .join(\" · \");\n    return prefix ? `${prefix}: ${item.message}` : item.message;\n  });\n}\n\nexport function useRecipeExecutions({\n  recipeId,\n  currentSignature,\n  payloadResult,\n  onExecutionStart,\n  onPreviewSuccess,\n}: UseRecipeExecutionsParams): UseRecipeExecutionsResult {\n  const [validateLoading, setValidateLoading] = useState(false);\n  const [validateResult, setValidateResult] = useState<{\n    valid: boolean;\n    errors: string[];\n    rawDetail: string | null;\n  } | null>(null);\n  const {\n    runDialogOpen,\n    runDialogKind,\n    previewRows,\n    fullRows,\n    fullRunName,\n    runErrors,\n    runSettings,\n    previewLoading,\n    fullLoading,\n    executions,\n    selectedExecutionId,\n    setRunDialogOpen,\n    setRunDialogKind,\n    setPreviewRows,\n    setFullRows,\n    setFullRunName,\n    setRunErrors,\n    setRunSettings,\n    setPreviewLoading,\n    setFullLoading,\n    setExecutions,\n    upsertExecution,\n    selectExecution,\n    resetForRecipe,\n  } = useRecipeExecutionsStore(\n    useShallow((state) => ({\n      runDialogOpen: state.runDialogOpen,\n      runDialogKind: state.runDialogKind,\n      previewRows: state.previewRows,\n      fullRows: state.fullRows,\n      fullRunName: state.fullRunName,\n      runErrors: state.runErrors,\n      runSettings: state.runSettings,\n      previewLoading: state.previewLoading,\n      fullLoading: state.fullLoading,\n      executions: state.executions,\n      selectedExecutionId: state.selectedExecutionId,\n      setRunDialogOpen: state.setRunDialogOpen,\n      setRunDialogKind: state.setRunDialogKind,\n      setPreviewRows: state.setPreviewRows,\n      setFullRows: state.setFullRows,\n      setFullRunName: state.setFullRunName,\n      setRunErrors: state.setRunErrors,\n      setRunSettings: state.setRunSettings,\n      setPreviewLoading: state.setPreviewLoading,\n      setFullLoading: state.setFullLoading,\n      setExecutions: state.setExecutions,\n      upsertExecution: state.upsertExecution,\n      selectExecution: state.selectExecution,\n      resetForRecipe: state.resetForRecipe,\n    })),\n  );\n  const payloadErrorMessage = payloadResult.errors[0] ?? \"Invalid payload.\";\n\n  const upsertAndPersist = useCallback(\n    (record: RecipeExecutionRecord): void => {\n      const normalizedRecord = withExecutionDefaults(record);\n      upsertExecution(normalizedRecord);\n      void saveRecipeExecution(normalizedRecord).catch((error) => {\n        console.error(\"Save recipe execution failed:\", error);\n      });\n    },\n    [upsertExecution],\n  );\n\n  useEffect(() => {\n    let cancelled = false;\n\n    resetForRecipe();\n\n    async function hydrate(): Promise<void> {\n      try {\n        const records = await loadSortedRecipeExecutions(recipeId);\n        if (cancelled) {\n          return;\n        }\n\n        setExecutions(records);\n        const resumable = findResumableExecution(records);\n        if (!resumable?.jobId) {\n          return;\n        }\n\n        void trackRecipeExecution({\n          label: executionLabel(resumable.kind),\n          kind: resumable.kind,\n          rows: resumable.rows,\n          jobId: resumable.jobId,\n          initialExecution: resumable,\n          notify: false,\n          onUpsert: upsertAndPersist,\n          onSetPreviewErrors: setRunErrors,\n          onPreviewSuccess,\n        });\n      } catch (error) {\n        console.error(\"Load recipe executions failed:\", error);\n      }\n    }\n\n    void hydrate();\n\n    return () => {\n      cancelled = true;\n    };\n  }, [\n    onPreviewSuccess,\n    recipeId,\n    resetForRecipe,\n    setExecutions,\n    setRunErrors,\n    upsertAndPersist,\n  ]);\n\n  const readPayload = useCallback((): RecipePayload | null => {\n    if (payloadResult.errors.length === 0) {\n      return payloadResult.payload;\n    }\n    return null;\n  }, [payloadResult.errors.length, payloadResult.payload]);\n\n  const readExecutablePayload = useCallback((): RecipePayload | null => {\n    const payload = readPayload();\n    if (payload) {\n      return payload;\n    }\n\n    setRunErrors(payloadResult.errors);\n    toastError(\"Invalid recipe payload\", payloadErrorMessage);\n    return null;\n  }, [payloadErrorMessage, payloadResult.errors, readPayload, setRunErrors]);\n\n  const runExecution = useCallback(\n    async (input: {\n      kind: RecipeExecutionKind;\n      payload: RecipePayload;\n      rows: number;\n      settings: RecipeRunSettings;\n      runName: string | null;\n    }): Promise<boolean> => {\n      const { kind, payload, rows, settings, runName } = input;\n      const setLoading = kind === \"preview\" ? setPreviewLoading : setFullLoading;\n      const label = executionLabel(kind);\n\n      setLoading(true);\n      const baseExecution = createBaseExecutionRecord({\n        recipeId,\n        kind,\n        rows,\n        currentSignature,\n        runName,\n      });\n\n      upsertAndPersist(baseExecution);\n      onExecutionStart?.();\n      setRunDialogOpen(false);\n\n      try {\n        const jobPayload = buildExecutionPayload({\n          payload,\n          kind,\n          rows,\n          settings,\n          runName,\n        });\n        const createdJob = await createRecipeJob(jobPayload);\n        const executionWithJob = {\n          ...baseExecution,\n          jobId: createdJob.job_id,\n        };\n        upsertAndPersist(executionWithJob);\n\n        return await trackRecipeExecution({\n          label,\n          kind,\n          rows,\n          jobId: createdJob.job_id,\n          initialExecution: executionWithJob,\n          notify: true,\n          onUpsert: upsertAndPersist,\n          onSetPreviewErrors: setRunErrors,\n          onPreviewSuccess,\n        });\n      } catch (error) {\n        const message = toErrorMessage(error, `${label} request failed.`);\n        upsertAndPersist({\n          ...baseExecution,\n          status: \"error\",\n          error: message,\n          finishedAt: Date.now(),\n        });\n        setRunErrors([message]);\n        toastError(`${label} failed`, message);\n        return false;\n      } finally {\n        setLoading(false);\n      }\n    },\n    [\n      currentSignature,\n      onExecutionStart,\n      onPreviewSuccess,\n      recipeId,\n      setFullLoading,\n      setPreviewLoading,\n      setRunDialogOpen,\n      setRunErrors,\n      upsertAndPersist,\n    ],\n  );\n\n  const runWithValidation = useCallback(\n    async (\n      kind: RecipeExecutionKind,\n      rows: number,\n      runName: string | null,\n    ): Promise<boolean> => {\n      const trimmedRunName = typeof runName === \"string\" ? runName.trim() : \"\";\n      if (kind === \"full\" && !trimmedRunName) {\n        const message = \"Run name required for full runs.\";\n        setRunErrors([message]);\n        toastError(\"Run name required\", message);\n        return false;\n      }\n\n      const payload = readExecutablePayload();\n      if (!payload) {\n        return false;\n      }\n\n      const normalizedRows = sanitizeExecutionRows(rows, kind);\n      const executionPayload = buildExecutionPayload({\n        payload,\n        kind,\n        rows: normalizedRows,\n        settings: runSettings,\n        runName,\n      });\n\n      try {\n        const validation = await validateRecipe(executionPayload);\n        if (!validation.valid) {\n          const errors = formatValidationMessages({ errors: validation.errors });\n          const fallback = validation.raw_detail ?? \"Validation failed.\";\n          const nextErrors = errors.length > 0 ? errors : [fallback];\n          setRunErrors(nextErrors);\n          toastError(\"Validation failed\", nextErrors[0]);\n          return false;\n        }\n      } catch (error) {\n        const message = toErrorMessage(error, \"Validation failed.\");\n        setRunErrors([message]);\n        toastError(\"Validation failed\", message);\n        return false;\n      }\n\n      return runExecution({\n        kind,\n        payload,\n        rows: normalizedRows,\n        settings: runSettings,\n        runName,\n      });\n    },\n    [readExecutablePayload, runExecution, runSettings, setRunErrors],\n  );\n\n  const runPreview = useCallback(async (): Promise<boolean> => {\n    return runWithValidation(\"preview\", previewRows, null);\n  }, [previewRows, runWithValidation]);\n\n  const runFull = useCallback(async (): Promise<boolean> => {\n    return runWithValidation(\"full\", fullRows, fullRunName);\n  }, [fullRows, fullRunName, runWithValidation]);\n\n  const runFromDialog = useCallback(async (): Promise<boolean> => {\n    setValidateResult(null);\n    if (runDialogKind === \"preview\") {\n      return runPreview();\n    }\n    return runFull();\n  }, [runDialogKind, runFull, runPreview]);\n\n  const validateFromDialog = useCallback(async (): Promise<boolean> => {\n    setRunErrors([]);\n    const payload = readPayload();\n    if (!payload) {\n      const nextErrors = payloadResult.errors.length > 0\n        ? payloadResult.errors\n        : [payloadErrorMessage];\n      setValidateResult({\n        valid: false,\n        errors: nextErrors,\n        rawDetail: null,\n      });\n      return false;\n    }\n\n    const rows = runDialogKind === \"preview\" ? previewRows : fullRows;\n    const normalizedRows = sanitizeExecutionRows(rows, runDialogKind);\n    const executionPayload = buildExecutionPayload({\n      payload,\n      kind: runDialogKind,\n      rows: normalizedRows,\n      settings: runSettings,\n      runName: runDialogKind === \"full\" ? normalizeRunName(fullRunName) : null,\n    });\n\n    setValidateLoading(true);\n    try {\n      const validation = await validateRecipe(executionPayload);\n      const errors = formatValidationMessages({ errors: validation.errors });\n      setValidateResult({\n        valid: validation.valid,\n        errors,\n        rawDetail: validation.raw_detail ?? null,\n      });\n      return validation.valid;\n    } catch (error) {\n      const message = toErrorMessage(error, \"Validation failed.\");\n      setValidateResult({\n        valid: false,\n        errors: [message],\n        rawDetail: null,\n      });\n      return false;\n    } finally {\n      setValidateLoading(false);\n    }\n  }, [\n    fullRunName,\n    fullRows,\n    payloadErrorMessage,\n    payloadResult.errors,\n    previewRows,\n    readPayload,\n    runDialogKind,\n    runSettings,\n    setRunErrors,\n  ]);\n\n  const openRunDialog = useCallback(\n    (kind: RecipeExecutionKind): void => {\n      setRunErrors([]);\n      setValidateResult(null);\n      setRunDialogKind(kind);\n      if (kind === \"full\") {\n        const payload = readPayload();\n        const payloadRows = Number(payload?.run?.rows);\n        if (Number.isFinite(payloadRows) && payloadRows > 0) {\n          setFullRows(Math.floor(payloadRows));\n        }\n      }\n      setRunDialogOpen(true);\n    },\n    [\n      readPayload,\n      setFullRows,\n      setRunDialogKind,\n      setRunDialogOpen,\n      setRunErrors,\n    ],\n  );\n\n  const cancelExecution = useCallback(\n    async (id: string): Promise<void> => {\n      const execution = executions.find((entry) => entry.id === id);\n      if (!execution?.jobId) {\n        return;\n      }\n      try {\n        await cancelRecipeJob(execution.jobId);\n        upsertAndPersist({\n          ...execution,\n          status: \"cancelling\",\n        });\n      } catch (error) {\n        const message = toErrorMessage(error, \"Could not cancel execution.\");\n        toastError(\"Cancel failed\", message);\n      }\n    },\n    [executions, upsertAndPersist],\n  );\n\n  const loadExecutionDatasetPage = useCallback(\n    async (id: string, page: number): Promise<void> => {\n      const execution = executions.find((entry) => entry.id === id);\n      if (!execution || execution.kind !== \"full\" || !execution.jobId || page < 1) {\n        return;\n      }\n\n      const pageSize = execution.datasetPageSize || DATASET_PAGE_SIZE;\n      const offset = (page - 1) * pageSize;\n      try {\n        const response = await getRecipeJobDataset(execution.jobId, {\n          limit: pageSize,\n          offset,\n        });\n        const dataset = normalizeDatasetRows(response.dataset);\n        const total =\n          typeof response.total === \"number\" ? response.total : execution.datasetTotal;\n        upsertAndPersist({\n          ...execution,\n          dataset,\n          datasetTotal: total,\n          datasetPage: page,\n        });\n      } catch (error) {\n        const message = toErrorMessage(error, \"Could not load dataset page.\");\n        toastError(\"Dataset page failed\", message);\n      }\n    },\n    [executions, upsertAndPersist],\n  );\n\n  const setSelectedExecutionId = useCallback(\n    (id: string): void => {\n      selectExecution(id);\n    },\n    [selectExecution],\n  );\n\n  return {\n    runDialogOpen,\n    runDialogKind,\n    setRunDialogKind,\n    setRunDialogOpen,\n    previewRows,\n    fullRows,\n    fullRunName,\n    setPreviewRows,\n    setFullRows,\n    setFullRunName,\n    runErrors,\n    runSettings,\n    setRunSettings,\n    previewLoading,\n    fullLoading,\n    executions,\n    selectedExecutionId,\n    setSelectedExecutionId,\n    openRunDialog,\n    runFromDialog,\n    validateFromDialog,\n    validateLoading,\n    validateResult,\n    runPreview,\n    runFull,\n    cancelExecution,\n    loadExecutionDatasetPage,\n  };\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/hooks/use-recipe-persistence.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { useCallback, useEffect, useMemo, useState } from \"react\";\nimport { toastError, toastSuccess } from \"@/shared/toast\";\nimport { normalizeNonEmptyName } from \"@/utils\";\nimport {\n  buildSignature,\n  copyTextToClipboard,\n  formatSavedLabel,\n} from \"../executions/execution-helpers\";\nimport { importRecipePayload, type RecipeSnapshot } from \"../utils/import\";\nimport type { RecipePayloadResult } from \"../utils/payload/types\";\n\ntype SaveTone = \"success\" | \"error\";\n\ntype PersistRecipeFn = (input: {\n  id: string | null;\n  name: string;\n  payload: RecipePayloadResult[\"payload\"];\n}) => Promise<{\n  id: string;\n  updatedAt: number;\n}>;\n\ntype UseRecipePersistenceParams = {\n  recipeId: string;\n  initialRecipeName: string;\n  initialPayload: RecipePayloadResult[\"payload\"];\n  initialSavedAt: number;\n  payloadResult: RecipePayloadResult;\n  onPersistRecipe: PersistRecipeFn;\n  resetRecipe: () => void;\n  loadRecipe: (snapshot: RecipeSnapshot) => void;\n  getCurrentPayloadFromStore: () => RecipePayloadResult[\"payload\"];\n};\n\ntype UseRecipePersistenceResult = {\n  initialRecipeReady: boolean;\n  workflowName: string;\n  setWorkflowName: (value: string) => void;\n  saveLoading: boolean;\n  saveTone: SaveTone;\n  savedAtLabel: string;\n  copied: boolean;\n  importOpen: boolean;\n  setImportOpen: (open: boolean) => void;\n  currentSignature: string;\n  persistRecipe: () => Promise<void>;\n  copyRecipe: () => Promise<void>;\n  importRecipe: (value: string) => string | null;\n};\n\nfunction stripApiKeys(value: unknown): unknown {\n  if (Array.isArray(value)) {\n    return value.map(stripApiKeys);\n  }\n  if (!value || typeof value !== \"object\") {\n    return value;\n  }\n  const output: Record<string, unknown> = {};\n  for (const [key, entry] of Object.entries(value)) {\n    if (key === \"api_key\") {\n      continue;\n    }\n    output[key] = stripApiKeys(entry);\n  }\n  if (\n    output.provider_type === \"stdio\" &&\n    output.env &&\n    typeof output.env === \"object\" &&\n    !Array.isArray(output.env)\n  ) {\n    output.env = Object.fromEntries(\n      Object.keys(output.env as Record<string, unknown>).map((envKey) => [envKey, \"\"]),\n    );\n  }\n  return output;\n}\n\nfunction inferHfRepoIdFromPath(pathValue: unknown): string {\n  if (typeof pathValue !== \"string\") {\n    return \"\";\n  }\n  const parts = pathValue\n    .trim()\n    .split(\"/\")\n    .filter(Boolean);\n  if (parts.length >= 3 && parts[0] === \"datasets\") {\n    return `${parts[1]}/${parts[2]}`;\n  }\n  if (parts.length >= 2) {\n    return `${parts[0]}/${parts[1]}`;\n  }\n  return \"\";\n}\n\nfunction sanitizeSeedForShare(payload: unknown): unknown {\n  if (!payload || typeof payload !== \"object\") {\n    return payload;\n  }\n  const root = payload as Record<string, unknown>;\n  const recipe =\n    root.recipe && typeof root.recipe === \"object\"\n      ? (root.recipe as Record<string, unknown>)\n      : null;\n  const ui =\n    root.ui && typeof root.ui === \"object\"\n      ? (root.ui as Record<string, unknown>)\n      : null;\n\n  const seedConfig =\n    recipe?.seed_config && typeof recipe.seed_config === \"object\"\n      ? (recipe.seed_config as Record<string, unknown>)\n      : null;\n  const source =\n    seedConfig?.source && typeof seedConfig.source === \"object\"\n      ? (seedConfig.source as Record<string, unknown>)\n      : null;\n\n  if (source && \"token\" in source) {\n    delete source.token;\n  }\n\n  const uiSourceType =\n    typeof ui?.seed_source_type === \"string\" ? ui.seed_source_type : null;\n  const sourceType =\n    typeof source?.seed_type === \"string\" ? source.seed_type : null;\n  const shouldResetHfState =\n    sourceType === \"hf\" || uiSourceType === \"hf\";\n  const shouldResetLocalState =\n    sourceType === \"local\" ||\n    sourceType === \"unstructured\" ||\n    uiSourceType === \"local\" ||\n    uiSourceType === \"unstructured\";\n\n  if (shouldResetHfState) {\n    const repoId = inferHfRepoIdFromPath(source?.path);\n    if (source && \"path\" in source) {\n      source.path = repoId;\n    }\n    if (ui) {\n      ui.seed_columns = [];\n      ui.seed_drop_columns = [];\n      ui.seed_preview_rows = [];\n      ui.local_file_name = \"\";\n      ui.unstructured_file_name = \"\";\n    }\n  }\n\n  if (shouldResetLocalState) {\n    if (source && \"path\" in source) {\n      source.path = \"\";\n    }\n    if (ui) {\n      ui.seed_columns = [];\n      ui.seed_drop_columns = [];\n      ui.seed_preview_rows = [];\n      ui.local_file_name = \"\";\n      ui.unstructured_file_name = \"\";\n    }\n  }\n\n  return root;\n}\n\nexport function useRecipePersistence({\n  recipeId,\n  initialRecipeName,\n  initialPayload,\n  initialSavedAt,\n  payloadResult,\n  onPersistRecipe,\n  resetRecipe,\n  loadRecipe,\n  getCurrentPayloadFromStore,\n}: UseRecipePersistenceParams): UseRecipePersistenceResult {\n  const [initialRecipeReady, setInitialRecipeReady] = useState(false);\n  const [workflowName, setWorkflowName] = useState(\"Unnamed\");\n  const [lastSavedAt, setLastSavedAt] = useState<number | null>(null);\n  const [savedSignature, setSavedSignature] = useState(\"\");\n  const [saveLoading, setSaveLoading] = useState(false);\n  const [copied, setCopied] = useState(false);\n  const [importOpen, setImportOpen] = useState(false);\n\n  const normalizedWorkflowName = useMemo(\n    () => normalizeNonEmptyName(workflowName, \"Unnamed\"),\n    [workflowName],\n  );\n  const currentPayload = payloadResult.payload;\n  const currentSignature = useMemo(\n    () => buildSignature(normalizedWorkflowName, currentPayload),\n    [currentPayload, normalizedWorkflowName],\n  );\n  const isDirty = savedSignature.length > 0 && currentSignature !== savedSignature;\n  const saveTone: SaveTone = !isDirty && Boolean(lastSavedAt) ? \"success\" : \"error\";\n  const savedAtLabel = formatSavedLabel(lastSavedAt);\n\n  useEffect(() => {\n    setInitialRecipeReady(false);\n    const nextName = normalizeNonEmptyName(initialRecipeName, \"Unnamed\");\n    resetRecipe();\n    setWorkflowName(nextName);\n    setLastSavedAt(initialSavedAt);\n    setCopied(false);\n\n    const parsed = importRecipePayload(JSON.stringify(initialPayload));\n    if (parsed.snapshot) {\n      loadRecipe(parsed.snapshot);\n    } else {\n      console.error(\"Failed to load recipe payload.\", parsed.errors);\n    }\n\n    const payload = getCurrentPayloadFromStore();\n    setSavedSignature(buildSignature(nextName, payload));\n    setInitialRecipeReady(true);\n  }, [\n    getCurrentPayloadFromStore,\n    initialPayload,\n    initialRecipeName,\n    initialSavedAt,\n    loadRecipe,\n    recipeId,\n    resetRecipe,\n  ]);\n\n  const persistRecipe = useCallback(async (): Promise<void> => {\n    if (saveLoading) {\n      return;\n    }\n    const nextName = normalizeNonEmptyName(workflowName, \"Unnamed\");\n    if (nextName !== workflowName) {\n      setWorkflowName(nextName);\n    }\n\n    setSaveLoading(true);\n    try {\n      const result = await onPersistRecipe({\n        id: recipeId,\n        name: nextName,\n        payload: currentPayload,\n      });\n      setLastSavedAt(result.updatedAt);\n      setSavedSignature(buildSignature(nextName, currentPayload));\n    } catch (error) {\n      console.error(\"Save recipe failed:\", error);\n      toastError(\"Save failed\", \"Could not save recipe.\");\n    } finally {\n      setSaveLoading(false);\n    }\n  }, [currentPayload, onPersistRecipe, recipeId, saveLoading, workflowName]);\n\n  useEffect(() => {\n    if (!isDirty || saveLoading) {\n      return;\n    }\n    const timeoutId = window.setTimeout(() => {\n      void persistRecipe();\n    }, 800);\n    return () => window.clearTimeout(timeoutId);\n  }, [isDirty, persistRecipe, saveLoading]);\n\n  const copyRecipe = useCallback(async (): Promise<void> => {\n    setCopied(false);\n    try {\n      const safePayload = sanitizeSeedForShare(stripApiKeys(payloadResult.payload));\n      const ok = await copyTextToClipboard(JSON.stringify(safePayload, null, 2));\n      if (!ok) {\n        throw new Error(\"Clipboard not available.\");\n      }\n      setCopied(true);\n      window.setTimeout(() => setCopied(false), 1500);\n      toastSuccess(\"👨‍🍳 Recipe copied\");\n    } catch (error) {\n      console.error(\"Copy failed:\", error);\n      toastError(\"Copy failed\", \"Could not copy payload.\");\n    }\n  }, [payloadResult.payload]);\n\n  const importRecipe = useCallback(\n    (value: string): string | null => {\n      const result = importRecipePayload(value);\n      if (result.errors.length > 0 || !result.snapshot) {\n        return result.errors[0] ?? \"Invalid payload.\";\n      }\n      loadRecipe(result.snapshot);\n      toastSuccess(\"Recipe imported\");\n      return null;\n    },\n    [loadRecipe],\n  );\n\n  return {\n    initialRecipeReady,\n    workflowName,\n    setWorkflowName,\n    saveLoading,\n    saveTone,\n    savedAtLabel,\n    copied,\n    importOpen,\n    setImportOpen,\n    currentSignature,\n    persistRecipe,\n    copyRecipe,\n    importRecipe,\n  };\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/hooks/use-recipe-runtime-visuals.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport {\n  BalanceScaleIcon,\n  Clock01Icon,\n  CodeIcon,\n  CodeSimpleIcon,\n  DiceFaces03Icon,\n  EqualSignIcon,\n  FingerPrintIcon,\n  FunctionIcon,\n  Plug01Icon,\n  Parabola02Icon,\n  PencilEdit02Icon,\n  Plant01Icon,\n  Shield02Icon,\n  Tag01Icon,\n  TagsIcon,\n  UserAccountIcon,\n} from \"@hugeicons/core-free-icons\";\nimport { useMemo } from \"react\";\nimport type { Edge } from \"@xyflow/react\";\nimport { deriveDisplayGraph } from \"../utils/graph/derive-display-graph\";\nimport {\n  deriveGraphRuntimeVisualState,\n  pickLatestActiveExecution,\n} from \"../utils/graph/runtime-visual-state\";\nimport type {\n  LayoutDirection,\n  LlmType,\n  NodeConfig,\n  RecipeNode as RecipeBuilderNode,\n  SamplerType,\n} from \"../types\";\nimport type { RecipeExecutionRecord } from \"../execution-types\";\n\ntype IconType = typeof CodeIcon;\n\nconst SAMPLER_ICONS: Record<SamplerType, IconType> = {\n  category: Tag01Icon,\n  subcategory: TagsIcon,\n  uniform: EqualSignIcon,\n  gaussian: Parabola02Icon,\n  bernoulli: EqualSignIcon,\n  datetime: Clock01Icon,\n  timedelta: Clock01Icon,\n  uuid: FingerPrintIcon,\n  person: UserAccountIcon,\n  person_from_faker: UserAccountIcon,\n};\n\nconst LLM_ICONS: Record<LlmType, IconType> = {\n  text: PencilEdit02Icon,\n  structured: CodeIcon,\n  code: CodeSimpleIcon,\n  judge: BalanceScaleIcon,\n};\n\nfunction resolveExecutionColumnIcon(config: NodeConfig | null): IconType {\n  if (!config) {\n    return DiceFaces03Icon;\n  }\n  if (config.kind === \"sampler\") {\n    return SAMPLER_ICONS[config.sampler_type];\n  }\n  if (config.kind === \"llm\") {\n    return LLM_ICONS[config.llm_type];\n  }\n  if (config.kind === \"expression\") {\n    return FunctionIcon;\n  }\n  if (config.kind === \"validator\") {\n    return Shield02Icon;\n  }\n  if (config.kind === \"seed\") {\n    return Plant01Icon;\n  }\n  if (config.kind === \"model_provider\") {\n    return Shield02Icon;\n  }\n  if (config.kind === \"model_config\") {\n    return Plant01Icon;\n  }\n  if (config.kind === \"tool_config\") {\n    return Plug01Icon;\n  }\n  return PencilEdit02Icon;\n}\n\ntype UseRecipeRuntimeVisualsArgs = {\n  executions: RecipeExecutionRecord[];\n  configs: Record<string, NodeConfig>;\n  nodes: RecipeBuilderNode[];\n  edges: Edge[];\n  layoutDirection: LayoutDirection;\n  auxNodePositions: Record<string, { x: number; y: number }>;\n  llmAuxVisibility: Record<string, boolean>;\n};\n\ntype UseRecipeRuntimeVisualsResult = {\n  activeExecution: RecipeExecutionRecord | null;\n  runtimeVisualState: ReturnType<typeof deriveGraphRuntimeVisualState>;\n  displayGraph: ReturnType<typeof deriveDisplayGraph>;\n  displayNodeIds: string[];\n  currentColumnIcon: IconType;\n};\n\nexport function useRecipeRuntimeVisuals({\n  executions,\n  configs,\n  nodes,\n  edges,\n  layoutDirection,\n  auxNodePositions,\n  llmAuxVisibility,\n}: UseRecipeRuntimeVisualsArgs): UseRecipeRuntimeVisualsResult {\n  const activeExecution = useMemo(\n    () => pickLatestActiveExecution(executions),\n    [executions],\n  );\n\n  const runtimeVisualState = useMemo(\n    () =>\n      deriveGraphRuntimeVisualState({\n        activeExecution,\n        configs,\n        edges,\n      }),\n    [activeExecution, configs, edges],\n  );\n\n  const displayGraph = useMemo(\n    () =>\n      deriveDisplayGraph({\n        nodes,\n        edges,\n        configs,\n        layoutDirection,\n        auxNodePositions,\n        llmAuxVisibility,\n        runtime: runtimeVisualState,\n      }),\n    [\n      auxNodePositions,\n      configs,\n      edges,\n      layoutDirection,\n      llmAuxVisibility,\n      nodes,\n      runtimeVisualState,\n    ],\n  );\n\n  const currentColumnConfig = useMemo(() => {\n    const columnName = activeExecution?.current_column?.trim();\n    if (!columnName) {\n      return null;\n    }\n    for (const config of Object.values(configs)) {\n      if (config.name.trim() === columnName) {\n        return config;\n      }\n    }\n    return null;\n  }, [activeExecution?.current_column, configs]);\n\n  const currentColumnIcon = useMemo(\n    () => resolveExecutionColumnIcon(currentColumnConfig),\n    [currentColumnConfig],\n  );\n\n  const displayNodeIds = useMemo(\n    () => displayGraph.nodes.map((node) => node.id),\n    [displayGraph.nodes],\n  );\n\n  return {\n    activeExecution,\n    runtimeVisualState,\n    displayGraph,\n    displayNodeIds,\n    currentColumnIcon,\n  };\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/hooks/use-recipe-studio-actions.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { useRecipeExecutions } from \"./use-recipe-executions\";\nimport { useRecipePersistence } from \"./use-recipe-persistence\";\nimport type {\n  RecipeExecutionKind,\n  RecipeExecutionRecord,\n} from \"../execution-types\";\nimport type { RecipeRunSettings } from \"../stores/recipe-executions\";\nimport type { RecipeSnapshot } from \"../utils/import\";\nimport type { RecipePayload, RecipePayloadResult } from \"../utils/payload/types\";\n\ntype SaveTone = \"success\" | \"error\";\n\ntype PersistRecipeFn = (input: {\n  id: string | null;\n  name: string;\n  payload: RecipePayload;\n}) => Promise<{\n  id: string;\n  updatedAt: number;\n}>;\n\ntype UseRecipeStudioActionsParams = {\n  recipeId: string;\n  initialRecipeName: string;\n  initialPayload: RecipePayload;\n  initialSavedAt: number;\n  payloadResult: RecipePayloadResult;\n  onPersistRecipe: PersistRecipeFn;\n  resetRecipe: () => void;\n  loadRecipe: (snapshot: RecipeSnapshot) => void;\n  getCurrentPayloadFromStore: () => RecipePayload;\n  onExecutionStart?: () => void;\n  onPreviewSuccess?: () => void;\n};\n\ntype UseRecipeStudioActionsResult = {\n  initialRecipeReady: boolean;\n  workflowName: string;\n  setWorkflowName: (value: string) => void;\n  saveLoading: boolean;\n  saveTone: SaveTone;\n  savedAtLabel: string;\n  copied: boolean;\n  importOpen: boolean;\n  setImportOpen: (open: boolean) => void;\n  runDialogOpen: boolean;\n  runDialogKind: RecipeExecutionKind;\n  setRunDialogKind: (kind: RecipeExecutionKind) => void;\n  setRunDialogOpen: (open: boolean) => void;\n  previewRows: number;\n  fullRows: number;\n  fullRunName: string;\n  setPreviewRows: (rows: number) => void;\n  setFullRows: (rows: number) => void;\n  setFullRunName: (name: string) => void;\n  runErrors: string[];\n  runSettings: RecipeRunSettings;\n  setRunSettings: (patch: Partial<RecipeRunSettings>) => void;\n  previewLoading: boolean;\n  fullLoading: boolean;\n  currentSignature: string;\n  executions: RecipeExecutionRecord[];\n  selectedExecutionId: string | null;\n  setSelectedExecutionId: (id: string) => void;\n  persistRecipe: () => Promise<void>;\n  openRunDialog: (kind: RecipeExecutionKind) => void;\n  runFromDialog: () => Promise<boolean>;\n  validateFromDialog: () => Promise<boolean>;\n  validateLoading: boolean;\n  validateResult: {\n    valid: boolean;\n    errors: string[];\n    rawDetail: string | null;\n  } | null;\n  runPreview: () => Promise<boolean>;\n  runFull: () => Promise<boolean>;\n  cancelExecution: (id: string) => Promise<void>;\n  loadExecutionDatasetPage: (id: string, page: number) => Promise<void>;\n  copyRecipe: () => Promise<void>;\n  importRecipe: (value: string) => string | null;\n};\n\nexport function useRecipeStudioActions({\n  recipeId,\n  initialRecipeName,\n  initialPayload,\n  initialSavedAt,\n  payloadResult,\n  onPersistRecipe,\n  resetRecipe,\n  loadRecipe,\n  getCurrentPayloadFromStore,\n  onExecutionStart,\n  onPreviewSuccess,\n}: UseRecipeStudioActionsParams): UseRecipeStudioActionsResult {\n  const persistence = useRecipePersistence({\n    recipeId,\n    initialRecipeName,\n    initialPayload,\n    initialSavedAt,\n    payloadResult,\n    onPersistRecipe,\n    resetRecipe,\n    loadRecipe,\n    getCurrentPayloadFromStore,\n  });\n\n  const executions = useRecipeExecutions({\n    recipeId,\n    currentSignature: persistence.currentSignature,\n    payloadResult,\n    onExecutionStart,\n    onPreviewSuccess,\n  });\n\n  return {\n    initialRecipeReady: persistence.initialRecipeReady,\n    workflowName: persistence.workflowName,\n    setWorkflowName: persistence.setWorkflowName,\n    saveLoading: persistence.saveLoading,\n    saveTone: persistence.saveTone,\n    savedAtLabel: persistence.savedAtLabel,\n    copied: persistence.copied,\n    importOpen: persistence.importOpen,\n    setImportOpen: persistence.setImportOpen,\n    runDialogOpen: executions.runDialogOpen,\n    runDialogKind: executions.runDialogKind,\n    setRunDialogKind: executions.setRunDialogKind,\n    setRunDialogOpen: executions.setRunDialogOpen,\n    previewRows: executions.previewRows,\n    fullRows: executions.fullRows,\n    fullRunName: executions.fullRunName,\n    setPreviewRows: executions.setPreviewRows,\n    setFullRows: executions.setFullRows,\n    setFullRunName: executions.setFullRunName,\n    runErrors: executions.runErrors,\n    runSettings: executions.runSettings,\n    setRunSettings: executions.setRunSettings,\n    previewLoading: executions.previewLoading,\n    fullLoading: executions.fullLoading,\n    currentSignature: persistence.currentSignature,\n    executions: executions.executions,\n    selectedExecutionId: executions.selectedExecutionId,\n    setSelectedExecutionId: executions.setSelectedExecutionId,\n    persistRecipe: persistence.persistRecipe,\n    openRunDialog: executions.openRunDialog,\n    runFromDialog: executions.runFromDialog,\n    validateFromDialog: executions.validateFromDialog,\n    validateLoading: executions.validateLoading,\n    validateResult: executions.validateResult,\n    runPreview: executions.runPreview,\n    runFull: executions.runFull,\n    cancelExecution: executions.cancelExecution,\n    loadExecutionDatasetPage: executions.loadExecutionDatasetPage,\n    copyRecipe: persistence.copyRecipe,\n    importRecipe: persistence.importRecipe,\n  };\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/index.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nexport { RecipeStudioPage } from \"./recipe-studio-page\";\nexport type {\n  PersistRecipeInput,\n  PersistRecipeResult,\n  RecipeStudioPageProps,\n} from \"./recipe-studio-page\";\nexport type { RecipePayload } from \"./utils/payload/types\";\nexport { createEmptyRecipePayload } from \"./utils/payload/empty\";\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/recipe-studio-page.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport {\n  DocumentAttachmentIcon,\n  PlusSignIcon,\n} from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport {\n  Background,\n  BackgroundVariant,\n  type Edge,\n  type EdgeTypes,\n  type Node,\n  type NodeTypes,\n  Panel,\n  ReactFlow,\n  type ReactFlowInstance,\n} from \"@xyflow/react\";\nimport {\n  type ReactElement,\n  useCallback,\n  useEffect,\n  useMemo,\n  useRef,\n  useState,\n} from \"react\";\nimport { useShallow } from \"zustand/react/shallow\";\nimport \"@xyflow/react/dist/style.css\";\nimport { Button } from \"@/components/ui/button\";\nimport { BlockSheet } from \"./components/block-sheet\";\nimport { LayoutControls } from \"./components/controls/layout-controls\";\nimport { RunValidateFloatingControls } from \"./components/controls/run-validate-floating-controls\";\nimport { ViewportControls } from \"./components/controls/viewport-controls\";\nimport { ExecutionsView } from \"./components/executions/executions-view\";\nimport { InternalsSync } from \"./components/graph/internals-sync\";\nimport {\n  RecipeGraphAuxNode,\n  type RecipeGraphAuxNodeData,\n} from \"./components/recipe-graph-aux-node\";\nimport { RecipeNode } from \"./components/recipe-graph-node\";\nimport { RecipeGraphSemanticEdge } from \"./components/recipe-graph-semantic-edge\";\nimport { RecipeStudioHeader } from \"./components/recipe-studio-header\";\nimport { DataEdge } from \"./components/rf-ui/data-edge\";\nimport { ExecutionProgressIsland } from \"./components/runtime/execution-progress-island\";\nimport { ConfigDialog } from \"./dialogs/config-dialog\";\nimport { ImportDialog } from \"./dialogs/import-dialog\";\nimport { RunDialog } from \"./dialogs/preview-dialog\";\nimport { ProcessorsDialog } from \"./dialogs/processors-dialog\";\nimport type {\n  RecipeExecutionRecord,\n  RecipeStudioView,\n} from \"./execution-types\";\nimport { isExecutionInProgress } from \"./executions/execution-helpers\";\nimport { useRecipeEditorGraph } from \"./hooks/use-recipe-editor-graph\";\nimport { useRecipeRuntimeVisuals } from \"./hooks/use-recipe-runtime-visuals\";\nimport { useRecipeStudioActions } from \"./hooks/use-recipe-studio-actions\";\nimport { useRecipeStudioStore } from \"./stores/recipe-studio\";\nimport type { RecipeNodeData } from \"./types\";\nimport { getGraphWarnings } from \"./utils/graph-warnings\";\nimport { getFitNodeIdsIgnoringNotes } from \"./utils/graph/fit-view\";\nimport { buildRecipePayload } from \"./utils/payload\";\nimport type { RecipePayload } from \"./utils/payload/types\";\nimport { buildDefaultSchemaTransform } from \"./utils/processors\";\nimport { buildDialogOptions } from \"./utils/recipe-studio-view\";\n\nconst NODE_TYPES: NodeTypes = { builder: RecipeNode, aux: RecipeGraphAuxNode };\nconst EDGE_TYPES: EdgeTypes = {\n  canvas: DataEdge,\n  semantic: RecipeGraphSemanticEdge,\n};\nconst COMPLETE_ISLAND_VISIBLE_MS = 7_000;\nconst TAB_SWITCH_FIT_DELAY_MS = 110;\nconst FIT_ANIMATION_MS = 340;\n\nexport type PersistRecipeInput = {\n  id: string | null;\n  name: string;\n  payload: RecipePayload;\n};\n\nexport type PersistRecipeResult = {\n  id: string;\n  updatedAt: number;\n};\n\nexport type RecipeStudioPageProps = {\n  recipeId: string;\n  initialRecipeName: string;\n  initialPayload: RecipePayload;\n  initialSavedAt: number;\n  onPersistRecipe: (input: PersistRecipeInput) => Promise<PersistRecipeResult>;\n};\n\nexport function RecipeStudioPage({\n  recipeId,\n  initialRecipeName,\n  initialPayload,\n  initialSavedAt,\n  onPersistRecipe,\n}: RecipeStudioPageProps): ReactElement {\n  const {\n    nodes,\n    edges,\n    auxNodePositions,\n    llmAuxVisibility,\n    configs,\n    processors,\n    sheetOpen,\n    sheetView,\n    activeConfigId,\n    dialogOpen,\n    layoutDirection,\n    fitViewTick,\n    onNodesChange,\n    onEdgesChange,\n    onConnect,\n    addSamplerNode,\n    addSeedNode,\n    addLlmNode,\n    addModelProviderNode,\n    addModelConfigNode,\n    addToolProfileNode,\n    addExpressionNode,\n    addValidatorNode,\n    addMarkdownNoteNode,\n    selectConfig,\n    openConfig,\n    updateConfig,\n    isValidConnection,\n    setSheetOpen,\n    setSheetView,\n    setProcessors,\n    setDialogOpen,\n    resetRecipe,\n    loadRecipe,\n    setLayoutDirection,\n    applyLayout,\n    setAuxNodePosition,\n    setExecutionLocked,\n  } = useRecipeStudioStore(\n    useShallow((state) => ({\n      nodes: state.nodes,\n      edges: state.edges,\n      auxNodePositions: state.auxNodePositions,\n      llmAuxVisibility: state.llmAuxVisibility,\n      configs: state.configs,\n      processors: state.processors,\n      sheetOpen: state.sheetOpen,\n      sheetView: state.sheetView,\n      activeConfigId: state.activeConfigId,\n      dialogOpen: state.dialogOpen,\n      layoutDirection: state.layoutDirection,\n      fitViewTick: state.fitViewTick,\n      onNodesChange: state.onNodesChange,\n      onEdgesChange: state.onEdgesChange,\n      onConnect: state.onConnect,\n      addSamplerNode: state.addSamplerNode,\n      addSeedNode: state.addSeedNode,\n      addLlmNode: state.addLlmNode,\n      addModelProviderNode: state.addModelProviderNode,\n      addModelConfigNode: state.addModelConfigNode,\n      addToolProfileNode: state.addToolProfileNode,\n      addExpressionNode: state.addExpressionNode,\n      addValidatorNode: state.addValidatorNode,\n      addMarkdownNoteNode: state.addMarkdownNoteNode,\n      selectConfig: state.selectConfig,\n      openConfig: state.openConfig,\n      updateConfig: state.updateConfig,\n      isValidConnection: state.isValidConnection,\n      setSheetOpen: state.setSheetOpen,\n      setSheetView: state.setSheetView,\n      setProcessors: state.setProcessors,\n      setDialogOpen: state.setDialogOpen,\n      resetRecipe: state.resetRecipe,\n      loadRecipe: state.loadRecipe,\n      setLayoutDirection: state.setLayoutDirection,\n      applyLayout: state.applyLayout,\n      setAuxNodePosition: state.setAuxNodePosition,\n      setExecutionLocked: state.setExecutionLocked,\n    })),\n  );\n  const [sheetContainer, setSheetContainer] = useState<HTMLDivElement | null>(\n    null,\n  );\n  const flowContainerRef = useRef<HTMLDivElement | null>(null);\n  const [activeView, setActiveView] = useState<RecipeStudioView>(\"editor\");\n  const [processorsOpen, setProcessorsOpen] = useState(false);\n  const [interactive, setInteractive] = useState(true);\n  const [runtimeIslandMinimized, setRuntimeIslandMinimized] = useState(false);\n  const [recentCompletedExecution, setRecentCompletedExecution] =\n    useState<RecipeExecutionRecord | null>(null);\n  const [reactFlowInstance, setReactFlowInstance] = useState<ReactFlowInstance<\n    Node<RecipeNodeData | RecipeGraphAuxNodeData>,\n    Edge\n  > | null>(null);\n  const lastProcessedFitTickRef = useRef(0);\n  const previousActiveViewRef = useRef<RecipeStudioView>(\"editor\");\n  const previousActiveExecutionIdRef = useRef<string | null>(null);\n  const pendingEditorTabFitRef = useRef(false);\n  const forceEditorTabFitRef = useRef(false);\n  const viewportMovedSinceAutoFitRef = useRef(true);\n  const {\n    handleNodeClick,\n    handleNodeDoubleClick,\n    handleNodesChange,\n    handleEdgesChange,\n    handleDragOver,\n    handleDrop,\n    handleAddSamplerFromSheet,\n    handleAddSeedFromSheet,\n    handleAddLlmFromSheet,\n    handleAddModelProviderFromSheet,\n    handleAddModelConfigFromSheet,\n    handleAddToolProfileFromSheet,\n    handleAddExpressionFromSheet,\n    handleAddValidatorFromSheet,\n    handleAddMarkdownNoteFromSheet,\n  } = useRecipeEditorGraph({\n    nodes,\n    edges,\n    configs,\n    reactFlowInstance,\n    flowContainerRef,\n    selectConfig,\n    openConfig,\n    onNodesChange,\n    onEdgesChange,\n    setAuxNodePosition,\n    addSamplerNode,\n    addSeedNode,\n    addLlmNode,\n    addModelProviderNode,\n    addModelConfigNode,\n    addToolProfileNode,\n    addExpressionNode,\n    addValidatorNode,\n    addMarkdownNoteNode,\n  });\n\n  const configList = useMemo(() => Object.values(configs), [configs]);\n  const config = activeConfigId ? configs[activeConfigId] : null;\n  const dialogOptions = useMemo(\n    () => buildDialogOptions(configList),\n    [configList],\n  );\n\n  const handleToggleDirection = useCallback(() => {\n    setLayoutDirection(layoutDirection === \"LR\" ? \"TB\" : \"LR\");\n  }, [layoutDirection, setLayoutDirection]);\n\n  const payloadResult = useMemo(\n    () =>\n      buildRecipePayload(\n        configs,\n        nodes,\n        edges,\n        processors,\n        layoutDirection,\n        auxNodePositions,\n      ),\n    [auxNodePositions, configs, edges, layoutDirection, nodes, processors],\n  );\n  const getCurrentPayloadFromStore = useCallback((): RecipePayload => {\n    const state = useRecipeStudioStore.getState();\n    return buildRecipePayload(\n      state.configs,\n      state.nodes,\n      state.edges,\n      state.processors,\n      state.layoutDirection,\n      state.auxNodePositions,\n    ).payload;\n  }, []);\n  const {\n    initialRecipeReady,\n    workflowName,\n    setWorkflowName,\n    saveLoading,\n    saveTone,\n    savedAtLabel,\n    copied,\n    importOpen,\n    setImportOpen,\n    runDialogOpen,\n    runDialogKind,\n    setRunDialogKind,\n    setRunDialogOpen,\n    previewRows,\n    fullRows,\n    fullRunName,\n    setPreviewRows,\n    setFullRows,\n    setFullRunName,\n    runErrors,\n    runSettings,\n    setRunSettings,\n    previewLoading,\n    fullLoading,\n    currentSignature,\n    executions,\n    selectedExecutionId,\n    setSelectedExecutionId,\n    persistRecipe,\n    openRunDialog,\n    runFromDialog,\n    validateFromDialog,\n    validateLoading,\n    validateResult,\n    cancelExecution,\n    loadExecutionDatasetPage,\n    copyRecipe,\n    importRecipe,\n  } = useRecipeStudioActions({\n    recipeId,\n    initialRecipeName,\n    initialPayload,\n    initialSavedAt,\n    payloadResult,\n    onPersistRecipe,\n    resetRecipe,\n    loadRecipe,\n    getCurrentPayloadFromStore,\n  });\n  const {\n    activeExecution,\n    runtimeVisualState,\n    displayGraph,\n    displayNodeIds,\n    currentColumnIcon,\n  } = useRecipeRuntimeVisuals({\n    executions,\n    configs,\n    nodes,\n    edges,\n    layoutDirection,\n    auxNodePositions,\n    llmAuxVisibility,\n  });\n  const executionLocked = runtimeVisualState.executionLocked;\n  const canvasInteractive = interactive && !executionLocked;\n  const runBusy = previewLoading || fullLoading || executionLocked;\n  const islandExecution = activeExecution ?? recentCompletedExecution;\n\n  const toggleInteractive = useCallback(() => {\n    if (executionLocked) {\n      return;\n    }\n    setInteractive((value) => !value);\n  }, [executionLocked]);\n\n  useEffect(() => {\n    setExecutionLocked(executionLocked);\n  }, [executionLocked, setExecutionLocked]);\n\n  useEffect(() => {\n    const activeExecutionId = activeExecution?.id ?? null;\n    if (\n      activeExecutionId &&\n      activeExecutionId !== previousActiveExecutionIdRef.current\n    ) {\n      setRuntimeIslandMinimized(false);\n    }\n    previousActiveExecutionIdRef.current = activeExecutionId;\n  }, [activeExecution?.id]);\n\n  useEffect(() => {\n    if (activeExecution) {\n      setRecentCompletedExecution(null);\n      return;\n    }\n    const latestCompleted = executions.find(\n      (execution) =>\n        execution.status === \"completed\" &&\n        typeof execution.finishedAt === \"number\",\n    );\n    if (!latestCompleted || typeof latestCompleted.finishedAt !== \"number\") {\n      setRecentCompletedExecution(null);\n      return;\n    }\n    const elapsedMs = Date.now() - latestCompleted.finishedAt;\n    if (elapsedMs >= COMPLETE_ISLAND_VISIBLE_MS) {\n      setRecentCompletedExecution(null);\n      return;\n    }\n    setRecentCompletedExecution(latestCompleted);\n    const hideTimer = window.setTimeout(() => {\n      setRecentCompletedExecution(null);\n      setActiveView((currentView) =>\n        currentView === \"editor\" ? \"executions\" : currentView,\n      );\n    }, COMPLETE_ISLAND_VISIBLE_MS - elapsedMs);\n    return () => {\n      window.clearTimeout(hideTimer);\n    };\n  }, [activeExecution, executions]);\n\n  const openProcessorsFromSheet = useCallback(() => {\n    if (\n      !processors.some(\n        (processor) => processor.processor_type === \"schema_transform\",\n      )\n    ) {\n      setProcessors([...processors, buildDefaultSchemaTransform()]);\n    }\n    setProcessorsOpen(true);\n  }, [processors, setProcessors]);\n\n  const openRootBlockSheet = useCallback(() => {\n    setSheetView(\"root\");\n    setSheetOpen(true);\n  }, [setSheetOpen, setSheetView]);\n  const openSourceBlockSheet = useCallback(() => {\n    setSheetView(\"seed\");\n    setSheetOpen(true);\n  }, [setSheetOpen, setSheetView]);\n  const runDialogRows = runDialogKind === \"preview\" ? previewRows : fullRows;\n  const runDialogLoading =\n    runDialogKind === \"preview\" ? previewLoading : fullLoading;\n\n  const scheduleFitView = useCallback(\n    ({ delayMs = 0 }: { delayMs?: number } = {}) => {\n      if (!reactFlowInstance) {\n        return () => {};\n      }\n\n      let timeoutId = 0;\n      let frameId = 0;\n      let retryFrameId = 0;\n\n      const fitWithCurrentNodes = () => {\n        const targetNodes = getFitNodeIdsIgnoringNotes(\n          reactFlowInstance.getNodes(),\n        );\n        if (targetNodes.length === 0) {\n          return false;\n        }\n        viewportMovedSinceAutoFitRef.current = false;\n        reactFlowInstance.fitView({\n          duration: FIT_ANIMATION_MS,\n          nodes: targetNodes,\n        });\n        return true;\n      };\n\n      const runFit = () => {\n        if (fitWithCurrentNodes()) {\n          return;\n        }\n\n        retryFrameId = window.requestAnimationFrame(() => {\n          fitWithCurrentNodes();\n        });\n      };\n\n      const start = () => {\n        frameId = window.requestAnimationFrame(runFit);\n      };\n\n      if (delayMs > 0) {\n        timeoutId = window.setTimeout(start, delayMs);\n      } else {\n        start();\n      }\n\n      return () => {\n        if (timeoutId) {\n          window.clearTimeout(timeoutId);\n        }\n        if (frameId) {\n          window.cancelAnimationFrame(frameId);\n        }\n        if (retryFrameId) {\n          window.cancelAnimationFrame(retryFrameId);\n        }\n      };\n    },\n    [reactFlowInstance],\n  );\n\n  useEffect(() => {\n    if (\n      previousActiveViewRef.current !== activeView &&\n      activeView === \"editor\"\n    ) {\n      pendingEditorTabFitRef.current = true;\n      forceEditorTabFitRef.current =\n        previousActiveViewRef.current === \"executions\";\n    }\n    previousActiveViewRef.current = activeView;\n  }, [activeView]);\n\n  useEffect(() => {\n    if (activeView !== \"editor\" && reactFlowInstance) {\n      setReactFlowInstance(null);\n    }\n  }, [activeView, reactFlowInstance]);\n\n  useEffect(() => {\n    if (\n      !reactFlowInstance ||\n      activeView !== \"editor\" ||\n      !pendingEditorTabFitRef.current\n    ) {\n      return;\n    }\n    pendingEditorTabFitRef.current = false;\n    const forceFit = forceEditorTabFitRef.current;\n    forceEditorTabFitRef.current = false;\n    if (!(forceFit || viewportMovedSinceAutoFitRef.current)) {\n      return;\n    }\n    return scheduleFitView({ delayMs: TAB_SWITCH_FIT_DELAY_MS });\n  }, [activeView, reactFlowInstance, scheduleFitView]);\n\n  useEffect(() => {\n    if (!reactFlowInstance || fitViewTick === 0 || activeView !== \"editor\") {\n      return;\n    }\n    if (lastProcessedFitTickRef.current === fitViewTick) {\n      return;\n    }\n    lastProcessedFitTickRef.current = fitViewTick;\n    return scheduleFitView();\n  }, [activeView, fitViewTick, reactFlowInstance, scheduleFitView]);\n\n  let editorContent: ReactElement;\n  if (initialRecipeReady) {\n    editorContent = (\n      <ReactFlow<Node<RecipeNodeData | RecipeGraphAuxNodeData>, Edge>\n        onInit={setReactFlowInstance}\n        onDragOver={handleDragOver}\n        onDrop={handleDrop}\n        nodes={displayGraph.nodes}\n        edges={displayGraph.edges}\n        proOptions={{ hideAttribution: true }}\n        nodeTypes={NODE_TYPES}\n        edgeTypes={EDGE_TYPES}\n        defaultEdgeOptions={{\n          type: \"canvas\",\n          data: { path: \"smoothstep\" },\n        }}\n        onNodesChange={handleNodesChange}\n        onEdgesChange={handleEdgesChange}\n        onConnect={onConnect}\n        onNodeClick={handleNodeClick}\n        onNodeDoubleClick={handleNodeDoubleClick}\n        isValidConnection={isValidConnection}\n        onMoveEnd={(event) => {\n          if (event) {\n            viewportMovedSinceAutoFitRef.current = true;\n          }\n        }}\n        nodesDraggable={canvasInteractive}\n        nodesConnectable={canvasInteractive}\n        elementsSelectable={canvasInteractive}\n        fitView={false}\n        className=\"h-full w-full rounded-t-none\"\n      >\n        <LayoutControls\n          direction={layoutDirection}\n          onLayout={applyLayout}\n          onToggleDirection={handleToggleDirection}\n        />\n        <InternalsSync nodeIds={displayNodeIds} />\n        <Background\n          variant={BackgroundVariant.Dots}\n          gap={18}\n          size={1}\n          color=\"#d4d4d8\"\n        />\n        {nodes.length === 0 && (\n          <div className=\"pointer-events-none absolute inset-0 z-10 flex items-center justify-center p-4\">\n            <div className=\"pointer-events-auto w-full max-w-md rounded-2xl border border-dashed border-border/70 bg-background/80 px-6 py-6 text-center shadow-border backdrop-blur-[1px]\">\n              <div className=\"mx-auto flex size-12 items-center justify-center corner-squircle rounded-xl border border-border/70 bg-muted/40\">\n                <HugeiconsIcon\n                  icon={DocumentAttachmentIcon}\n                  className=\"size-6 text-muted-foreground\"\n                />\n              </div>\n              <div className=\"mt-4 space-y-2\">\n                <p className=\"text-[11px] font-semibold uppercase tracking-wide text-primary\">\n                  Best place to start\n                </p>\n                <p className=\"text-sm font-semibold text-foreground\">\n                  Start with source data\n                </p>\n                <p className=\"text-xs text-muted-foreground\">\n                  Most synthetic-data recipes begin with a document, dataset, or\n                  file before adding generation and checks.\n                </p>\n              </div>\n              <div className=\"mt-5 flex flex-col justify-center gap-2 sm:flex-row\">\n                <Button\n                  type=\"button\"\n                  className=\"corner-squircle\"\n                  onClick={openSourceBlockSheet}\n                >\n                  <HugeiconsIcon\n                    icon={DocumentAttachmentIcon}\n                    className=\"size-4\"\n                  />\n                  Start with source data\n                </Button>\n                <Button\n                  type=\"button\"\n                  variant=\"outline\"\n                  className=\"corner-squircle\"\n                  onClick={openRootBlockSheet}\n                >\n                  <HugeiconsIcon icon={PlusSignIcon} className=\"size-4\" />\n                  Browse all steps\n                </Button>\n              </div>\n            </div>\n          </div>\n        )}\n        <Panel position=\"top-right\" className=\"m-3\">\n          <BlockSheet\n            container={sheetContainer}\n            sheetView={sheetView}\n            onViewChange={setSheetView}\n            open={sheetOpen}\n            onOpenChange={setSheetOpen}\n            onAddSampler={handleAddSamplerFromSheet}\n            onAddSeed={handleAddSeedFromSheet}\n            onAddLlm={handleAddLlmFromSheet}\n            onAddModelProvider={handleAddModelProviderFromSheet}\n            onAddModelConfig={handleAddModelConfigFromSheet}\n            onAddToolProfile={handleAddToolProfileFromSheet}\n            onAddExpression={handleAddExpressionFromSheet}\n            onAddValidator={handleAddValidatorFromSheet}\n            onAddMarkdownNote={handleAddMarkdownNoteFromSheet}\n            onOpenProcessors={openProcessorsFromSheet}\n            copied={copied}\n            onCopy={copyRecipe}\n            onImport={() => setImportOpen(true)}\n          />\n        </Panel>\n        <ViewportControls\n          interactive={canvasInteractive}\n          lockDisabled={executionLocked}\n          onToggleInteractive={toggleInteractive}\n        />\n        {islandExecution &&\n          (isExecutionInProgress(islandExecution.status) ||\n            islandExecution.status === \"completed\") && (\n            <Panel position=\"top-center\" className=\"!m-0\">\n              <ExecutionProgressIsland\n                execution={islandExecution}\n                currentColumnIcon={currentColumnIcon}\n                minimized={runtimeIslandMinimized}\n                onMinimizedChange={setRuntimeIslandMinimized}\n                onViewExecutions={() => setActiveView(\"executions\")}\n              />\n            </Panel>\n          )}\n        <RunValidateFloatingControls\n          runBusy={runBusy}\n          runDialogKind={runDialogKind}\n          validateLoading={validateLoading}\n          executionLocked={executionLocked}\n          onOpenRunDialog={openRunDialog}\n          onValidate={() => {\n            openRunDialog(runDialogKind);\n            void validateFromDialog();\n          }}\n        />\n      </ReactFlow>\n    );\n  } else {\n    editorContent = (\n      <div className=\"flex h-full items-center justify-center px-6\">\n        <div className=\"rounded-2xl border border-border/70 bg-background/80 px-5 py-4 text-center shadow-border backdrop-blur-[1px]\">\n          <p className=\"text-sm font-medium text-foreground\">Loading recipe</p>\n          <p className=\"mt-1 text-xs text-muted-foreground\">\n            Restoring the studio graph and saved settings.\n          </p>\n        </div>\n      </div>\n    );\n  }\n\n  return (\n    <div className=\"min-h-screen bg-background\">\n      <main className=\"w-full px-6 py-8\">\n        <div\n          className=\"relative w-full overflow-hidden rounded-2xl corner-squircle border\"\n          ref={setSheetContainer}\n        >\n          <RecipeStudioHeader\n            activeView={activeView}\n            saveLoading={saveLoading}\n            saveTone={saveTone}\n            savedAtLabel={savedAtLabel}\n            workflowName={workflowName}\n            warnings={getGraphWarnings(configs, edges)}\n            onWorkflowNameChange={setWorkflowName}\n            onViewChange={setActiveView}\n            onSaveRecipe={() => {\n              void persistRecipe();\n            }}\n          />\n          <div\n            className=\"h-[75vh] w-full rounded-t-none\"\n            ref={flowContainerRef}\n          >\n            {activeView === \"editor\" ? (\n              editorContent\n            ) : (\n              <ExecutionsView\n                executions={executions}\n                selectedExecutionId={selectedExecutionId}\n                currentSignature={currentSignature}\n                onSelectExecution={setSelectedExecutionId}\n                onCancelExecution={(executionId) => {\n                  void cancelExecution(executionId);\n                }}\n                onLoadDatasetPage={(executionId, page) => {\n                  void loadExecutionDatasetPage(executionId, page);\n                }}\n              />\n            )}\n          </div>\n        </div>\n      </main>\n      <ConfigDialog\n        open={dialogOpen}\n        onOpenChange={setDialogOpen}\n        config={config}\n        readOnly={executionLocked}\n        categoryOptions={dialogOptions.categoryOptions}\n        modelConfigAliases={dialogOptions.modelConfigAliases}\n        modelProviderOptions={dialogOptions.modelProviderOptions}\n        toolProfileAliases={dialogOptions.toolProfileAliases}\n        datetimeOptions={dialogOptions.datetimeOptions}\n        onUpdate={updateConfig}\n        container={sheetContainer}\n      />\n      <ImportDialog\n        open={importOpen}\n        onOpenChange={setImportOpen}\n        onImport={importRecipe}\n        container={sheetContainer}\n      />\n      <ProcessorsDialog\n        open={processorsOpen}\n        onOpenChange={setProcessorsOpen}\n        processors={processors}\n        onProcessorsChange={setProcessors}\n        container={sheetContainer}\n      />\n      <RunDialog\n        open={runDialogOpen}\n        onOpenChange={setRunDialogOpen}\n        kind={runDialogKind}\n        onKindChange={setRunDialogKind}\n        rows={runDialogRows}\n        fullRunName={fullRunName}\n        onFullRunNameChange={setFullRunName}\n        onRowsChange={(rows) => {\n          if (runDialogKind === \"preview\") {\n            setPreviewRows(rows);\n            return;\n          }\n          setFullRows(rows);\n        }}\n        settings={runSettings}\n        onSettingsChange={setRunSettings}\n        loading={runDialogLoading}\n        validateLoading={validateLoading}\n        validateResult={validateResult}\n        errors={runErrors}\n        onValidate={() => {\n          void validateFromDialog();\n        }}\n        onRun={() => {\n          void runFromDialog();\n        }}\n        container={sheetContainer}\n      />\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/stores/helpers/edge-sync.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { type Edge, addEdge } from \"@xyflow/react\";\nimport type {\n  LayoutDirection,\n  ModelConfig,\n  NodeConfig,\n  SamplerConfig,\n  ValidatorConfig,\n} from \"../../types\";\nimport { applyRecipeConnection } from \"../../utils/graph\";\nimport { isCategoryConfig, isSubcategoryConfig } from \"../../utils\";\nimport { HANDLE_IDS } from \"../../utils/handles\";\n\nfunction findNodeIdByName(\n  configs: Record<string, NodeConfig>,\n  name: string,\n): string | null {\n  const entry = Object.entries(configs).find(\n    ([, config]) => config.name === name,\n  );\n  return entry ? entry[0] : null;\n}\n\nfunction addRecipeEdge(edges: Edge[], source: string, target: string): Edge[] {\n  return addEdge(\n    {\n      source,\n      target,\n      sourceHandle: HANDLE_IDS.dataOut,\n      targetHandle: HANDLE_IDS.dataIn,\n      type: \"canvas\",\n    },\n    edges,\n  );\n}\n\nfunction addValidatorSemanticEdge(\n  edges: Edge[],\n  source: string,\n  target: string,\n): Edge[] {\n  return addEdge(\n    {\n      source,\n      target,\n      sourceHandle: HANDLE_IDS.dataOut,\n      targetHandle: HANDLE_IDS.dataIn,\n      type: \"semantic\",\n    },\n    edges,\n  );\n}\n\nfunction removeTargetEdges(edges: Edge[], targetId: string): Edge[] {\n  return edges.filter((edge) => edge.target !== targetId);\n}\n\nfunction removeTargetEdgesBySource(\n  edges: Edge[],\n  configs: Record<string, NodeConfig>,\n  targetId: string,\n  shouldRemove: (source: NodeConfig | undefined) => boolean,\n): Edge[] {\n  return edges.filter((edge) => {\n    if (edge.target !== targetId) {\n      return true;\n    }\n    return !shouldRemove(configs[edge.source]);\n  });\n}\n\nexport function syncEdgesForConfigPatch(\n  current: NodeConfig,\n  patch: Partial<NodeConfig>,\n  configs: Record<string, NodeConfig>,\n  edges: Edge[],\n  layoutDirection: LayoutDirection,\n): Edge[] {\n  let nextEdges = edges;\n\n  const hasParentPatch = Object.prototype.hasOwnProperty.call(\n    patch,\n    \"subcategory_parent\",\n  );\n  if (isSubcategoryConfig(current) && hasParentPatch) {\n    const nextParent = (patch as Partial<SamplerConfig>).subcategory_parent ?? \"\";\n    const parentId = nextParent ? findNodeIdByName(configs, nextParent) : null;\n    nextEdges = removeTargetEdges(nextEdges, current.id);\n    if (parentId) {\n      nextEdges = addRecipeEdge(nextEdges, parentId, current.id);\n    }\n  }\n\n  const hasProviderPatch = Object.prototype.hasOwnProperty.call(\n    patch,\n    \"provider\",\n  );\n  if (current.kind === \"model_config\" && hasProviderPatch) {\n    const nextProvider = (patch as Partial<ModelConfig>).provider ?? \"\";\n    if (nextProvider.trim() === current.provider.trim()) {\n      return nextEdges;\n    }\n    nextEdges = removeTargetEdgesBySource(\n      nextEdges,\n      configs,\n      current.id,\n      (source) => Boolean(source && source.kind === \"model_provider\"),\n    );\n    if (nextProvider) {\n      const providerId = findNodeIdByName(configs, nextProvider);\n      if (providerId) {\n        const result = applyRecipeConnection(\n          {\n            source: providerId,\n            sourceHandle: HANDLE_IDS.semanticOut,\n            target: current.id,\n            targetHandle: HANDLE_IDS.semanticIn,\n          },\n          configs,\n          nextEdges,\n          layoutDirection,\n        );\n        nextEdges = result.edges;\n      }\n    }\n  }\n\n  const hasReferencePatch = Object.prototype.hasOwnProperty.call(\n    patch,\n    \"reference_column_name\",\n  );\n  if (\n    current.kind === \"sampler\" &&\n    current.sampler_type === \"timedelta\" &&\n    hasReferencePatch\n  ) {\n    const nextReference =\n      (patch as Partial<SamplerConfig>).reference_column_name ?? \"\";\n    nextEdges = removeTargetEdgesBySource(\n      nextEdges,\n      configs,\n      current.id,\n      (source) =>\n        Boolean(\n          source &&\n            source.kind === \"sampler\" &&\n            source.sampler_type === \"datetime\",\n        ),\n    );\n    if (nextReference) {\n      const referenceId = findNodeIdByName(configs, nextReference);\n      const source = referenceId ? configs[referenceId] : null;\n      if (\n        referenceId &&\n        source &&\n        source.kind === \"sampler\" &&\n        source.sampler_type === \"datetime\"\n      ) {\n        nextEdges = addRecipeEdge(nextEdges, referenceId, current.id);\n      }\n    }\n  }\n\n  const hasModelAliasPatch = Object.prototype.hasOwnProperty.call(\n    patch,\n    \"model_alias\",\n  );\n  if (current.kind === \"llm\" && hasModelAliasPatch) {\n    const nextAlias =\n      (patch as Partial<NodeConfig> & { model_alias?: string }).model_alias ?? \"\";\n    if (nextAlias.trim() === current.model_alias.trim()) {\n      return nextEdges;\n    }\n    nextEdges = removeTargetEdgesBySource(\n      nextEdges,\n      configs,\n      current.id,\n      (source) => Boolean(source && source.kind === \"model_config\"),\n    );\n    if (nextAlias) {\n      const modelConfigId = findNodeIdByName(configs, nextAlias);\n      if (modelConfigId) {\n        const result = applyRecipeConnection(\n          {\n            source: modelConfigId,\n            sourceHandle: HANDLE_IDS.semanticOut,\n            target: current.id,\n            targetHandle: HANDLE_IDS.semanticIn,\n          },\n          configs,\n          nextEdges,\n          layoutDirection,\n        );\n        nextEdges = result.edges;\n      }\n    }\n  }\n\n  const hasToolAliasPatch = Object.prototype.hasOwnProperty.call(\n    patch,\n    \"tool_alias\",\n  );\n  if (current.kind === \"llm\" && hasToolAliasPatch) {\n    const nextAlias =\n      (patch as Partial<NodeConfig> & { tool_alias?: string }).tool_alias ?? \"\";\n    if (nextAlias.trim() === (current.tool_alias ?? \"\").trim()) {\n      return nextEdges;\n    }\n    nextEdges = removeTargetEdgesBySource(\n      nextEdges,\n      configs,\n      current.id,\n      (source) => Boolean(source && source.kind === \"tool_config\"),\n    );\n    if (nextAlias) {\n      const toolConfigId = findNodeIdByName(configs, nextAlias);\n      if (toolConfigId) {\n        const result = applyRecipeConnection(\n          {\n            source: toolConfigId,\n            sourceHandle: HANDLE_IDS.semanticOut,\n            target: current.id,\n            targetHandle: HANDLE_IDS.semanticIn,\n          },\n          configs,\n          nextEdges,\n          layoutDirection,\n        );\n        nextEdges = result.edges;\n      }\n    }\n  }\n\n  const hasValidatorTargetsPatch = Object.prototype.hasOwnProperty.call(\n    patch,\n    \"target_columns\",\n  );\n  if (current.kind === \"validator\" && hasValidatorTargetsPatch) {\n    const nextTargets =\n      ((patch as Partial<ValidatorConfig>).target_columns ?? [])\n        .map((value) => value.trim())\n        .filter(Boolean);\n    nextEdges = nextEdges.filter((edge) => {\n      if (edge.source !== current.id && edge.target !== current.id) {\n        return true;\n      }\n      const otherId = edge.source === current.id ? edge.target : edge.source;\n      const other = configs[otherId];\n      return !(\n        other &&\n        other.kind === \"llm\" &&\n        other.llm_type === \"code\"\n      );\n    });\n    const nextTargetName = nextTargets[0];\n    if (nextTargetName) {\n      const targetId = findNodeIdByName(configs, nextTargetName);\n      const target = targetId ? configs[targetId] : null;\n      if (\n        targetId &&\n        target &&\n        target.kind === \"llm\" &&\n        target.llm_type === \"code\"\n      ) {\n        nextEdges = addValidatorSemanticEdge(nextEdges, targetId, current.id);\n      }\n    }\n  }\n\n  return nextEdges;\n}\n\nexport function syncSubcategoryConfigsForCategoryUpdate(\n  current: NodeConfig,\n  next: NodeConfig,\n  configs: Record<string, NodeConfig>,\n  oldName: string,\n  newName: string,\n  nameChanged: boolean,\n): Record<string, NodeConfig> {\n  if (!isCategoryConfig(current)) {\n    return configs;\n  }\n  const nextCategory = isCategoryConfig(next) ? next : current;\n  const oldValues = current.values ?? [];\n  const newValues = nextCategory.values ?? [];\n  const valuesChanged =\n    oldValues.length !== newValues.length ||\n    oldValues.some((value, index) => value !== newValues[index]);\n\n  let nextConfigs = configs;\n  for (const config of Object.values(configs)) {\n    if (!isSubcategoryConfig(config)) {\n      continue;\n    }\n    if (config.subcategory_parent !== oldName) {\n      continue;\n    }\n    const mapping = config.subcategory_mapping ?? {};\n    const nextMapping: Record<string, string[]> = {};\n    for (const value of newValues) {\n      nextMapping[value] = mapping[value] ?? [];\n    }\n    const updated: NodeConfig = {\n      ...config,\n      // biome-ignore lint/style/useNamingConvention: api schema\n      subcategory_parent: nameChanged ? newName : config.subcategory_parent,\n      // biome-ignore lint/style/useNamingConvention: api schema\n      subcategory_mapping: valuesChanged ? nextMapping : mapping,\n    };\n    nextConfigs = { ...nextConfigs, [config.id]: updated };\n  }\n  return nextConfigs;\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/stores/helpers/model-infra-layout.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { Edge, XYPosition } from \"@xyflow/react\";\nimport { DEFAULT_NODE_HEIGHT, DEFAULT_NODE_WIDTH } from \"../../constants\";\nimport type { LayoutDirection, NodeConfig, RecipeNode } from \"../../types\";\nimport { HANDLE_IDS, normalizeRecipeHandleId } from \"../../utils/handles\";\nimport { readNodeHeight, readNodeWidth } from \"../../utils/rf-node-dimensions\";\n\ntype Rect = {\n  x: number;\n  y: number;\n  width: number;\n  height: number;\n};\n\ntype Bounds = {\n  minX: number;\n  maxX: number;\n  minY: number;\n  maxY: number;\n};\n\nfunction toRect(node: RecipeNode): Rect {\n  return {\n    x: node.position.x,\n    y: node.position.y,\n    width: readNodeWidth(node) ?? DEFAULT_NODE_WIDTH,\n    height: readNodeHeight(node) ?? DEFAULT_NODE_HEIGHT,\n  };\n}\n\nfunction intersects(a: Rect, b: Rect, pad = 18): boolean {\n  return !(\n    a.x + a.width + pad <= b.x ||\n    b.x + b.width + pad <= a.x ||\n    a.y + a.height + pad <= b.y ||\n    b.y + b.height + pad <= a.y\n  );\n}\n\nfunction findNonOverlappingPosition(\n  preferred: XYPosition,\n  width: number,\n  height: number,\n  occupied: Rect[],\n): XYPosition {\n  const step = 24;\n  for (let ring = 0; ring <= 16; ring += 1) {\n    for (let dx = -ring; dx <= ring; dx += 1) {\n      for (let dy = -ring; dy <= ring; dy += 1) {\n        if (ring > 0 && Math.max(Math.abs(dx), Math.abs(dy)) !== ring) {\n          continue;\n        }\n        const candidate = {\n          x: preferred.x + dx * step,\n          y: preferred.y + dy * step,\n        };\n        const rect = {\n          x: candidate.x,\n          y: candidate.y,\n          width,\n          height,\n        };\n        if (!occupied.some((item) => intersects(rect, item))) {\n          return candidate;\n        }\n      }\n    }\n  }\n  return preferred;\n}\n\nfunction isProviderToConfigEdge(\n  edge: Edge,\n  configs: Record<string, NodeConfig>,\n): boolean {\n  const source = configs[edge.source];\n  const target = configs[edge.target];\n  return source?.kind === \"model_provider\" && target?.kind === \"model_config\";\n}\n\nfunction isConfigToLlmEdge(\n  edge: Edge,\n  configs: Record<string, NodeConfig>,\n): boolean {\n  const source = configs[edge.source];\n  const target = configs[edge.target];\n  return source?.kind === \"model_config\" && target?.kind === \"llm\";\n}\n\nfunction isToolConfigToLlmEdge(\n  edge: Edge,\n  configs: Record<string, NodeConfig>,\n): boolean {\n  const source = configs[edge.source];\n  const target = configs[edge.target];\n  return source?.kind === \"tool_config\" && target?.kind === \"llm\";\n}\n\nfunction usageKey(nodeId: string, handleId: string): string {\n  return `${nodeId}::${handleId}`;\n}\n\nfunction incrementUsage(\n  map: Map<string, number>,\n  nodeId: string,\n  handleId: string,\n): void {\n  const key = usageKey(nodeId, handleId);\n  map.set(key, (map.get(key) ?? 0) + 1);\n}\n\nfunction decrementUsage(\n  map: Map<string, number>,\n  nodeId: string,\n  handleId: string,\n): void {\n  const key = usageKey(nodeId, handleId);\n  map.set(key, Math.max(0, (map.get(key) ?? 0) - 1));\n}\n\nfunction getUsage(\n  map: Map<string, number>,\n  nodeId: string,\n  handleId: string,\n): number {\n  return map.get(usageKey(nodeId, handleId)) ?? 0;\n}\n\nfunction pickHandleByUsage(\n  candidates: string[],\n  nodeId: string,\n  usageMap: Map<string, number>,\n): string {\n  const free = candidates.filter(\n    (handleId) => getUsage(usageMap, nodeId, handleId) === 0,\n  );\n  if (free.length > 0) {\n    return free[0];\n  }\n  let bestHandle = candidates[0];\n  let bestCount = Number.POSITIVE_INFINITY;\n  for (const handleId of candidates) {\n    const count = getUsage(usageMap, nodeId, handleId);\n    if (count < bestCount) {\n      bestHandle = handleId;\n      bestCount = count;\n    }\n  }\n  return bestHandle;\n}\n\nfunction applyEdgeWithHandles(\n  edge: Edge,\n  sourceHandle: string,\n  targetHandle: string,\n  sourceUsage: Map<string, number>,\n  targetUsage: Map<string, number>,\n): Edge {\n  incrementUsage(sourceUsage, edge.source, sourceHandle);\n  incrementUsage(targetUsage, edge.target, targetHandle);\n  return { ...edge, sourceHandle, targetHandle, type: \"semantic\" };\n}\n\nfunction getNodeCenter(node: RecipeNode): { x: number; y: number } {\n  const width = readNodeWidth(node) ?? DEFAULT_NODE_WIDTH;\n  const height = readNodeHeight(node) ?? DEFAULT_NODE_HEIGHT;\n  return {\n    x: node.position.x + width / 2,\n    y: node.position.y + height / 2,\n  };\n}\n\nfunction collectBounds(\n  ids: string[],\n  nodesById: Map<string, RecipeNode>,\n): Bounds | null {\n  const rects = ids\n    .map((id) => nodesById.get(id))\n    .flatMap((node) => (node ? [toRect(node)] : []));\n  if (rects.length === 0) {\n    return null;\n  }\n  return rects.reduce<Bounds>(\n    (acc, rect) => ({\n      minX: Math.min(acc.minX, rect.x),\n      maxX: Math.max(acc.maxX, rect.x + rect.width),\n      minY: Math.min(acc.minY, rect.y),\n      maxY: Math.max(acc.maxY, rect.y + rect.height),\n    }),\n    {\n      minX: rects[0].x,\n      maxX: rects[0].x + rects[0].width,\n      minY: rects[0].y,\n      maxY: rects[0].y + rects[0].height,\n    },\n  );\n}\n\nfunction sortPreferredLlmTargetHandles(\n  direction: LayoutDirection,\n  sourceNode: RecipeNode | undefined,\n  targetNode: RecipeNode | undefined,\n): string[] {\n  const sourceCenter = sourceNode ? getNodeCenter(sourceNode) : { x: 0, y: 0 };\n  const targetCenter = targetNode ? getNodeCenter(targetNode) : { x: 0, y: 0 };\n\n  if (direction === \"TB\") {\n    const horizontalFirst =\n      sourceCenter.x <= targetCenter.x\n        ? [HANDLE_IDS.dataIn, HANDLE_IDS.dataInRight]\n        : [HANDLE_IDS.dataInRight, HANDLE_IDS.dataIn];\n    return [...horizontalFirst, HANDLE_IDS.dataInTop, HANDLE_IDS.dataInBottom];\n  }\n\n  const verticalFirst =\n    sourceCenter.y <= targetCenter.y\n      ? [HANDLE_IDS.dataInTop, HANDLE_IDS.dataInBottom]\n      : [HANDLE_IDS.dataInBottom, HANDLE_IDS.dataInTop];\n  return [...verticalFirst, HANDLE_IDS.dataIn, HANDLE_IDS.dataInRight];\n}\n\nfunction getProviderSourceHandleCandidates(\n  direction: LayoutDirection,\n): string[] {\n  return direction === \"TB\"\n    ? [HANDLE_IDS.semanticOut, HANDLE_IDS.semanticOutBottom]\n    : [HANDLE_IDS.semanticOutBottom, HANDLE_IDS.semanticOut];\n}\n\nfunction getProviderTargetHandleCandidates(\n  direction: LayoutDirection,\n): string[] {\n  return direction === \"TB\"\n    ? [HANDLE_IDS.semanticIn, HANDLE_IDS.semanticInTop]\n    : [HANDLE_IDS.semanticInTop, HANDLE_IDS.semanticIn];\n}\n\nfunction getConfigSourceHandleCandidates(direction: LayoutDirection): string[] {\n  return direction === \"TB\"\n    ? [HANDLE_IDS.semanticOut]\n    : [HANDLE_IDS.semanticOutBottom];\n}\n\nexport function optimizeModelInfraEdgeHandles(\n  edges: Edge[],\n  nodes: RecipeNode[],\n  configs: Record<string, NodeConfig>,\n  direction: LayoutDirection,\n): Edge[] {\n  const nodesById = new Map(nodes.map((node) => [node.id, node] as const));\n  const sourceUsage = new Map<string, number>();\n  const targetUsage = new Map<string, number>();\n\n  for (const edge of edges) {\n    const sourceHandle = normalizeRecipeHandleId(edge.sourceHandle);\n    const targetHandle = normalizeRecipeHandleId(edge.targetHandle);\n    if (sourceHandle) {\n      incrementUsage(sourceUsage, edge.source, sourceHandle);\n    }\n    if (targetHandle) {\n      incrementUsage(targetUsage, edge.target, targetHandle);\n    }\n  }\n\n  const nextEdges: Edge[] = [];\n  for (const edge of edges) {\n    const source = configs[edge.source];\n    const target = configs[edge.target];\n    if (!(source && target)) {\n      nextEdges.push(edge);\n      continue;\n    }\n\n    const sourceHandleBefore = normalizeRecipeHandleId(edge.sourceHandle);\n    const targetHandleBefore = normalizeRecipeHandleId(edge.targetHandle);\n    const isSemanticInfra =\n      isProviderToConfigEdge(edge, configs) ||\n      isConfigToLlmEdge(edge, configs) ||\n      isToolConfigToLlmEdge(edge, configs);\n    if (!isSemanticInfra) {\n      nextEdges.push(edge);\n      continue;\n    }\n\n    if (sourceHandleBefore) {\n      decrementUsage(sourceUsage, edge.source, sourceHandleBefore);\n    }\n    if (targetHandleBefore) {\n      decrementUsage(targetUsage, edge.target, targetHandleBefore);\n    }\n\n    if (isProviderToConfigEdge(edge, configs)) {\n      const sourceCandidates = getProviderSourceHandleCandidates(direction);\n      const targetCandidates = getProviderTargetHandleCandidates(direction);\n      const sourceHandle = pickHandleByUsage(\n        sourceCandidates,\n        edge.source,\n        sourceUsage,\n      );\n      const targetHandle = pickHandleByUsage(\n        targetCandidates,\n        edge.target,\n        targetUsage,\n      );\n      nextEdges.push(\n        applyEdgeWithHandles(\n          edge,\n          sourceHandle,\n          targetHandle,\n          sourceUsage,\n          targetUsage,\n        ),\n      );\n      continue;\n    }\n\n    const targetCandidates = sortPreferredLlmTargetHandles(\n      direction,\n      nodesById.get(edge.source),\n      nodesById.get(edge.target),\n    );\n    const sourceCandidates = getConfigSourceHandleCandidates(direction);\n    const sourceHandle = pickHandleByUsage(\n      sourceCandidates,\n      edge.source,\n      sourceUsage,\n    );\n    const targetHandle = pickHandleByUsage(\n      targetCandidates,\n      edge.target,\n      targetUsage,\n    );\n    nextEdges.push(\n      applyEdgeWithHandles(\n        edge,\n        sourceHandle,\n        targetHandle,\n        sourceUsage,\n        targetUsage,\n      ),\n    );\n  }\n\n  return nextEdges;\n}\n\nexport function centerModelInfraNodes(\n  nodes: RecipeNode[],\n  edges: Edge[],\n  configs: Record<string, NodeConfig>,\n  direction: LayoutDirection,\n): RecipeNode[] {\n  const nodesById = new Map(nodes.map((node) => [node.id, node] as const));\n  const configToLlmIds = new Map<string, string[]>();\n  const toolConfigToLlmIds = new Map<string, string[]>();\n  const providerToConfigIds = new Map<string, string[]>();\n\n  for (const edge of edges) {\n    if (isProviderToConfigEdge(edge, configs)) {\n      const entries = providerToConfigIds.get(edge.source) ?? [];\n      if (!entries.includes(edge.target)) {\n        entries.push(edge.target);\n      }\n      providerToConfigIds.set(edge.source, entries);\n      continue;\n    }\n    if (isConfigToLlmEdge(edge, configs)) {\n      const entries = configToLlmIds.get(edge.source) ?? [];\n      if (!entries.includes(edge.target)) {\n        entries.push(edge.target);\n      }\n      configToLlmIds.set(edge.source, entries);\n      continue;\n    }\n    if (isToolConfigToLlmEdge(edge, configs)) {\n      const entries = toolConfigToLlmIds.get(edge.source) ?? [];\n      if (!entries.includes(edge.target)) {\n        entries.push(edge.target);\n      }\n      toolConfigToLlmIds.set(edge.source, entries);\n    }\n  }\n\n  const modelConfigIds = Object.values(configs)\n    .filter(\n      (config) => config.kind === \"model_config\" && nodesById.has(config.id),\n    )\n    .map((config) => config.id);\n  const modelProviderIds = Object.values(configs)\n    .filter(\n      (config) => config.kind === \"model_provider\" && nodesById.has(config.id),\n    )\n    .map((config) => config.id);\n  const toolConfigIds = Object.values(configs)\n    .filter((config) => config.kind === \"tool_config\" && nodesById.has(config.id))\n    .map((config) => config.id);\n\n  const occupiedById = new Map(\n    nodes.map((node) => [node.id, toRect(node)] as const),\n  );\n  const clusterGap = 72;\n\n  const placeNode = (nodeId: string, preferred: XYPosition): void => {\n    const currentNode = nodesById.get(nodeId);\n    if (!currentNode) {\n      return;\n    }\n    const width = readNodeWidth(currentNode) ?? DEFAULT_NODE_WIDTH;\n    const height = readNodeHeight(currentNode) ?? DEFAULT_NODE_HEIGHT;\n    occupiedById.delete(nodeId);\n    const position = findNonOverlappingPosition(\n      preferred,\n      width,\n      height,\n      Array.from(occupiedById.values()),\n    );\n    const nextNode = { ...currentNode, position };\n    nodesById.set(nodeId, nextNode);\n    occupiedById.set(nodeId, {\n      x: position.x,\n      y: position.y,\n      width,\n      height,\n    });\n  };\n\n  for (const modelConfigId of modelConfigIds) {\n    const llmIds = configToLlmIds.get(modelConfigId) ?? [];\n    const targetBounds = collectBounds(llmIds, nodesById);\n    const modelConfigNode = nodesById.get(modelConfigId);\n    if (!(targetBounds && modelConfigNode)) {\n      continue;\n    }\n    const width = readNodeWidth(modelConfigNode) ?? DEFAULT_NODE_WIDTH;\n    const height = readNodeHeight(modelConfigNode) ?? DEFAULT_NODE_HEIGHT;\n    const preferred =\n      direction === \"LR\"\n        ? {\n            x: (targetBounds.minX + targetBounds.maxX) / 2 - width / 2,\n            y: targetBounds.minY - height - clusterGap,\n          }\n        : {\n            x: targetBounds.minX - width - clusterGap,\n            y: (targetBounds.minY + targetBounds.maxY) / 2 - height / 2,\n          };\n    placeNode(modelConfigId, preferred);\n  }\n\n  for (const modelProviderId of modelProviderIds) {\n    const configIds = providerToConfigIds.get(modelProviderId) ?? [];\n    const targetBounds = collectBounds(configIds, nodesById);\n    const modelProviderNode = nodesById.get(modelProviderId);\n    if (!(targetBounds && modelProviderNode)) {\n      continue;\n    }\n    const width = readNodeWidth(modelProviderNode) ?? DEFAULT_NODE_WIDTH;\n    const height = readNodeHeight(modelProviderNode) ?? DEFAULT_NODE_HEIGHT;\n    const preferred =\n      direction === \"LR\"\n        ? {\n            x: (targetBounds.minX + targetBounds.maxX) / 2 - width / 2,\n            y: targetBounds.minY - height - clusterGap,\n          }\n        : {\n            x: targetBounds.minX - width - clusterGap,\n            y: (targetBounds.minY + targetBounds.maxY) / 2 - height / 2,\n          };\n    placeNode(modelProviderId, preferred);\n  }\n\n  for (const toolConfigId of toolConfigIds) {\n    const llmIds = toolConfigToLlmIds.get(toolConfigId) ?? [];\n    const targetBounds = collectBounds(llmIds, nodesById);\n    const toolConfigNode = nodesById.get(toolConfigId);\n    if (!(targetBounds && toolConfigNode)) {\n      continue;\n    }\n    const width = readNodeWidth(toolConfigNode) ?? DEFAULT_NODE_WIDTH;\n    const height = readNodeHeight(toolConfigNode) ?? DEFAULT_NODE_HEIGHT;\n    const preferred =\n      direction === \"LR\"\n        ? {\n            x: (targetBounds.minX + targetBounds.maxX) / 2 - width / 2,\n            y: targetBounds.minY - height - clusterGap,\n          }\n        : {\n            x: targetBounds.minX - width - clusterGap,\n            y: (targetBounds.minY + targetBounds.maxY) / 2 - height / 2,\n          };\n    placeNode(toolConfigId, preferred);\n  }\n\n  return nodes.map((node) => nodesById.get(node.id) ?? node);\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/stores/helpers/node-updates.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { XYPosition } from \"@xyflow/react\";\nimport { DEFAULT_NODE_WIDTH } from \"../../constants\";\nimport type {\n  RecipeNode,\n  LayoutDirection,\n  NodeConfig,\n} from \"../../types\";\nimport { nodeDataFromConfig } from \"../../utils\";\nimport { getConfigUiMode } from \"../../components/inline/inline-policy\";\n\nexport type NodeUpdateState = {\n  configs: Record<string, NodeConfig>;\n  nodes: RecipeNode[];\n  nextId: number;\n  nextY: number;\n};\n\nexport type NodeUpdateResult = {\n  configs: Record<string, NodeConfig>;\n  nodes: RecipeNode[];\n  nextId: number;\n  nextY: number;\n  activeConfigId: string;\n  dialogOpen: boolean;\n};\n\nexport function updateNodeData(\n  nodes: RecipeNode[],\n  id: string,\n  config: NodeConfig,\n  layoutDirection: LayoutDirection,\n): RecipeNode[] {\n  return nodes.map((node) =>\n    node.id === id\n      ? { ...node, data: nodeDataFromConfig(config, layoutDirection) }\n      : node,\n  );\n}\n\nexport function buildNodeUpdate(\n  state: NodeUpdateState,\n  config: NodeConfig,\n  layoutDirection: LayoutDirection,\n  position?: XYPosition,\n  openDialog = true,\n): NodeUpdateResult {\n  const node: RecipeNode = {\n    id: config.id,\n    type: \"builder\",\n    position: position ?? { x: 0, y: state.nextY },\n    data: nodeDataFromConfig(config, layoutDirection),\n    style: { width: DEFAULT_NODE_WIDTH },\n    selected: true,\n  };\n  const mode = getConfigUiMode(config);\n  return {\n    configs: { ...state.configs, [config.id]: config },\n    nodes: [...state.nodes.map((item) => ({ ...item, selected: false })), node],\n    nextId: state.nextId + 1,\n    nextY: position ? state.nextY : state.nextY + 140,\n    activeConfigId: config.id,\n    dialogOpen: openDialog && mode === \"dialog\",\n  };\n}\n\nexport function applyLayoutDirectionToNodes(\n  nodes: RecipeNode[],\n  configs: Record<string, NodeConfig>,\n  layoutDirection: LayoutDirection,\n): RecipeNode[] {\n  return nodes.map((node) => {\n    const config = configs[node.id];\n    if (config) {\n      return { ...node, data: nodeDataFromConfig(config, layoutDirection) };\n    }\n    return {\n      ...node,\n      data: { ...node.data, layoutDirection },\n    };\n  });\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/stores/helpers/reference-sync.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type {\n  LlmConfig,\n  ModelConfig,\n  NodeConfig,\n  SamplerConfig,\n} from \"../../types\";\nimport { removeRef, replaceRef } from \"../../utils/refs\";\n\nfunction updateTemplateFields(\n  config: NodeConfig,\n  updater: (value: string) => string,\n): NodeConfig {\n  if (config.kind === \"llm\") {\n    const nextPrompt = updater(config.prompt);\n    const nextSystem = updater(config.system_prompt);\n    const nextOutput =\n      typeof config.output_format === \"string\"\n        ? updater(config.output_format)\n        : config.output_format;\n    if (\n      nextPrompt === config.prompt &&\n      nextSystem === config.system_prompt &&\n      nextOutput === config.output_format\n    ) {\n      return config;\n    }\n    return {\n      ...config,\n      prompt: nextPrompt,\n      // biome-ignore lint/style/useNamingConvention: api schema\n      system_prompt: nextSystem,\n      // biome-ignore lint/style/useNamingConvention: api schema\n      output_format: nextOutput,\n    };\n  }\n  if (config.kind === \"expression\") {\n    const nextExpr = updater(config.expr);\n    if (nextExpr === config.expr) {\n      return config;\n    }\n    return { ...config, expr: nextExpr };\n  }\n  return config;\n}\n\nexport function applyRenameToConfig(\n  config: NodeConfig,\n  from: string,\n  to: string,\n): NodeConfig {\n  let next = updateTemplateFields(config, (value) =>\n    replaceRef(value, from, to),\n  );\n  if (\n    config.kind === \"sampler\" &&\n    config.sampler_type === \"subcategory\" &&\n    config.subcategory_parent === from\n  ) {\n    const base = next as SamplerConfig;\n    next = {\n      ...base,\n      // biome-ignore lint/style/useNamingConvention: api schema\n      subcategory_parent: to,\n    };\n  }\n  if (\n    config.kind === \"sampler\" &&\n    config.sampler_type === \"timedelta\" &&\n    config.reference_column_name === from\n  ) {\n    const base = next as SamplerConfig;\n    next = {\n      ...base,\n      // biome-ignore lint/style/useNamingConvention: api schema\n      reference_column_name: to,\n    };\n  }\n  if (config.kind === \"model_config\" && config.provider === from) {\n    const base = next as ModelConfig;\n    next = { ...base, provider: to };\n  }\n  if (config.kind === \"llm\" && config.model_alias === from) {\n    const base = next as LlmConfig;\n    next = { ...base, model_alias: to };\n  }\n  if (config.kind === \"llm\" && config.tool_alias === from) {\n    const base = next as LlmConfig;\n    next = { ...base, tool_alias: to };\n  }\n  if (config.kind === \"validator\") {\n    const targets = config.target_columns ?? [];\n    if (targets.includes(from)) {\n      const base = next as typeof config;\n      next = {\n        ...base,\n        // biome-ignore lint/style/useNamingConvention: api schema\n        target_columns: targets.map((target) => (target === from ? to : target)),\n      };\n    }\n  }\n  return next;\n}\n\nexport function applyRemovalToConfig(\n  config: NodeConfig,\n  ref: string,\n): NodeConfig {\n  let next = updateTemplateFields(config, (value) => removeRef(value, ref));\n  if (\n    config.kind === \"sampler\" &&\n    config.sampler_type === \"subcategory\" &&\n    config.subcategory_parent === ref\n  ) {\n    const base = next as SamplerConfig;\n    next = {\n      ...base,\n      // biome-ignore lint/style/useNamingConvention: api schema\n      subcategory_parent: \"\",\n      // biome-ignore lint/style/useNamingConvention: api schema\n      subcategory_mapping: {},\n    };\n  }\n  if (\n    config.kind === \"sampler\" &&\n    config.sampler_type === \"timedelta\" &&\n    config.reference_column_name === ref\n  ) {\n    const base = next as SamplerConfig;\n    next = {\n      ...base,\n      // biome-ignore lint/style/useNamingConvention: api schema\n      reference_column_name: \"\",\n    };\n  }\n  if (config.kind === \"model_config\" && config.provider === ref) {\n    const base = next as ModelConfig;\n    next = { ...base, provider: \"\" };\n  }\n  if (config.kind === \"llm\" && config.model_alias === ref) {\n    const base = next as LlmConfig;\n    next = { ...base, model_alias: \"\" };\n  }\n  if (config.kind === \"llm\" && config.tool_alias === ref) {\n    const base = next as LlmConfig;\n    next = { ...base, tool_alias: \"\" };\n  }\n  if (config.kind === \"validator\") {\n    const targets = (config.target_columns ?? []).filter((target) => target !== ref);\n    if (targets.length !== (config.target_columns ?? []).length) {\n      const base = next as typeof config;\n      next = {\n        ...base,\n        // biome-ignore lint/style/useNamingConvention: api schema\n        target_columns: targets,\n      };\n    }\n  }\n  return next;\n}\n\nfunction applyConfigTransform(\n  configs: Record<string, NodeConfig>,\n  transform: (config: NodeConfig) => NodeConfig,\n): Record<string, NodeConfig> {\n  let next = configs;\n  for (const [id, config] of Object.entries(configs)) {\n    const updated = transform(config);\n    if (updated !== config) {\n      if (next === configs) {\n        next = { ...configs };\n      }\n      next[id] = updated;\n    }\n  }\n  return next;\n}\n\nexport function applyRenameToConfigs(\n  configs: Record<string, NodeConfig>,\n  from: string,\n  to: string,\n): Record<string, NodeConfig> {\n  if (!from || from === to) {\n    return configs;\n  }\n  return applyConfigTransform(configs, (config) =>\n    applyRenameToConfig(config, from, to),\n  );\n}\n\nexport function applyRemovalToConfigs(\n  configs: Record<string, NodeConfig>,\n  ref: string,\n): Record<string, NodeConfig> {\n  if (!ref) {\n    return configs;\n  }\n  return applyConfigTransform(configs, (config) => applyRemovalToConfig(config, ref));\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/stores/helpers/removals.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { Edge } from \"@xyflow/react\";\nimport type { NodeConfig } from \"../../types\";\nimport { isCategoryConfig, isSubcategoryConfig } from \"../../utils\";\nimport { applyRemovalToConfig, applyRemovalToConfigs } from \"../recipe-studio-helpers\";\n\nexport function applyNodeRemovals(\n  input: { edges: Edge[]; configs: Record<string, NodeConfig> },\n  removedIds: string[],\n): { edges: Edge[]; configs: Record<string, NodeConfig> } {\n  if (removedIds.length === 0) {\n    return input;\n  }\n\n  const edges = input.edges.filter(\n    (edge) => !(removedIds.includes(edge.source) || removedIds.includes(edge.target)),\n  );\n  let configs: Record<string, NodeConfig> = { ...input.configs };\n  const removedNames: string[] = [];\n\n  for (const id of removedIds) {\n    const removed = configs[id];\n    delete configs[id];\n    if (removed?.name) {\n      removedNames.push(removed.name);\n    }\n\n    if (isCategoryConfig(removed)) {\n      const removedName = removed.name;\n      for (const config of Object.values(configs)) {\n        if (!isSubcategoryConfig(config)) {\n          continue;\n        }\n        if (config.subcategory_parent !== removedName) {\n          continue;\n        }\n        configs[config.id] = {\n          ...config,\n          // biome-ignore lint/style/useNamingConvention: api schema\n          subcategory_parent: \"\",\n          // biome-ignore lint/style/useNamingConvention: api schema\n          subcategory_mapping: {},\n        };\n      }\n    }\n  }\n\n  for (const name of removedNames) {\n    configs = applyRemovalToConfigs(configs, name);\n  }\n\n  return { edges, configs };\n}\n\nexport function applyEdgeRemovals(\n  configs: Record<string, NodeConfig>,\n  removedEdges: Edge[],\n): Record<string, NodeConfig> {\n  if (removedEdges.length === 0) {\n    return configs;\n  }\n\n  let next = configs;\n  for (const edge of removedEdges) {\n    const source = next[edge.source];\n    const target = next[edge.target];\n    if (!(source && target)) {\n      continue;\n    }\n    const updated = applyRemovalToConfig(target, source.name);\n    if (updated !== target) {\n      if (next === configs) {\n        next = { ...configs };\n      }\n      next[target.id] = updated;\n    }\n    if (\n      source.kind === \"validator\" &&\n      target.kind === \"llm\" &&\n      target.llm_type === \"code\"\n    ) {\n      const sourceUpdated = applyRemovalToConfig(source, target.name);\n      if (sourceUpdated !== source) {\n        if (next === configs) {\n          next = { ...configs };\n        }\n        next[source.id] = sourceUpdated;\n      }\n    }\n  }\n  return next;\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/stores/recipe-executions.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { create } from \"zustand\";\nimport type { RecipeExecutionKind } from \"../execution-types\";\nimport type { RecipeExecutionRecord } from \"../execution-types\";\nimport { sortExecutions, withExecutionDefaults } from \"../executions/execution-helpers\";\n\nexport type RecipeRunSettings = {\n  batchSize: number;\n  batchEnabled: boolean;\n  mergeBatches: boolean;\n  llmParallelRequests: number | null;\n  nonInferenceWorkers: number;\n  maxConversationRestarts: number;\n  maxConversationCorrectionSteps: number;\n  disableEarlyShutdown: boolean;\n  shutdownErrorRate: number;\n  shutdownErrorWindow: number;\n};\n\nconst DEFAULT_RUN_SETTINGS: RecipeRunSettings = {\n  batchSize: 1000,\n  batchEnabled: false,\n  mergeBatches: false,\n  llmParallelRequests: null,\n  nonInferenceWorkers: 4,\n  maxConversationRestarts: 5,\n  maxConversationCorrectionSteps: 0,\n  disableEarlyShutdown: true,\n  shutdownErrorRate: 0.5,\n  shutdownErrorWindow: 10,\n};\n\ntype RecipeExecutionsState = {\n  runDialogOpen: boolean;\n  runDialogKind: RecipeExecutionKind;\n  previewRows: number;\n  fullRows: number;\n  fullRunName: string;\n  runErrors: string[];\n  runSettings: RecipeRunSettings;\n  previewLoading: boolean;\n  fullLoading: boolean;\n  executions: RecipeExecutionRecord[];\n  selectedExecutionId: string | null;\n  setRunDialogOpen: (open: boolean) => void;\n  setRunDialogKind: (kind: RecipeExecutionKind) => void;\n  setPreviewRows: (rows: number) => void;\n  setFullRows: (rows: number) => void;\n  setFullRunName: (name: string) => void;\n  setRunErrors: (errors: string[]) => void;\n  setRunSettings: (patch: Partial<RecipeRunSettings>) => void;\n  setPreviewLoading: (loading: boolean) => void;\n  setFullLoading: (loading: boolean) => void;\n  setExecutions: (records: RecipeExecutionRecord[]) => void;\n  upsertExecution: (record: RecipeExecutionRecord) => void;\n  selectExecution: (id: string | null) => void;\n  resetForRecipe: () => void;\n};\n\nconst INITIAL_STATE = {\n  runDialogOpen: false,\n  runDialogKind: \"preview\",\n  previewRows: 5,\n  fullRows: 100,\n  fullRunName: \"\",\n  runErrors: [],\n  runSettings: DEFAULT_RUN_SETTINGS,\n  previewLoading: false,\n  fullLoading: false,\n  executions: [],\n  selectedExecutionId: null,\n} satisfies Pick<\n  RecipeExecutionsState,\n  | \"runDialogOpen\"\n  | \"runDialogKind\"\n  | \"previewRows\"\n  | \"fullRows\"\n  | \"fullRunName\"\n  | \"runErrors\"\n  | \"runSettings\"\n  | \"previewLoading\"\n  | \"fullLoading\"\n  | \"executions\"\n  | \"selectedExecutionId\"\n>;\n\nexport const useRecipeExecutionsStore = create<RecipeExecutionsState>((set) => ({\n  ...INITIAL_STATE,\n  setRunDialogOpen: (open) => set({ runDialogOpen: open }),\n  setRunDialogKind: (kind) =>\n    set((state) => {\n      if (state.runDialogKind === \"preview\" && kind === \"full\") {\n        return {\n          runDialogKind: kind,\n          fullRows: 100,\n          runSettings: {\n            ...state.runSettings,\n            batchEnabled: false,\n          },\n        };\n      }\n      return { runDialogKind: kind };\n    }),\n  setPreviewRows: (rows) =>\n    set({ previewRows: Number.isFinite(rows) && rows > 0 ? Math.floor(rows) : 1 }),\n  setFullRows: (rows) =>\n    set({ fullRows: Number.isFinite(rows) && rows > 0 ? Math.floor(rows) : 1 }),\n  setFullRunName: (name) => set({ fullRunName: name }),\n  setRunErrors: (errors) => set({ runErrors: errors }),\n  setRunSettings: (patch) =>\n    set((state) => ({\n      runSettings: {\n        ...state.runSettings,\n        ...patch,\n      },\n    })),\n  setPreviewLoading: (loading) => set({ previewLoading: loading }),\n  setFullLoading: (loading) => set({ fullLoading: loading }),\n  setExecutions: (records) =>\n    set(() => {\n      const normalized = sortExecutions(records.map(withExecutionDefaults));\n      return {\n        executions: normalized,\n        selectedExecutionId: normalized[0]?.id ?? null,\n      };\n    }),\n  upsertExecution: (record) =>\n    set((state) => {\n      const normalized = withExecutionDefaults(record);\n      const withoutCurrent = state.executions.filter((item) => item.id !== normalized.id);\n      return {\n        executions: sortExecutions([normalized, ...withoutCurrent]),\n        selectedExecutionId: normalized.id,\n      };\n    }),\n  selectExecution: (id) => set({ selectedExecutionId: id }),\n  resetForRecipe: () => set(INITIAL_STATE),\n}));\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/stores/recipe-studio-helpers.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nexport {\n  applyLayoutDirectionToNodes,\n  buildNodeUpdate,\n  type NodeUpdateResult,\n  type NodeUpdateState,\n  updateNodeData,\n} from \"./helpers/node-updates\";\nexport {\n  syncEdgesForConfigPatch,\n  syncSubcategoryConfigsForCategoryUpdate,\n} from \"./helpers/edge-sync\";\nexport {\n  applyRemovalToConfig,\n  applyRemovalToConfigs,\n  applyRenameToConfig,\n  applyRenameToConfigs,\n} from \"./helpers/reference-sync\";\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/stores/recipe-studio.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport {\n  type Connection,\n  type Edge,\n  type EdgeChange,\n  type IsValidConnection,\n  type NodeChange,\n  type XYPosition,\n  applyEdgeChanges,\n  applyNodeChanges,\n} from \"@xyflow/react\";\nimport { create } from \"zustand\";\nimport type {\n  RecipeNode,\n  RecipeProcessorConfig,\n  LayoutDirection,\n  LlmType,\n  NodeConfig,\n  SeedSourceType,\n  SamplerType,\n} from \"../types\";\nimport {\n  getBlockDefinition,\n  type BlockKind,\n  type BlockType,\n  type SeedBlockType,\n} from \"../blocks/registry\";\nimport { deriveDisplayGraph } from \"../utils/graph/derive-display-graph\";\nimport { applyRecipeConnection, isValidRecipeConnection } from \"../utils/graph\";\nimport {\n  HANDLE_IDS,\n  normalizeRecipeHandleId,\n  remapRecipeEdgeHandlesForLayout,\n} from \"../utils/handles\";\nimport type { RecipeSnapshot } from \"../utils/import\";\nimport { getLayoutedElements } from \"../utils/layout\";\nimport {\n  centerModelInfraNodes,\n  optimizeModelInfraEdgeHandles,\n} from \"./helpers/model-infra-layout\";\nimport { applyEdgeRemovals, applyNodeRemovals } from \"./helpers/removals\";\nimport {\n  applyRenameToConfigs,\n  applyLayoutDirectionToNodes,\n  buildNodeUpdate,\n  syncEdgesForConfigPatch,\n  syncSubcategoryConfigsForCategoryUpdate,\n  updateNodeData,\n} from \"./recipe-studio-helpers\";\n\ntype SheetView =\n  | \"root\"\n  | \"sampler\"\n  | \"seed\"\n  | \"llm\"\n  | \"validator\"\n  | \"expression\"\n  | \"note\"\n  | \"processor\";\n\ntype RecipeStudioState = {\n  nodes: RecipeNode[];\n  edges: Edge[];\n  auxNodePositions: Record<string, XYPosition>;\n  llmAuxVisibility: Record<string, boolean>;\n  configs: Record<string, NodeConfig>;\n  processors: RecipeProcessorConfig[];\n  sheetOpen: boolean;\n  sheetView: SheetView;\n  activeConfigId: string | null;\n  dialogOpen: boolean;\n  layoutDirection: LayoutDirection;\n  executionLocked: boolean;\n  nextId: number;\n  nextY: number;\n  fitViewTick: number;\n  setSheetOpen: (open: boolean) => void;\n  setSheetView: (view: SheetView) => void;\n  setProcessors: (processors: RecipeProcessorConfig[]) => void;\n  setDialogOpen: (open: boolean) => void;\n  setExecutionLocked: (locked: boolean) => void;\n  resetRecipe: () => void;\n  selectConfig: (id: string) => void;\n  openConfig: (id: string) => void;\n  setLayoutDirection: (direction: LayoutDirection) => void;\n  applyLayout: () => void;\n  setLlmAuxVisibility: (id: string, visible: boolean) => void;\n  addSamplerNode: (\n    type: SamplerType,\n    position?: XYPosition,\n    openDialog?: boolean,\n  ) => void;\n  addSeedNode: (\n    type: SeedBlockType,\n    position?: XYPosition,\n    openDialog?: boolean,\n  ) => void;\n  addLlmNode: (type: LlmType, position?: XYPosition, openDialog?: boolean) => void;\n  addModelProviderNode: (position?: XYPosition, openDialog?: boolean) => void;\n  addModelConfigNode: (position?: XYPosition, openDialog?: boolean) => void;\n  addToolProfileNode: (position?: XYPosition, openDialog?: boolean) => void;\n  addExpressionNode: (position?: XYPosition, openDialog?: boolean) => void;\n  addValidatorNode: (\n    type: \"validator_python\" | \"validator_sql\" | \"validator_oxc\",\n    position?: XYPosition,\n    openDialog?: boolean,\n  ) => void;\n  addMarkdownNoteNode: (position?: XYPosition, openDialog?: boolean) => void;\n  updateConfig: (id: string, patch: Partial<NodeConfig>) => void;\n  loadRecipe: (snapshot: RecipeSnapshot) => void;\n  setAuxNodePosition: (id: string, position: XYPosition) => void;\n  onNodesChange: (changes: NodeChange<RecipeNode>[]) => void;\n  onEdgesChange: (changes: EdgeChange<Edge>[]) => void;\n  onConnect: (connection: Connection) => void;\n  isValidConnection: IsValidConnection;\n};\n\nconst INITIAL_STATE = {\n  nodes: [],\n  edges: [],\n  auxNodePositions: {},\n  llmAuxVisibility: {},\n  configs: {},\n  processors: [],\n  sheetOpen: false,\n  sheetView: \"root\",\n  activeConfigId: null,\n  dialogOpen: false,\n  layoutDirection: \"LR\",\n  executionLocked: false,\n  nextId: 3,\n  nextY: 280,\n  fitViewTick: 0,\n} satisfies Pick<\n  RecipeStudioState,\n  | \"nodes\"\n  | \"edges\"\n  | \"auxNodePositions\"\n  | \"llmAuxVisibility\"\n  | \"configs\"\n  | \"processors\"\n  | \"sheetOpen\"\n  | \"sheetView\"\n  | \"activeConfigId\"\n  | \"dialogOpen\"\n  | \"layoutDirection\"\n  | \"executionLocked\"\n  | \"nextId\"\n  | \"nextY\"\n  | \"fitViewTick\"\n>;\n\nfunction buildAddedNodeState(\n  state: RecipeStudioState,\n  kind: BlockKind,\n  type: BlockType,\n  position?: XYPosition,\n  openDialog = true,\n): Partial<RecipeStudioState> | RecipeStudioState {\n  const id = `n${state.nextId}`;\n  const existing = Object.values(state.configs);\n  const definition = getBlockDefinition(kind, type);\n  if (!definition) {\n    return state;\n  }\n  const config = definition.createConfig(id, existing);\n  return buildNodeUpdate(\n    state,\n    config,\n    state.layoutDirection,\n    position,\n    openDialog,\n  );\n}\n\nfunction getAddedNodeContext(\n  update: Partial<RecipeStudioState> | RecipeStudioState,\n): {\n  nodes: RecipeNode[];\n  configs: Record<string, NodeConfig>;\n  newNodeId: string;\n} | null {\n  const nodes = \"nodes\" in update ? update.nodes : null;\n  const configs = \"configs\" in update ? update.configs : null;\n  const newNodeId = \"activeConfigId\" in update ? update.activeConfigId : null;\n  if (!(nodes && configs && newNodeId)) {\n    return null;\n  }\n  return { nodes, configs, newNodeId };\n}\n\nfunction placeNodeNear(\n  nodes: RecipeNode[],\n  nodeId: string,\n  anchorId: string,\n  direction: LayoutDirection,\n  relation: \"before\" | \"after\",\n): RecipeNode[] {\n  const anchor = nodes.find((node) => node.id === anchorId);\n  if (!anchor) {\n    return nodes;\n  }\n  const primaryOffset = relation === \"before\" ? -440 : 440;\n  return nodes.map((node) => {\n    if (node.id !== nodeId) {\n      return node;\n    }\n    if (direction === \"TB\") {\n      return {\n        ...node,\n        position: {\n          x: anchor.position.x,\n          y: anchor.position.y + primaryOffset,\n        },\n      };\n    }\n    return {\n      ...node,\n      position: {\n        x: anchor.position.x + primaryOffset,\n        y: anchor.position.y,\n      },\n    };\n  });\n}\n\nfunction connectSemantic(\n  edges: Edge[],\n  configs: Record<string, NodeConfig>,\n  sourceId: string,\n  targetId: string,\n  layoutDirection: LayoutDirection,\n): { edges: Edge[]; configs: Record<string, NodeConfig> } {\n  const result = applyRecipeConnection(\n    {\n      source: sourceId,\n      sourceHandle: HANDLE_IDS.semanticOut,\n      target: targetId,\n      targetHandle: HANDLE_IDS.semanticIn,\n    },\n    configs,\n    edges,\n    layoutDirection,\n  );\n  return {\n    edges: result.edges,\n    configs: result.configs ?? configs,\n  };\n}\n\nfunction isModelSemanticEdge(edge: Edge, configs: Record<string, NodeConfig>): boolean {\n  const source = configs[edge.source];\n  const target = configs[edge.target];\n  return Boolean(\n    source &&\n      target &&\n      ((source.kind === \"model_provider\" && target.kind === \"model_config\") ||\n        (source.kind === \"model_config\" && target.kind === \"llm\") ||\n        (source.kind === \"tool_config\" && target.kind === \"llm\")),\n  );\n}\n\nexport const useRecipeStudioStore = create<RecipeStudioState>((set, get) => ({\n  ...INITIAL_STATE,\n  setSheetOpen: (open) => set({ sheetOpen: open }),\n  setSheetView: (view) => set({ sheetView: view }),\n  setProcessors: (processors) =>\n    set((state) => (state.executionLocked ? state : { processors })),\n  setDialogOpen: (open) => set({ dialogOpen: open }),\n  setExecutionLocked: (locked) => set({ executionLocked: locked }),\n  resetRecipe: () => set(INITIAL_STATE),\n  selectConfig: (id) => set({ activeConfigId: id, dialogOpen: false }),\n  openConfig: (id) => set({ activeConfigId: id, dialogOpen: true }),\n  setLayoutDirection: (direction) =>\n    set((state) => {\n      if (state.executionLocked) {\n        return state;\n      }\n      return {\n        layoutDirection: direction,\n        edges: state.edges.map((edge) => {\n          if (isModelSemanticEdge(edge, state.configs)) {\n            return {\n              ...edge,\n              sourceHandle: normalizeRecipeHandleId(edge.sourceHandle),\n              targetHandle: normalizeRecipeHandleId(edge.targetHandle),\n            };\n          }\n          return {\n            ...edge,\n            ...remapRecipeEdgeHandlesForLayout(edge, direction),\n          };\n        }),\n        nodes: applyLayoutDirectionToNodes(\n          state.nodes,\n          state.configs,\n          direction,\n        ),\n      };\n    }),\n  applyLayout: () =>\n    set((state) => {\n      if (state.executionLocked) {\n        return state;\n      }\n      const isTopBottom = state.layoutDirection === \"TB\";\n\n      const displayGraph = deriveDisplayGraph({\n        nodes: state.nodes,\n        edges: state.edges,\n        configs: state.configs,\n        layoutDirection: state.layoutDirection,\n        auxNodePositions: {},\n        llmAuxVisibility: state.llmAuxVisibility,\n      });\n      const { nodes } = getLayoutedElements(displayGraph.nodes, displayGraph.edges, {\n        direction: state.layoutDirection,\n        nodesep: isTopBottom ? 120 : 80,\n        ranksep: isTopBottom ? 140 : 80,\n        configs: state.configs,\n      });\n      const layoutedPositions = new Map(\n        nodes.map((node) => [node.id, node.position] as const),\n      );\n      const nextNodes = state.nodes.map((node) => {\n        const position = layoutedPositions.get(node.id);\n        if (!position) {\n          return node;\n        }\n        return { ...node, position };\n      });\n      const centeredNodes = centerModelInfraNodes(\n        nextNodes,\n        state.edges,\n        state.configs,\n        state.layoutDirection,\n      );\n      const optimizedEdges = optimizeModelInfraEdgeHandles(\n        state.edges,\n        centeredNodes,\n        state.configs,\n        state.layoutDirection,\n      );\n      return {\n        auxNodePositions: {},\n        edges: optimizedEdges,\n        nodes: applyLayoutDirectionToNodes(\n          centeredNodes,\n          state.configs,\n          state.layoutDirection,\n        ),\n      };\n    }),\n  setLlmAuxVisibility: (id, visible) =>\n    set((state) => {\n      if (state.llmAuxVisibility[id] === visible) {\n        return state;\n      }\n      return {\n        llmAuxVisibility: {\n          ...state.llmAuxVisibility,\n          [id]: visible,\n        },\n      };\n    }),\n  addSamplerNode: (type, position, openDialog = true) =>\n    set((state) => {\n      if (state.executionLocked) {\n        return state;\n      }\n      return buildAddedNodeState(state, \"sampler\", type, position, openDialog);\n    }),\n  addSeedNode: (type, position, openDialog = true) =>\n    set((state) => {\n      if (state.executionLocked) {\n        return state;\n      }\n      const existing = Object.values(state.configs).find(\n        (config) => config.kind === \"seed\",\n      );\n      if (!existing) {\n        return buildAddedNodeState(\n          state,\n          \"seed\",\n          type,\n          position,\n          openDialog,\n        );\n      }\n      let nextSourceType: SeedSourceType = \"hf\";\n      if (type === \"seed_local\") {\n        nextSourceType = \"local\";\n      } else if (type === \"seed_unstructured\") {\n        nextSourceType = \"unstructured\";\n      }\n\n      const nextConfig = {\n        ...existing,\n        seed_source_type: nextSourceType,\n        hf_repo_id: \"\",\n        hf_subset: \"\",\n        hf_split: \"\",\n        hf_path: \"\",\n        hf_token: \"\",\n        hf_endpoint: \"https://huggingface.co\",\n        local_file_name: \"\",\n        unstructured_file_name: \"\",\n        seed_columns: [],\n        seed_drop_columns: [],\n        seed_preview_rows: [],\n        unstructured_chunk_size: \"1200\",\n        unstructured_chunk_overlap: \"200\",\n      };\n      return {\n        configs: {\n          ...state.configs,\n          [existing.id]: nextConfig,\n        },\n        nodes: updateNodeData(\n          state.nodes.map((node) => ({ ...node, selected: node.id === existing.id })),\n          existing.id,\n          nextConfig,\n          state.layoutDirection,\n        ),\n        activeConfigId: existing.id,\n        dialogOpen: openDialog,\n      };\n    }),\n  addLlmNode: (type, position, openDialog = true) =>\n    set((state) => {\n      if (state.executionLocked) {\n        return state;\n      }\n      const added = buildAddedNodeState(state, \"llm\", type, position, openDialog);\n      const context = getAddedNodeContext(added);\n      if (!context) {\n        return added;\n      }\n      let { nodes, configs } = context;\n      let edges = state.edges;\n      const modelConfigs = Object.values(configs).filter(\n        (config) => config.kind === \"model_config\",\n      );\n      if (modelConfigs.length === 1) {\n        if (!position) {\n          nodes = placeNodeNear(\n            nodes,\n            context.newNodeId,\n            modelConfigs[0].id,\n            state.layoutDirection,\n            \"after\",\n          );\n        }\n        const next = connectSemantic(\n          edges,\n          configs,\n          modelConfigs[0].id,\n          context.newNodeId,\n          state.layoutDirection,\n        );\n        edges = next.edges;\n        configs = next.configs;\n      }\n      return { ...added, nodes, edges, configs };\n    }),\n  addModelProviderNode: (position, openDialog = true) =>\n    set((state) => {\n      if (state.executionLocked) {\n        return state;\n      }\n      const added = buildAddedNodeState(\n        state,\n        \"llm\",\n        \"model_provider\",\n        position,\n        openDialog,\n      );\n      const context = getAddedNodeContext(added);\n      if (!context) {\n        return added;\n      }\n      let { nodes, configs } = context;\n      let edges = state.edges;\n      const unboundModelConfigs = Object.values(configs).filter(\n        (config) =>\n          config.kind === \"model_config\" &&\n          !config.provider.trim(),\n      );\n      if (!position && unboundModelConfigs.length > 0) {\n        nodes = placeNodeNear(\n          nodes,\n          context.newNodeId,\n          unboundModelConfigs[0].id,\n          state.layoutDirection,\n          \"before\",\n        );\n      }\n      if (unboundModelConfigs.length === 1) {\n        const next = connectSemantic(\n          edges,\n          configs,\n          context.newNodeId,\n          unboundModelConfigs[0].id,\n          state.layoutDirection,\n        );\n        edges = next.edges;\n        configs = next.configs;\n      }\n      return { ...added, nodes, edges, configs };\n    }),\n  addModelConfigNode: (position, openDialog = true) =>\n    set((state) => {\n      if (state.executionLocked) {\n        return state;\n      }\n      const added = buildAddedNodeState(\n        state,\n        \"llm\",\n        \"model_config\",\n        position,\n        openDialog,\n      );\n      const context = getAddedNodeContext(added);\n      if (!context) {\n        return added;\n      }\n      let { nodes, configs } = context;\n      let edges = state.edges;\n      const providers = Object.values(configs).filter(\n        (config) => config.kind === \"model_provider\",\n      );\n      const unboundLlms = Object.values(configs).filter(\n        (config) => config.kind === \"llm\" && !config.model_alias.trim(),\n      );\n      if (!position && providers.length === 1) {\n        nodes = placeNodeNear(\n          nodes,\n          context.newNodeId,\n          providers[0].id,\n          state.layoutDirection,\n          \"after\",\n        );\n      } else if (!position && unboundLlms.length > 0) {\n        nodes = placeNodeNear(\n          nodes,\n          context.newNodeId,\n          unboundLlms[0].id,\n          state.layoutDirection,\n          \"before\",\n        );\n      }\n      if (providers.length === 1) {\n        const next = connectSemantic(\n          edges,\n          configs,\n          providers[0].id,\n          context.newNodeId,\n          state.layoutDirection,\n        );\n        edges = next.edges;\n        configs = next.configs;\n      }\n      if (unboundLlms.length === 1) {\n        const next = connectSemantic(\n          edges,\n          configs,\n          context.newNodeId,\n          unboundLlms[0].id,\n          state.layoutDirection,\n        );\n        edges = next.edges;\n        configs = next.configs;\n      }\n      return { ...added, nodes, edges, configs };\n    }),\n  addToolProfileNode: (position, openDialog = true) =>\n    set((state) => {\n      if (state.executionLocked) {\n        return state;\n      }\n      const added = buildAddedNodeState(\n        state,\n        \"llm\",\n        \"tool_config\",\n        position,\n        openDialog,\n      );\n      const context = getAddedNodeContext(added);\n      if (!context) {\n        return added;\n      }\n      let { nodes, configs } = context;\n      let edges = state.edges;\n      const unboundLlms = Object.values(configs).filter(\n        (config) => config.kind === \"llm\" && !(config.tool_alias?.trim()),\n      );\n      if (!position && unboundLlms.length > 0) {\n        nodes = placeNodeNear(\n          nodes,\n          context.newNodeId,\n          unboundLlms[0].id,\n          state.layoutDirection,\n          \"before\",\n        );\n      }\n      if (unboundLlms.length === 1) {\n        const next = connectSemantic(\n          edges,\n          configs,\n          context.newNodeId,\n          unboundLlms[0].id,\n          state.layoutDirection,\n        );\n        edges = next.edges;\n        configs = next.configs;\n      }\n      return { ...added, nodes, edges, configs };\n    }),\n  addExpressionNode: (position, openDialog = true) =>\n    set((state) => {\n      if (state.executionLocked) {\n        return state;\n      }\n      return buildAddedNodeState(\n        state,\n        \"expression\",\n        \"expression\",\n        position,\n        openDialog,\n      );\n    }),\n  addValidatorNode: (type, position, openDialog = true) =>\n    set((state) => {\n      if (state.executionLocked) {\n        return state;\n      }\n      return buildAddedNodeState(\n        state,\n        \"validator\",\n        type,\n        position,\n        openDialog,\n      );\n    }),\n  addMarkdownNoteNode: (position, openDialog = true) =>\n    set((state) => {\n      if (state.executionLocked) {\n        return state;\n      }\n      return buildAddedNodeState(\n        state,\n        \"note\",\n        \"markdown_note\",\n        position,\n        openDialog,\n      );\n    }),\n  loadRecipe: (snapshot) =>\n    set((state) => ({\n      configs: snapshot.configs,\n      nodes: applyLayoutDirectionToNodes(\n        snapshot.nodes,\n        snapshot.configs,\n        snapshot.layoutDirection,\n      ),\n      edges: snapshot.edges,\n      processors: snapshot.processors,\n      layoutDirection: snapshot.layoutDirection,\n      nextId: snapshot.nextId,\n      nextY: snapshot.nextY,\n      auxNodePositions: snapshot.auxNodePositions ?? {},\n      llmAuxVisibility: {},\n      activeConfigId: null,\n      dialogOpen: false,\n      sheetView: \"root\",\n      fitViewTick: state.fitViewTick + 1,\n    })),\n  setAuxNodePosition: (id, position) =>\n    set((state) => {\n      const current = state.auxNodePositions[id];\n      if (current && current.x === position.x && current.y === position.y) {\n        return state;\n      }\n      return {\n        auxNodePositions: {\n          ...state.auxNodePositions,\n          [id]: position,\n        },\n      };\n    }),\n  updateConfig: (id, patch) => {\n    const applyUpdate = (state: RecipeStudioState) => {\n      if (state.executionLocked) {\n        return state;\n      }\n      const current = state.configs[id];\n      if (!current) {\n        return state;\n      }\n      const next = { ...current, ...patch } as NodeConfig;\n      const oldName = current.name;\n      const newName = next.name;\n      const nameChanged = oldName !== newName;\n      let configs: Record<string, NodeConfig> = {\n        ...state.configs,\n        [id]: next,\n      };\n      const nodes = updateNodeData(\n        state.nodes,\n        id,\n        next,\n        state.layoutDirection,\n      );\n      const edges = syncEdgesForConfigPatch(\n        current,\n        patch,\n        configs,\n        state.edges,\n        state.layoutDirection,\n      );\n      configs = syncSubcategoryConfigsForCategoryUpdate(\n        current,\n        next,\n        configs,\n        oldName,\n        newName,\n        nameChanged,\n      );\n\n      if (nameChanged) {\n        configs = applyRenameToConfigs(configs, oldName, newName);\n      }\n\n      return { configs, nodes, edges };\n    };\n    set(applyUpdate);\n  },\n  onNodesChange: (changes) => {\n    const applyNodesChange = (state: RecipeStudioState) => {\n      if (state.executionLocked) {\n        return state;\n      }\n      const removedIds = changes\n        .filter((change) => change.type === \"remove\")\n        .map((change) => change.id);\n\n      const removed = applyNodeRemovals(\n        { edges: state.edges, configs: state.configs },\n        removedIds,\n      );\n      const nodes = applyNodeChanges<RecipeNode>(changes, state.nodes);\n      const llmAuxVisibility =\n        removedIds.length === 0\n          ? state.llmAuxVisibility\n          : Object.fromEntries(\n              Object.entries(state.llmAuxVisibility).filter(\n                ([id]) => !removedIds.includes(id),\n              ),\n            );\n      return {\n        nodes,\n        edges: removed.edges,\n        configs: removed.configs,\n        llmAuxVisibility,\n      };\n    };\n    set(applyNodesChange);\n  },\n  onEdgesChange: (changes) => {\n    set((state) => {\n      if (state.executionLocked) {\n        return state;\n      }\n      const removedEdges = changes\n        .filter((change) => change.type === \"remove\")\n        .map((change) => state.edges.find((edge) => edge.id === change.id))\n        .filter((edge): edge is Edge => Boolean(edge));\n\n      const configs = applyEdgeRemovals(state.configs, removedEdges);\n\n      const edges = applyEdgeChanges(changes, state.edges);\n      return configs === state.configs ? { edges } : { edges, configs };\n    });\n  },\n  onConnect: (connection) => {\n    set((state) => {\n      if (state.executionLocked) {\n        return state;\n      }\n      const result = applyRecipeConnection(\n        connection,\n        state.configs,\n        state.edges,\n        state.layoutDirection,\n      );\n      return result.configs\n        ? { edges: result.edges, configs: result.configs }\n        : { edges: result.edges };\n    });\n  },\n  isValidConnection: (connection) =>\n    isValidRecipeConnection(\n      {\n        source: connection.source ?? null,\n        target: connection.target ?? null,\n        sourceHandle: connection.sourceHandle ?? null,\n        targetHandle: connection.targetHandle ?? null,\n      },\n      get().configs,\n    ),\n}));\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/types/index.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { Node } from \"@xyflow/react\";\n\nexport type SamplerType =\n  | \"category\"\n  | \"subcategory\"\n  | \"uniform\"\n  | \"gaussian\"\n  | \"bernoulli\"\n  | \"datetime\"\n  | \"timedelta\"\n  | \"uuid\"\n  | \"person\"\n  | \"person_from_faker\";\n\nexport type LlmType = \"text\" | \"structured\" | \"code\" | \"judge\";\nexport type ValidatorCodeLang =\n  | \"javascript\"\n  | \"typescript\"\n  | \"jsx\"\n  | \"tsx\"\n  | \"python\"\n  | \"sql:sqlite\"\n  | \"sql:postgres\"\n  | \"sql:mysql\"\n  | \"sql:tsql\"\n  | \"sql:bigquery\"\n  | \"sql:ansi\";\nexport type ValidatorType = \"code\" | \"oxc\";\nexport type OxcValidationMode = \"syntax\" | \"lint\" | \"syntax+lint\";\nexport type OxcCodeShape = \"auto\" | \"module\" | \"snippet\";\n\nexport type ExpressionDtype = \"str\" | \"int\" | \"float\" | \"bool\";\n\nexport type LayoutDirection = \"LR\" | \"TB\";\n\nexport type SeedSamplingStrategy = \"ordered\" | \"shuffle\";\nexport type SeedSelectionType = \"none\" | \"index_range\" | \"partition_block\";\nexport type SeedSourceType = \"hf\" | \"local\" | \"unstructured\";\nexport const INFRA_NODE_KINDS = new Set([\n  \"model_provider\",\n  \"model_config\",\n  \"tool_config\",\n]);\n\nexport type RecipeNodeData = {\n  title: string;\n  name: string;\n  kind:\n    | \"sampler\"\n    | \"llm\"\n    | \"validator\"\n    | \"expression\"\n    | \"seed\"\n    | \"note\"\n    | \"model_provider\"\n    | \"model_config\"\n    | \"tool_config\";\n  subtype: string;\n  blockType:\n    | SamplerType\n    | LlmType\n    | \"validator_python\"\n    | \"validator_sql\"\n    | \"validator_oxc\"\n    | \"expression\"\n    | \"seed\"\n    | \"markdown_note\"\n    | \"model_provider\"\n    | \"model_config\"\n    | \"tool_config\";\n  layoutDirection?: LayoutDirection;\n  runtimeState?: \"idle\" | \"running\" | \"done\";\n  executionLocked?: boolean;\n};\n\nexport type RecipeNode = Node<RecipeNodeData, \"builder\">;\n\nexport type CategoryConditionalParams = {\n  // biome-ignore lint/style/useNamingConvention: api schema\n  sampler_type: \"category\";\n  values: string[];\n  weights?: Array<number | null>;\n};\n\nexport type SamplerConfig = {\n  id: string;\n  kind: \"sampler\";\n  // ui-only\n  advancedOpen?: boolean;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  sampler_type: SamplerType;\n  name: string;\n  drop?: boolean;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  convert_to?: \"float\" | \"int\" | \"str\";\n  values?: string[];\n  weights?: Array<number | null>;\n  low?: string;\n  high?: string;\n  mean?: string;\n  std?: string;\n  p?: string;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  datetime_start?: string;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  datetime_end?: string;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  datetime_unit?: string;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  dt_min?: string;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  dt_max?: string;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  reference_column_name?: string;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  timedelta_unit?: \"D\" | \"h\" | \"m\" | \"s\";\n  // biome-ignore lint/style/useNamingConvention: api schema\n  uuid_format?: string;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  person_locale?: string;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  person_sex?: string;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  person_age_range?: string;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  person_city?: string;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  person_with_synthetic_personas?: boolean;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  subcategory_parent?: string;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  subcategory_mapping?: Record<string, string[]>;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  conditional_params?: Record<string, CategoryConditionalParams>;\n};\n\nexport type ScoreOption = {\n  value: string;\n  description: string;\n};\n\nexport type Score = {\n  name: string;\n  description: string;\n  options: ScoreOption[];\n};\n\nexport type McpProviderType = \"stdio\" | \"streamable_http\";\n\nexport type McpEnvVar = {\n  key: string;\n  value: string;\n};\n\nexport type LlmMcpProviderConfig = {\n  id: string;\n  name: string;\n  // biome-ignore lint/style/useNamingConvention: ui schema\n  provider_type: McpProviderType;\n  command?: string;\n  args?: string[];\n  env?: McpEnvVar[];\n  endpoint?: string;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  api_key?: string;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  api_key_env?: string;\n};\n\nexport type LlmToolConfig = {\n  id: string;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  tool_alias: string;\n  providers: string[];\n  // biome-ignore lint/style/useNamingConvention: api schema\n  allow_tools?: string[];\n  // biome-ignore lint/style/useNamingConvention: api schema\n  max_tool_call_turns?: string;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  timeout_sec?: string;\n};\n\nexport type ToolProfileConfig = {\n  id: string;\n  kind: \"tool_config\";\n  name: string;\n  // biome-ignore lint/style/useNamingConvention: ui schema\n  mcp_providers: LlmMcpProviderConfig[];\n  // biome-ignore lint/style/useNamingConvention: ui schema\n  fetched_tools_by_provider?: Record<string, string[]>;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  allow_tools?: string[];\n  // biome-ignore lint/style/useNamingConvention: api schema\n  max_tool_call_turns?: string;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  timeout_sec?: string;\n};\n\nexport type LlmImageContextConfig = {\n  enabled: boolean;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  column_name: string;\n};\n\nexport type LlmTraceType = \"none\" | \"last_message\" | \"all_messages\";\n\nexport type LlmConfig = {\n  id: string;\n  kind: \"llm\";\n  // ui-only\n  advancedOpen?: boolean;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  llm_type: LlmType;\n  name: string;\n  drop?: boolean;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  model_alias: string;\n  prompt: string;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  system_prompt: string;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  code_lang?: string;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  output_format?: string;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  tool_alias?: string;\n  scores?: Score[];\n  // ui-only, serialized into multi_modal_context for DataDesigner\n  // biome-ignore lint/style/useNamingConvention: ui schema\n  image_context?: LlmImageContextConfig;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  with_trace?: LlmTraceType;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  extract_reasoning_content?: boolean;\n};\n\nexport type ModelProviderConfig = {\n  id: string;\n  kind: \"model_provider\";\n  name: string;\n  endpoint: string;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  provider_type: string;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  api_key_env?: string;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  api_key?: string;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  extra_headers?: string;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  extra_body?: string;\n};\n\nexport type ModelConfig = {\n  id: string;\n  kind: \"model_config\";\n  name: string;\n  model: string;\n  provider: string;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  inference_temperature?: string;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  inference_top_p?: string;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  inference_max_tokens?: string;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  inference_timeout?: string;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  inference_extra_body?: string;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  skip_health_check?: boolean;\n};\n\nexport type ExpressionConfig = {\n  id: string;\n  kind: \"expression\";\n  name: string;\n  drop?: boolean;\n  expr: string;\n  dtype: ExpressionDtype;\n};\n\nexport type ValidatorConfig = {\n  id: string;\n  kind: \"validator\";\n  // ui-only\n  advancedOpen?: boolean;\n  name: string;\n  drop?: boolean;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  target_columns: string[];\n  // ui-only\n  validator_type: ValidatorType;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  code_lang: ValidatorCodeLang;\n  // ui-only (used for OXC validators)\n  oxc_validation_mode: OxcValidationMode;\n  // ui-only (used for OXC validators)\n  oxc_code_shape: OxcCodeShape;\n  // ui ergonomics (serialized to int in payload)\n  batch_size: string;\n};\n\nexport type MarkdownNoteConfig = {\n  id: string;\n  kind: \"markdown_note\";\n  name: string;\n  markdown: string;\n  // ui-only\n  note_color?: string;\n  // ui-only (0-100 as string for slider/input ergonomics)\n  note_opacity?: string;\n};\n\nexport type SeedConfig = {\n  id: string;\n  kind: \"seed\";\n  // ui-only\n  advancedOpen?: boolean;\n  name: string;\n  drop?: boolean;\n  // ui-only: explicit per-column drop for structured seed sources (hf/local)\n  seed_drop_columns?: string[];\n  seed_source_type: SeedSourceType;\n  // ui-only (serialized in seed_config)\n  hf_repo_id: string;\n  hf_subset?: string;\n  hf_split?: string;\n  hf_path: string;\n  hf_token?: string;\n  hf_endpoint?: string;\n  local_file_name?: string;\n  unstructured_file_name?: string;\n  // ui-only\n  seed_preview_rows?: Record<string, unknown>[];\n  // ui-only (string for input ergonomics)\n  unstructured_chunk_size?: string;\n  // ui-only (string for input ergonomics)\n  unstructured_chunk_overlap?: string;\n  seed_splits?: string[];\n  // ui-only\n  // biome-ignore lint/style/useNamingConvention: ui schema\n  seed_globs_by_split?: Record<string, string>;\n  seed_columns?: string[];\n  sampling_strategy: SeedSamplingStrategy;\n  selection_type: SeedSelectionType;\n  selection_start?: string;\n  selection_end?: string;\n  selection_index?: string;\n  selection_num_partitions?: string;\n};\n\nexport type SchemaTransformProcessorConfig = {\n  id: string;\n  // biome-ignore lint/style/useNamingConvention: api schema\n  processor_type: \"schema_transform\";\n  name: string;\n  template: string;\n};\n\nexport type RecipeProcessorConfig = SchemaTransformProcessorConfig;\n\nexport type NodeConfig =\n  | SamplerConfig\n  | LlmConfig\n  | ValidatorConfig\n  | ExpressionConfig\n  | MarkdownNoteConfig\n  | SeedConfig\n  | ModelProviderConfig\n  | ModelConfig\n  | ToolProfileConfig;\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/config-factories.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type {\n  ExpressionConfig,\n  LlmConfig,\n  LlmType,\n  MarkdownNoteConfig,\n  ModelConfig,\n  ModelProviderConfig,\n  NodeConfig,\n  SeedConfig,\n  SeedSourceType,\n  SamplerConfig,\n  SamplerType,\n  ToolProfileConfig,\n  ValidatorCodeLang,\n  ValidatorType,\n  ValidatorConfig,\n} from \"../types\";\nimport { nextName } from \"./naming\";\n\nexport function makeSamplerConfig(\n  id: string,\n  samplerType: SamplerType,\n  existing: NodeConfig[],\n): SamplerConfig {\n  const namePrefix =\n    samplerType === \"subcategory\" ? \"subcategory\" : samplerType;\n  const name = nextName(existing, namePrefix);\n  if (samplerType === \"category\") {\n    return {\n      id,\n      kind: \"sampler\",\n      // biome-ignore lint/style/useNamingConvention: api schema\n      sampler_type: \"category\",\n      name,\n      drop: false,\n      values: [],\n      weights: [],\n    };\n  }\n  if (samplerType === \"subcategory\") {\n    return {\n      id,\n      kind: \"sampler\",\n      // biome-ignore lint/style/useNamingConvention: api schema\n      sampler_type: \"subcategory\",\n      name,\n      drop: false,\n      // biome-ignore lint/style/useNamingConvention: api schema\n      subcategory_parent: \"\",\n      // biome-ignore lint/style/useNamingConvention: api schema\n      subcategory_mapping: {},\n    };\n  }\n  if (samplerType === \"uniform\") {\n    return {\n      id,\n      kind: \"sampler\",\n      // biome-ignore lint/style/useNamingConvention: api schema\n      sampler_type: \"uniform\",\n      name,\n      drop: false,\n      low: \"0\",\n      high: \"1\",\n    };\n  }\n  if (samplerType === \"gaussian\") {\n    return {\n      id,\n      kind: \"sampler\",\n      // biome-ignore lint/style/useNamingConvention: api schema\n      sampler_type: \"gaussian\",\n      name,\n      drop: false,\n      mean: \"0\",\n      std: \"1\",\n    };\n  }\n  if (samplerType === \"bernoulli\") {\n    return {\n      id,\n      kind: \"sampler\",\n      // biome-ignore lint/style/useNamingConvention: api schema\n      sampler_type: \"bernoulli\",\n      name,\n      drop: false,\n      p: \"0.5\",\n    };\n  }\n  if (samplerType === \"datetime\") {\n    return {\n      id,\n      kind: \"sampler\",\n      // biome-ignore lint/style/useNamingConvention: api schema\n      sampler_type: \"datetime\",\n      name,\n      drop: false,\n      // biome-ignore lint/style/useNamingConvention: api schema\n      datetime_start: \"\",\n      // biome-ignore lint/style/useNamingConvention: api schema\n      datetime_end: \"\",\n      // biome-ignore lint/style/useNamingConvention: api schema\n      datetime_unit: \"day\",\n    };\n  }\n  if (samplerType === \"timedelta\") {\n    return {\n      id,\n      kind: \"sampler\",\n      // biome-ignore lint/style/useNamingConvention: api schema\n      sampler_type: \"timedelta\",\n      name,\n      drop: false,\n      // biome-ignore lint/style/useNamingConvention: api schema\n      dt_min: \"0\",\n      // biome-ignore lint/style/useNamingConvention: api schema\n      dt_max: \"1\",\n      // biome-ignore lint/style/useNamingConvention: api schema\n      reference_column_name: \"\",\n      // biome-ignore lint/style/useNamingConvention: api schema\n      timedelta_unit: \"D\",\n    };\n  }\n  if (samplerType === \"uuid\") {\n    return {\n      id,\n      kind: \"sampler\",\n      // biome-ignore lint/style/useNamingConvention: api schema\n      sampler_type: \"uuid\",\n      name,\n      drop: false,\n      // biome-ignore lint/style/useNamingConvention: api schema\n      uuid_format: \"\",\n    };\n  }\n  if (samplerType === \"person\" || samplerType === \"person_from_faker\") {\n    return {\n      id,\n      kind: \"sampler\",\n      // biome-ignore lint/style/useNamingConvention: api schema\n      sampler_type: \"person_from_faker\",\n      name,\n      drop: false,\n      // biome-ignore lint/style/useNamingConvention: api schema\n      person_locale: \"\",\n      // biome-ignore lint/style/useNamingConvention: api schema\n      person_sex: \"\",\n      // biome-ignore lint/style/useNamingConvention: api schema\n      person_age_range: \"\",\n      // biome-ignore lint/style/useNamingConvention: api schema\n      person_city: \"\",\n    };\n  }\n  return {\n    id,\n    kind: \"sampler\",\n    // biome-ignore lint/style/useNamingConvention: api schema\n    sampler_type: \"person_from_faker\",\n    name,\n    drop: false,\n    // biome-ignore lint/style/useNamingConvention: api schema\n    person_locale: \"\",\n    // biome-ignore lint/style/useNamingConvention: api schema\n    person_sex: \"\",\n    // biome-ignore lint/style/useNamingConvention: api schema\n    person_age_range: \"\",\n    // biome-ignore lint/style/useNamingConvention: api schema\n    person_city: \"\",\n  };\n}\n\nexport function makeLlmConfig(\n  id: string,\n  llmType: LlmType,\n  existing: NodeConfig[],\n): LlmConfig {\n  let namePrefix = \"llm_text\";\n  if (llmType === \"structured\") {\n    namePrefix = \"llm_structured\";\n  } else if (llmType === \"code\") {\n    namePrefix = \"llm_code\";\n  } else if (llmType === \"judge\") {\n    namePrefix = \"llm_judge\";\n  }\n  const name = nextName(existing, namePrefix);\n  return {\n    id,\n    kind: \"llm\",\n    // biome-ignore lint/style/useNamingConvention: api schema\n    llm_type: llmType,\n    name,\n    drop: false,\n    // biome-ignore lint/style/useNamingConvention: api schema\n    model_alias: \"\",\n    prompt:\n      llmType === \"judge\"\n        ? \"Evaluate the content using the scoring criteria below.\"\n        : \"Write a response.\",\n    // biome-ignore lint/style/useNamingConvention: api schema\n    system_prompt: \"\",\n    // biome-ignore lint/style/useNamingConvention: api schema\n    code_lang: llmType === \"code\" ? \"python\" : undefined,\n    // biome-ignore lint/style/useNamingConvention: api schema\n    output_format:\n      llmType === \"structured\" ? '{\\n  \"field\": \"string\"\\n}' : undefined,\n    // biome-ignore lint/style/useNamingConvention: api schema\n    tool_alias: \"\",\n    // biome-ignore lint/style/useNamingConvention: ui schema\n    image_context: {\n      enabled: false,\n      // biome-ignore lint/style/useNamingConvention: api schema\n      column_name: \"\",\n    },\n    // biome-ignore lint/style/useNamingConvention: api schema\n    with_trace: \"none\",\n    // biome-ignore lint/style/useNamingConvention: api schema\n    extract_reasoning_content: false,\n    scores: llmType === \"judge\" ? [] : undefined,\n  };\n}\n\nexport function makeModelProviderConfig(\n  id: string,\n  existing: NodeConfig[],\n): ModelProviderConfig {\n  return {\n    id,\n    kind: \"model_provider\",\n    name: nextName(existing, \"provider\"),\n    endpoint: \"\",\n    // biome-ignore lint/style/useNamingConvention: api schema\n    provider_type: \"openai\",\n    // biome-ignore lint/style/useNamingConvention: api schema\n    api_key_env: \"\",\n    // biome-ignore lint/style/useNamingConvention: api schema\n    api_key: \"\",\n    // biome-ignore lint/style/useNamingConvention: api schema\n    extra_headers: \"\",\n    // biome-ignore lint/style/useNamingConvention: api schema\n    extra_body: \"\",\n  };\n}\n\nexport function makeModelConfig(\n  id: string,\n  existing: NodeConfig[],\n): ModelConfig {\n  return {\n    id,\n    kind: \"model_config\",\n    name: nextName(existing, \"model\"),\n    model: \"\",\n    provider: \"\",\n    // biome-ignore lint/style/useNamingConvention: api schema\n    inference_temperature: \"0.7\",\n    // biome-ignore lint/style/useNamingConvention: api schema\n    inference_max_tokens: \"\",\n    // biome-ignore lint/style/useNamingConvention: api schema\n    inference_top_p: \"\",\n    // biome-ignore lint/style/useNamingConvention: api schema\n    inference_timeout: \"\",\n    // biome-ignore lint/style/useNamingConvention: api schema\n    inference_extra_body: \"\",\n    // biome-ignore lint/style/useNamingConvention: api schema\n    skip_health_check: false,\n  };\n}\n\nexport function makeToolProfileConfig(\n  id: string,\n  existing: NodeConfig[],\n): ToolProfileConfig {\n  return {\n    id,\n    kind: \"tool_config\",\n    name: nextName(existing, \"tools\"),\n    // biome-ignore lint/style/useNamingConvention: ui schema\n    mcp_providers: [],\n    // biome-ignore lint/style/useNamingConvention: ui schema\n    fetched_tools_by_provider: {},\n    // biome-ignore lint/style/useNamingConvention: api schema\n    allow_tools: [],\n    // biome-ignore lint/style/useNamingConvention: api schema\n    max_tool_call_turns: \"5\",\n    // biome-ignore lint/style/useNamingConvention: api schema\n    timeout_sec: \"\",\n  };\n}\n\nexport function makeExpressionConfig(\n  id: string,\n  existing: NodeConfig[],\n): ExpressionConfig {\n  return {\n    id,\n    kind: \"expression\",\n    name: nextName(existing, \"expr\"),\n    drop: false,\n    expr: \"\",\n    dtype: \"str\",\n  };\n}\n\nexport function makeValidatorConfig(\n  id: string,\n  validatorType: ValidatorType,\n  codeLang: ValidatorCodeLang,\n  existing: NodeConfig[],\n): ValidatorConfig {\n  const isSql = validatorType === \"code\" && codeLang.startsWith(\"sql:\");\n  const isOxc = validatorType === \"oxc\";\n  let namePrefix = \"validator_python\";\n  if (isSql) {\n    namePrefix = \"validator_sql\";\n  } else if (isOxc) {\n    namePrefix = \"validator_oxc\";\n  }\n  return {\n    id,\n    kind: \"validator\",\n    name: nextName(existing, namePrefix),\n    drop: false,\n    // biome-ignore lint/style/useNamingConvention: api schema\n    target_columns: [],\n    validator_type: validatorType,\n    // biome-ignore lint/style/useNamingConvention: api schema\n    code_lang: codeLang,\n    oxc_validation_mode: \"syntax\",\n    oxc_code_shape: \"auto\",\n    batch_size: \"10\",\n  };\n}\n\nexport function makeMarkdownNoteConfig(\n  id: string,\n  existing: NodeConfig[],\n): MarkdownNoteConfig {\n  return {\n    id,\n    kind: \"markdown_note\",\n    name: nextName(existing, \"note\"),\n    markdown: \"## Note\\n\\nAdd markdown here.\",\n    note_color: \"#FDE68A\",\n    note_opacity: \"35\",\n  };\n}\n\nexport function makeSeedConfig(\n  id: string,\n  existing: NodeConfig[],\n  seedSourceType: SeedSourceType = \"hf\",\n): SeedConfig {\n  return {\n    id,\n    kind: \"seed\",\n    name: nextName(existing, \"seed\"),\n    drop: false,\n    seed_drop_columns: [],\n    seed_source_type: seedSourceType,\n    hf_repo_id: \"\",\n    hf_subset: \"\",\n    hf_split: \"\",\n    hf_path: \"\",\n    hf_token: \"\",\n    hf_endpoint: \"https://huggingface.co\",\n    local_file_name: \"\",\n    unstructured_file_name: \"\",\n    seed_preview_rows: [],\n    unstructured_chunk_size: \"1200\",\n    unstructured_chunk_overlap: \"200\",\n    seed_splits: [],\n    seed_globs_by_split: {},\n    seed_columns: [],\n    sampling_strategy: \"ordered\",\n    selection_type: \"none\",\n    selection_start: \"0\",\n    selection_end: \"10\",\n    selection_index: \"0\",\n    selection_num_partitions: \"1\",\n  };\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/config-labels.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type {\n  ExpressionDtype,\n  LlmType,\n  SamplerType,\n} from \"../types\";\n\nconst SAMPLER_LABELS: Record<SamplerType, string> = {\n  category: \"Category\",\n  subcategory: \"Subcategory\",\n  uniform: \"Random number\",\n  gaussian: \"Bell-curve number\",\n  bernoulli: \"Yes/no value\",\n  datetime: \"Date and time\",\n  timedelta: \"Time offset\",\n  uuid: \"Unique ID\",\n  person: \"Synthetic person\",\n  person_from_faker: \"Synthetic person\",\n};\n\nconst LLM_LABELS: Record<LlmType, string> = {\n  text: \"AI text\",\n  structured: \"AI structured data\",\n  code: \"AI code\",\n  judge: \"AI scorer\",\n};\n\nconst EXPRESSION_LABELS: Record<ExpressionDtype, string> = {\n  str: \"Text\",\n  int: \"Int\",\n  float: \"Float\",\n  bool: \"Bool\",\n};\n\nexport function labelForSampler(type: SamplerType): string {\n  return SAMPLER_LABELS[type] ?? \"Generated field\";\n}\n\nexport function labelForLlm(type: LlmType): string {\n  return LLM_LABELS[type] ?? \"AI\";\n}\n\nexport function labelForExpression(type: ExpressionDtype): string {\n  return EXPRESSION_LABELS[type] ?? \"Formula\";\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/config-type-guards.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type {\n  ExpressionConfig,\n  LlmConfig,\n  NodeConfig,\n  SamplerConfig,\n  ValidatorConfig,\n} from \"../types\";\n\nexport function isSamplerConfig(\n  config: NodeConfig | null | undefined,\n): config is SamplerConfig {\n  return Boolean(config && config.kind === \"sampler\");\n}\n\nexport function isCategoryConfig(\n  config: NodeConfig | null | undefined,\n): config is SamplerConfig {\n  return Boolean(\n    config && config.kind === \"sampler\" && config.sampler_type === \"category\",\n  );\n}\n\nexport function isSubcategoryConfig(\n  config: NodeConfig | null | undefined,\n): config is SamplerConfig {\n  return Boolean(\n    config &&\n      config.kind === \"sampler\" &&\n      config.sampler_type === \"subcategory\",\n  );\n}\n\nexport function isLlmConfig(\n  config: NodeConfig | null | undefined,\n): config is LlmConfig {\n  return Boolean(config && config.kind === \"llm\");\n}\n\nexport function isExpressionConfig(\n  config: NodeConfig | null | undefined,\n): config is ExpressionConfig {\n  return Boolean(config && config.kind === \"expression\");\n}\n\nexport function isValidatorConfig(\n  config: NodeConfig | null | undefined,\n): config is ValidatorConfig {\n  return Boolean(config && config.kind === \"validator\");\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/graph/derive-display-graph.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { Edge, Node, XYPosition } from \"@xyflow/react\";\nimport type { RecipeGraphAuxNodeData } from \"../../components/recipe-graph-aux-node\";\nimport { DEFAULT_NODE_HEIGHT, DEFAULT_NODE_WIDTH } from \"../../constants\";\nimport type { RecipeNode, LayoutDirection, NodeConfig } from \"../../types\";\nimport {\n  getDefaultDataSourceHandle,\n  getDefaultDataTargetHandle,\n  getDefaultSemanticSourceHandle,\n  getDefaultSemanticTargetHandle,\n  HANDLE_IDS,\n  isDataSourceHandle,\n  isDataTargetHandle,\n  isSemanticSourceHandle,\n  isSemanticTargetHandle,\n  normalizeRecipeHandleId,\n} from \"../handles\";\nimport { readNodeHeight, readNodeWidth } from \"../rf-node-dimensions\";\nimport { isSemanticRelation } from \"./relations\";\n\ntype DisplayGraphInput = {\n  nodes: RecipeNode[];\n  edges: Edge[];\n  configs: Record<string, NodeConfig>;\n  layoutDirection: LayoutDirection;\n  auxNodePositions: Record<string, XYPosition>;\n  llmAuxVisibility: Record<string, boolean>;\n  runtime?: {\n    runningNodeId: string | null;\n    doneNodeIds: Set<string>;\n    activeEdgeIds: Set<string>;\n    executionLocked: boolean;\n  };\n};\n\nexport type DisplayGraph = {\n  nodes: Array<Node<RecipeNode[\"data\"] | RecipeGraphAuxNodeData>>;\n  edges: Edge[];\n};\n\nfunction isAuxEdge(edge: Edge): boolean {\n  return edge.source.startsWith(\"aux-\") || edge.target.startsWith(\"aux-\");\n}\n\nfunction normalizeEdge(\n  edge: Edge,\n  configs: Record<string, NodeConfig>,\n  layoutDirection: LayoutDirection,\n  activeEdgeIds: Set<string>,\n  runningNodeId: string | null,\n  doneNodeIds: Set<string>,\n): Edge {\n  const isActiveByRuntimeTarget =\n    Boolean(runningNodeId) &&\n    edge.target === runningNodeId &&\n    !isAuxEdge(edge);\n  const isActiveEdge = activeEdgeIds.has(edge.id) || isActiveByRuntimeTarget;\n  const isAux = isAuxEdge(edge);\n  if (isAux) {\n    return {\n      ...edge,\n      type: \"canvas\",\n      data: { ...(edge.data ?? {}), path: \"smoothstep\", active: isActiveEdge },\n      animated: isActiveEdge,\n    };\n  }\n\n  const isActiveReversedRuntimeEdge =\n    Boolean(runningNodeId) &&\n    isActiveEdge &&\n    edge.source === runningNodeId &&\n    doneNodeIds.has(edge.target);\n  const displayEdge = isActiveReversedRuntimeEdge\n    ? {\n        ...edge,\n        source: edge.target,\n        target: edge.source,\n        sourceHandle: getDefaultDataSourceHandle(layoutDirection),\n        targetHandle: getDefaultDataTargetHandle(layoutDirection),\n      }\n    : edge;\n\n  const source = configs[displayEdge.source];\n  const target = configs[displayEdge.target];\n  const semantic =\n    displayEdge.type === \"semantic\" ||\n    (Boolean(source && target) && isSemanticRelation(source, target));\n  const sourceHandleNormalized = normalizeRecipeHandleId(displayEdge.sourceHandle);\n  const targetHandleNormalized = normalizeRecipeHandleId(displayEdge.targetHandle);\n  const semanticSourceDefault =\n    source?.kind === \"llm\"\n      ? getDefaultDataSourceHandle(layoutDirection)\n      : getDefaultSemanticSourceHandle(layoutDirection);\n  const semanticTargetDefault =\n    target?.kind === \"llm\"\n      ? getDefaultDataTargetHandle(layoutDirection)\n      : getDefaultSemanticTargetHandle(layoutDirection);\n  let sourceHandle = getDefaultDataSourceHandle(layoutDirection);\n  let targetHandle = getDefaultDataTargetHandle(layoutDirection);\n\n  if (semantic) {\n    sourceHandle =\n      isSemanticSourceHandle(sourceHandleNormalized) ||\n      isDataSourceHandle(sourceHandleNormalized)\n        ? sourceHandleNormalized ?? semanticSourceDefault\n        : semanticSourceDefault;\n    targetHandle =\n      isSemanticTargetHandle(targetHandleNormalized) ||\n      isDataTargetHandle(targetHandleNormalized)\n        ? targetHandleNormalized ?? semanticTargetDefault\n        : semanticTargetDefault;\n    // LLM nodes only expose data lane handles; coerce legacy semantic handles.\n    if (source?.kind === \"llm\" && isSemanticSourceHandle(sourceHandle)) {\n      sourceHandle = semanticSourceDefault;\n    }\n    if (target?.kind === \"llm\" && isSemanticTargetHandle(targetHandle)) {\n      targetHandle = semanticTargetDefault;\n    }\n  } else {\n    sourceHandle = isDataSourceHandle(sourceHandleNormalized)\n      ? sourceHandleNormalized ?? getDefaultDataSourceHandle(layoutDirection)\n      : getDefaultDataSourceHandle(layoutDirection);\n    targetHandle = isDataTargetHandle(targetHandleNormalized)\n      ? targetHandleNormalized ?? getDefaultDataTargetHandle(layoutDirection)\n      : getDefaultDataTargetHandle(layoutDirection);\n  }\n\n  return {\n    ...displayEdge,\n    type: semantic ? \"semantic\" : \"canvas\",\n    data: semantic\n      ? { ...(displayEdge.data ?? {}), active: isActiveEdge }\n      : { ...(displayEdge.data ?? {}), path: \"smoothstep\", active: isActiveEdge },\n    sourceHandle,\n    targetHandle,\n    animated: isActiveEdge,\n  };\n}\n\ntype AuxNodeItem = {\n  key: string;\n  data: RecipeGraphAuxNodeData;\n};\n\ntype Rect = {\n  x: number;\n  y: number;\n  width: number;\n  height: number;\n};\n\nfunction toRect(\n  position: XYPosition,\n  width: number,\n  height: number,\n): Rect {\n  return {\n    x: position.x,\n    y: position.y,\n    width,\n    height,\n  };\n}\n\nfunction intersects(a: Rect, b: Rect, pad = 18): boolean {\n  return !(\n    a.x + a.width + pad <= b.x ||\n    b.x + b.width + pad <= a.x ||\n    a.y + a.height + pad <= b.y ||\n    b.y + b.height + pad <= a.y\n  );\n}\n\nfunction findNonOverlappingPosition(\n  preferred: XYPosition,\n  width: number,\n  height: number,\n  occupied: Rect[],\n): XYPosition {\n  const step = 24;\n  for (let ring = 0; ring <= 10; ring += 1) {\n    for (let dx = -ring; dx <= ring; dx += 1) {\n      for (let dy = -ring; dy <= ring; dy += 1) {\n        if (ring > 0 && Math.max(Math.abs(dx), Math.abs(dy)) !== ring) {\n          continue;\n        }\n        const candidate = {\n          x: preferred.x + dx * step,\n          y: preferred.y + dy * step,\n        };\n        const rect = toRect(candidate, width, height);\n        if (!occupied.some((other) => intersects(rect, other))) {\n          return candidate;\n        }\n      }\n    }\n  }\n  return preferred;\n}\n\ntype HandleSide = \"left\" | \"right\" | \"top\" | \"bottom\";\n\nconst SIDE_TO_TARGET_HANDLE: Record<HandleSide, string> = {\n  left: HANDLE_IDS.dataIn,\n  right: HANDLE_IDS.dataInRight,\n  top: HANDLE_IDS.dataInTop,\n  bottom: HANDLE_IDS.dataInBottom,\n};\n\nfunction getTargetSide(\n  handleId: string | null | undefined,\n  direction: LayoutDirection,\n): HandleSide {\n  const normalized = normalizeRecipeHandleId(handleId);\n  if (!normalized) {\n    return direction === \"TB\" ? \"top\" : \"left\";\n  }\n  if (\n    normalized === HANDLE_IDS.dataInRight ||\n    normalized === HANDLE_IDS.semanticInRight\n  ) {\n    return \"right\";\n  }\n  if (\n    normalized === HANDLE_IDS.dataInBottom ||\n    normalized === HANDLE_IDS.semanticInBottom\n  ) {\n    return \"bottom\";\n  }\n  if (\n    normalized === HANDLE_IDS.dataInTop ||\n    normalized === HANDLE_IDS.semanticInTop\n  ) {\n    return \"top\";\n  }\n  return \"left\";\n}\n\nfunction getSourceSide(\n  handleId: string | null | undefined,\n  direction: LayoutDirection,\n): HandleSide {\n  const normalized = normalizeRecipeHandleId(handleId);\n  if (!normalized) {\n    return direction === \"TB\" ? \"bottom\" : \"right\";\n  }\n  if (\n    normalized === HANDLE_IDS.dataOutLeft ||\n    normalized === HANDLE_IDS.semanticOutLeft\n  ) {\n    return \"left\";\n  }\n  if (\n    normalized === HANDLE_IDS.dataOutTop ||\n    normalized === HANDLE_IDS.semanticOutTop\n  ) {\n    return \"top\";\n  }\n  if (\n    normalized === HANDLE_IDS.dataOutBottom ||\n    normalized === HANDLE_IDS.semanticOutBottom\n  ) {\n    return \"bottom\";\n  }\n  return \"right\";\n}\n\nfunction pickAuxTargetHandle(\n  llmId: string,\n  direction: LayoutDirection,\n  edges: Edge[],\n): string {\n  const occupied = new Set<HandleSide>();\n  for (const edge of edges) {\n    if (isAuxEdge(edge)) {\n      continue;\n    }\n    if (edge.target === llmId) {\n      occupied.add(getTargetSide(edge.targetHandle, direction));\n    }\n    if (edge.source === llmId) {\n      occupied.add(getSourceSide(edge.sourceHandle, direction));\n    }\n  }\n\n  const priority: HandleSide[] =\n    direction === \"LR\"\n      ? [\"left\", \"right\", \"bottom\", \"top\"]\n      : [\"top\", \"bottom\", \"right\", \"left\"];\n  for (const side of priority) {\n    if (!occupied.has(side)) {\n      return SIDE_TO_TARGET_HANDLE[side];\n    }\n  }\n\n  const fallback: HandleSide = direction === \"LR\" ? \"bottom\" : \"right\";\n  return SIDE_TO_TARGET_HANDLE[fallback];\n}\n\nfunction getHandleSideFromTargetHandle(targetHandle: string): HandleSide {\n  if (targetHandle === HANDLE_IDS.dataInRight) {\n    return \"right\";\n  }\n  if (targetHandle === HANDLE_IDS.dataInTop) {\n    return \"top\";\n  }\n  if (targetHandle === HANDLE_IDS.dataInBottom) {\n    return \"bottom\";\n  }\n  return \"left\";\n}\n\nfunction pickAuxSourceHandle(\n  auxPosition: XYPosition,\n  auxWidth: number,\n  auxHeight: number,\n  llmPosition: XYPosition,\n  llmWidth: number,\n  llmHeight: number,\n): string {\n  const auxCenter = {\n    x: auxPosition.x + auxWidth / 2,\n    y: auxPosition.y + auxHeight / 2,\n  };\n  const llmCenter = {\n    x: llmPosition.x + llmWidth / 2,\n    y: llmPosition.y + llmHeight / 2,\n  };\n  const dx = llmCenter.x - auxCenter.x;\n  const dy = llmCenter.y - auxCenter.y;\n\n  if (Math.abs(dx) >= Math.abs(dy)) {\n    return dx >= 0 ? HANDLE_IDS.llmInputOutRight : HANDLE_IDS.llmInputOutLeft;\n  }\n  return dy >= 0 ? HANDLE_IDS.llmInputOutBottom : HANDLE_IDS.llmInputOutTop;\n}\n\ntype AppendAuxNodeAndEdgeInput = {\n  auxNodes: Node<RecipeGraphAuxNodeData>[];\n  auxEdges: Edge[];\n  entry: {\n    item: AuxNodeItem;\n    auxId: string;\n    width: number;\n    height: number;\n  };\n  position: XYPosition;\n  parentNode: Node<RecipeNode[\"data\"] | RecipeGraphAuxNodeData>;\n  parentWidth: number;\n  parentHeight: number;\n  auxTargetHandle: string;\n};\n\nfunction appendAuxNodeAndEdge({\n  auxNodes,\n  auxEdges,\n  entry,\n  position,\n  parentNode,\n  parentWidth,\n  parentHeight,\n  auxTargetHandle,\n}: AppendAuxNodeAndEdgeInput): void {\n  auxNodes.push({\n    id: entry.auxId,\n    type: \"aux\",\n    data: entry.item.data,\n    position,\n    width: entry.width,\n    height: entry.height,\n    style: {\n      width: entry.width,\n      height: entry.height,\n    },\n    draggable: true,\n    selectable: true,\n    focusable: true,\n    connectable: false,\n  });\n\n  auxEdges.push({\n    id: `e-${entry.auxId}-${parentNode.id}`,\n    source: entry.auxId,\n    sourceHandle: pickAuxSourceHandle(\n      position,\n      entry.width,\n      entry.height,\n      parentNode.position,\n      parentWidth,\n      parentHeight,\n    ),\n    target: parentNode.id,\n    targetHandle: auxTargetHandle,\n    type: \"canvas\",\n    data: { path: \"auto\" },\n    selectable: false,\n    focusable: false,\n  });\n}\n\nexport function deriveDisplayGraph({\n  nodes,\n  edges,\n  configs,\n  layoutDirection,\n  auxNodePositions,\n  llmAuxVisibility,\n  runtime,\n}: DisplayGraphInput): DisplayGraph {\n  const executionLocked = runtime?.executionLocked ?? false;\n  const runningNodeId = runtime?.runningNodeId ?? null;\n  const doneNodeIds = runtime?.doneNodeIds ?? new Set<string>();\n  const activeEdgeIds = runtime?.activeEdgeIds ?? new Set<string>();\n  const displayNodes = nodes.map((node) => {\n    const hasWidth =\n      typeof node.width === \"number\" ||\n      typeof node.style?.width === \"number\" ||\n      (typeof node.style?.width === \"string\" &&\n        Number.isFinite(Number.parseFloat(node.style.width)));\n    const runtimeState: \"idle\" | \"running\" | \"done\" =\n      node.id === runningNodeId\n        ? \"running\"\n        : doneNodeIds.has(node.id)\n          ? \"done\"\n          : \"idle\";\n    if (hasWidth) {\n      return {\n        ...node,\n        data: {\n          ...node.data,\n          runtimeState,\n          executionLocked,\n        },\n      };\n    }\n    return {\n      ...node,\n      data: {\n        ...node.data,\n        runtimeState,\n        executionLocked,\n      },\n      style: { ...node.style, width: DEFAULT_NODE_WIDTH },\n    };\n  });\n  const auxNodes: Node<RecipeGraphAuxNodeData>[] = [];\n  const auxEdges: Edge[] = [];\n  const occupiedRects: Rect[] = displayNodes.map((node) =>\n    toRect(\n      node.position,\n      readNodeWidth(node) ?? DEFAULT_NODE_WIDTH,\n      readNodeHeight(node) ?? DEFAULT_NODE_HEIGHT,\n    ),\n  );\n\n  for (const node of displayNodes) {\n    const config = configs[node.id];\n    if (!(config && config.kind === \"llm\")) {\n      continue;\n    }\n    if (!llmAuxVisibility[config.id]) {\n      continue;\n    }\n    const llmDirection = node.data.layoutDirection ?? layoutDirection;\n    const auxTargetHandle = pickAuxTargetHandle(node.id, llmDirection, edges);\n    const auxTargetSide = getHandleSideFromTargetHandle(auxTargetHandle);\n    const items: AuxNodeItem[] = [];\n\n    if (config.system_prompt.trim()) {\n      items.push({\n        key: \"system\",\n        data: {\n          kind: \"llm-prompt-input\",\n          llmId: config.id,\n          field: \"system_prompt\",\n          title: \"System Prompt\",\n          executionLocked,\n        },\n      });\n    }\n\n    if (config.prompt.trim()) {\n      items.push({\n        key: \"prompt\",\n        data: {\n          kind: \"llm-prompt-input\",\n          llmId: config.id,\n          field: \"prompt\",\n          title: \"Prompt\",\n          executionLocked,\n        },\n      });\n    }\n\n    if (config.llm_type === \"judge\") {\n      (config.scores ?? []).forEach((_score, scoreIndex) => {\n        items.push({\n          key: `score-${scoreIndex}`,\n          data: {\n            kind: \"llm-judge-score\",\n            llmId: config.id,\n            scoreIndex,\n            executionLocked,\n          },\n        });\n      });\n    }\n\n    if (items.length === 0) {\n      continue;\n    }\n\n    const parentWidth = readNodeWidth(node) ?? DEFAULT_NODE_WIDTH;\n    const parentHeight = readNodeHeight(node) ?? DEFAULT_NODE_HEIGHT;\n    const itemsWithLayout = items.map((item) => {\n      const auxId = `aux-${node.id}-${item.key}`;\n      return {\n        item,\n        auxId,\n        width: DEFAULT_NODE_WIDTH,\n        height: DEFAULT_NODE_HEIGHT,\n      };\n    });\n\n    const gap = 24;\n    const sideOffset = 48;\n    const stackHorizontal =\n      auxTargetSide === \"top\" || auxTargetSide === \"bottom\";\n\n    if (stackHorizontal) {\n      const totalWidth =\n        itemsWithLayout.reduce((sum, entry) => sum + entry.width, 0) +\n        (itemsWithLayout.length - 1) * gap;\n      const startX = node.position.x + (parentWidth - totalWidth) / 2;\n      let xCursor = startX;\n\n      for (const entry of itemsWithLayout) {\n        const preferredPosition = {\n          x: xCursor,\n          y:\n            auxTargetSide === \"top\"\n              ? node.position.y - entry.height - sideOffset\n              : node.position.y + parentHeight + sideOffset,\n        };\n        const defaultPosition = findNonOverlappingPosition(\n          preferredPosition,\n          entry.width,\n          entry.height,\n          occupiedRects,\n        );\n        const position = auxNodePositions[entry.auxId] ?? defaultPosition;\n        xCursor += entry.width + gap;\n\n        occupiedRects.push(toRect(position, entry.width, entry.height));\n        appendAuxNodeAndEdge({\n          auxNodes,\n          auxEdges,\n          entry,\n          position,\n          parentNode: node,\n          parentWidth,\n          parentHeight,\n          auxTargetHandle,\n        });\n      }\n      continue;\n    }\n\n    const totalHeight =\n      itemsWithLayout.reduce((sum, entry) => sum + entry.height, 0) +\n      (itemsWithLayout.length - 1) * gap;\n    const maxWidth = Math.max(...itemsWithLayout.map((entry) => entry.width));\n    const baseX =\n      auxTargetSide === \"right\"\n        ? node.position.x + parentWidth + sideOffset\n        : node.position.x - maxWidth - sideOffset;\n    let yCursor = node.position.y + (parentHeight - totalHeight) / 2;\n\n    for (const entry of itemsWithLayout) {\n      const preferredPosition = {\n        x: baseX + (maxWidth - entry.width),\n        y: yCursor,\n      };\n      const defaultPosition = findNonOverlappingPosition(\n        preferredPosition,\n        entry.width,\n        entry.height,\n        occupiedRects,\n      );\n      const position = auxNodePositions[entry.auxId] ?? defaultPosition;\n      yCursor += entry.height + gap;\n\n      occupiedRects.push(toRect(position, entry.width, entry.height));\n      appendAuxNodeAndEdge({\n        auxNodes,\n        auxEdges,\n        entry,\n        position,\n        parentNode: node,\n        parentWidth,\n        parentHeight,\n        auxTargetHandle,\n      });\n    }\n  }\n\n  return {\n    nodes: [...displayNodes, ...auxNodes],\n    edges: [...edges, ...auxEdges].map((edge) =>\n      normalizeEdge(\n        edge,\n        configs,\n        layoutDirection,\n        activeEdgeIds,\n        runningNodeId,\n        doneNodeIds,\n      ),\n    ),\n  };\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/graph/fit-view.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { Node } from \"@xyflow/react\";\n\nfunction isMarkdownNoteNode(node: Node): boolean {\n  if (node.type !== \"builder\") {\n    return false;\n  }\n  if (!node.data || typeof node.data !== \"object\") {\n    return false;\n  }\n  return (node.data as { kind?: string }).kind === \"note\";\n}\n\nexport function getFitNodeIdsIgnoringNotes(nodes: Node[]): Array<{ id: string }> {\n  const nodesWithoutNotes = nodes.filter((node) => !isMarkdownNoteNode(node));\n  const targetNodes = nodesWithoutNotes.length > 0 ? nodesWithoutNotes : nodes;\n  return targetNodes.map((node) => ({ id: node.id }));\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/graph/recipe-graph-connection.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { type Connection, type Edge, addEdge } from \"@xyflow/react\";\nimport type { LayoutDirection, NodeConfig, SamplerConfig } from \"../../types\";\nimport {\n  HANDLE_IDS,\n  isDataSourceHandle,\n  isDataTargetHandle,\n  isSemanticSourceHandle,\n  isSemanticTargetHandle,\n  normalizeRecipeHandleId,\n} from \"../handles\";\nimport { isSemanticRelation } from \"./relations\";\nimport {\n  isCategoryConfig,\n  isExpressionConfig,\n  isSubcategoryConfig,\n} from \"../index\";\nimport {\n  VALIDATOR_OXC_CODE_LANGS,\n  VALIDATOR_SQL_CODE_LANGS,\n} from \"../validators/code-lang\";\n\nfunction buildTemplateWithRef(template: string, ref: string): string {\n  if (template.includes(ref)) {\n    return template;\n  }\n  if (template.trim()) {\n    return `${template}\\n${ref}`;\n  }\n  return ref;\n}\n\nfunction syncSubcategoryMapping(\n  subcategory: SamplerConfig,\n  parent: NodeConfig,\n): SamplerConfig {\n  if (!isCategoryConfig(parent)) {\n    return {\n      ...subcategory,\n      // biome-ignore lint/style/useNamingConvention: api schema\n      subcategory_parent: parent.name,\n    };\n  }\n  const nextMapping: Record<string, string[]> = {\n    ...(subcategory.subcategory_mapping ?? {}),\n  };\n  for (const value of parent.values ?? []) {\n    if (!nextMapping[value]) {\n      nextMapping[value] = [];\n    }\n  }\n  return {\n    ...subcategory,\n    // biome-ignore lint/style/useNamingConvention: api schema\n    subcategory_parent: parent.name,\n    // biome-ignore lint/style/useNamingConvention: api schema\n    subcategory_mapping: nextMapping,\n  };\n}\n\nfunction isModelInfraNode(config: NodeConfig): boolean {\n  return (\n    config.kind === \"model_provider\" ||\n    config.kind === \"model_config\" ||\n    config.kind === \"tool_config\"\n  );\n}\n\nfunction isSemanticLane(connection: Connection): boolean {\n  return (\n    (isSemanticSourceHandle(connection.sourceHandle) ||\n      isDataSourceHandle(connection.sourceHandle)) &&\n    (isSemanticTargetHandle(connection.targetHandle) ||\n      isDataTargetHandle(connection.targetHandle))\n  );\n}\n\nfunction isDataLane(connection: Connection): boolean {\n  return (\n    isDataSourceHandle(connection.sourceHandle) &&\n    isDataTargetHandle(connection.targetHandle)\n  );\n}\n\ntype SingleRefRelation =\n  | \"provider\"\n  | \"model_alias\"\n  | \"tool_alias\"\n  | \"reference_column_name\"\n  | \"subcategory_parent\"\n  | \"validator_target_columns\";\n\nfunction getSingleRefRelation(\n  source: NodeConfig,\n  target: NodeConfig,\n): SingleRefRelation | null {\n  if (source.kind === \"model_provider\" && target.kind === \"model_config\") {\n    return \"provider\";\n  }\n  if (source.kind === \"model_config\" && target.kind === \"llm\") {\n    return \"model_alias\";\n  }\n  if (source.kind === \"tool_config\" && target.kind === \"llm\") {\n    return \"tool_alias\";\n  }\n  if (\n    source.kind === \"sampler\" &&\n    source.sampler_type === \"datetime\" &&\n    target.kind === \"sampler\" &&\n    target.sampler_type === \"timedelta\"\n  ) {\n    return \"reference_column_name\";\n  }\n  if (isCategoryConfig(source) && isSubcategoryConfig(target)) {\n    return \"subcategory_parent\";\n  }\n  if (\n    source.kind === \"llm\" &&\n    source.llm_type === \"code\" &&\n    target.kind === \"validator\"\n  ) {\n    return \"validator_target_columns\";\n  }\n  return null;\n}\n\nfunction isCompetingIncomingEdge(\n  edge: Edge,\n  targetId: string,\n  relation: SingleRefRelation,\n  configs: Record<string, NodeConfig>,\n): boolean {\n  if (edge.target !== targetId) {\n    return false;\n  }\n  const source = configs[edge.source];\n  if (!source) {\n    return false;\n  }\n  if (relation === \"provider\") {\n    return source.kind === \"model_provider\";\n  }\n  if (relation === \"model_alias\") {\n    return source.kind === \"model_config\";\n  }\n  if (relation === \"tool_alias\") {\n    return source.kind === \"tool_config\";\n  }\n  if (relation === \"subcategory_parent\") {\n    return isCategoryConfig(source);\n  }\n  if (relation === \"validator_target_columns\") {\n    return source.kind === \"llm\" && source.llm_type === \"code\";\n  }\n  return source.kind === \"sampler\" && source.sampler_type === \"datetime\";\n}\n\nfunction isModelSemanticRelation(source: NodeConfig, target: NodeConfig): boolean {\n  return (\n    (source.kind === \"model_provider\" && target.kind === \"model_config\") ||\n    (source.kind === \"model_config\" && target.kind === \"llm\") ||\n    (source.kind === \"tool_config\" && target.kind === \"llm\")\n  );\n}\n\nfunction canApplyCodeLangToValidator(\n  validator: Extract<NodeConfig, { kind: \"validator\" }>,\n  codeLang: string,\n): boolean {\n  const normalized = codeLang.trim();\n  if (!normalized) {\n    return false;\n  }\n  if (validator.validator_type === \"oxc\") {\n    return VALIDATOR_OXC_CODE_LANGS.includes(\n      normalized as typeof validator.code_lang,\n    );\n  }\n  if (normalized === \"python\") {\n    return true;\n  }\n  return VALIDATOR_SQL_CODE_LANGS.includes(normalized as typeof validator.code_lang);\n}\n\nfunction countHandleUsage(\n  edges: Edge[],\n  nodeId: string,\n  handleId: string,\n  lane: \"source\" | \"target\",\n): number {\n  return edges.reduce((count, edge) => {\n    const edgeNodeId = lane === \"source\" ? edge.source : edge.target;\n    if (edgeNodeId !== nodeId) {\n      return count;\n    }\n    const edgeHandleId =\n      lane === \"source\"\n        ? normalizeRecipeHandleId(edge.sourceHandle)\n        : normalizeRecipeHandleId(edge.targetHandle);\n    return edgeHandleId === handleId ? count + 1 : count;\n  }, 0);\n}\n\nfunction pickLeastUsedHandle(\n  candidates: string[],\n  requested: string | null,\n  usageFor: (handleId: string) => number,\n): string {\n  let bestHandle = candidates[0];\n  let bestCount = Number.POSITIVE_INFINITY;\n  const requestedNormalized = requested\n    ? normalizeRecipeHandleId(requested)\n    : null;\n\n  for (const candidate of candidates) {\n    const usage = usageFor(candidate);\n    if (usage < bestCount) {\n      bestHandle = candidate;\n      bestCount = usage;\n      continue;\n    }\n    if (usage === bestCount && requestedNormalized === candidate) {\n      bestHandle = candidate;\n    }\n  }\n\n  return bestHandle;\n}\n\nfunction chooseModelSemanticHandles(\n  connection: Connection,\n  source: NodeConfig,\n  target: NodeConfig,\n  edges: Edge[],\n  layoutDirection: LayoutDirection,\n): Connection {\n  if (!isModelSemanticRelation(source, target)) {\n    return connection;\n  }\n\n  const sourceCandidates =\n    source.kind === \"model_config\" && target.kind === \"llm\"\n      ? layoutDirection === \"TB\"\n        ? [HANDLE_IDS.semanticOut]\n        : [HANDLE_IDS.semanticOutBottom]\n      : layoutDirection === \"TB\"\n        ? [HANDLE_IDS.semanticOut, HANDLE_IDS.semanticOutBottom]\n        : [HANDLE_IDS.semanticOutBottom, HANDLE_IDS.semanticOut];\n  const targetCandidates =\n    target.kind === \"model_config\"\n      ? layoutDirection === \"TB\"\n        ? [HANDLE_IDS.semanticIn, HANDLE_IDS.semanticInTop]\n        : [HANDLE_IDS.semanticInTop, HANDLE_IDS.semanticIn]\n      : [\n          HANDLE_IDS.dataInTop,\n          HANDLE_IDS.dataInBottom,\n          HANDLE_IDS.dataIn,\n          HANDLE_IDS.dataInRight,\n        ];\n\n  const sourceHandle = pickLeastUsedHandle(\n    sourceCandidates,\n    connection.sourceHandle ?? null,\n    (handleId) => countHandleUsage(edges, source.id, handleId, \"source\"),\n  );\n  const targetHandle = pickLeastUsedHandle(\n    targetCandidates,\n    connection.targetHandle ?? null,\n    (handleId) => countHandleUsage(edges, target.id, handleId, \"target\"),\n  );\n\n  return {\n    ...connection,\n    sourceHandle,\n    targetHandle,\n  };\n}\n\nfunction normalizeValidatorSemanticConnection(\n  connection: Connection,\n  source: NodeConfig,\n  target: NodeConfig,\n): Connection {\n  if (\n    source.kind === \"validator\" &&\n    target.kind === \"llm\" &&\n    target.llm_type === \"code\"\n  ) {\n    return {\n      ...connection,\n      source: target.id,\n      target: source.id,\n      sourceHandle: HANDLE_IDS.dataOut,\n      targetHandle: HANDLE_IDS.dataIn,\n    };\n  }\n  return connection;\n}\n\nexport function isValidRecipeConnection(\n  connection: Connection,\n  configs: Record<string, NodeConfig>,\n): boolean {\n  if (!(connection.source && connection.target)) {\n    return false;\n  }\n  if (connection.source === connection.target) {\n    return false;\n  }\n  const source = configs[connection.source];\n  const target = configs[connection.target];\n  if (!(source && target)) {\n    return false;\n  }\n  const semanticRelation = isSemanticRelation(source, target);\n  if (semanticRelation) {\n    return isSemanticLane(connection);\n  }\n  if (isModelInfraNode(source) || isModelInfraNode(target)) {\n    return false;\n  }\n  return isDataLane(connection);\n}\n\nexport function applyRecipeConnection(\n  connection: Connection,\n  configs: Record<string, NodeConfig>,\n  edges: Edge[],\n  layoutDirection: LayoutDirection = \"LR\",\n): { edges: Edge[]; configs?: Record<string, NodeConfig> } {\n  if (!isValidRecipeConnection(connection, configs)) {\n    return { edges };\n  }\n  const initialSource = connection.source\n    ? configs[connection.source]\n    : null;\n  const initialTarget = connection.target\n    ? configs[connection.target]\n    : null;\n  if (!(initialSource && initialTarget)) {\n    return { edges };\n  }\n  const normalizedConnection = normalizeValidatorSemanticConnection(\n    connection,\n    initialSource,\n    initialTarget,\n  );\n  const source = normalizedConnection.source\n    ? configs[normalizedConnection.source]\n    : null;\n  const target = normalizedConnection.target\n    ? configs[normalizedConnection.target]\n    : null;\n  if (!(source && target)) {\n    return { edges };\n  }\n\n  const semanticRelation = isSemanticRelation(source, target);\n  const singleRefRelation = getSingleRefRelation(source, target);\n  if (\n    singleRefRelation === \"subcategory_parent\" &&\n    isSubcategoryConfig(target)\n  ) {\n    const currentParent = target.subcategory_parent?.trim() ?? \"\";\n    if (currentParent && currentParent !== source.name) {\n      return { edges };\n    }\n  }\n  const nextBaseEdges = singleRefRelation\n    ? edges.filter(\n        (edge) =>\n          !isCompetingIncomingEdge(edge, target.id, singleRefRelation, configs),\n      )\n    : edges;\n  const resolvedConnection = chooseModelSemanticHandles(\n    normalizedConnection,\n    source,\n    target,\n    nextBaseEdges,\n    layoutDirection,\n  );\n  const nextEdges = addEdge(\n    { ...resolvedConnection, type: semanticRelation ? \"semantic\" : \"canvas\" },\n    nextBaseEdges,\n  );\n  if (source.kind === \"model_provider\" && target.kind === \"model_config\") {\n    const next = { ...target, provider: source.name };\n    return { edges: nextEdges, configs: { ...configs, [target.id]: next } };\n  }\n  if (source.kind === \"model_config\" && target.kind === \"llm\") {\n    const next = { ...target, model_alias: source.name };\n    return { edges: nextEdges, configs: { ...configs, [target.id]: next } };\n  }\n  if (source.kind === \"tool_config\" && target.kind === \"llm\") {\n    const next = { ...target, tool_alias: source.name };\n    return { edges: nextEdges, configs: { ...configs, [target.id]: next } };\n  }\n  if (\n    source.kind === \"sampler\" &&\n    source.sampler_type === \"datetime\" &&\n    target.kind === \"sampler\" &&\n    target.sampler_type === \"timedelta\"\n  ) {\n    const next = {\n      ...target,\n      // biome-ignore lint/style/useNamingConvention: api schema\n      reference_column_name: source.name,\n    };\n    return { edges: nextEdges, configs: { ...configs, [target.id]: next } };\n  }\n  if (\n    source.kind === \"llm\" &&\n    source.llm_type === \"code\" &&\n    target.kind === \"validator\"\n  ) {\n    const nextCodeLang = (source.code_lang ?? \"\").trim();\n    const canUseCodeLangForTarget = canApplyCodeLangToValidator(\n      target,\n      nextCodeLang,\n    );\n    const next = {\n      ...target,\n      // biome-ignore lint/style/useNamingConvention: api schema\n      target_columns: [source.name],\n      // biome-ignore lint/style/useNamingConvention: api schema\n      code_lang:\n        (\n          canUseCodeLangForTarget ? nextCodeLang : target.code_lang\n        ) as typeof target.code_lang,\n    };\n    return { edges: nextEdges, configs: { ...configs, [target.id]: next } };\n  }\n  if (\n    isExpressionConfig(target) &&\n    !semanticRelation &&\n    source.kind !== \"seed\" &&\n    source.kind !== \"model_provider\" &&\n    source.kind !== \"model_config\" &&\n    source.kind !== \"validator\"\n  ) {\n    const ref = `{{ ${source.name} }}`;\n    const next = {\n      ...target,\n      expr: buildTemplateWithRef(target.expr ?? \"\", ref),\n    };\n    return { edges: nextEdges, configs: { ...configs, [target.id]: next } };\n  }\n  if (isSubcategoryConfig(target) && isCategoryConfig(source)) {\n    const next = syncSubcategoryMapping(target, source);\n    return { edges: nextEdges, configs: { ...configs, [target.id]: next } };\n  }\n  return { edges: nextEdges };\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/graph/relations.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { NodeConfig } from \"../../types\";\n\nexport function isSemanticRelation(\n  source: NodeConfig,\n  target: NodeConfig,\n): boolean {\n  if (source.kind === \"model_provider\" && target.kind === \"model_config\") {\n    return true;\n  }\n  if (source.kind === \"model_config\" && target.kind === \"llm\") {\n    return true;\n  }\n  if (source.kind === \"tool_config\" && target.kind === \"llm\") {\n    return true;\n  }\n  if (\n    source.kind === \"llm\" &&\n    source.llm_type === \"code\" &&\n    target.kind === \"validator\"\n  ) {\n    return true;\n  }\n  return (\n    source.kind === \"validator\" &&\n    target.kind === \"llm\" &&\n    target.llm_type === \"code\"\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/graph/runtime-visual-state.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { Edge } from \"@xyflow/react\";\nimport type {\n  RecipeExecutionBatch,\n  RecipeExecutionRecord,\n  RecipeExecutionStatus,\n} from \"../../execution-types\";\nimport type { NodeConfig } from \"../../types\";\nimport { extractRefs } from \"../refs\";\n\nconst ACTIVE_STATUSES: ReadonlySet<RecipeExecutionStatus> = new Set([\n  \"pending\",\n  \"running\",\n  \"active\",\n  \"cancelling\",\n]);\nconst FRESH_PENDING_WINDOW_MS = 60_000;\n\nconst DONE_UPSTREAM_KINDS: ReadonlySet<NodeConfig[\"kind\"]> = new Set([\n  \"sampler\",\n  \"seed\",\n  \"expression\",\n  \"llm\",\n  \"model_config\",\n  \"model_provider\",\n  \"tool_config\",\n]);\n\nexport type GraphRuntimeVisualState = {\n  executionLocked: boolean;\n  runningNodeId: string | null;\n  doneNodeIds: Set<string>;\n  activeEdgeIds: Set<string>;\n  batch: RecipeExecutionBatch | null;\n};\n\nfunction isAuxEdge(edge: Edge): boolean {\n  return edge.source.startsWith(\"aux-\") || edge.target.startsWith(\"aux-\");\n}\n\nfunction collectTemplateRefs(config: NodeConfig | null): Set<string> {\n  if (!config) {\n    return new Set();\n  }\n  const refs = new Set<string>();\n  if (config.kind === \"llm\") {\n    for (const ref of extractRefs(config.prompt ?? \"\")) {\n      refs.add(ref.trim());\n    }\n    for (const ref of extractRefs(config.system_prompt ?? \"\")) {\n      refs.add(ref.trim());\n    }\n    if (typeof config.output_format === \"string\") {\n      for (const ref of extractRefs(config.output_format)) {\n        refs.add(ref.trim());\n      }\n    }\n    return refs;\n  }\n  if (config.kind === \"expression\") {\n    for (const ref of extractRefs(config.expr ?? \"\")) {\n      refs.add(ref.trim());\n    }\n  }\n  return refs;\n}\n\nfunction isReversedRuntimeReferenceEdge(input: {\n  edge: Edge;\n  runningNodeId: string;\n  runningTemplateRefs: Set<string>;\n  configs: Record<string, NodeConfig>;\n}): boolean {\n  const { edge, runningNodeId, runningTemplateRefs, configs } = input;\n  if (edge.source !== runningNodeId) {\n    return false;\n  }\n  const targetName = configs[edge.target]?.name?.trim() ?? \"\";\n  return Boolean(targetName && runningTemplateRefs.has(targetName));\n}\n\nfunction hasLiveExecutionSignal(execution: RecipeExecutionRecord): boolean {\n  if (execution.lastEventId !== null) {\n    return true;\n  }\n  if (execution.current_column !== null) {\n    return true;\n  }\n  if (execution.progress !== null || execution.column_progress !== null) {\n    return true;\n  }\n  return Boolean(execution.batch?.idx ?? execution.batch?.total);\n}\n\nexport function pickLatestActiveExecution(\n  executions: RecipeExecutionRecord[],\n): RecipeExecutionRecord | null {\n  const now = Date.now();\n  for (const execution of executions) {\n    if (!ACTIVE_STATUSES.has(execution.status)) {\n      continue;\n    }\n    if (!execution.jobId) {\n      continue;\n    }\n    if (execution.finishedAt !== null) {\n      continue;\n    }\n\n    const liveSignal = hasLiveExecutionSignal(execution);\n    if (!liveSignal && execution.status === \"pending\") {\n      const ageMs = Math.max(0, now - execution.createdAt);\n      if (ageMs > FRESH_PENDING_WINDOW_MS) {\n        continue;\n      }\n    }\n    if (!liveSignal && execution.status !== \"pending\") {\n      continue;\n    }\n\n    return execution;\n  }\n  return null;\n}\n\nexport function deriveGraphRuntimeVisualState(input: {\n  activeExecution: RecipeExecutionRecord | null;\n  configs: Record<string, NodeConfig>;\n  edges: Edge[];\n}): GraphRuntimeVisualState {\n  const { activeExecution, configs, edges } = input;\n  if (!activeExecution) {\n    return {\n      executionLocked: false,\n      runningNodeId: null,\n      doneNodeIds: new Set(),\n      activeEdgeIds: new Set(),\n      batch: null,\n    };\n  }\n\n  const nameToNodeId = new Map<string, string>();\n  for (const config of Object.values(configs)) {\n    const name = config.name.trim();\n    if (!name) {\n      continue;\n    }\n    nameToNodeId.set(name, config.id);\n  }\n\n  const doneNodeIds = new Set<string>();\n  for (const columnName of activeExecution.completed_columns) {\n    const nodeId = nameToNodeId.get(columnName.trim());\n    if (nodeId) {\n      doneNodeIds.add(nodeId);\n    }\n  }\n\n  const runningNodeId = activeExecution.current_column\n    ? nameToNodeId.get(activeExecution.current_column.trim()) ?? null\n    : null;\n  if (runningNodeId) {\n    doneNodeIds.delete(runningNodeId);\n  }\n\n  const activeEdgeIds = new Set<string>();\n  if (runningNodeId) {\n    const runningConfig = configs[runningNodeId] ?? null;\n    const runningTemplateRefs = collectTemplateRefs(runningConfig);\n    for (const ref of runningTemplateRefs) {\n      const refNodeId = nameToNodeId.get(ref);\n      if (refNodeId && refNodeId !== runningNodeId) {\n        doneNodeIds.add(refNodeId);\n      }\n    }\n    for (const upstreamNodeId of collectUpstreamDoneNodeIds({\n      rootNodeId: runningNodeId,\n      edges,\n      configs,\n    })) {\n      doneNodeIds.add(upstreamNodeId);\n    }\n    for (const edge of edges) {\n      if (isAuxEdge(edge)) {\n        continue;\n      }\n      if (edge.target === runningNodeId) {\n        activeEdgeIds.add(edge.id);\n        continue;\n      }\n      if (\n        isReversedRuntimeReferenceEdge({\n          edge,\n          runningNodeId,\n          runningTemplateRefs,\n          configs,\n        })\n      ) {\n        activeEdgeIds.add(edge.id);\n      }\n    }\n  }\n\n  const batch =\n    activeExecution.batch &&\n    typeof activeExecution.batch.total === \"number\" &&\n    activeExecution.batch.total > 1\n      ? activeExecution.batch\n      : null;\n\n  return {\n    executionLocked: true,\n    runningNodeId,\n    doneNodeIds,\n    activeEdgeIds,\n    batch,\n  };\n}\n\nfunction collectUpstreamDoneNodeIds(input: {\n  rootNodeId: string;\n  edges: Edge[];\n  configs: Record<string, NodeConfig>;\n}): Set<string> {\n  const { rootNodeId, edges, configs } = input;\n  const incoming = new Map<string, string[]>();\n  for (const edge of edges) {\n    if (isAuxEdge(edge)) {\n      continue;\n    }\n    const list = incoming.get(edge.target) ?? [];\n    list.push(edge.source);\n    incoming.set(edge.target, list);\n  }\n\n  const visited = new Set<string>();\n  const queue = [rootNodeId];\n  let queueIndex = 0;\n  const doneNodeIds = new Set<string>();\n  while (queueIndex < queue.length) {\n    const current = queue[queueIndex];\n    queueIndex += 1;\n    if (!current || visited.has(current)) {\n      continue;\n    }\n    visited.add(current);\n    const sources = incoming.get(current) ?? [];\n    for (const sourceId of sources) {\n      if (!visited.has(sourceId)) {\n        queue.push(sourceId);\n      }\n      const config = configs[sourceId];\n      if (config && DONE_UPSTREAM_KINDS.has(config.kind)) {\n        doneNodeIds.add(sourceId);\n      }\n    }\n  }\n\n  return doneNodeIds;\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/graph-warnings.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { Edge } from \"@xyflow/react\";\nimport { INFRA_NODE_KINDS, type NodeConfig } from \"../types\";\n\nexport type GraphWarning = {\n  nodeId?: string;\n  nodeName?: string;\n  global?: boolean;\n  message: string;\n  severity: \"error\" | \"warning\";\n};\n\nfunction checkDataSourceRequired(allConfigs: NodeConfig[]): GraphWarning[] {\n  const hasLlm = allConfigs.some((c) => c.kind === \"llm\");\n  const hasDataSource = allConfigs.some(\n    (c) => c.kind === \"seed\" || c.kind === \"sampler\" || c.kind === \"expression\",\n  );\n  if (hasLlm && !hasDataSource) {\n    return [\n      {\n        global: true,\n        message:\n          \"Add a data source (seed, sampler, or expression) before LLM blocks can generate data.\",\n        severity: \"warning\",\n      },\n    ];\n  }\n  return [];\n}\n\nfunction checkLlmModelAlias(allConfigs: NodeConfig[]): GraphWarning[] {\n  const warnings: GraphWarning[] = [];\n  for (const config of allConfigs) {\n    if (config.kind === \"llm\" && !config.model_alias?.trim()) {\n      warnings.push({\n        nodeId: config.id,\n        nodeName: config.name,\n        message: \"Needs a model preset.\",\n        severity: \"error\",\n      });\n    }\n  }\n  return warnings;\n}\n\nfunction checkModelConfigProvider(allConfigs: NodeConfig[]): GraphWarning[] {\n  const warnings: GraphWarning[] = [];\n  for (const config of allConfigs) {\n    if (config.kind === \"model_config\" && !config.provider?.trim()) {\n      warnings.push({\n        nodeId: config.id,\n        nodeName: config.name,\n        message: \"Needs a provider connection.\",\n        severity: \"error\",\n      });\n    }\n  }\n  return warnings;\n}\n\nfunction checkSubcategoryParent(allConfigs: NodeConfig[]): GraphWarning[] {\n  const categoryNames = new Set(\n    allConfigs\n      .filter((c) => c.kind === \"sampler\" && c.sampler_type === \"category\")\n      .map((c) => c.name),\n  );\n  const warnings: GraphWarning[] = [];\n  for (const config of allConfigs) {\n    if (config.kind !== \"sampler\" || config.sampler_type !== \"subcategory\") {\n      continue;\n    }\n    if (!config.subcategory_parent?.trim()) {\n      warnings.push({\n        nodeId: config.id,\n        nodeName: config.name,\n        message: \"Needs a parent category block.\",\n        severity: \"error\",\n      });\n    } else if (!categoryNames.has(config.subcategory_parent)) {\n      warnings.push({\n        nodeId: config.id,\n        nodeName: config.name,\n        message: `Parent category \"${config.subcategory_parent}\" not found.`,\n        severity: \"error\",\n      });\n    }\n  }\n  return warnings;\n}\n\nfunction checkValidatorTargets(allConfigs: NodeConfig[]): GraphWarning[] {\n  const warnings: GraphWarning[] = [];\n  for (const config of allConfigs) {\n    if (\n      config.kind === \"validator\" &&\n      (!config.target_columns || config.target_columns.length === 0)\n    ) {\n      warnings.push({\n        nodeId: config.id,\n        nodeName: config.name,\n        message: \"Needs at least one target column.\",\n        severity: \"warning\",\n      });\n    }\n  }\n  return warnings;\n}\n\nfunction checkDisconnectedNodes(\n  allConfigs: NodeConfig[],\n  edges: Edge[],\n): GraphWarning[] {\n  const connectedIds = new Set<string>();\n  for (const edge of edges) {\n    connectedIds.add(edge.source);\n    connectedIds.add(edge.target);\n  }\n\n  const warnings: GraphWarning[] = [];\n  for (const config of allConfigs) {\n    if (config.kind === \"markdown_note\") {\n      continue;\n    }\n    if (connectedIds.has(config.id)) {\n      continue;\n    }\n\n    warnings.push({\n      nodeId: config.id,\n      nodeName: config.name,\n      message: \"This block has no connections.\",\n      severity: \"warning\",\n    });\n  }\n  return warnings;\n}\n\nfunction checkLlmMissingDataInput(\n  allConfigs: NodeConfig[],\n  edges: Edge[],\n): GraphWarning[] {\n  const configById = new Map(allConfigs.map((c) => [c.id, c]));\n\n  /** LLM IDs that have at least one non-infra pipeline edge. */\n  const llmWithPipelineEdge = new Set<string>();\n  for (const edge of edges) {\n    const sourceConfig = configById.get(edge.source);\n    const targetConfig = configById.get(edge.target);\n\n    if (\n      sourceConfig?.kind === \"llm\" &&\n      targetConfig &&\n      !INFRA_NODE_KINDS.has(targetConfig.kind)\n    ) {\n      llmWithPipelineEdge.add(sourceConfig.id);\n    }\n    if (\n      targetConfig?.kind === \"llm\" &&\n      sourceConfig &&\n      !INFRA_NODE_KINDS.has(sourceConfig.kind)\n    ) {\n      llmWithPipelineEdge.add(targetConfig.id);\n    }\n  }\n\n  const warnings: GraphWarning[] = [];\n  for (const config of allConfigs) {\n    if (config.kind !== \"llm\") {\n      continue;\n    }\n    if (llmWithPipelineEdge.has(config.id)) {\n      continue;\n    }\n\n    const hasAnyEdge = edges.some(\n      (e) => e.source === config.id || e.target === config.id,\n    );\n    if (!hasAnyEdge) {\n      continue; // already caught by checkDisconnectedNodes\n    }\n\n    warnings.push({\n      nodeId: config.id,\n      nodeName: config.name,\n      message: \"No data-pipeline connection — connect it to a source or downstream step.\",\n      severity: \"warning\",\n    });\n  }\n  return warnings;\n}\n\nexport function getGraphWarnings(\n  configs: Record<string, NodeConfig>,\n  edges: Edge[] = [],\n): GraphWarning[] {\n  const allConfigs = Object.values(configs);\n  return [\n    ...checkDataSourceRequired(allConfigs),\n    ...checkLlmModelAlias(allConfigs),\n    ...checkModelConfigProvider(allConfigs),\n    ...checkSubcategoryParent(allConfigs),\n    ...checkValidatorTargets(allConfigs),\n    ...checkDisconnectedNodes(allConfigs, edges),\n    ...checkLlmMissingDataInput(allConfigs, edges),\n  ];\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/graph.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nexport {\n  applyRecipeConnection,\n  isValidRecipeConnection,\n} from \"./graph/recipe-graph-connection\";\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/handle-layout.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nexport const NODE_HANDLE_CLASS =\n  \"pointer-events-auto !size-2.5 !border-border/80 !bg-muted shadow-sm hover:!border-primary/70 hover:!bg-primary/20\";\n\nexport const AUX_HANDLE_CLASS =\n  \"!size-2 !border-border/80 !bg-muted/80 shadow-sm\";\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/handles.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { Connection } from \"@xyflow/react\";\nimport type { LayoutDirection } from \"../types\";\n\nexport const HANDLE_IDS = {\n  // data flow lanes\n  dataIn: \"data-in\",\n  dataInTop: \"data-in-top\",\n  dataInRight: \"data-in-right\",\n  dataInBottom: \"data-in-bottom\",\n  dataOut: \"data-out\",\n  dataOutLeft: \"data-out-left\",\n  dataOutTop: \"data-out-top\",\n  dataOutBottom: \"data-out-bottom\",\n  // semantic dependency lanes\n  semanticIn: \"semantic-in\",\n  semanticInTop: \"semantic-in-top\",\n  semanticInRight: \"semantic-in-right\",\n  semanticInBottom: \"semantic-in-bottom\",\n  semanticInLeft: \"semantic-in-left\",\n  semanticOut: \"semantic-out\",\n  semanticOutLeft: \"semantic-out-left\",\n  semanticOutTop: \"semantic-out-top\",\n  semanticOutBottom: \"semantic-out-bottom\",\n  semanticOutRight: \"semantic-out-right\",\n  // llm prompt/scorer lanes\n  llmInputOutLeft: \"llm-input-out-left\",\n  llmInputOutRight: \"llm-input-out-right\",\n  llmInputOutTop: \"llm-input-out-top\",\n  llmInputOutBottom: \"llm-input-out-bottom\",\n} as const;\n\nexport type RecipeHandleId = (typeof HANDLE_IDS)[keyof typeof HANDLE_IDS];\n\nconst LEGACY_HANDLE_ALIAS_MAP: Record<string, string> = {\n  [HANDLE_IDS.semanticInLeft]: HANDLE_IDS.semanticIn,\n  [HANDLE_IDS.semanticOutRight]: HANDLE_IDS.semanticOut,\n};\n\nconst DATA_TARGET_HANDLES = new Set<string>([\n  HANDLE_IDS.dataIn,\n  HANDLE_IDS.dataInTop,\n  HANDLE_IDS.dataInRight,\n  HANDLE_IDS.dataInBottom,\n]);\n\nconst DATA_SOURCE_HANDLES = new Set<string>([\n  HANDLE_IDS.dataOut,\n  HANDLE_IDS.dataOutLeft,\n  HANDLE_IDS.dataOutTop,\n  HANDLE_IDS.dataOutBottom,\n]);\n\nconst SEMANTIC_TARGET_HANDLES = new Set<string>([\n  HANDLE_IDS.semanticIn,\n  HANDLE_IDS.semanticInTop,\n  HANDLE_IDS.semanticInRight,\n  HANDLE_IDS.semanticInBottom,\n  HANDLE_IDS.semanticInLeft,\n]);\n\nconst SEMANTIC_SOURCE_HANDLES = new Set<string>([\n  HANDLE_IDS.semanticOut,\n  HANDLE_IDS.semanticOutLeft,\n  HANDLE_IDS.semanticOutTop,\n  HANDLE_IDS.semanticOutBottom,\n  HANDLE_IDS.semanticOutRight,\n]);\n\nconst DATA_TARGET_HORIZONTAL_HANDLES = new Set<string>([\n  HANDLE_IDS.dataIn,\n  HANDLE_IDS.dataInRight,\n]);\n\nconst DATA_TARGET_VERTICAL_HANDLES = new Set<string>([\n  HANDLE_IDS.dataInTop,\n  HANDLE_IDS.dataInBottom,\n]);\n\nconst DATA_SOURCE_HORIZONTAL_HANDLES = new Set<string>([\n  HANDLE_IDS.dataOut,\n  HANDLE_IDS.dataOutLeft,\n]);\n\nconst DATA_SOURCE_VERTICAL_HANDLES = new Set<string>([\n  HANDLE_IDS.dataOutTop,\n  HANDLE_IDS.dataOutBottom,\n]);\n\nconst SEMANTIC_TARGET_HORIZONTAL_HANDLES = new Set<string>([\n  HANDLE_IDS.semanticIn,\n  HANDLE_IDS.semanticInRight,\n  HANDLE_IDS.semanticInLeft,\n]);\n\nconst SEMANTIC_TARGET_VERTICAL_HANDLES = new Set<string>([\n  HANDLE_IDS.semanticInTop,\n  HANDLE_IDS.semanticInBottom,\n]);\n\nconst SEMANTIC_SOURCE_HORIZONTAL_HANDLES = new Set<string>([\n  HANDLE_IDS.semanticOut,\n  HANDLE_IDS.semanticOutLeft,\n  HANDLE_IDS.semanticOutRight,\n]);\n\nconst SEMANTIC_SOURCE_VERTICAL_HANDLES = new Set<string>([\n  HANDLE_IDS.semanticOutTop,\n  HANDLE_IDS.semanticOutBottom,\n]);\n\nexport function normalizeRecipeHandleId(\n  handleId: string | null | undefined,\n): string | null {\n  if (!handleId) {\n    return null;\n  }\n  return LEGACY_HANDLE_ALIAS_MAP[handleId] ?? handleId;\n}\n\nexport function normalizeRecipeConnectionHandles(\n  connection: Connection,\n): Connection {\n  return {\n    ...connection,\n    sourceHandle: normalizeRecipeHandleId(connection.sourceHandle),\n    targetHandle: normalizeRecipeHandleId(connection.targetHandle),\n  };\n}\n\nfunction isKnownHandle(\n  handleId: string | null | undefined,\n  handles: Set<string>,\n): boolean {\n  if (!handleId) {\n    return false;\n  }\n  return handles.has(normalizeRecipeHandleId(handleId) ?? \"\");\n}\n\nfunction remapHandleForDirection(\n  handleId: string | null | undefined,\n  direction: LayoutDirection,\n  horizontalHandles: Set<string>,\n  verticalHandles: Set<string>,\n  defaultHandle: string,\n): string {\n  const normalizedHandleId = normalizeRecipeHandleId(handleId);\n  if (!normalizedHandleId) {\n    return defaultHandle;\n  }\n  if (direction === \"LR\") {\n    if (verticalHandles.has(normalizedHandleId)) {\n      return defaultHandle;\n    }\n    return normalizedHandleId;\n  }\n  if (horizontalHandles.has(normalizedHandleId)) {\n    return defaultHandle;\n  }\n  return normalizedHandleId;\n}\n\nexport function isDataTargetHandle(\n  handleId: string | null | undefined,\n): boolean {\n  return isKnownHandle(handleId, DATA_TARGET_HANDLES);\n}\n\nexport function isDataSourceHandle(\n  handleId: string | null | undefined,\n): boolean {\n  return isKnownHandle(handleId, DATA_SOURCE_HANDLES);\n}\n\nexport function isSemanticTargetHandle(\n  handleId: string | null | undefined,\n): boolean {\n  return isKnownHandle(handleId, SEMANTIC_TARGET_HANDLES);\n}\n\nexport function isSemanticSourceHandle(\n  handleId: string | null | undefined,\n): boolean {\n  return isKnownHandle(handleId, SEMANTIC_SOURCE_HANDLES);\n}\n\nexport function getDefaultDataTargetHandle(direction: LayoutDirection): string {\n  return direction === \"TB\" ? HANDLE_IDS.dataInTop : HANDLE_IDS.dataIn;\n}\n\nexport function getDefaultDataSourceHandle(direction: LayoutDirection): string {\n  return direction === \"TB\" ? HANDLE_IDS.dataOutBottom : HANDLE_IDS.dataOut;\n}\n\nexport function getDefaultSemanticTargetHandle(\n  direction: LayoutDirection,\n): string {\n  return direction === \"TB\" ? HANDLE_IDS.semanticInTop : HANDLE_IDS.semanticIn;\n}\n\nexport function getDefaultSemanticSourceHandle(\n  direction: LayoutDirection,\n): string {\n  return direction === \"TB\" ? HANDLE_IDS.semanticOutBottom : HANDLE_IDS.semanticOut;\n}\n\ntype RecipeEdgeHandles = {\n  sourceHandle?: string | null;\n  targetHandle?: string | null;\n  type?: string | null;\n};\n\nexport function remapRecipeEdgeHandlesForLayout(\n  edge: RecipeEdgeHandles,\n  direction: LayoutDirection,\n): { sourceHandle: string; targetHandle: string } {\n  const semantic =\n    edge.type === \"semantic\" ||\n    (isSemanticSourceHandle(edge.sourceHandle) &&\n      isSemanticTargetHandle(edge.targetHandle));\n  if (semantic) {\n    const sourceIsData = isDataSourceHandle(edge.sourceHandle);\n    const targetIsData = isDataTargetHandle(edge.targetHandle);\n    return {\n      sourceHandle: remapHandleForDirection(\n        edge.sourceHandle,\n        direction,\n        sourceIsData\n          ? DATA_SOURCE_HORIZONTAL_HANDLES\n          : SEMANTIC_SOURCE_HORIZONTAL_HANDLES,\n        sourceIsData\n          ? DATA_SOURCE_VERTICAL_HANDLES\n          : SEMANTIC_SOURCE_VERTICAL_HANDLES,\n        sourceIsData\n          ? getDefaultDataSourceHandle(direction)\n          : getDefaultSemanticSourceHandle(direction),\n      ),\n      targetHandle: remapHandleForDirection(\n        edge.targetHandle,\n        direction,\n        targetIsData\n          ? DATA_TARGET_HORIZONTAL_HANDLES\n          : SEMANTIC_TARGET_HORIZONTAL_HANDLES,\n        targetIsData\n          ? DATA_TARGET_VERTICAL_HANDLES\n          : SEMANTIC_TARGET_VERTICAL_HANDLES,\n        targetIsData\n          ? getDefaultDataTargetHandle(direction)\n          : getDefaultSemanticTargetHandle(direction),\n      ),\n    };\n  }\n  return {\n    sourceHandle: remapHandleForDirection(\n      edge.sourceHandle,\n      direction,\n      DATA_SOURCE_HORIZONTAL_HANDLES,\n      DATA_SOURCE_VERTICAL_HANDLES,\n      getDefaultDataSourceHandle(direction),\n    ),\n    targetHandle: remapHandleForDirection(\n      edge.targetHandle,\n      direction,\n      DATA_TARGET_HORIZONTAL_HANDLES,\n      DATA_TARGET_VERTICAL_HANDLES,\n      getDefaultDataTargetHandle(direction),\n    ),\n  };\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/image-preview.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nexport const MAX_IMAGE_PREVIEW_BYTES = 200 * 1024;\n\ntype PreviewImagePayload = {\n  type?: unknown;\n  mime?: unknown;\n  data?: unknown;\n};\n\ntype UnknownRecord = Record<string, unknown>;\n\nexport type ImagePreviewResult =\n  | { kind: \"ready\"; src: string }\n  | { kind: \"too_large\"; estimatedBytes: number };\n\nfunction normalizeBase64(value: string): string {\n  return value.replace(/\\s+/g, \"\");\n}\n\nfunction estimateBase64Bytes(base64: string): number {\n  const normalized = normalizeBase64(base64);\n  const padding = normalized.endsWith(\"==\")\n    ? 2\n    : normalized.endsWith(\"=\")\n      ? 1\n      : 0;\n  return Math.max(0, Math.floor((normalized.length * 3) / 4) - padding);\n}\n\nfunction inferMimeFromBase64(base64: string): string | null {\n  const normalized = normalizeBase64(base64);\n  if (normalized.startsWith(\"iVBORw0KGgo\")) {\n    return \"image/png\";\n  }\n  if (normalized.startsWith(\"/9j/\")) {\n    return \"image/jpeg\";\n  }\n  if (normalized.startsWith(\"R0lGOD\")) {\n    return \"image/gif\";\n  }\n  if (normalized.startsWith(\"UklGR\")) {\n    return \"image/webp\";\n  }\n  return null;\n}\n\nfunction isLikelyRawBase64Image(value: string): boolean {\n  const normalized = normalizeBase64(value);\n  if (normalized.length < 64) {\n    return false;\n  }\n  if (!/^[A-Za-z0-9+/=]+$/.test(normalized)) {\n    return false;\n  }\n  return inferMimeFromBase64(normalized) !== null;\n}\n\nfunction toDataUrlFromBase64(base64: string, mime: string): string {\n  return `data:${mime};base64,${normalizeBase64(base64)}`;\n}\n\nfunction isRecord(value: unknown): value is UnknownRecord {\n  return Boolean(value) && typeof value === \"object\" && !Array.isArray(value);\n}\n\nfunction isByteArray(value: unknown): value is number[] {\n  if (!Array.isArray(value) || value.length === 0) {\n    return false;\n  }\n  return value.every(\n    (item) => typeof item === \"number\" && Number.isInteger(item) && item >= 0 && item <= 255,\n  );\n}\n\nfunction byteArrayToBase64(bytes: number[]): string {\n  let binary = \"\";\n  const chunkSize = 0x8000;\n  for (let idx = 0; idx < bytes.length; idx += chunkSize) {\n    const chunk = bytes.slice(idx, idx + chunkSize);\n    binary += String.fromCharCode(...chunk);\n  }\n  return btoa(binary);\n}\n\nfunction resolveStringCandidate(\n  value: unknown,\n  maxBytes: number,\n): ImagePreviewResult | null {\n  if (typeof value !== \"string\") {\n    return null;\n  }\n  return resolveImagePreviewFromString(value, maxBytes);\n}\n\nfunction resolveImagePreviewFromString(\n  value: string,\n  maxBytes: number,\n): ImagePreviewResult | null {\n  const trimmed = value.trim();\n  if (!trimmed) {\n    return null;\n  }\n  if (trimmed.startsWith(\"http://\") || trimmed.startsWith(\"https://\")) {\n    return { kind: \"ready\", src: trimmed };\n  }\n  if (trimmed.startsWith(\"data:image/\")) {\n    const marker = \"base64,\";\n    const markerIdx = trimmed.indexOf(marker);\n    if (markerIdx < 0) {\n      return { kind: \"ready\", src: trimmed };\n    }\n    const encoded = trimmed.slice(markerIdx + marker.length);\n    const estimatedBytes = estimateBase64Bytes(encoded);\n    if (estimatedBytes > maxBytes) {\n      return { kind: \"too_large\", estimatedBytes };\n    }\n    return { kind: \"ready\", src: trimmed };\n  }\n  if (isLikelyRawBase64Image(trimmed)) {\n    const estimatedBytes = estimateBase64Bytes(trimmed);\n    if (estimatedBytes > maxBytes) {\n      return { kind: \"too_large\", estimatedBytes };\n    }\n    const mime = inferMimeFromBase64(trimmed) ?? \"image/png\";\n    return { kind: \"ready\", src: toDataUrlFromBase64(trimmed, mime) };\n  }\n  return null;\n}\n\nfunction resolveImagePayloadObject(value: unknown, maxBytes: number): ImagePreviewResult | null {\n  if (!isRecord(value)) {\n    return null;\n  }\n  const payload = value as PreviewImagePayload;\n  if (payload.type === \"image\" && typeof payload.data === \"string\") {\n    const mime = typeof payload.mime === \"string\" ? payload.mime : \"image/jpeg\";\n    const estimatedBytes = estimateBase64Bytes(payload.data);\n    if (estimatedBytes > maxBytes) {\n      return { kind: \"too_large\", estimatedBytes };\n    }\n    return {\n      kind: \"ready\",\n      src: toDataUrlFromBase64(payload.data, mime),\n    };\n  }\n\n  const imageUrl = value.image_url;\n  const directImageUrl = resolveStringCandidate(imageUrl, maxBytes);\n  if (directImageUrl !== null) {\n    return directImageUrl;\n  }\n  if (isRecord(imageUrl)) {\n    const nestedImageUrl = resolveStringCandidate(imageUrl.url, maxBytes);\n    if (nestedImageUrl !== null) {\n      return nestedImageUrl;\n    }\n  }\n\n  const scalarCandidates = [\n    value.url,\n    value.data,\n    value.bytes,\n    value.base64,\n    value.base64_image,\n    value.image,\n    value.path,\n  ];\n  for (const candidate of scalarCandidates) {\n    const resolved = resolveStringCandidate(candidate, maxBytes);\n    if (resolved !== null) {\n      return resolved;\n    }\n  }\n\n  if (isByteArray(value.bytes)) {\n    const resolved = resolveStringCandidate(byteArrayToBase64(value.bytes), maxBytes);\n    if (resolved !== null) {\n      return resolved;\n    }\n  }\n\n  if (isRecord(value.image)) {\n    return resolveImagePayloadObject(value.image, maxBytes);\n  }\n\n  return null;\n}\n\nexport function resolveImagePreview(\n  value: unknown,\n  maxBytes = MAX_IMAGE_PREVIEW_BYTES,\n): ImagePreviewResult | null {\n  const payloadPreview = resolveImagePayloadObject(value, maxBytes);\n  if (payloadPreview) {\n    return payloadPreview;\n  }\n\n  if (typeof value !== \"string\") {\n    return null;\n  }\n  return resolveImagePreviewFromString(value, maxBytes);\n}\n\nexport function isLikelyImageValue(value: unknown): boolean {\n  return resolveImagePreview(value, Number.POSITIVE_INFINITY) !== null;\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/import/edges.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { Edge } from \"@xyflow/react\";\nimport type { LayoutDirection, NodeConfig } from \"../../types\";\nimport {\n  getDefaultDataSourceHandle,\n  getDefaultDataTargetHandle,\n  getDefaultSemanticSourceHandle,\n  getDefaultSemanticTargetHandle,\n  isDataSourceHandle,\n  isDataTargetHandle,\n  isSemanticSourceHandle,\n  isSemanticTargetHandle,\n  normalizeRecipeHandleId,\n} from \"../handles\";\nimport { extractRefs } from \"./helpers\";\n\nfunction isSemanticConnection(source: NodeConfig, target: NodeConfig): boolean {\n  if (source.kind === \"model_provider\" && target.kind === \"model_config\") {\n    return true;\n  }\n  if (source.kind === \"model_config\" && target.kind === \"llm\") {\n    return true;\n  }\n  if (source.kind === \"tool_config\" && target.kind === \"llm\") {\n    return true;\n  }\n  if (\n    source.kind === \"llm\" &&\n    source.llm_type === \"code\" &&\n    target.kind === \"validator\"\n  ) {\n    return true;\n  }\n  return (\n    source.kind === \"validator\" &&\n    target.kind === \"llm\" &&\n    target.llm_type === \"code\"\n  );\n}\n\nexport function buildEdges(\n  configs: NodeConfig[],\n  nameToId: Map<string, string>,\n  uiEdges:\n    | Array<{\n        from: string;\n        to: string;\n        type?: string;\n        sourceHandle?: string;\n        targetHandle?: string;\n      }>\n    | null,\n  layoutDirection: LayoutDirection,\n): Edge[] {\n  const edges: Edge[] = [];\n  const seen = new Set<string>();\n  const configByName = new Map(configs.map((config) => [config.name, config]));\n  const addEdgeByName = (\n    from: string,\n    to: string,\n    sourceHandleInput?: string,\n    targetHandleInput?: string,\n  ): void => {\n    const sourceId = nameToId.get(from);\n    const targetId = nameToId.get(to);\n    if (!(sourceId && targetId)) {\n      return;\n    }\n    const key = `${sourceId}-${targetId}`;\n    if (seen.has(key)) {\n      return;\n    }\n    seen.add(key);\n    const source = configByName.get(from);\n    const target = configByName.get(to);\n    const isSemantic = Boolean(\n      source && target && isSemanticConnection(source, target),\n    );\n    const normalizedType = isSemantic ? \"semantic\" : \"canvas\";\n    const sourceHandleNormalized = normalizeRecipeHandleId(sourceHandleInput);\n    const targetHandleNormalized = normalizeRecipeHandleId(targetHandleInput);\n    const semanticSourceDefault =\n      source?.kind === \"llm\"\n        ? getDefaultDataSourceHandle(layoutDirection)\n        : getDefaultSemanticSourceHandle(layoutDirection);\n    const semanticTargetDefault =\n      target?.kind === \"llm\"\n        ? getDefaultDataTargetHandle(layoutDirection)\n        : getDefaultSemanticTargetHandle(layoutDirection);\n    let sourceHandle = getDefaultDataSourceHandle(layoutDirection);\n    let targetHandle = getDefaultDataTargetHandle(layoutDirection);\n\n    if (isSemantic) {\n      sourceHandle =\n        isSemanticSourceHandle(sourceHandleNormalized) ||\n        isDataSourceHandle(sourceHandleNormalized)\n          ? sourceHandleNormalized ?? semanticSourceDefault\n          : semanticSourceDefault;\n      targetHandle =\n        isSemanticTargetHandle(targetHandleNormalized) ||\n        isDataTargetHandle(targetHandleNormalized)\n          ? targetHandleNormalized ?? semanticTargetDefault\n          : semanticTargetDefault;\n    } else {\n      sourceHandle = isDataSourceHandle(sourceHandleNormalized)\n        ? sourceHandleNormalized ?? getDefaultDataSourceHandle(layoutDirection)\n        : getDefaultDataSourceHandle(layoutDirection);\n      targetHandle = isDataTargetHandle(targetHandleNormalized)\n        ? targetHandleNormalized ?? getDefaultDataTargetHandle(layoutDirection)\n        : getDefaultDataTargetHandle(layoutDirection);\n    }\n    edges.push({\n      id: `e-${key}`,\n      source: sourceId,\n      target: targetId,\n      type: normalizedType,\n      sourceHandle,\n      targetHandle,\n    });\n  };\n\n  if (uiEdges && uiEdges.length > 0) {\n    for (const edge of uiEdges) {\n      addEdgeByName(\n        edge.from,\n        edge.to,\n        edge.sourceHandle,\n        edge.targetHandle,\n      );\n    }\n    if (edges.length > 0) {\n      return edges;\n    }\n  }\n\n  for (const config of configs) {\n    if (config.kind === \"llm\") {\n      for (const ref of extractRefs(config.prompt ?? \"\")) {\n        addEdgeByName(ref, config.name);\n      }\n      for (const ref of extractRefs(config.system_prompt ?? \"\")) {\n        addEdgeByName(ref, config.name);\n      }\n    }\n    if (config.kind === \"expression\") {\n      for (const ref of extractRefs(config.expr)) {\n        addEdgeByName(ref, config.name);\n      }\n    }\n    if (\n      config.kind === \"sampler\" &&\n      config.sampler_type === \"subcategory\" &&\n      config.subcategory_parent\n    ) {\n      addEdgeByName(config.subcategory_parent, config.name);\n    }\n    if (config.kind === \"model_config\" && config.provider) {\n      addEdgeByName(config.provider, config.name);\n    }\n    if (\n      config.kind === \"sampler\" &&\n      config.sampler_type === \"timedelta\" &&\n      config.reference_column_name\n    ) {\n      addEdgeByName(config.reference_column_name, config.name);\n    }\n    if (config.kind === \"llm\" && config.model_alias) {\n      addEdgeByName(config.model_alias, config.name);\n    }\n    if (config.kind === \"llm\" && config.tool_alias) {\n      addEdgeByName(config.tool_alias, config.name);\n    }\n    if (config.kind === \"validator\") {\n      for (const targetColumn of config.target_columns ?? []) {\n        if (targetColumn.trim()) {\n          addEdgeByName(targetColumn, config.name);\n        }\n      }\n    }\n  }\n\n  return edges;\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/import/helpers.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { extractRefs as extractJinjaRefs } from \"../refs\";\n\nexport function isRecord(value: unknown): value is Record<string, unknown> {\n  return Boolean(value && typeof value === \"object\" && !Array.isArray(value));\n}\n\nexport function readString(value: unknown): string | null {\n  return typeof value === \"string\" ? value : null;\n}\n\nexport function readNumberString(value: unknown): string {\n  if (typeof value === \"number\" && Number.isFinite(value)) {\n    return String(value);\n  }\n  if (typeof value === \"string\") {\n    return value;\n  }\n  return \"\";\n}\n\nexport function parseJson(\n  input: string,\n): { data: unknown | null; error?: string } {\n  try {\n    return { data: JSON.parse(input) };\n  } catch (error) {\n    return {\n      data: null,\n      error: error instanceof Error ? error.message : \"Invalid JSON.\",\n    };\n  }\n}\n\nexport function normalizeOutputFormat(value: unknown): string {\n  if (typeof value === \"string\") {\n    return value;\n  }\n  if (isRecord(value)) {\n    return JSON.stringify(value, null, 2);\n  }\n  return \"\";\n}\n\nexport function extractRefs(template: string): string[] {\n  return extractJinjaRefs(template);\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/import/importer.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type {\n  LlmConfig,\n  LlmMcpProviderConfig,\n  LlmToolConfig,\n  MarkdownNoteConfig,\n  NodeConfig,\n  RecipeProcessorConfig,\n  SeedConfig,\n  SamplerConfig,\n  SeedSourceType,\n  ToolProfileConfig,\n  ValidatorConfig,\n} from \"../../types\";\nimport { buildEdges } from \"./edges\";\nimport { isRecord, parseJson, readString } from \"./helpers\";\nimport {\n  parseColumn,\n  parseModelConfig,\n  parseModelProvider,\n} from \"./parsers\";\nimport { parseSeedConfig } from \"./parsers/seed-config-parser\";\nimport { buildNodes, parseUi } from \"./ui\";\nimport type { ImportResult } from \"./types\";\n\ntype RecipeInput = {\n  columns?: unknown;\n  model_configs?: unknown;\n  model_providers?: unknown;\n  mcp_providers?: unknown;\n  tool_configs?: unknown;\n  processors?: unknown;\n  seed_config?: unknown;\n};\n\ntype UiInput = {\n  nodes?: unknown;\n  edges?: unknown;\n  seed_source_type?: unknown;\n  seed_columns?: unknown;\n  seed_drop_columns?: unknown;\n  seed_preview_rows?: unknown;\n  local_file_name?: unknown;\n  unstructured_file_name?: unknown;\n  unstructured_chunk_size?: unknown;\n  unstructured_chunk_overlap?: unknown;\n  advanced_open_by_node?: unknown;\n};\n\ntype UiMarkdownNoteNode = {\n  name: string;\n  markdown: string;\n  note_color?: string;\n  note_opacity?: string;\n};\n\nfunction readStringNumber(value: unknown): string | undefined {\n  if (typeof value === \"string\") {\n    return value;\n  }\n  if (typeof value === \"number\" && Number.isFinite(value)) {\n    return String(value);\n  }\n  return undefined;\n}\n\nfunction parseProcessors(input: unknown): RecipeProcessorConfig[] {\n  if (!Array.isArray(input)) {\n    return [];\n  }\n  const processors: RecipeProcessorConfig[] = [];\n  input.forEach((item, index) => {\n    if (!isRecord(item)) {\n      return;\n    }\n    const type = readString(item.processor_type);\n    const templateRaw = item.template;\n    const isSchemaTransform =\n      type === \"schema_transform\" || isRecord(templateRaw);\n    if (!isSchemaTransform) {\n      return;\n    }\n    const name = readString(item.name) ?? `schema_transform_${index + 1}`;\n    const template =\n      typeof templateRaw === \"string\"\n        ? templateRaw\n        : isRecord(templateRaw)\n          ? JSON.stringify(templateRaw, null, 2)\n          : \"{\\n  \\\"text\\\": \\\"{{ column_name }}\\\"\\n}\";\n    processors.push({\n      id: `p${index + 1}`,\n      // biome-ignore lint/style/useNamingConvention: api schema\n      processor_type: \"schema_transform\",\n      name,\n      template,\n    });\n  });\n  return processors;\n}\n\nfunction parseSeedDropColumns(input: unknown): string[] {\n  if (!Array.isArray(input)) {\n    return [];\n  }\n  const values = new Set<string>();\n  for (const item of input) {\n    if (!isRecord(item)) {\n      continue;\n    }\n    const type = readString(item.processor_type);\n    if (type !== \"drop_columns\") {\n      continue;\n    }\n    const name = readString(item.name);\n    if (name !== \"drop_seed_columns\") {\n      continue;\n    }\n    const columnNames = Array.isArray(item.column_names)\n      ? item.column_names\n      : [];\n    for (const columnName of columnNames) {\n      if (typeof columnName !== \"string\") {\n        continue;\n      }\n      const next = columnName.trim();\n      if (next) {\n        values.add(next);\n      }\n    }\n  }\n  return Array.from(values);\n}\n\nfunction parseMcpProviders(\n  input: unknown,\n): Map<string, LlmMcpProviderConfig> {\n  const providers = new Map<string, LlmMcpProviderConfig>();\n  if (!Array.isArray(input)) {\n    return providers;\n  }\n  input.forEach((item, index) => {\n    if (!isRecord(item)) {\n      return;\n    }\n    const name = readString(item.name)?.trim();\n    if (!name) {\n      return;\n    }\n    const providerTypeRaw = readString(item.provider_type);\n    const providerType =\n      providerTypeRaw === \"stdio\" ? \"stdio\" : \"streamable_http\";\n    const args = Array.isArray(item.args)\n      ? item.args.map((value) => String(value))\n      : [];\n    const envPairs =\n      isRecord(item.env)\n        ? Object.entries(item.env).map(([key, value]) => ({\n            key: String(key),\n            value: String(value),\n          }))\n        : [];\n    providers.set(name, {\n      id: `mcp-${index + 1}`,\n      name,\n      // biome-ignore lint/style/useNamingConvention: ui schema\n      provider_type: providerType,\n      command: readString(item.command) ?? \"\",\n      args,\n      env: envPairs,\n      endpoint: readString(item.endpoint) ?? \"\",\n      // biome-ignore lint/style/useNamingConvention: api schema\n      api_key: readString(item.api_key) ?? \"\",\n      // biome-ignore lint/style/useNamingConvention: api schema\n      api_key_env: readString(item.api_key_env) ?? \"\",\n    });\n  });\n  return providers;\n}\n\nfunction parseToolConfigs(input: unknown): Map<string, LlmToolConfig> {\n  const toolConfigs = new Map<string, LlmToolConfig>();\n  if (!Array.isArray(input)) {\n    return toolConfigs;\n  }\n  input.forEach((item, index) => {\n    if (!isRecord(item)) {\n      return;\n    }\n    const toolAlias = readString(item.tool_alias)?.trim();\n    if (!toolAlias) {\n      return;\n    }\n    const providers = Array.isArray(item.providers)\n      ? item.providers.map((value) => String(value).trim()).filter(Boolean)\n      : [];\n    const allowTools = Array.isArray(item.allow_tools)\n      ? item.allow_tools.map((value) => String(value).trim()).filter(Boolean)\n      : [];\n    toolConfigs.set(toolAlias, {\n      id: `tool-${index + 1}`,\n      // biome-ignore lint/style/useNamingConvention: api schema\n      tool_alias: toolAlias,\n      providers,\n      // biome-ignore lint/style/useNamingConvention: api schema\n      allow_tools: allowTools,\n      // biome-ignore lint/style/useNamingConvention: api schema\n      max_tool_call_turns:\n        item.max_tool_call_turns === null || item.max_tool_call_turns === undefined\n          ? \"5\"\n          : String(item.max_tool_call_turns),\n      // biome-ignore lint/style/useNamingConvention: api schema\n      timeout_sec:\n        item.timeout_sec === null || item.timeout_sec === undefined\n          ? \"\"\n          : String(item.timeout_sec),\n    });\n  });\n  return toolConfigs;\n}\n\nfunction cloneMcpProvider(config: LlmMcpProviderConfig): LlmMcpProviderConfig {\n  return {\n    ...config,\n    args: [...(config.args ?? [])],\n    env: [...(config.env ?? [])],\n  };\n}\n\nfunction parseUiMarkdownNoteNodes(input: unknown): UiMarkdownNoteNode[] {\n  if (!Array.isArray(input)) {\n    return [];\n  }\n  const noteNodes: UiMarkdownNoteNode[] = [];\n  for (const node of input) {\n    if (!isRecord(node)) {\n      continue;\n    }\n    const nodeType = readString(node.node_type) ?? readString(node.type);\n    if (nodeType !== \"markdown_note\") {\n      continue;\n    }\n    const name = readString(node.name) ?? readString(node.id);\n    if (!name?.trim()) {\n      continue;\n    }\n    noteNodes.push({\n      name: name.trim(),\n      markdown: readString(node.markdown) ?? \"\",\n      note_color: readString(node.note_color) ?? undefined,\n      note_opacity: readStringNumber(node.note_opacity) ?? undefined,\n    });\n  }\n  return noteNodes;\n}\n\nfunction parseUiToolProfileNodes(input: unknown): Map<string, Record<string, string[]>> {\n  const toolProfiles = new Map<string, Record<string, string[]>>();\n  if (!Array.isArray(input)) {\n    return toolProfiles;\n  }\n  for (const node of input) {\n    if (!isRecord(node)) {\n      continue;\n    }\n    const nodeType = readString(node.node_type) ?? readString(node.type);\n    if (nodeType !== \"tool_config\") {\n      continue;\n    }\n    const name = readString(node.name) ?? readString(node.id);\n    if (!name?.trim()) {\n      continue;\n    }\n    const rawToolsByProvider = isRecord(node.tools_by_provider)\n      ? node.tools_by_provider\n      : null;\n    if (!rawToolsByProvider) {\n      continue;\n    }\n    const toolsByProvider = Object.fromEntries(\n      Object.entries(rawToolsByProvider).flatMap(([providerName, tools]) => {\n        const trimmedName = providerName.trim();\n        if (!trimmedName || !Array.isArray(tools)) {\n          return [];\n        }\n        const values = Array.from(\n          new Set(tools.map((value) => String(value).trim()).filter(Boolean)),\n        );\n        return values.length > 0 ? [[trimmedName, values]] : [];\n      }),\n    );\n    toolProfiles.set(name.trim(), toolsByProvider);\n  }\n  return toolProfiles;\n}\n\nfunction parseAdvancedOpenByNode(input: unknown): Record<string, boolean> {\n  if (!isRecord(input)) {\n    return {};\n  }\n  const out: Record<string, boolean> = {};\n  for (const [nameRaw, value] of Object.entries(input)) {\n    const name = nameRaw.trim();\n    if (!name || typeof value !== \"boolean\") {\n      continue;\n    }\n    out[name] = value;\n  }\n  return out;\n}\n\ntype AdvancedOpenConfig = LlmConfig | SamplerConfig | SeedConfig | ValidatorConfig;\n\nfunction isAdvancedOpenConfig(config: NodeConfig): config is AdvancedOpenConfig {\n  return (\n    config.kind === \"llm\" ||\n    config.kind === \"sampler\" ||\n    config.kind === \"seed\" ||\n    config.kind === \"validator\"\n  );\n}\n\nfunction applyAdvancedOpen(\n  config: NodeConfig,\n  advancedOpenByNode: Record<string, boolean>,\n): void {\n  if (!isAdvancedOpenConfig(config)) {\n    return;\n  }\n  config.advancedOpen = advancedOpenByNode[config.name] === true;\n}\n\nfunction buildToolProfileConfig(\n  toolConfig: LlmToolConfig,\n  toolConfigsByAlias: Map<string, LlmToolConfig>,\n  mcpProvidersByName: Map<string, LlmMcpProviderConfig>,\n  fetchedToolsByProfileName: Map<string, Record<string, string[]>>,\n  id: string,\n): ToolProfileConfig {\n  const canonical = toolConfigsByAlias.get(toolConfig.tool_alias) ?? toolConfig;\n  return {\n    id,\n    kind: \"tool_config\",\n    name: canonical.tool_alias,\n    // biome-ignore lint/style/useNamingConvention: ui schema\n    mcp_providers: canonical.providers\n      .map((providerName) => mcpProvidersByName.get(providerName))\n      .flatMap((provider) => (provider ? [cloneMcpProvider(provider)] : [])),\n    // biome-ignore lint/style/useNamingConvention: ui schema\n    fetched_tools_by_provider: fetchedToolsByProfileName.get(canonical.tool_alias) ?? {},\n    // biome-ignore lint/style/useNamingConvention: api schema\n    allow_tools: [...(canonical.allow_tools ?? [])],\n    // biome-ignore lint/style/useNamingConvention: api schema\n    max_tool_call_turns: canonical.max_tool_call_turns ?? \"5\",\n    // biome-ignore lint/style/useNamingConvention: api schema\n    timeout_sec: canonical.timeout_sec ?? \"\",\n  };\n}\n\nexport function importRecipePayload(input: string): ImportResult {\n  const parsed = parseJson(input);\n  if (!parsed.data || !isRecord(parsed.data)) {\n    return {\n      errors: [parsed.error ?? \"Invalid JSON payload.\"],\n      snapshot: null,\n    };\n  }\n\n  const recipe = (isRecord(parsed.data.recipe)\n    ? parsed.data.recipe\n    : parsed.data) as RecipeInput;\n  const ui = isRecord(parsed.data.ui) ? (parsed.data.ui as UiInput) : null;\n\n  if (!Array.isArray(recipe.columns)) {\n    return { errors: [\"Recipe must include columns.\"], snapshot: null };\n  }\n\n  const errors: string[] = [];\n  const configs: NodeConfig[] = [];\n  const processors = parseProcessors(recipe.processors);\n  const mcpProvidersByName = parseMcpProviders(recipe.mcp_providers);\n  const toolConfigsByAlias = parseToolConfigs(recipe.tool_configs);\n  const nameToId = new Map<string, string>();\n\n  let nextId = 1;\n  const uiSeedSourceTypeRaw = readString(ui?.seed_source_type);\n  const uiSeedSourceType: SeedSourceType | undefined =\n    uiSeedSourceTypeRaw === \"hf\" ||\n    uiSeedSourceTypeRaw === \"local\" ||\n    uiSeedSourceTypeRaw === \"unstructured\"\n      ? uiSeedSourceTypeRaw\n      : undefined;\n  const uiSeedColumns = Array.isArray(ui?.seed_columns)\n    ? ui.seed_columns\n        .map((value) => (typeof value === \"string\" ? value.trim() : \"\"))\n        .filter(Boolean)\n    : undefined;\n  const uiSeedDropColumns = Array.isArray(ui?.seed_drop_columns)\n    ? ui.seed_drop_columns\n        .map((value) => (typeof value === \"string\" ? value.trim() : \"\"))\n        .filter(Boolean)\n    : undefined;\n  const payloadSeedDropColumns = parseSeedDropColumns(recipe.processors);\n  const uiSeedPreviewRows = Array.isArray(ui?.seed_preview_rows)\n    ? ui.seed_preview_rows\n        .filter((row): row is Record<string, unknown> => isRecord(row))\n        .map((row) => ({ ...row }))\n    : undefined;\n  const uiLocalFileName = readString(ui?.local_file_name) ?? undefined;\n  const uiUnstructuredFileName =\n    readString(ui?.unstructured_file_name) ?? undefined;\n  const uiUnstructuredChunkSize = readStringNumber(ui?.unstructured_chunk_size);\n  const uiUnstructuredChunkOverlap = readStringNumber(\n    ui?.unstructured_chunk_overlap,\n  );\n  const uiAdvancedOpenByNode = parseAdvancedOpenByNode(ui?.advanced_open_by_node);\n  const uiMarkdownNotes = parseUiMarkdownNoteNodes(ui?.nodes);\n  const uiToolProfilesByName = parseUiToolProfileNodes(ui?.nodes);\n\n  for (const note of uiMarkdownNotes) {\n    const id = `n${nextId}`;\n    nextId += 1;\n    const config: MarkdownNoteConfig = {\n      id,\n      kind: \"markdown_note\",\n      name: note.name,\n      markdown: note.markdown,\n      note_color: note.note_color ?? \"#FDE68A\",\n      note_opacity: note.note_opacity ?? \"35\",\n    };\n    if (nameToId.has(config.name)) {\n      errors.push(`Duplicate column name: ${config.name}.`);\n      continue;\n    }\n    nameToId.set(config.name, config.id);\n    configs.push(config);\n  }\n\n  if (recipe.seed_config) {\n    const id = `n${nextId}`;\n    nextId += 1;\n    const seedConfig = parseSeedConfig(recipe.seed_config, id, {\n      preferredSourceType: uiSeedSourceType,\n      seed_columns: uiSeedColumns,\n      seed_drop_columns:\n        uiSeedDropColumns && uiSeedDropColumns.length > 0\n          ? uiSeedDropColumns\n          : payloadSeedDropColumns,\n      seed_preview_rows: uiSeedPreviewRows,\n      local_file_name: uiLocalFileName,\n      unstructured_file_name: uiUnstructuredFileName,\n      unstructured_chunk_size: uiUnstructuredChunkSize,\n      unstructured_chunk_overlap: uiUnstructuredChunkOverlap,\n    });\n    if (seedConfig) {\n      applyAdvancedOpen(seedConfig, uiAdvancedOpenByNode);\n      if (nameToId.has(seedConfig.name)) {\n        errors.push(`Duplicate column name: ${seedConfig.name}.`);\n      } else {\n        nameToId.set(seedConfig.name, seedConfig.id);\n      }\n      configs.push(seedConfig);\n    }\n  }\n\n  if (Array.isArray(recipe.model_providers)) {\n    recipe.model_providers.forEach((provider, index) => {\n      if (!isRecord(provider)) {\n        errors.push(`Model provider ${index + 1}: invalid object.`);\n        return;\n      }\n      const name = readString(provider.name);\n      if (!name) {\n        errors.push(`Model provider ${index + 1}: missing name.`);\n        return;\n      }\n      const id = `n${nextId}`;\n      nextId += 1;\n      const config = parseModelProvider(provider, name, id);\n      if (nameToId.has(config.name)) {\n        errors.push(`Duplicate column name: ${config.name}.`);\n        return;\n      }\n      nameToId.set(config.name, config.id);\n      configs.push(config);\n    });\n  }\n\n  if (Array.isArray(recipe.model_configs)) {\n    recipe.model_configs.forEach((model, index) => {\n      if (!isRecord(model)) {\n        errors.push(`Model config ${index + 1}: invalid object.`);\n        return;\n      }\n      const name = readString(model.alias) ?? readString(model.name);\n      if (!name) {\n        errors.push(`Model config ${index + 1}: missing alias.`);\n        return;\n      }\n      const id = `n${nextId}`;\n      nextId += 1;\n      const config = parseModelConfig(model, name, id);\n      if (nameToId.has(config.name)) {\n        errors.push(`Duplicate column name: ${config.name}.`);\n        return;\n      }\n      nameToId.set(config.name, config.id);\n      configs.push(config);\n    });\n  }\n\n  for (const toolConfig of toolConfigsByAlias.values()) {\n    const id = `n${nextId}`;\n    nextId += 1;\n    const config = buildToolProfileConfig(\n      toolConfig,\n      toolConfigsByAlias,\n      mcpProvidersByName,\n      uiToolProfilesByName,\n      id,\n    );\n    if (nameToId.has(config.name)) {\n      errors.push(`Duplicate column name: ${config.name}.`);\n      continue;\n    }\n    nameToId.set(config.name, config.id);\n    configs.push(config);\n  }\n\n  recipe.columns.forEach((column, index) => {\n    if (!isRecord(column)) {\n      errors.push(`Column ${index + 1}: invalid object.`);\n      return;\n    }\n    const id = `n${nextId}`;\n    nextId += 1;\n    const config = parseColumn(column, id, errors);\n    if (!config) {\n      return;\n    }\n    applyAdvancedOpen(config, uiAdvancedOpenByNode);\n    if (nameToId.has(config.name)) {\n      errors.push(`Duplicate column name: ${config.name}.`);\n      return;\n    }\n    nameToId.set(config.name, config.id);\n    configs.push(config);\n  });\n\n  if (errors.length > 0) {\n    return { errors, snapshot: null };\n  }\n\n  const { layouts, auxNodes, edges: uiEdges, layoutDirection } = parseUi(ui);\n  const resolvedLayoutDirection = layoutDirection ?? \"LR\";\n  const nodes = buildNodes(configs, layouts);\n  const edges = buildEdges(\n    configs,\n    nameToId,\n    uiEdges,\n    resolvedLayoutDirection,\n  );\n  const auxNodePositions = Object.fromEntries(\n    auxNodes.flatMap((item) => {\n      const llmId = nameToId.get(item.llm);\n      if (!llmId) {\n        return [];\n      }\n      return [[`aux-${llmId}-${item.key}`, { x: item.x, y: item.y }]];\n    }),\n  );\n\n  const maxY = nodes.reduce(\n    (acc, node) => Math.max(acc, node.position.y),\n    0,\n  );\n\n  return {\n    errors: [],\n    snapshot: {\n      configs: Object.fromEntries(configs.map((config) => [config.id, config])),\n      nodes,\n      edges,\n      auxNodePositions,\n      processors,\n      layoutDirection: resolvedLayoutDirection,\n      nextId,\n      nextY: maxY + 140,\n    },\n  };\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/import/index.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nexport { importRecipePayload } from \"./importer\";\nexport type { RecipeSnapshot, ImportResult } from \"./types\";\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/import/parsers/expression-parser.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type {\n  ExpressionConfig,\n  ExpressionDtype,\n} from \"../../../types\";\nimport { readString } from \"../helpers\";\n\nconst EXPRESSION_DTYPES: ExpressionDtype[] = [\"str\", \"int\", \"float\", \"bool\"];\n\nexport function parseExpression(\n  column: Record<string, unknown>,\n  name: string,\n  id: string,\n): ExpressionConfig {\n  const dtype = readString(column.dtype);\n  const normalized = EXPRESSION_DTYPES.includes(dtype as ExpressionDtype)\n    ? (dtype as ExpressionDtype)\n    : \"str\";\n  return {\n    id,\n    kind: \"expression\",\n    name,\n    drop: column.drop === true,\n    expr: readString(column.expr) ?? \"\",\n    dtype: normalized,\n  };\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/import/parsers/llm-parser.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type {\n  LlmConfig,\n  Score,\n  ScoreOption,\n} from \"../../../types\";\nimport {\n  isRecord,\n  normalizeOutputFormat,\n  readString,\n} from \"../helpers\";\n\nfunction parseTraceMode(value: unknown): LlmConfig[\"with_trace\"] {\n  const traceRaw = readString(value) ?? \"none\";\n  if (traceRaw === \"last_message\" || traceRaw === \"all_messages\") {\n    return traceRaw;\n  }\n  return \"none\";\n}\n\nexport function parseLlm(\n  column: Record<string, unknown>,\n  name: string,\n  id: string,\n): LlmConfig {\n  const columnType = readString(column.column_type) ?? \"llm-text\";\n  let llmType: LlmConfig[\"llm_type\"] = \"text\";\n  if (columnType === \"llm-structured\") {\n    llmType = \"structured\";\n  } else if (columnType === \"llm-code\") {\n    llmType = \"code\";\n  } else if (columnType === \"llm-judge\") {\n    llmType = \"judge\";\n  }\n\n  const scores: Score[] =\n    columnType === \"llm-judge\" && Array.isArray(column.scores)\n      ? column.scores\n          .filter((score) => isRecord(score))\n          .map((score) => {\n            const options: ScoreOption[] = [];\n            const rawOptions = isRecord(score.options) ? score.options : {};\n            for (const [key, value] of Object.entries(rawOptions)) {\n              const description =\n                typeof value === \"string\" ? value : JSON.stringify(value);\n              options.push({ value: String(key), description });\n            }\n            return {\n              name: readString(score.name) ?? \"\",\n              description: readString(score.description) ?? \"\",\n              options,\n            };\n          })\n      : [];\n\n  let imageContext: LlmConfig[\"image_context\"] = {\n    enabled: false,\n    // biome-ignore lint/style/useNamingConvention: api schema\n    column_name: \"\",\n  };\n  if (Array.isArray(column.multi_modal_context)) {\n    const first = column.multi_modal_context.find((entry) => isRecord(entry));\n    if (first && isRecord(first)) {\n      const modality = readString(first.modality);\n      const columnName = readString(first.column_name) ?? \"\";\n      if (modality === \"image\" && columnName) {\n        imageContext = {\n          enabled: true,\n          // biome-ignore lint/style/useNamingConvention: api schema\n          column_name: columnName,\n        };\n      }\n    }\n  }\n\n  const withTrace = parseTraceMode(column.with_trace);\n  const extractReasoningContent = column.extract_reasoning_content === true;\n\n  return {\n    id,\n    kind: \"llm\",\n    // biome-ignore lint/style/useNamingConvention: api schema\n    llm_type: llmType,\n    name,\n    drop: column.drop === true,\n    // biome-ignore lint/style/useNamingConvention: api schema\n    model_alias: readString(column.model_alias) ?? \"\",\n    prompt: readString(column.prompt) ?? \"\",\n    // biome-ignore lint/style/useNamingConvention: api schema\n    system_prompt: readString(column.system_prompt) ?? \"\",\n    // biome-ignore lint/style/useNamingConvention: api schema\n    code_lang: readString(column.code_lang) ?? \"\",\n    // biome-ignore lint/style/useNamingConvention: api schema\n    output_format: normalizeOutputFormat(column.output_format),\n    // biome-ignore lint/style/useNamingConvention: api schema\n    tool_alias: readString(column.tool_alias) ?? \"\",\n    // biome-ignore lint/style/useNamingConvention: api schema\n    with_trace: withTrace,\n    // biome-ignore lint/style/useNamingConvention: api schema\n    extract_reasoning_content: extractReasoningContent,\n    scores: llmType === \"judge\" ? scores : undefined,\n    // biome-ignore lint/style/useNamingConvention: ui schema\n    image_context: imageContext,\n  };\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/import/parsers/model-parser.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type {\n  ModelConfig,\n  ModelProviderConfig,\n} from \"../../../types\";\nimport {\n  isRecord,\n  readNumberString,\n  readString,\n} from \"../helpers\";\n\nexport function parseModelProvider(\n  provider: Record<string, unknown>,\n  name: string,\n  id: string,\n): ModelProviderConfig {\n  return {\n    id,\n    kind: \"model_provider\",\n    name,\n    endpoint: readString(provider.endpoint) ?? \"\",\n    // biome-ignore lint/style/useNamingConvention: api schema\n    provider_type: readString(provider.provider_type) ?? \"openai\",\n    // biome-ignore lint/style/useNamingConvention: api schema\n    api_key_env: readString(provider.api_key_env) ?? \"\",\n    // biome-ignore lint/style/useNamingConvention: api schema\n    api_key: readString(provider.api_key) ?? \"\",\n    // biome-ignore lint/style/useNamingConvention: api schema\n    extra_headers: isRecord(provider.extra_headers)\n      ? JSON.stringify(provider.extra_headers, null, 2)\n      : \"\",\n    // biome-ignore lint/style/useNamingConvention: api schema\n    extra_body: isRecord(provider.extra_body)\n      ? JSON.stringify(provider.extra_body, null, 2)\n      : \"\",\n  };\n}\n\nexport function parseModelConfig(\n  model: Record<string, unknown>,\n  name: string,\n  id: string,\n): ModelConfig {\n  const inference = isRecord(model.inference_parameters)\n    ? (model.inference_parameters as Record<string, unknown>)\n    : {};\n  return {\n    id,\n    kind: \"model_config\",\n    name,\n    model: readString(model.model) ?? \"\",\n    provider: readString(model.provider) ?? \"\",\n    // biome-ignore lint/style/useNamingConvention: api schema\n    inference_temperature: readNumberString(inference.temperature),\n    // biome-ignore lint/style/useNamingConvention: api schema\n    inference_top_p: readNumberString(inference.top_p),\n    // biome-ignore lint/style/useNamingConvention: api schema\n    inference_max_tokens: readNumberString(inference.max_tokens),\n    // biome-ignore lint/style/useNamingConvention: api schema\n    inference_timeout: readNumberString(inference.timeout),\n    // biome-ignore lint/style/useNamingConvention: api schema\n    inference_extra_body: isRecord(inference.extra_body)\n      ? JSON.stringify(inference.extra_body, null, 2)\n      : \"\",\n    // biome-ignore lint/style/useNamingConvention: api schema\n    skip_health_check:\n      typeof model.skip_health_check === \"boolean\"\n        ? model.skip_health_check\n        : false,\n  };\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/import/parsers/sampler-parser.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type {\n  SamplerConfig,\n  SamplerType,\n} from \"../../../types\";\nimport {\n  isRecord,\n  readNumberString,\n  readString,\n} from \"../helpers\";\n\nconst SAMPLER_TYPES: SamplerType[] = [\n  \"category\",\n  \"subcategory\",\n  \"uniform\",\n  \"gaussian\",\n  \"bernoulli\",\n  \"datetime\",\n  \"timedelta\",\n  \"uuid\",\n  \"person\",\n  \"person_from_faker\",\n];\n\nconst TIMEDELTA_UNITS = new Set([\"D\", \"h\", \"m\", \"s\"]);\n\nfunction parseCategoryConditionalParams(\n  column: Record<string, unknown>,\n): SamplerConfig[\"conditional_params\"] {\n  if (!isRecord(column.conditional_params)) {\n    return undefined;\n  }\n  const conditional: NonNullable<SamplerConfig[\"conditional_params\"]> = {};\n  for (const [condition, rawParams] of Object.entries(column.conditional_params)) {\n    if (!isRecord(rawParams)) {\n      continue;\n    }\n    if (readString(rawParams.sampler_type) !== \"category\") {\n      continue;\n    }\n    const values = Array.isArray(rawParams.values)\n      ? rawParams.values.filter((item) => typeof item === \"string\")\n      : [];\n    if (values.length === 0) {\n      continue;\n    }\n    const weights = Array.isArray(rawParams.weights)\n      ? rawParams.weights.map((item) => (typeof item === \"number\" ? item : null))\n      : undefined;\n    conditional[condition] = {\n      // biome-ignore lint/style/useNamingConvention: api schema\n      sampler_type: \"category\",\n      values,\n      weights,\n    };\n  }\n  return Object.keys(conditional).length > 0 ? conditional : undefined;\n}\n\nexport function parseSampler(\n  column: Record<string, unknown>,\n  name: string,\n  id: string,\n  errors: string[],\n): SamplerConfig | null {\n  const drop = column.drop === true;\n  const samplerType = readString(column.sampler_type);\n  if (!samplerType || !SAMPLER_TYPES.includes(samplerType as SamplerType)) {\n    errors.push(`Sampler ${name}: unsupported sampler_type.`);\n    return null;\n  }\n  const convertTo = readString(column.convert_to);\n  const normalizedConvertTo =\n    convertTo && [\"float\", \"int\", \"str\"].includes(convertTo)\n      ? (convertTo as \"float\" | \"int\" | \"str\")\n      : undefined;\n  const params =\n    typeof column.params === \"object\" && column.params\n      ? (column.params as Record<string, unknown>)\n      : {};\n\n  if (samplerType === \"category\") {\n    const values = Array.isArray(params.values)\n      ? params.values.filter((item) => typeof item === \"string\")\n      : [];\n    const weights = Array.isArray(params.weights)\n      ? params.weights.map((item) => (typeof item === \"number\" ? item : null))\n      : [];\n    return {\n      id,\n      kind: \"sampler\",\n      // biome-ignore lint/style/useNamingConvention: api schema\n      sampler_type: \"category\",\n      name,\n      drop,\n      // biome-ignore lint/style/useNamingConvention: api schema\n      convert_to: normalizedConvertTo,\n      values,\n      weights,\n      // biome-ignore lint/style/useNamingConvention: api schema\n      conditional_params: parseCategoryConditionalParams(column),\n    };\n  }\n\n  if (samplerType === \"subcategory\") {\n    const mapping: Record<string, string[]> = {};\n    if (params.values && typeof params.values === \"object\") {\n      for (const [key, value] of Object.entries(params.values)) {\n        if (Array.isArray(value)) {\n          mapping[key] = value.filter((item) => typeof item === \"string\");\n        }\n      }\n    }\n    return {\n      id,\n      kind: \"sampler\",\n      // biome-ignore lint/style/useNamingConvention: api schema\n      sampler_type: \"subcategory\",\n      name,\n      drop,\n      // biome-ignore lint/style/useNamingConvention: api schema\n      convert_to: normalizedConvertTo,\n      // biome-ignore lint/style/useNamingConvention: api schema\n      subcategory_parent: readString(params.category) ?? \"\",\n      // biome-ignore lint/style/useNamingConvention: api schema\n      subcategory_mapping: mapping,\n    };\n  }\n\n  if (samplerType === \"uniform\") {\n    return {\n      id,\n      kind: \"sampler\",\n      // biome-ignore lint/style/useNamingConvention: api schema\n      sampler_type: \"uniform\",\n      name,\n      drop,\n      // biome-ignore lint/style/useNamingConvention: api schema\n      convert_to: normalizedConvertTo,\n      low: readNumberString(params.low),\n      high: readNumberString(params.high),\n    };\n  }\n\n  if (samplerType === \"gaussian\") {\n    return {\n      id,\n      kind: \"sampler\",\n      // biome-ignore lint/style/useNamingConvention: api schema\n      sampler_type: \"gaussian\",\n      name,\n      drop,\n      // biome-ignore lint/style/useNamingConvention: api schema\n      convert_to: normalizedConvertTo,\n      mean: readNumberString(params.mean),\n      std: readNumberString(params.std),\n    };\n  }\n\n  if (samplerType === \"bernoulli\") {\n    return {\n      id,\n      kind: \"sampler\",\n      // biome-ignore lint/style/useNamingConvention: api schema\n      sampler_type: \"bernoulli\",\n      name,\n      drop,\n      // biome-ignore lint/style/useNamingConvention: api schema\n      convert_to: normalizedConvertTo,\n      p: readNumberString(params.p),\n    };\n  }\n\n  if (samplerType === \"datetime\") {\n    return {\n      id,\n      kind: \"sampler\",\n      // biome-ignore lint/style/useNamingConvention: api schema\n      sampler_type: \"datetime\",\n      name,\n      drop,\n      // biome-ignore lint/style/useNamingConvention: api schema\n      convert_to: normalizedConvertTo,\n      // biome-ignore lint/style/useNamingConvention: api schema\n      datetime_start: readString(params.start) ?? \"\",\n      // biome-ignore lint/style/useNamingConvention: api schema\n      datetime_end: readString(params.end) ?? \"\",\n      // biome-ignore lint/style/useNamingConvention: api schema\n      datetime_unit: readString(params.unit) ?? \"\",\n    };\n  }\n\n  if (samplerType === \"timedelta\") {\n    const rawUnit = readString(params.unit);\n    const unit =\n      rawUnit && TIMEDELTA_UNITS.has(rawUnit)\n        ? (rawUnit as \"D\" | \"h\" | \"m\" | \"s\")\n        : \"D\";\n    return {\n      id,\n      kind: \"sampler\",\n      // biome-ignore lint/style/useNamingConvention: api schema\n      sampler_type: \"timedelta\",\n      name,\n      drop,\n      // biome-ignore lint/style/useNamingConvention: api schema\n      convert_to: normalizedConvertTo,\n      // biome-ignore lint/style/useNamingConvention: api schema\n      dt_min: readNumberString(params.dt_min),\n      // biome-ignore lint/style/useNamingConvention: api schema\n      dt_max: readNumberString(params.dt_max),\n      // biome-ignore lint/style/useNamingConvention: api schema\n      reference_column_name: readString(params.reference_column_name) ?? \"\",\n      // biome-ignore lint/style/useNamingConvention: api schema\n      timedelta_unit: unit,\n    };\n  }\n\n  if (samplerType === \"uuid\") {\n    return {\n      id,\n      kind: \"sampler\",\n      // biome-ignore lint/style/useNamingConvention: api schema\n      sampler_type: \"uuid\",\n      name,\n      drop,\n      // biome-ignore lint/style/useNamingConvention: api schema\n      convert_to: normalizedConvertTo,\n      // biome-ignore lint/style/useNamingConvention: api schema\n      uuid_format: readString(params.format) ?? \"\",\n    };\n  }\n\n  const ageRange =\n    Array.isArray(params.age_range) &&\n    params.age_range.length === 2 &&\n    params.age_range.every((item) => typeof item === \"number\")\n      ? `${params.age_range[0]}-${params.age_range[1]}`\n      : readString(params.age_range) ?? \"\";\n\n  const base: SamplerConfig = {\n    id,\n    kind: \"sampler\",\n    name,\n    drop,\n    // biome-ignore lint/style/useNamingConvention: api schema\n    sampler_type: samplerType as SamplerType,\n    // biome-ignore lint/style/useNamingConvention: api schema\n    convert_to: normalizedConvertTo,\n    // biome-ignore lint/style/useNamingConvention: api schema\n    person_locale: readString(params.locale) ?? \"\",\n    // biome-ignore lint/style/useNamingConvention: api schema\n    person_sex: readString(params.sex) ?? \"\",\n    // biome-ignore lint/style/useNamingConvention: api schema\n    person_age_range: ageRange,\n    // biome-ignore lint/style/useNamingConvention: api schema\n    person_city: readString(params.city) ?? \"\",\n  };\n\n  if (samplerType === \"person\") {\n    return {\n      ...base,\n      // biome-ignore lint/style/useNamingConvention: api schema\n      person_with_synthetic_personas:\n        typeof params.with_synthetic_personas === \"boolean\"\n          ? params.with_synthetic_personas\n          : false,\n    };\n  }\n\n  return base;\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/import/parsers/seed-config-parser.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type {\n  SeedConfig,\n  SeedSamplingStrategy,\n  SeedSelectionType,\n  SeedSourceType,\n} from \"../../../types\";\nimport { isRecord, readNumberString, readString } from \"../helpers\";\n\nfunction normalizeSampling(value: unknown): SeedSamplingStrategy {\n  const raw = readString(value);\n  if (raw === \"shuffle\") return \"shuffle\";\n  return \"ordered\";\n}\n\nfunction makeDefaultSeedConfig(id: string): SeedConfig {\n  return {\n    id,\n    kind: \"seed\",\n    name: \"seed\",\n    drop: false,\n    seed_drop_columns: [],\n    seed_source_type: \"hf\",\n    hf_repo_id: \"\",\n    hf_subset: \"\",\n    hf_split: \"\",\n    hf_path: \"\",\n    hf_token: \"\",\n    hf_endpoint: \"https://huggingface.co\",\n    local_file_name: \"\",\n    unstructured_file_name: \"\",\n    seed_preview_rows: [],\n    unstructured_chunk_size: \"1200\",\n    unstructured_chunk_overlap: \"200\",\n    seed_splits: [],\n    seed_globs_by_split: {},\n    seed_columns: [],\n    sampling_strategy: \"ordered\",\n    selection_type: \"none\",\n    selection_start: \"0\",\n    selection_end: \"10\",\n    selection_index: \"0\",\n    selection_num_partitions: \"1\",\n  };\n}\n\nfunction inferRepoIdFromSeedPath(path: string): string {\n  const trimmed = path.trim();\n  if (!trimmed) return \"\";\n  const parts = trimmed.split(\"/\").filter(Boolean);\n  if (parts.length >= 3 && parts[0] === \"datasets\") {\n    return `${parts[1]}/${parts[2]}`;\n  }\n  if (parts.length >= 2) {\n    return `${parts[0]}/${parts[1]}`;\n  }\n  return \"\";\n}\n\nfunction parseSeedSettings(seedConfigRaw: unknown): Partial<SeedConfig> {\n  if (!isRecord(seedConfigRaw)) {\n    return {};\n  }\n\n  const sampling_strategy = normalizeSampling(seedConfigRaw.sampling_strategy);\n\n  let seed_source_type: SeedSourceType = \"hf\";\n  let hf_path = \"\";\n  let hf_token = \"\";\n  let hf_endpoint = \"https://huggingface.co\";\n  let hf_repo_id = \"\";\n  let local_file_name = \"\";\n  let unstructured_file_name = \"\";\n  let unstructured_chunk_size = \"1200\";\n  let unstructured_chunk_overlap = \"200\";\n  const sourceRaw = seedConfigRaw.source;\n  if (isRecord(sourceRaw)) {\n    const seedType = readString(sourceRaw.seed_type);\n    const sourcePath = readString(sourceRaw.path) ?? \"\";\n    if (seedType === \"hf\") {\n      seed_source_type = \"hf\";\n      hf_path = sourcePath;\n      hf_token = readString(sourceRaw.token) ?? \"\";\n      hf_endpoint = readString(sourceRaw.endpoint) ?? hf_endpoint;\n      hf_repo_id = inferRepoIdFromSeedPath(hf_path);\n    } else if (seedType === \"local\") {\n      seed_source_type = \"local\";\n      hf_path = sourcePath;\n      local_file_name = sourcePath.split(\"/\").pop() ?? sourcePath;\n    } else if (seedType === \"unstructured\") {\n      seed_source_type = \"unstructured\";\n      hf_path = sourcePath;\n      unstructured_file_name = sourcePath.split(\"/\").pop() ?? sourcePath;\n      unstructured_chunk_size = readNumberString(sourceRaw.chunk_size) || \"1200\";\n      unstructured_chunk_overlap = readNumberString(sourceRaw.chunk_overlap) || \"200\";\n    }\n  }\n\n  let selection_type: SeedSelectionType = \"none\";\n  let selection_start = \"0\";\n  let selection_end = \"10\";\n  let selection_index = \"0\";\n  let selection_num_partitions = \"1\";\n  const selectionRaw = seedConfigRaw.selection_strategy;\n  if (isRecord(selectionRaw)) {\n    if (\n      typeof selectionRaw.start === \"number\" &&\n      typeof selectionRaw.end === \"number\"\n    ) {\n      selection_type = \"index_range\";\n      selection_start = String(selectionRaw.start);\n      selection_end = String(selectionRaw.end);\n    } else if (\n      typeof selectionRaw.index === \"number\" &&\n      typeof selectionRaw.num_partitions === \"number\"\n    ) {\n      selection_type = \"partition_block\";\n      selection_index = String(selectionRaw.index);\n      selection_num_partitions = String(selectionRaw.num_partitions);\n    }\n  }\n\n  return {\n    seed_source_type,\n    hf_repo_id,\n    hf_path,\n    hf_token,\n    hf_endpoint,\n    local_file_name,\n    unstructured_file_name,\n    unstructured_chunk_size,\n    unstructured_chunk_overlap,\n    sampling_strategy,\n    selection_type,\n    selection_start,\n    selection_end,\n    selection_index,\n    selection_num_partitions,\n  };\n}\n\nexport function parseSeedConfig(\n  seedConfigRaw: unknown,\n  id: string,\n  options?: {\n    preferredSourceType?: SeedSourceType;\n    seed_columns?: string[];\n    seed_drop_columns?: string[];\n    seed_preview_rows?: Record<string, unknown>[];\n    local_file_name?: string;\n    unstructured_file_name?: string;\n    unstructured_chunk_size?: string;\n    unstructured_chunk_overlap?: string;\n  },\n): SeedConfig | null {\n  if (!seedConfigRaw) {\n    return null;\n  }\n  const parsed = parseSeedSettings(seedConfigRaw);\n  let sourceType: SeedSourceType = \"hf\";\n  if (parsed.seed_source_type === \"hf\") {\n    sourceType = \"hf\";\n  } else if (options?.preferredSourceType) {\n    sourceType = options.preferredSourceType;\n  } else if (parsed.seed_source_type) {\n    sourceType = parsed.seed_source_type;\n  }\n  return {\n    ...makeDefaultSeedConfig(id),\n    ...parsed, // payload-only fields override ui defaults\n    seed_source_type: sourceType,\n    ...(options?.seed_columns ? { seed_columns: options.seed_columns } : {}),\n    ...(options?.seed_drop_columns\n      ? { seed_drop_columns: options.seed_drop_columns }\n      : {}),\n    ...(options?.seed_preview_rows\n      ? { seed_preview_rows: options.seed_preview_rows }\n      : {}),\n    ...(options?.local_file_name !== undefined\n      ? { local_file_name: options.local_file_name }\n      : {}),\n    ...(options?.unstructured_file_name !== undefined\n      ? { unstructured_file_name: options.unstructured_file_name }\n      : {}),\n    ...(options?.unstructured_chunk_size !== undefined\n      ? { unstructured_chunk_size: options.unstructured_chunk_size }\n      : {}),\n    ...(options?.unstructured_chunk_overlap !== undefined\n      ? { unstructured_chunk_overlap: options.unstructured_chunk_overlap }\n      : {}),\n  };\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/import/parsers/validator-parser.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { ValidatorConfig } from \"../../../types\";\nimport { readNumberString } from \"../helpers\";\nimport { normalizeValidatorCodeLang } from \"../../validators/code-lang\";\nimport { normalizeOxcCodeShape } from \"../../validators/oxc-code-shape\";\nimport { normalizeOxcValidationMode } from \"../../validators/oxc-mode\";\n\nconst OXC_VALIDATION_FN_MARKER = \"unsloth_oxc_validator\";\n\nfunction parseOxcValidationMarker(\n  validationFunctionRaw: string,\n): { codeLang: string; mode: string; codeShape: string } {\n  const marker = `${OXC_VALIDATION_FN_MARKER}:`;\n  if (!validationFunctionRaw.startsWith(marker)) {\n    return { codeLang: \"\", mode: \"syntax\", codeShape: \"auto\" };\n  }\n  const parts = validationFunctionRaw\n    .slice(marker.length)\n    .split(\":\")\n    .map((value) => value.trim())\n    .filter(Boolean);\n  if (parts.length < 2) {\n    return { codeLang: \"\", mode: \"syntax\", codeShape: \"auto\" };\n  }\n  return {\n    codeLang: parts[0],\n    mode: parts[1],\n    codeShape: parts[2] ?? \"auto\",\n  };\n}\n\nexport function parseValidator(\n  column: Record<string, unknown>,\n  name: string,\n  id: string,\n): ValidatorConfig {\n  const targetColumns = Array.isArray(column.target_columns)\n    ? column.target_columns\n        .filter((value): value is string => typeof value === \"string\")\n        .map((value) => value.trim())\n        .filter(Boolean)\n    : [];\n  const params =\n    column.validator_params && typeof column.validator_params === \"object\"\n      ? (column.validator_params as Record<string, unknown>)\n      : {};\n  const validationFunctionRaw =\n    typeof params.validation_function === \"string\"\n      ? params.validation_function.trim()\n      : \"\";\n  const isOxc =\n    String(column.validator_type ?? \"\").trim() === \"local_callable\" &&\n    validationFunctionRaw.startsWith(OXC_VALIDATION_FN_MARKER);\n  const marker = isOxc\n    ? parseOxcValidationMarker(validationFunctionRaw)\n    : { codeLang: \"\", mode: \"syntax\", codeShape: \"auto\" };\n  return {\n    id,\n    kind: \"validator\",\n    name,\n    drop: column.drop === true,\n    // biome-ignore lint/style/useNamingConvention: api schema\n    target_columns: targetColumns,\n    validator_type: isOxc ? \"oxc\" : \"code\",\n    // biome-ignore lint/style/useNamingConvention: api schema\n    code_lang: normalizeValidatorCodeLang(\n      isOxc ? marker.codeLang || \"javascript\" : params.code_lang,\n    ),\n    oxc_validation_mode: isOxc\n      ? normalizeOxcValidationMode(marker.mode)\n      : \"syntax\",\n    oxc_code_shape: isOxc\n      ? normalizeOxcCodeShape(marker.codeShape)\n      : \"auto\",\n    batch_size: readNumberString(column.batch_size) || \"10\",\n  };\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/import/parsers.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { NodeConfig } from \"../../types\";\nimport { readString } from \"./helpers\";\nimport { parseExpression } from \"./parsers/expression-parser\";\nimport { parseLlm } from \"./parsers/llm-parser\";\nexport { parseModelConfig, parseModelProvider } from \"./parsers/model-parser\";\nimport { parseSampler } from \"./parsers/sampler-parser\";\nimport { parseValidator } from \"./parsers/validator-parser\";\n\ntype ColumnParser = (\n  column: Record<string, unknown>,\n  name: string,\n  id: string,\n  errors: string[],\n) => NodeConfig | null;\n\nconst COLUMN_PARSERS: Record<string, ColumnParser> = {\n  sampler: (column, name, id, errors) =>\n    parseSampler(column, name, id, errors),\n  expression: (column, name, id) => parseExpression(column, name, id),\n  \"llm-text\": (column, name, id) => parseLlm(column, name, id),\n  \"llm-structured\": (column, name, id) => parseLlm(column, name, id),\n  \"llm-code\": (column, name, id) => parseLlm(column, name, id),\n  \"llm-judge\": (column, name, id) => parseLlm(column, name, id),\n  validation: (column, name, id) => parseValidator(column, name, id),\n};\n\nexport function parseColumn(\n  column: Record<string, unknown>,\n  id: string,\n  errors: string[],\n): NodeConfig | null {\n  const name = readString(column.name);\n  if (!name) {\n    errors.push(\"Column missing name.\");\n    return null;\n  }\n  const columnType = readString(column.column_type);\n  const parser = columnType ? COLUMN_PARSERS[columnType] : null;\n  if (parser) {\n    return parser(column, name, id, errors);\n  }\n  errors.push(`Column ${name}: unsupported column_type.`);\n  return null;\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/import/types.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { Edge, XYPosition } from \"@xyflow/react\";\nimport type {\n  LayoutDirection,\n  RecipeNode,\n  RecipeProcessorConfig,\n  NodeConfig,\n} from \"../../types\";\n\nexport type RecipeSnapshot = {\n  configs: Record<string, NodeConfig>;\n  nodes: RecipeNode[];\n  edges: Edge[];\n  auxNodePositions: Record<string, XYPosition>;\n  processors: RecipeProcessorConfig[];\n  layoutDirection: LayoutDirection;\n  nextId: number;\n  nextY: number;\n};\n\nexport type ImportResult = {\n  errors: string[];\n  snapshot: RecipeSnapshot | null;\n};\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/import/ui.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { RecipeNode, NodeConfig } from \"../../types\";\nimport { DEFAULT_NODE_WIDTH } from \"../../constants\";\nimport { nodeDataFromConfig } from \"../index\";\nimport { normalizeRecipeHandleId } from \"../handles\";\nimport { isRecord, readString } from \"./helpers\";\n\ntype UiInput = {\n  nodes?: unknown;\n  edges?: unknown;\n  aux_nodes?: unknown;\n  layout_direction?: unknown;\n  layoutDirection?: unknown;\n};\n\ntype ParsedAuxNode = {\n  llm: string;\n  key: string;\n  x: number;\n  y: number;\n};\n\nexport function parseUi(\n  ui: UiInput | null,\n): {\n  layouts: Map<string, { x: number; y: number; width?: number }>;\n  auxNodes: ParsedAuxNode[];\n  edges: Array<{\n    from: string;\n    to: string;\n    type?: string;\n    sourceHandle?: string;\n    targetHandle?: string;\n  }> | null;\n  layoutDirection: \"LR\" | \"TB\" | null;\n} {\n  const layouts = new Map<string, { x: number; y: number; width?: number }>();\n  const auxNodes: ParsedAuxNode[] = [];\n  const edges: Array<{\n    from: string;\n    to: string;\n    type?: string;\n    sourceHandle?: string;\n    targetHandle?: string;\n  }> = [];\n  if (ui && Array.isArray(ui.nodes)) {\n    for (const node of ui.nodes) {\n      if (isRecord(node)) {\n        const id = readString(node.id);\n        const x = typeof node.x === \"number\" ? node.x : null;\n        const y = typeof node.y === \"number\" ? node.y : null;\n        const width = typeof node.width === \"number\" ? node.width : null;\n        if (id && x !== null && y !== null) {\n          layouts.set(id, {\n            x,\n            y,\n            ...(width && width > 0 ? { width } : {}),\n          });\n        }\n      }\n    }\n  }\n  if (ui && Array.isArray(ui.edges)) {\n    for (const edge of ui.edges) {\n      if (isRecord(edge)) {\n        const from = readString(edge.from);\n        const to = readString(edge.to);\n        if (from && to) {\n          const sourceHandle = normalizeRecipeHandleId(\n            readString(edge.source_handle) ?? readString(edge.sourceHandle),\n          );\n          const targetHandle = normalizeRecipeHandleId(\n            readString(edge.target_handle) ?? readString(edge.targetHandle),\n          );\n          edges.push({\n            from,\n            to,\n            type: readString(edge.type) ?? undefined,\n            sourceHandle: sourceHandle ?? undefined,\n            targetHandle: targetHandle ?? undefined,\n          });\n        }\n      }\n    }\n  }\n  if (ui && Array.isArray(ui.aux_nodes)) {\n    for (const node of ui.aux_nodes) {\n      if (!isRecord(node)) {\n        continue;\n      }\n      const llm = readString(node.llm);\n      const key = readString(node.key);\n      const x = typeof node.x === \"number\" ? node.x : null;\n      const y = typeof node.y === \"number\" ? node.y : null;\n      if (!(llm && key && x !== null && y !== null)) {\n        continue;\n      }\n      auxNodes.push({ llm, key, x, y });\n    }\n  }\n  const layoutDirectionRaw =\n    readString(ui?.layout_direction) ?? readString(ui?.layoutDirection);\n  const layoutDirection =\n    layoutDirectionRaw === \"TB\"\n      ? \"TB\"\n      : layoutDirectionRaw === \"LR\"\n        ? \"LR\"\n        : null;\n\n  return {\n    layouts,\n    auxNodes,\n    edges: edges.length > 0 ? edges : null,\n    layoutDirection,\n  };\n}\n\nexport function buildNodes(\n  configs: NodeConfig[],\n  layouts: Map<string, { x: number; y: number; width?: number }>,\n): RecipeNode[] {\n  return configs.map((config, index) => {\n    const fallbackLayout: { x: number; y: number; width?: number } = {\n      x: 0,\n      y: index * 140,\n    };\n    const layout =\n      layouts.get(config.name) ?? fallbackLayout;\n    return {\n      id: config.id,\n      type: \"builder\",\n      position: { x: layout.x, y: layout.y },\n      data: nodeDataFromConfig(config),\n      style: { width: layout.width ?? DEFAULT_NODE_WIDTH },\n    };\n  });\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/index.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nexport {\n  makeExpressionConfig,\n  makeLlmConfig,\n  makeMarkdownNoteConfig,\n  makeModelConfig,\n  makeModelProviderConfig,\n  makeSamplerConfig,\n  makeSeedConfig,\n  makeToolProfileConfig,\n  makeValidatorConfig,\n} from \"./config-factories\";\nexport {\n  labelForExpression,\n  labelForLlm,\n  labelForSampler,\n} from \"./config-labels\";\nexport {\n  isCategoryConfig,\n  isExpressionConfig,\n  isLlmConfig,\n  isSamplerConfig,\n  isSubcategoryConfig,\n  isValidatorConfig,\n} from \"./config-type-guards\";\nexport { getGraphWarnings, type GraphWarning } from \"./graph-warnings\";\nexport { nextName } from \"./naming\";\nexport { nodeDataFromConfig } from \"./node-data\";\nexport { getConfigErrors } from \"./validation\";\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/layout.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport dagre from \"@dagrejs/dagre\";\nimport type { Edge, Node } from \"@xyflow/react\";\nimport { DEFAULT_NODE_HEIGHT, DEFAULT_NODE_WIDTH } from \"../constants\";\nimport { INFRA_NODE_KINDS, type LayoutDirection, type NodeConfig } from \"../types\";\nimport { readNodeHeight, readNodeWidth } from \"./rf-node-dimensions\";\n\ntype LayoutOptions = {\n  direction?: LayoutDirection;\n  nodesep?: number;\n  ranksep?: number;\n  edgesep?: number;\n  nodeWidth?: number;\n  nodeHeight?: number;\n  configs?: Record<string, NodeConfig>;\n};\n\n/**\n * Pipeline rank order used to enforce a logical flow even for disconnected nodes.\n * Lower rank = earlier in the pipeline.\n */\nfunction getPipelineRank(config: NodeConfig | undefined): number {\n  if (!config) {\n    return 2;\n  }\n  switch (config.kind) {\n    case \"seed\":\n      return 0;\n    case \"sampler\":\n      return 1;\n    case \"expression\":\n      return 2;\n    case \"llm\":\n      return 3;\n    case \"validator\":\n      return 4;\n    default:\n      return 2;\n  }\n}\n\nfunction isInfraNode(\n  nodeId: string,\n  configs: Record<string, NodeConfig>,\n): boolean {\n  const config = configs[nodeId];\n  return config ? INFRA_NODE_KINDS.has(config.kind) : false;\n}\n\nfunction isAuxNode(nodeId: string): boolean {\n  return nodeId.startsWith(\"aux-\");\n}\n\nfunction getEdgeWeight(edgeType: string | undefined): number {\n  if (edgeType === \"phantom\") {\n    return 0;\n  }\n  if (edgeType === \"semantic\") {\n    return 10;\n  }\n  return 3;\n}\n\n/**\n * Build phantom edges between disconnected data-pipeline nodes so dagre\n * respects the pipeline rank order even when blocks aren't wired together.\n *\n * Groups nodes by rank, then inserts invisible edges from the last node of\n * rank N to the first node of rank N+1 when no real edge already connects them.\n */\nfunction buildPhantomEdges(\n  nodes: Node[],\n  edges: Edge[],\n  configs: Record<string, NodeConfig>,\n): Edge[] {\n  // Group nodes by rank\n  const byRank = new Map<number, string[]>();\n  for (const node of nodes) {\n    const rank = getPipelineRank(configs[node.id]);\n    const list = byRank.get(rank) ?? [];\n    list.push(node.id);\n    byRank.set(rank, list);\n  }\n\n  const ranks = Array.from(byRank.keys()).sort((a, b) => a - b);\n  const phantoms: Edge[] = [];\n\n  for (let i = 0; i < ranks.length - 1; i++) {\n    const currentIds = byRank.get(ranks[i]) ?? [];\n    const nextIds = byRank.get(ranks[i + 1]) ?? [];\n    if (currentIds.length === 0 || nextIds.length === 0) {\n      continue;\n    }\n\n    // Check if any real edge already connects these rank groups\n    const hasRealEdge = edges.some(\n      (e) => currentIds.includes(e.source) && nextIds.includes(e.target),\n    );\n    if (hasRealEdge) {\n      continue;\n    }\n\n    // Insert one phantom edge from last node in current rank to first in next\n    phantoms.push({\n      id: `phantom-${ranks[i]}-${ranks[i + 1]}`,\n      source: currentIds[currentIds.length - 1],\n      target: nextIds[0],\n      type: \"phantom\",\n    });\n  }\n\n  return phantoms;\n}\n\nexport function getLayoutedElements<TNode extends Node>(\n  nodes: TNode[],\n  edges: Edge[],\n  options: LayoutOptions = {},\n): { nodes: TNode[]; edges: Edge[] } {\n  const {\n    direction = \"LR\",\n    nodesep = 80,\n    ranksep = 80,\n    edgesep = 28,\n    nodeWidth = DEFAULT_NODE_WIDTH,\n    nodeHeight = DEFAULT_NODE_HEIGHT,\n    configs,\n  } = options;\n\n  // When configs are provided, filter out infra and aux nodes from dagre\n  const hasConfigs = configs && Object.keys(configs).length > 0;\n  const dataNodes = hasConfigs\n    ? nodes.filter((n) => !(isInfraNode(n.id, configs) || isAuxNode(n.id)))\n    : nodes;\n  const dataEdges = hasConfigs\n    ? edges.filter(\n        (e) =>\n          !(\n            isInfraNode(e.source, configs) ||\n            isInfraNode(e.target, configs) ||\n            isAuxNode(e.source) ||\n            isAuxNode(e.target)\n          ),\n      )\n    : edges;\n\n  // Build phantom edges to enforce pipeline rank ordering for disconnected nodes\n  const phantomEdges = hasConfigs\n    ? buildPhantomEdges(dataNodes, dataEdges, configs)\n    : [];\n\n  const graph = new dagre.graphlib.Graph();\n  graph.setDefaultEdgeLabel(() => ({}));\n  graph.setGraph({\n    rankdir: direction,\n    nodesep,\n    ranksep,\n    edgesep,\n    ranker: \"network-simplex\",\n  });\n\n  for (const node of dataNodes) {\n    const width = readNodeWidth(node) ?? nodeWidth;\n    const height = readNodeHeight(node) ?? nodeHeight;\n    graph.setNode(node.id, { width, height });\n  }\n\n  const allDagreEdges = [...dataEdges, ...phantomEdges];\n  for (const edge of allDagreEdges) {\n    const weight = getEdgeWeight(edge.type);\n    graph.setEdge(edge.source, edge.target, { minlen: 1, weight });\n  }\n\n  dagre.layout(graph);\n\n  // Build position map from dagre results (data nodes only)\n  const layoutedPositions = new Map<string, { x: number; y: number }>();\n  for (const node of dataNodes) {\n    const pos = graph.node(node.id);\n    const width = readNodeWidth(node) ?? nodeWidth;\n    const height = readNodeHeight(node) ?? nodeHeight;\n    layoutedPositions.set(node.id, {\n      x: pos.x - width / 2,\n      y: pos.y - height / 2,\n    });\n  }\n\n  // Apply positions: data nodes get dagre positions, infra/aux keep original\n  const layoutedNodes = nodes.map((node) => {\n    const position = layoutedPositions.get(node.id);\n    if (!position) {\n      return node;\n    }\n    return {\n      ...node,\n      position,\n    };\n  });\n\n  return { nodes: layoutedNodes, edges };\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/naming.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { NodeConfig } from \"../types\";\n\nexport function nextName(existing: NodeConfig[], prefix: string): string {\n  const counts = existing\n    .map((item) => item.name)\n    .filter((name) => name.startsWith(prefix))\n    .map((name) => {\n      const suffix = name.slice(prefix.length);\n      const num = Number.parseInt(suffix.replace(\"_\", \"\"), 10);\n      return Number.isNaN(num) ? 0 : num;\n    });\n  const next = counts.length > 0 ? Math.max(...counts) + 1 : 1;\n  return `${prefix}_${next}`;\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/node-data.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { RecipeNodeData, LayoutDirection, NodeConfig } from \"../types\";\nimport {\n  labelForExpression,\n  labelForLlm,\n  labelForSampler,\n} from \"./config-labels\";\n\nexport function nodeDataFromConfig(\n  config: NodeConfig,\n  layoutDirection: LayoutDirection = \"LR\",\n): RecipeNodeData {\n  if (config.kind === \"sampler\") {\n    return {\n      title: \"Generated field\",\n      kind: \"sampler\",\n      subtype: labelForSampler(config.sampler_type),\n      blockType: config.sampler_type,\n      name: config.name,\n      layoutDirection,\n    };\n  }\n  if (config.kind === \"expression\") {\n    return {\n      title: \"Formula\",\n      kind: \"expression\",\n      subtype: labelForExpression(config.dtype),\n      blockType: \"expression\",\n      name: config.name,\n      layoutDirection,\n    };\n  }\n  if (config.kind === \"validator\") {\n    const isOxc = config.validator_type === \"oxc\";\n    const isSql = config.code_lang.startsWith(\"sql:\");\n    let subtype = \"Python\";\n    let blockType: RecipeNodeData[\"blockType\"] = \"validator_python\";\n    if (isOxc) {\n      subtype = \"OXC\";\n      blockType = \"validator_oxc\";\n    } else if (isSql) {\n      subtype = \"SQL\";\n      blockType = \"validator_sql\";\n    }\n    return {\n      title: \"Check\",\n      kind: \"validator\",\n      subtype,\n      blockType,\n      name: config.name,\n      layoutDirection,\n    };\n  }\n  if (config.kind === \"markdown_note\") {\n    return {\n      title: \"Note\",\n      kind: \"note\",\n      subtype: \"Markdown\",\n      blockType: \"markdown_note\",\n      name: config.name,\n      layoutDirection,\n    };\n  }\n  if (config.kind === \"seed\") {\n    const seedSourceType = config.seed_source_type ?? \"hf\";\n    const sourceLabel =\n      seedSourceType === \"hf\"\n        ? \"Hugging Face dataset\"\n        : seedSourceType === \"local\"\n          ? \"CSV or JSON file\"\n          : \"Document file\";\n    return {\n      title: \"Source data\",\n      kind: \"seed\",\n      subtype: sourceLabel,\n      blockType: \"seed\",\n      name: sourceLabel,\n      layoutDirection,\n    };\n  }\n  if (config.kind === \"model_provider\") {\n    return {\n      title: \"Provider connection\",\n      kind: \"model_provider\",\n      subtype: config.provider_type || \"Connection\",\n      blockType: \"model_provider\",\n      name: config.name,\n      layoutDirection,\n    };\n  }\n  if (config.kind === \"model_config\") {\n    return {\n      title: \"Model preset\",\n      kind: \"model_config\",\n      subtype: config.model || \"Model\",\n      blockType: \"model_config\",\n      name: config.name,\n      layoutDirection,\n    };\n  }\n  if (config.kind === \"tool_config\") {\n    const providerCount = config.mcp_providers.length;\n    return {\n      title: \"Tool access\",\n      kind: \"tool_config\",\n      subtype: providerCount === 1 ? \"1 server\" : `${providerCount} servers`,\n      blockType: \"tool_config\",\n      name: config.name,\n      layoutDirection,\n    };\n  }\n  return {\n    title: \"AI step\",\n    kind: \"llm\",\n    subtype: labelForLlm(config.llm_type),\n    blockType: config.llm_type,\n    name: config.name,\n    layoutDirection,\n  };\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/parse.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nexport function parseNumber(value?: string): number | null {\n  if (!value) {\n    return null;\n  }\n  const num = Number(value);\n  return Number.isFinite(num) ? num : null;\n}\n\nexport function parseIntNumber(value?: string): number | null {\n  const num = parseNumber(value);\n  if (num === null || !Number.isInteger(num)) {\n    return null;\n  }\n  return num;\n}\n\nexport function parseAgeRange(value?: string): [number, number] | null {\n  if (!value) {\n    return null;\n  }\n  const parts = value.split(/[^0-9.]+/).filter(Boolean);\n  if (parts.length !== 2) {\n    return null;\n  }\n  const min = Number(parts[0]);\n  const max = Number(parts[1]);\n  if (!Number.isFinite(min) || !Number.isFinite(max)) {\n    return null;\n  }\n  return [min, max];\n}\n\nexport function parseJsonObject(\n  value: string | undefined,\n  label: string,\n  errors: string[],\n): Record<string, unknown> | undefined {\n  if (!value || !value.trim()) {\n    return undefined;\n  }\n  try {\n    const parsed = JSON.parse(value);\n    if (parsed && typeof parsed === \"object\" && !Array.isArray(parsed)) {\n      return parsed as Record<string, unknown>;\n    }\n  } catch {\n    errors.push(`${label}: invalid JSON.`);\n    return undefined;\n  }\n  errors.push(`${label}: must be a JSON object.`);\n  return undefined;\n}\n\nexport function isValidSex(value?: string): value is \"Male\" | \"Female\" {\n  if (!value) {\n    return false;\n  }\n  return value === \"Male\" || value === \"Female\";\n}\n\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/payload/build-payload.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { Edge, XYPosition } from \"@xyflow/react\";\nimport type {\n  LayoutDirection,\n  ModelConfig,\n  ModelProviderConfig,\n  NodeConfig,\n  RecipeNode,\n  RecipeProcessorConfig,\n} from \"../../types\";\nimport { isSemanticRelation } from \"../graph/relations\";\nimport { getConfigErrors } from \"../index\";\nimport {\n  getDefaultDataSourceHandle,\n  getDefaultDataTargetHandle,\n  getDefaultSemanticSourceHandle,\n  getDefaultSemanticTargetHandle,\n  isDataSourceHandle,\n  isDataTargetHandle,\n  isSemanticSourceHandle,\n  isSemanticTargetHandle,\n  normalizeRecipeHandleId,\n} from \"../handles\";\nimport { readNodeWidth } from \"../rf-node-dimensions\";\nimport {\n  buildExpressionColumn,\n  buildLlmColumn,\n  buildModelConfig,\n  buildModelProvider,\n  buildProcessors,\n  buildSamplerColumn,\n  buildSeedConfig,\n  buildSeedDropProcessor,\n  buildToolProfilePayload,\n  buildValidatorColumn,\n  pickFirstSeedConfig,\n} from \"./builders\";\nimport type { RecipePayloadResult } from \"./types\";\nimport {\n  validateModelAliasLinks,\n  validateModelConfigProviders,\n  validateSubcategoryConfigs,\n  validateTimedeltaConfigs,\n  validateValidatorConfigs,\n  validateUsedProviders,\n} from \"./validate\";\nimport { isLikelyImageValue } from \"../image-preview\";\n\nfunction pushUniqueJson(\n  label: string,\n  key: string,\n  item: Record<string, unknown>,\n  seen: Map<string, string>,\n  out: Record<string, unknown>[],\n  errors: string[],\n): void {\n  const serialized = JSON.stringify(item);\n  const existing = seen.get(key);\n  if (existing && existing !== serialized) {\n    errors.push(`${label} ${key}: conflicting definitions.`);\n    return;\n  }\n  if (!existing) {\n    seen.set(key, serialized);\n    out.push(item);\n  }\n}\n\nfunction collectAdvancedOpenByNode(\n  configs: Record<string, NodeConfig>,\n): Record<string, boolean> {\n  const out: Record<string, boolean> = {};\n  for (const config of Object.values(configs)) {\n    if (\n      !(\n        config.kind === \"sampler\" ||\n        config.kind === \"llm\" ||\n        config.kind === \"validator\" ||\n        config.kind === \"seed\"\n      )\n    ) {\n      continue;\n    }\n    if (config.advancedOpen !== true) {\n      continue;\n    }\n    out[config.name] = true;\n  }\n  return out;\n}\n\n// biome-ignore lint/complexity/noExcessiveCognitiveComplexity: payload build\nexport function buildRecipePayload(\n  configs: Record<string, NodeConfig>,\n  nodes: RecipeNode[],\n  edges: Edge[],\n  processors: RecipeProcessorConfig[] = [],\n  layoutDirection: LayoutDirection = \"LR\",\n  auxNodePositions: Record<string, XYPosition> = {},\n): RecipePayloadResult {\n  const errors: string[] = [];\n  const columns: Record<string, unknown>[] = [];\n  const modelAliases = new Set<string>();\n  const modelProviderNames = new Set<string>();\n  const modelProviders: Record<string, unknown>[] = [];\n  const mcpProviders: Record<string, unknown>[] = [];\n  const modelConfigs: Record<string, unknown>[] = [];\n  const toolConfigs: Record<string, unknown>[] = [];\n  const modelProviderConfigs: ModelProviderConfig[] = [];\n  const modelConfigConfigs: ModelConfig[] = [];\n  const llmToolAliasesUsed = new Set<string>();\n  const mcpProviderJsonByName = new Map<string, string>();\n  const toolConfigJsonByAlias = new Map<string, string>();\n  const nameSet = new Set<string>();\n  const nameToConfig = new Map<string, NodeConfig>();\n  const allNameToConfig = new Map<string, NodeConfig>();\n  const firstSeed = pickFirstSeedConfig(configs);\n\n  for (const config of Object.values(configs)) {\n    if (config.kind === \"seed\") {\n      continue;\n    }\n    allNameToConfig.set(config.name, config);\n  }\n\n  for (const node of nodes) {\n    const config = configs[node.id];\n    if (!config) {\n      continue;\n    }\n    for (const error of getConfigErrors(config)) {\n      errors.push(`${config.name}: ${error}`);\n    }\n    if (config.kind !== \"seed\") {\n      if (nameSet.has(config.name)) {\n        errors.push(`Duplicate node name: ${config.name}.`);\n      }\n      nameSet.add(config.name);\n    }\n\n    if (config.kind === \"sampler\") {\n      nameToConfig.set(config.name, config);\n      columns.push(buildSamplerColumn(config, errors));\n      continue;\n    }\n    if (config.kind === \"llm\") {\n      if (config.image_context?.enabled) {\n        const imageContext = config.image_context;\n        const columnName = imageContext.column_name.trim();\n        if (columnName) {\n          if (firstSeed?.seed_columns && firstSeed.seed_columns.length > 0) {\n            if (!firstSeed.seed_columns.includes(columnName)) {\n              errors.push(\n                `LLM ${config.name}: image context column '${columnName}' not found in seed columns.`,\n              );\n            }\n          }\n          const previewRows = firstSeed?.seed_preview_rows ?? [];\n          if (previewRows.length > 0) {\n            const hasImageLikeValue = previewRows.some((row) =>\n              isLikelyImageValue(row[columnName]),\n            );\n            if (!hasImageLikeValue) {\n              errors.push(\n                `LLM ${config.name}: image context column '${columnName}' has no image-like values in preview rows.`,\n              );\n            }\n          }\n        }\n      }\n      columns.push(buildLlmColumn(config, errors));\n      if (config.model_alias) {\n        modelAliases.add(config.model_alias);\n      }\n      const toolAlias = config.tool_alias?.trim();\n      if (toolAlias) {\n        llmToolAliasesUsed.add(toolAlias);\n      }\n      nameToConfig.set(config.name, config);\n      continue;\n    }\n    if (config.kind === \"expression\") {\n      columns.push(buildExpressionColumn(config, errors));\n      nameToConfig.set(config.name, config);\n      continue;\n    }\n    if (config.kind === \"validator\") {\n      columns.push(buildValidatorColumn(config, errors, allNameToConfig));\n      nameToConfig.set(config.name, config);\n      continue;\n    }\n    if (config.kind === \"seed\") {\n      // SeedConfig is global config (seed_config); seed-dataset columns are added by DataDesigner.\n      continue;\n    }\n    if (config.kind === \"markdown_note\") {\n      continue;\n    }\n    if (config.kind === \"model_provider\") {\n      modelProviderNames.add(config.name);\n      modelProviders.push(buildModelProvider(config, errors));\n      modelProviderConfigs.push(config);\n      continue;\n    }\n    if (config.kind === \"tool_config\") {\n      const built = buildToolProfilePayload(config, errors);\n      for (const provider of built.mcp_providers) {\n        pushUniqueJson(\n          \"MCP provider\",\n          String(provider.name),\n          provider,\n          mcpProviderJsonByName,\n          mcpProviders,\n          errors,\n        );\n      }\n      if (built.tool_config) {\n        pushUniqueJson(\n          \"Tool config\",\n          String(built.tool_config.tool_alias),\n          built.tool_config,\n          toolConfigJsonByAlias,\n          toolConfigs,\n          errors,\n        );\n      }\n      continue;\n    }\n    modelConfigs.push(buildModelConfig(config, errors));\n    modelConfigConfigs.push(config);\n  }\n\n  validateSubcategoryConfigs(configs, nameToConfig, errors);\n  validateTimedeltaConfigs(configs, nameToConfig, errors);\n  validateValidatorConfigs(configs, nameToConfig, errors);\n  validateModelAliasLinks(modelAliases, modelConfigConfigs, errors);\n  validateModelConfigProviders(\n    modelConfigConfigs,\n    modelAliases,\n    modelProviderNames,\n    errors,\n  );\n  validateUsedProviders(modelProviderConfigs, modelConfigConfigs, errors);\n  for (const toolAlias of llmToolAliasesUsed) {\n    if (!toolConfigJsonByAlias.has(toolAlias)) {\n      errors.push(`Tool alias ${toolAlias}: missing tool config.`);\n    }\n  }\n\n  const uiNodes = nodes.flatMap((node) => {\n    const config = configs[node.id];\n    if (!config) {\n      return [];\n    }\n    const width = readNodeWidth(node);\n    if (config.kind === \"markdown_note\") {\n      return [\n        {\n          id: config.name,\n          x: node.position.x,\n          y: node.position.y,\n          ...(width !== null ? { width } : {}),\n          node_type: \"markdown_note\" as const,\n          name: config.name,\n          markdown: config.markdown,\n          note_color: config.note_color,\n          note_opacity: config.note_opacity,\n        },\n      ];\n    }\n    if (config.kind === \"tool_config\") {\n      const toolsByProvider = Object.fromEntries(\n        Object.entries(config.fetched_tools_by_provider ?? {}).flatMap(\n          ([providerName, tools]) => {\n            const name = providerName.trim();\n            const values = Array.from(\n              new Set(tools.map((tool) => tool.trim()).filter(Boolean)),\n            );\n            return name && values.length > 0 ? [[name, values]] : [];\n          },\n        ),\n      );\n      return [\n        {\n          id: config.name,\n          x: node.position.x,\n          y: node.position.y,\n          ...(width !== null ? { width } : {}),\n          node_type: \"tool_config\" as const,\n          ...(Object.keys(toolsByProvider).length > 0 && {\n            tools_by_provider: toolsByProvider,\n          }),\n        },\n      ];\n    }\n    return [\n      {\n        id: config.name,\n        x: node.position.x,\n        y: node.position.y,\n        ...(width !== null ? { width } : {}),\n      },\n    ];\n  });\n\n  const uiEdges = edges.flatMap((edge) => {\n    const source = edge.source ? configs[edge.source] : null;\n    const target = edge.target ? configs[edge.target] : null;\n    if (!(source && target)) {\n      return [];\n    }\n    if (source.kind === \"markdown_note\" || target.kind === \"markdown_note\") {\n      return [];\n    }\n    const semantic =\n      edge.type === \"semantic\" || isSemanticRelation(source, target);\n    const sourceHandleNormalized = normalizeRecipeHandleId(edge.sourceHandle);\n    const targetHandleNormalized = normalizeRecipeHandleId(edge.targetHandle);\n    const semanticSourceDefault =\n      source.kind === \"llm\"\n        ? getDefaultDataSourceHandle(layoutDirection)\n        : getDefaultSemanticSourceHandle(layoutDirection);\n    const semanticTargetDefault =\n      target.kind === \"llm\"\n        ? getDefaultDataTargetHandle(layoutDirection)\n        : getDefaultSemanticTargetHandle(layoutDirection);\n    let sourceHandle = getDefaultDataSourceHandle(layoutDirection);\n    let targetHandle = getDefaultDataTargetHandle(layoutDirection);\n\n    if (semantic) {\n      sourceHandle =\n        isSemanticSourceHandle(sourceHandleNormalized) ||\n        isDataSourceHandle(sourceHandleNormalized)\n          ? sourceHandleNormalized ?? semanticSourceDefault\n          : semanticSourceDefault;\n      targetHandle =\n        isSemanticTargetHandle(targetHandleNormalized) ||\n        isDataTargetHandle(targetHandleNormalized)\n          ? targetHandleNormalized ?? semanticTargetDefault\n          : semanticTargetDefault;\n    } else {\n      sourceHandle = isDataSourceHandle(sourceHandleNormalized)\n        ? sourceHandleNormalized ?? getDefaultDataSourceHandle(layoutDirection)\n        : getDefaultDataSourceHandle(layoutDirection);\n      targetHandle = isDataTargetHandle(targetHandleNormalized)\n        ? targetHandleNormalized ?? getDefaultDataTargetHandle(layoutDirection)\n        : getDefaultDataTargetHandle(layoutDirection);\n    }\n    return [\n      {\n        from: source.name,\n        to: target.name,\n        type: semantic ? \"semantic\" : \"canvas\",\n        source_handle: sourceHandle ?? undefined,\n        target_handle: targetHandle ?? undefined,\n      },\n    ];\n  });\n  const uiAuxNodes = Object.entries(auxNodePositions).flatMap(\n    ([auxId, position]) => {\n      const match = /^aux-([^-]+)-(.+)$/.exec(auxId);\n      if (!match) {\n        return [];\n      }\n      const [, llmId, key] = match;\n      const llmConfig = configs[llmId];\n      if (!(llmConfig && llmConfig.kind === \"llm\")) {\n        return [];\n      }\n      return [\n        {\n          llm: llmConfig.name,\n          key,\n          x: position.x,\n          y: position.y,\n        },\n      ];\n    },\n  );\n  const recipeProcessors = buildProcessors(processors, errors);\n  const seedConfig = firstSeed ? buildSeedConfig(firstSeed, errors) : undefined;\n  const seedDropProcessor = firstSeed\n    ? buildSeedDropProcessor(firstSeed, errors)\n    : null;\n  if (seedDropProcessor) {\n    recipeProcessors.push(seedDropProcessor);\n  }\n  const uiAdvancedOpenByNode = collectAdvancedOpenByNode(configs);\n\n  return {\n    errors,\n    payload: {\n      recipe: {\n        // biome-ignore lint/style/useNamingConvention: api schema\n        model_providers: modelProviders,\n        // biome-ignore lint/style/useNamingConvention: api schema\n        mcp_providers: mcpProviders,\n        // biome-ignore lint/style/useNamingConvention: api schema\n        model_configs: modelConfigs,\n        // biome-ignore lint/style/useNamingConvention: api schema\n        seed_config: seedConfig,\n        // biome-ignore lint/style/useNamingConvention: api schema\n        tool_configs: toolConfigs,\n        columns,\n        processors: recipeProcessors,\n      },\n      run: {\n        rows: 5,\n        preview: true,\n        // biome-ignore lint/style/useNamingConvention: api schema\n        output_formats: [\"jsonl\"],\n      },\n      ui: {\n        nodes: uiNodes,\n        edges: uiEdges,\n        layout_direction: layoutDirection,\n        ...(uiAuxNodes.length > 0 && { aux_nodes: uiAuxNodes }),\n        ...(firstSeed && { seed_source_type: firstSeed.seed_source_type }),\n        ...(firstSeed && { seed_columns: firstSeed.seed_columns ?? [] }),\n        ...(firstSeed && {\n          seed_drop_columns: firstSeed.seed_drop_columns ?? [],\n        }),\n        ...(firstSeed && {\n          seed_preview_rows: firstSeed.seed_preview_rows ?? [],\n        }),\n        ...(firstSeed &&\n          firstSeed.local_file_name !== undefined && {\n            local_file_name: firstSeed.local_file_name,\n          }),\n        ...(firstSeed &&\n          firstSeed.unstructured_file_name !== undefined && {\n            unstructured_file_name: firstSeed.unstructured_file_name,\n          }),\n        ...(firstSeed &&\n          firstSeed.unstructured_chunk_size !== undefined && {\n            unstructured_chunk_size: firstSeed.unstructured_chunk_size,\n          }),\n        ...(firstSeed &&\n          firstSeed.unstructured_chunk_overlap !== undefined && {\n            unstructured_chunk_overlap: firstSeed.unstructured_chunk_overlap,\n          }),\n        ...(Object.keys(uiAdvancedOpenByNode).length > 0 && {\n          // biome-ignore lint/style/useNamingConvention: ui schema\n          advanced_open_by_node: uiAdvancedOpenByNode,\n        }),\n      },\n    },\n  };\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/payload/builders-llm.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type {\n  LlmConfig,\n  LlmMcpProviderConfig,\n  LlmToolConfig,\n  ToolProfileConfig,\n} from \"../../types\";\n\nfunction buildImageContext(\n  config: LlmConfig,\n  errors: string[],\n): Array<Record<string, unknown>> | undefined {\n  const imageContext = config.image_context;\n  if (!imageContext?.enabled) {\n    return undefined;\n  }\n  const columnName = imageContext.column_name.trim();\n  if (!columnName) {\n    errors.push(`LLM ${config.name}: image context column is required.`);\n    return undefined;\n  }\n  return [\n    {\n      modality: \"image\",\n      // biome-ignore lint/style/useNamingConvention: api schema\n      column_name: columnName,\n    },\n  ];\n}\n\nexport function buildLlmColumn(\n  config: LlmConfig,\n  errors: string[],\n): Record<string, unknown> {\n  const toolAlias = config.tool_alias?.trim();\n  const base = {\n    name: config.name,\n    drop: config.drop ?? false,\n    // biome-ignore lint/style/useNamingConvention: api schema\n    model_alias: config.model_alias,\n    prompt: config.prompt,\n    // biome-ignore lint/style/useNamingConvention: api schema\n    system_prompt: config.system_prompt || undefined,\n    // biome-ignore lint/style/useNamingConvention: api schema\n    multi_modal_context: buildImageContext(config, errors),\n    // biome-ignore lint/style/useNamingConvention: api schema\n    tool_alias: toolAlias || undefined,\n    // biome-ignore lint/style/useNamingConvention: api schema\n    with_trace: config.with_trace ?? \"none\",\n    // biome-ignore lint/style/useNamingConvention: api schema\n    extract_reasoning_content: config.extract_reasoning_content === true,\n  };\n\n  if (config.llm_type === \"code\") {\n    return {\n      // biome-ignore lint/style/useNamingConvention: api schema\n      column_type: \"llm-code\",\n      ...base,\n      // biome-ignore lint/style/useNamingConvention: api schema\n      code_lang: config.code_lang || \"python\",\n    };\n  }\n  if (config.llm_type === \"structured\") {\n    let outputFormat: unknown = config.output_format || undefined;\n    if (typeof outputFormat === \"string\" && outputFormat.trim()) {\n      try {\n        outputFormat = JSON.parse(outputFormat);\n      } catch {\n        errors.push(`LLM ${config.name}: output_format is not valid JSON.`);\n      }\n    }\n    return {\n      // biome-ignore lint/style/useNamingConvention: api schema\n      column_type: \"llm-structured\",\n      ...base,\n      // biome-ignore lint/style/useNamingConvention: api schema\n      output_format: outputFormat,\n    };\n  }\n  if (config.llm_type === \"judge\") {\n    const scores = (config.scores ?? [])\n      .map((score) => {\n        const options: Record<string, string> = {};\n        for (const option of score.options ?? []) {\n          const key = option.value.trim();\n          const value = option.description.trim();\n          if (!key || !value) {\n            continue;\n          }\n          options[key] = value;\n        }\n        return {\n          name: score.name.trim(),\n          description: score.description.trim(),\n          options,\n        };\n      })\n      .filter(\n        (score) =>\n          score.name && score.description && Object.keys(score.options).length > 0,\n      );\n    if (scores.length === 0) {\n      errors.push(`LLM ${config.name}: scores required for LLM Judge.`);\n    }\n    return {\n      // biome-ignore lint/style/useNamingConvention: api schema\n      column_type: \"llm-judge\",\n      ...base,\n      scores,\n    };\n  }\n  return {\n    // biome-ignore lint/style/useNamingConvention: api schema\n    column_type: \"llm-text\",\n    ...base,\n  };\n}\n\nexport function buildLlmMcpProvider(\n  provider: LlmMcpProviderConfig,\n  errors: string[],\n): Record<string, unknown> | null {\n  const name = provider.name.trim();\n  if (!name) {\n    errors.push(\"MCP provider: name is required.\");\n    return null;\n  }\n  if (provider.provider_type === \"stdio\") {\n    const command = provider.command?.trim() ?? \"\";\n    if (!command) {\n      errors.push(`MCP provider ${name}: command is required for stdio.`);\n      return null;\n    }\n    const env: Record<string, string> = {};\n    for (const item of provider.env ?? []) {\n      const key = item.key.trim();\n      const value = item.value.trim();\n      if (key && value) {\n        env[key] = value;\n      }\n    }\n    return {\n      // biome-ignore lint/style/useNamingConvention: api schema\n      provider_type: \"stdio\",\n      name,\n      command,\n      args: (provider.args ?? []).map((value) => value.trim()).filter(Boolean),\n      env,\n    };\n  }\n  const endpoint = provider.endpoint?.trim() ?? \"\";\n  if (!endpoint) {\n    errors.push(`MCP provider ${name}: endpoint is required.`);\n    return null;\n  }\n  return {\n    // biome-ignore lint/style/useNamingConvention: api schema\n    provider_type: \"streamable_http\",\n    name,\n    endpoint,\n    // biome-ignore lint/style/useNamingConvention: api schema\n    api_key: provider.api_key?.trim() || undefined,\n    // biome-ignore lint/style/useNamingConvention: api schema\n    api_key_env: provider.api_key_env?.trim() || undefined,\n  };\n}\n\nexport function buildLlmToolConfig(\n  config: LlmToolConfig,\n  errors: string[],\n): Record<string, unknown> | null {\n  const toolAlias = config.tool_alias.trim();\n  if (!toolAlias) {\n    errors.push(\"Tool config: tool_alias is required.\");\n    return null;\n  }\n  const providers = config.providers\n    .map((value) => value.trim())\n    .filter(Boolean);\n  if (providers.length === 0) {\n    errors.push(`Tool config ${toolAlias}: at least one provider is required.`);\n    return null;\n  }\n  const allowTools = (config.allow_tools ?? [])\n    .map((value) => value.trim())\n    .filter(Boolean);\n  const maxToolCallTurnsRaw = config.max_tool_call_turns?.trim();\n  const maxToolCallTurns =\n    maxToolCallTurnsRaw && Number.isFinite(Number(maxToolCallTurnsRaw))\n      ? Number(maxToolCallTurnsRaw)\n      : 5;\n  const timeoutRaw = config.timeout_sec?.trim();\n  const timeoutSec =\n    timeoutRaw && Number.isFinite(Number(timeoutRaw))\n      ? Number(timeoutRaw)\n      : undefined;\n  return {\n    // biome-ignore lint/style/useNamingConvention: api schema\n    tool_alias: toolAlias,\n    providers,\n    // biome-ignore lint/style/useNamingConvention: api schema\n    allow_tools: allowTools.length > 0 ? allowTools : undefined,\n    // biome-ignore lint/style/useNamingConvention: api schema\n    max_tool_call_turns: maxToolCallTurns,\n    // biome-ignore lint/style/useNamingConvention: api schema\n    timeout_sec: timeoutSec,\n  };\n}\n\nexport function buildToolProfilePayload(\n  config: ToolProfileConfig,\n  errors: string[],\n): {\n  // biome-ignore lint/style/useNamingConvention: api schema\n  mcp_providers: Record<string, unknown>[];\n  // biome-ignore lint/style/useNamingConvention: api schema\n  tool_config: Record<string, unknown> | null;\n} {\n  const mcpProviders = config.mcp_providers\n    .map((provider) => buildLlmMcpProvider(provider, errors))\n    .flatMap((provider) => (provider ? [provider] : []));\n  const toolConfig = buildLlmToolConfig(\n    {\n      id: config.id,\n      // biome-ignore lint/style/useNamingConvention: api schema\n      tool_alias: config.name,\n      providers: mcpProviders\n        .map((provider) => String(provider.name ?? \"\").trim())\n        .filter(Boolean),\n      // biome-ignore lint/style/useNamingConvention: api schema\n      allow_tools: config.allow_tools,\n      // biome-ignore lint/style/useNamingConvention: api schema\n      max_tool_call_turns: config.max_tool_call_turns,\n      // biome-ignore lint/style/useNamingConvention: api schema\n      timeout_sec: config.timeout_sec,\n    },\n    errors,\n  );\n  return {\n    // biome-ignore lint/style/useNamingConvention: api schema\n    mcp_providers: mcpProviders,\n    // biome-ignore lint/style/useNamingConvention: api schema\n    tool_config: toolConfig,\n  };\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/payload/builders-model.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { ModelConfig, ModelProviderConfig } from \"../../types\";\nimport { parseJsonObject } from \"./parse\";\n\nexport function buildModelProvider(\n  config: ModelProviderConfig,\n  errors: string[],\n): Record<string, unknown> {\n  const extraHeaders = parseJsonObject(\n    config.extra_headers,\n    `Provider ${config.name} extra_headers`,\n    errors,\n  );\n  const extraBody = parseJsonObject(\n    config.extra_body,\n    `Provider ${config.name} extra_body`,\n    errors,\n  );\n  return {\n    name: config.name,\n    endpoint: config.endpoint,\n    // biome-ignore lint/style/useNamingConvention: api schema\n    provider_type: \"openai\",\n    // biome-ignore lint/style/useNamingConvention: api schema\n    api_key_env: config.api_key_env?.trim() || undefined,\n    // biome-ignore lint/style/useNamingConvention: api schema\n    api_key: config.api_key?.trim() || undefined,\n    // biome-ignore lint/style/useNamingConvention: api schema\n    extra_headers: extraHeaders ?? {},\n    // biome-ignore lint/style/useNamingConvention: api schema\n    extra_body: extraBody ?? {},\n  };\n}\n\nexport function buildModelConfig(\n  config: ModelConfig,\n  errors: string[],\n): Record<string, unknown> {\n  const inference: Record<string, unknown> = {};\n  const temp = config.inference_temperature?.trim();\n  const topP = config.inference_top_p?.trim();\n  const maxTokens = config.inference_max_tokens?.trim();\n  const timeout = config.inference_timeout?.trim();\n  const extraBody = parseJsonObject(\n    config.inference_extra_body,\n    `Model ${config.name} inference extra_body`,\n    errors,\n  );\n\n  if (temp) {\n    const parsed = Number(temp);\n    if (Number.isFinite(parsed)) {\n      inference.temperature = parsed;\n    }\n  }\n  if (topP) {\n    const parsed = Number(topP);\n    if (Number.isFinite(parsed)) {\n      // biome-ignore lint/style/useNamingConvention: api schema\n      inference.top_p = parsed;\n    }\n  }\n  if (maxTokens) {\n    const parsed = Number(maxTokens);\n    if (Number.isFinite(parsed)) {\n      // biome-ignore lint/style/useNamingConvention: api schema\n      inference.max_tokens = parsed;\n    }\n  }\n  if (timeout) {\n    const parsed = Number(timeout);\n    if (Number.isFinite(parsed)) {\n      inference.timeout = Math.trunc(parsed);\n    }\n  }\n  if (extraBody) {\n    // biome-ignore lint/style/useNamingConvention: api schema\n    inference.extra_body = extraBody;\n  }\n\n  return {\n    alias: config.name,\n    model: config.model,\n    provider: config.provider || undefined,\n    // biome-ignore lint/style/useNamingConvention: api schema\n    inference_parameters:\n      Object.keys(inference).length > 0 ? inference : undefined,\n    // biome-ignore lint/style/useNamingConvention: api schema\n    skip_health_check: config.skip_health_check || undefined,\n  };\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/payload/builders-processors.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { ExpressionConfig, RecipeProcessorConfig } from \"../../types\";\nimport { parseJsonObject } from \"./parse\";\n\nexport function buildExpressionColumn(\n  config: ExpressionConfig,\n  errors: string[],\n): Record<string, unknown> {\n  if (!config.expr.trim()) {\n    errors.push(`Expression ${config.name}: expr required.`);\n  }\n  return {\n    // biome-ignore lint/style/useNamingConvention: api schema\n    column_type: \"expression\",\n    name: config.name,\n    drop: config.drop ?? false,\n    expr: config.expr,\n    dtype: config.dtype,\n  };\n}\n\nexport function buildProcessors(\n  processors: RecipeProcessorConfig[],\n  errors: string[],\n): Record<string, unknown>[] {\n  const output: Record<string, unknown>[] = [];\n  for (const processor of processors) {\n    if (processor.processor_type !== \"schema_transform\") {\n      continue;\n    }\n    const name = processor.name.trim();\n    if (!name) {\n      errors.push(\"Schema transform: name is required.\");\n      continue;\n    }\n    const template = parseJsonObject(\n      processor.template,\n      `Schema transform ${name} template`,\n      errors,\n    );\n    if (!template) {\n      continue;\n    }\n    output.push({\n      // biome-ignore lint/style/useNamingConvention: api schema\n      processor_type: \"schema_transform\",\n      name,\n      template,\n    });\n  }\n  return output;\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/payload/builders-sampler.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { CategoryConditionalParams, SamplerConfig } from \"../../types\";\nimport { isValidSex, parseAgeRange, parseNumber } from \"./parse\";\n\nconst DATETIME_UNIT_MAP: Record<string, \"Y\" | \"M\" | \"D\" | \"h\" | \"m\" | \"s\"> = {\n  year: \"Y\",\n  month: \"M\",\n  day: \"D\",\n  hour: \"h\",\n  minute: \"m\",\n  second: \"s\",\n};\n\nfunction buildCategoryConditionalParams(\n  config: SamplerConfig,\n  errors: string[],\n): Record<string, CategoryConditionalParams> | undefined {\n  const conditional = config.conditional_params ?? {};\n  const output: Record<string, CategoryConditionalParams> = {};\n  for (const [rawCondition, params] of Object.entries(conditional)) {\n    const condition = rawCondition.trim();\n    if (!condition) {\n      errors.push(`Sampler ${config.name}: conditional rule needs condition text.`);\n      continue;\n    }\n    const values = (params.values ?? [])\n      .map((value) => value.trim())\n      .filter(Boolean);\n    if (values.length === 0) {\n      errors.push(`Sampler ${config.name}: conditional '${condition}' needs values.`);\n      continue;\n    }\n    const weights = params.weights ?? [];\n    const hasWeights = weights.some((weight) => weight !== null);\n    if (\n      hasWeights &&\n      (weights.length !== values.length || weights.some((weight) => weight === null))\n    ) {\n      errors.push(`Sampler ${config.name}: conditional '${condition}' weights invalid.`);\n      continue;\n    }\n    output[condition] = {\n      // biome-ignore lint/style/useNamingConvention: api schema\n      sampler_type: \"category\",\n      values,\n      weights: hasWeights\n        ? weights.filter((weight): weight is number => weight !== null)\n        : undefined,\n    };\n  }\n  return Object.keys(output).length > 0 ? output : undefined;\n}\n\n// biome-ignore lint/complexity/noExcessiveCognitiveComplexity: per type logic\nfunction buildSamplerParams(\n  config: SamplerConfig,\n  errors: string[],\n): Record<string, unknown> {\n  if (config.sampler_type === \"category\") {\n    const values = config.values ?? [];\n    const params: Record<string, unknown> = { values };\n    const weights = config.weights ?? [];\n    const hasWeights = weights.some((weight) => weight !== null);\n    if (hasWeights && weights.some((weight) => weight === null)) {\n      errors.push(`Sampler ${config.name}: weights missing values.`);\n    } else if (hasWeights) {\n      params.weights = weights.filter((weight) => weight !== null);\n    }\n    return params;\n  }\n  if (config.sampler_type === \"subcategory\") {\n    const mapping = config.subcategory_mapping ?? {};\n    for (const [key, values] of Object.entries(mapping)) {\n      if (!values || values.length === 0) {\n        errors.push(\n          `Subcategory ${config.name}: '${key}' needs at least 1 subcategory.`,\n        );\n      }\n    }\n    return {\n      category: config.subcategory_parent,\n      values: mapping,\n    };\n  }\n  if (config.sampler_type === \"uniform\") {\n    return {\n      low: parseNumber(config.low),\n      high: parseNumber(config.high),\n    };\n  }\n  if (config.sampler_type === \"gaussian\") {\n    return {\n      mean: parseNumber(config.mean),\n      // data_designer expects `stddev`\n      stddev: parseNumber(config.std),\n    };\n  }\n  if (config.sampler_type === \"bernoulli\") {\n    return {\n      p: parseNumber(config.p),\n    };\n  }\n  if (config.sampler_type === \"datetime\") {\n    const rawUnit = config.datetime_unit?.trim();\n    let unit: string | undefined = rawUnit || undefined;\n    if (rawUnit && DATETIME_UNIT_MAP[rawUnit]) {\n      unit = DATETIME_UNIT_MAP[rawUnit];\n    }\n    if (rawUnit === \"week\") {\n      errors.push(`Datetime ${config.name}: unit 'week' not supported.`);\n      unit = undefined;\n    }\n    return {\n      start: config.datetime_start ?? undefined,\n      end: config.datetime_end ?? undefined,\n      unit,\n    };\n  }\n  if (config.sampler_type === \"timedelta\") {\n    return {\n      // biome-ignore lint/style/useNamingConvention: api schema\n      dt_min: parseNumber(config.dt_min),\n      // biome-ignore lint/style/useNamingConvention: api schema\n      dt_max: parseNumber(config.dt_max),\n      // biome-ignore lint/style/useNamingConvention: api schema\n      reference_column_name: config.reference_column_name || undefined,\n      unit: config.timedelta_unit || undefined,\n    };\n  }\n  if (config.sampler_type === \"uuid\") {\n    const raw = config.uuid_format?.trim();\n    if (!raw) {\n      return {};\n    }\n    // UI historically used \"uuid4\" as a \"format\". data_designer uuid sampler is always uuid4.\n    if (raw.toLowerCase() === \"uuid4\") {\n      return {};\n    }\n    if (raw.toLowerCase() === \"short\" || raw.toLowerCase() === \"short_form\") {\n      // biome-ignore lint/style/useNamingConvention: api schema\n      return { short_form: true };\n    }\n    if (raw.toLowerCase() === \"upper\" || raw.toLowerCase() === \"uppercase\") {\n      return { uppercase: true };\n    }\n    if (raw.toLowerCase().startsWith(\"prefix:\")) {\n      return { prefix: raw.slice(\"prefix:\".length).trim() || undefined };\n    }\n    return {\n      prefix: raw,\n    };\n  }\n  const params: Record<string, unknown> = {};\n  if (config.person_locale?.trim()) {\n    params.locale = config.person_locale.trim();\n  }\n  if (config.person_sex?.trim()) {\n    if (isValidSex(config.person_sex.trim())) {\n      params.sex = config.person_sex.trim();\n    } else {\n      errors.push(`Person ${config.name}: sex must be Male or Female.`);\n    }\n  }\n  if (config.person_city?.trim()) {\n    params.city = config.person_city.trim();\n  }\n  if (config.person_age_range?.trim()) {\n    const parsed = parseAgeRange(config.person_age_range);\n    if (parsed) {\n      // biome-ignore lint/style/useNamingConvention: api schema\n      params.age_range = parsed;\n    } else {\n      errors.push(`Person ${config.name}: age range must be like 18-70.`);\n    }\n  }\n  return params;\n}\n\nexport function buildSamplerColumn(\n  config: SamplerConfig,\n  errors: string[],\n): Record<string, unknown> {\n  const samplerColumn: Record<string, unknown> = {\n    // biome-ignore lint/style/useNamingConvention: api schema\n    column_type: \"sampler\",\n    name: config.name,\n    drop: config.drop ?? false,\n    // biome-ignore lint/style/useNamingConvention: api schema\n    sampler_type: config.sampler_type,\n    params: buildSamplerParams(config, errors),\n    // biome-ignore lint/style/useNamingConvention: api schema\n    convert_to: config.convert_to ?? undefined,\n  };\n  if (config.sampler_type === \"category\") {\n    const conditionalParams = buildCategoryConditionalParams(config, errors);\n    if (conditionalParams) {\n      // biome-ignore lint/style/useNamingConvention: api schema\n      samplerColumn.conditional_params = conditionalParams;\n    }\n  }\n  return samplerColumn;\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/payload/builders-seed.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { NodeConfig, SeedConfig } from \"../../types\";\n\nconst DEFAULT_CHUNK_SIZE = 1200;\nconst DEFAULT_CHUNK_OVERLAP = 200;\nconst MAX_CHUNK_SIZE = 20000;\n\nfunction parseIntStrict(value: string | undefined): number | null {\n  const trimmed = value?.trim();\n  if (!trimmed) return null;\n  const num = Number(value);\n  if (!Number.isFinite(num) || !Number.isInteger(num)) return null;\n  return num;\n}\n\nfunction resolveChunking(config: SeedConfig): { chunkSize: number; chunkOverlap: number } {\n  const rawSize = parseIntStrict(config.unstructured_chunk_size);\n  const rawOverlap = parseIntStrict(config.unstructured_chunk_overlap);\n  const chunkSize = Math.min(MAX_CHUNK_SIZE, Math.max(1, rawSize ?? DEFAULT_CHUNK_SIZE));\n  const chunkOverlap = Math.min(\n    Math.max(0, chunkSize - 1),\n    Math.max(0, rawOverlap ?? DEFAULT_CHUNK_OVERLAP),\n  );\n  return { chunkSize, chunkOverlap };\n}\n\nexport function buildSeedConfig(\n  config: SeedConfig,\n  errors: string[],\n): Record<string, unknown> | undefined {\n  const seedSourceType = config.seed_source_type ?? \"hf\";\n  const path = config.hf_path.trim();\n\n  const endpoint = config.hf_endpoint?.trim() || \"https://huggingface.co\";\n  const token = config.hf_token?.trim() || null;\n\n  let selectionStrategy: Record<string, unknown> | null = null;\n  if (config.selection_type === \"index_range\") {\n    const start = parseIntStrict(config.selection_start);\n    const end = parseIntStrict(config.selection_end);\n    if (start === null || end === null) {\n      errors.push(`Seed ${config.name}: selection index range invalid.`);\n      return undefined;\n    }\n    selectionStrategy = { start, end };\n  } else if (config.selection_type === \"partition_block\") {\n    const index = parseIntStrict(config.selection_index);\n    const numPartitions = parseIntStrict(config.selection_num_partitions);\n    if (index === null || numPartitions === null) {\n      errors.push(`Seed ${config.name}: selection partition invalid.`);\n      return undefined;\n    }\n    // biome-ignore lint/style/useNamingConvention: api schema\n    selectionStrategy = { index, num_partitions: numPartitions };\n  }\n\n  const source =\n    seedSourceType === \"hf\"\n      ? {\n          // biome-ignore lint/style/useNamingConvention: api schema\n          seed_type: \"hf\",\n          path,\n          token,\n          endpoint,\n        }\n      : seedSourceType === \"unstructured\"\n        ? (() => {\n            const { chunkSize, chunkOverlap } = resolveChunking(config);\n            return {\n              // biome-ignore lint/style/useNamingConvention: api schema\n              seed_type: \"unstructured\",\n              path,\n              // biome-ignore lint/style/useNamingConvention: api schema\n              chunk_size: chunkSize,\n              // biome-ignore lint/style/useNamingConvention: api schema\n              chunk_overlap: chunkOverlap,\n            };\n          })()\n        : {\n            // biome-ignore lint/style/useNamingConvention: api schema\n            seed_type: \"local\",\n            path,\n          };\n\n  return {\n    source,\n    // biome-ignore lint/style/useNamingConvention: api schema\n    sampling_strategy: config.sampling_strategy,\n    // biome-ignore lint/style/useNamingConvention: api schema\n    selection_strategy: selectionStrategy,\n  };\n}\n\nexport function pickFirstSeedConfig(\n  configs: Record<string, NodeConfig>,\n): SeedConfig | null {\n  for (const config of Object.values(configs)) {\n    if (config.kind === \"seed\") {\n      return config;\n    }\n  }\n  return null;\n}\n\nexport function buildSeedDropProcessor(\n  config: SeedConfig,\n  errors: string[],\n): Record<string, unknown> | null {\n  const seedSourceType = config.seed_source_type ?? \"hf\";\n  const loadedCols = (config.seed_columns ?? []).map((c) => c.trim()).filter(Boolean);\n  let cols: string[] = [];\n\n  if (seedSourceType === \"unstructured\") {\n    if (!config.drop) {\n      return null;\n    }\n    cols = loadedCols;\n  } else {\n    const selectedDropColumns = (config.seed_drop_columns ?? [])\n      .map((c) => c.trim())\n      .filter(Boolean);\n    if (selectedDropColumns.length === 0) {\n      return null;\n    }\n    const loadedSet = new Set(loadedCols);\n    cols =\n      loadedCols.length > 0\n        ? selectedDropColumns.filter((col) => loadedSet.has(col))\n        : selectedDropColumns;\n  }\n\n  if (cols.length === 0) {\n    errors.push(\n      `Seed ${config.name}: selected drop columns are unavailable.`,\n    );\n    return null;\n  }\n  return {\n    // biome-ignore lint/style/useNamingConvention: api schema\n    processor_type: \"drop_columns\",\n    name: \"drop_seed_columns\",\n    // biome-ignore lint/style/useNamingConvention: api schema\n    column_names: cols,\n  };\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/payload/builders-validator.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { NodeConfig, ValidatorConfig } from \"../../types\";\nimport { isValidatorCodeLang } from \"../validators/code-lang\";\n\nconst OXC_VALIDATION_FN_MARKER = \"unsloth_oxc_validator\";\n\nfunction parseBatchSize(value: string): number {\n  const parsed = Number.parseInt(value, 10);\n  if (!Number.isFinite(parsed) || parsed < 1) {\n    return 10;\n  }\n  return parsed;\n}\n\nexport function buildValidatorColumn(\n  config: ValidatorConfig,\n  errors: string[],\n  nameToConfig?: Map<string, NodeConfig>,\n): Record<string, unknown> {\n  const targetColumns = (config.target_columns ?? [])\n    .map((value) => value.trim())\n    .filter(Boolean);\n  if (targetColumns.length === 0) {\n    errors.push(`Validator ${config.name}: target code column required.`);\n  }\n  if (config.validator_type === \"oxc\") {\n    const targetName = targetColumns[0] ?? \"\";\n    const targetConfig = targetName ? nameToConfig?.get(targetName) : null;\n    let codeLang = config.code_lang;\n    if (\n      targetConfig &&\n      targetConfig.kind === \"llm\" &&\n      targetConfig.llm_type === \"code\"\n    ) {\n      const targetLang = (targetConfig.code_lang ?? \"\").trim();\n      if (isValidatorCodeLang(targetLang)) {\n        codeLang = targetLang;\n      }\n    }\n    return {\n      // biome-ignore lint/style/useNamingConvention: api schema\n      column_type: \"validation\",\n      name: config.name,\n      drop: config.drop ?? false,\n      // biome-ignore lint/style/useNamingConvention: api schema\n      target_columns: targetColumns,\n      // biome-ignore lint/style/useNamingConvention: api schema\n      validator_type: \"local_callable\",\n      // biome-ignore lint/style/useNamingConvention: api schema\n      validator_params: {\n        // backend resolves this marker to a real callable.\n        // biome-ignore lint/style/useNamingConvention: api schema\n        validation_function: `${OXC_VALIDATION_FN_MARKER}:${codeLang}:${config.oxc_validation_mode}:${config.oxc_code_shape ?? \"auto\"}`,\n      },\n      // biome-ignore lint/style/useNamingConvention: api schema\n      batch_size: parseBatchSize(config.batch_size),\n    };\n  }\n\n  return {\n    // biome-ignore lint/style/useNamingConvention: api schema\n    column_type: \"validation\",\n    name: config.name,\n    drop: config.drop ?? false,\n    // biome-ignore lint/style/useNamingConvention: api schema\n    target_columns: targetColumns,\n    // biome-ignore lint/style/useNamingConvention: api schema\n    validator_type: \"code\",\n    // biome-ignore lint/style/useNamingConvention: api schema\n    validator_params: {\n      // biome-ignore lint/style/useNamingConvention: api schema\n      code_lang: config.code_lang,\n    },\n    // biome-ignore lint/style/useNamingConvention: api schema\n    batch_size: parseBatchSize(config.batch_size),\n  };\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/payload/builders.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nexport {\n  buildLlmColumn,\n  buildLlmMcpProvider,\n  buildLlmToolConfig,\n  buildToolProfilePayload,\n} from \"./builders-llm\";\nexport { buildModelConfig, buildModelProvider } from \"./builders-model\";\nexport { buildExpressionColumn, buildProcessors } from \"./builders-processors\";\nexport { buildSamplerColumn } from \"./builders-sampler\";\nexport { buildValidatorColumn } from \"./builders-validator\";\nexport {\n  buildSeedConfig,\n  buildSeedDropProcessor,\n  pickFirstSeedConfig,\n} from \"./builders-seed\";\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/payload/empty.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { RecipePayload } from \"./types\";\n\nexport function createEmptyRecipePayload(): RecipePayload {\n  return {\n    recipe: {\n      // biome-ignore lint/style/useNamingConvention: api schema\n      model_providers: [],\n      // biome-ignore lint/style/useNamingConvention: api schema\n      mcp_providers: [],\n      // biome-ignore lint/style/useNamingConvention: api schema\n      model_configs: [],\n      // biome-ignore lint/style/useNamingConvention: api schema\n      tool_configs: [],\n      columns: [],\n      processors: [],\n    },\n    run: {\n      rows: 5,\n      preview: true,\n      // biome-ignore lint/style/useNamingConvention: api schema\n      output_formats: [\"jsonl\"],\n    },\n    ui: {\n      nodes: [],\n      edges: [],\n      layout_direction: \"LR\",\n    },\n  };\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/payload/index.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nexport { buildRecipePayload } from \"./build-payload\";\nexport { createEmptyRecipePayload } from \"./empty\";\nexport type { RecipePayload, RecipePayloadResult } from \"./types\";\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/payload/parse.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nexport {\n  isValidSex,\n  parseAgeRange,\n  parseJsonObject,\n  parseNumber,\n} from \"../parse\";\n\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/payload/types.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nexport type RecipePayload = {\n  recipe: {\n    // biome-ignore lint/style/useNamingConvention: api schema\n    model_providers: Record<string, unknown>[];\n    // biome-ignore lint/style/useNamingConvention: api schema\n    mcp_providers: Record<string, unknown>[];\n    // biome-ignore lint/style/useNamingConvention: api schema\n    model_configs: Record<string, unknown>[];\n    // biome-ignore lint/style/useNamingConvention: api schema\n    seed_config?: Record<string, unknown>;\n    // biome-ignore lint/style/useNamingConvention: api schema\n    tool_configs: Record<string, unknown>[];\n    columns: Record<string, unknown>[];\n    processors: Record<string, unknown>[];\n  };\n  run: {\n    rows: number;\n    preview: boolean;\n    // biome-ignore lint/style/useNamingConvention: api schema\n    output_formats: string[];\n    // biome-ignore lint/style/useNamingConvention: backend schema\n    execution_type?: \"preview\" | \"full\";\n    // biome-ignore lint/style/useNamingConvention: backend schema\n    run_config?: Record<string, unknown>;\n    // biome-ignore lint/style/useNamingConvention: backend schema\n    dataset_name?: string;\n    // biome-ignore lint/style/useNamingConvention: backend schema\n    artifact_path?: string;\n    // biome-ignore lint/style/useNamingConvention: backend schema\n    merge_batches?: boolean;\n    // biome-ignore lint/style/useNamingConvention: backend schema\n    run_name?: string | null;\n  };\n  ui: {\n    nodes: Array<{\n      id: string;\n      x: number;\n      y: number;\n      width?: number;\n      node_type?: \"markdown_note\" | \"tool_config\";\n      name?: string;\n      markdown?: string;\n      note_color?: string;\n      note_opacity?: string;\n      tools_by_provider?: Record<string, string[]>;\n    }>;\n    edges: {\n      from: string;\n      to: string;\n      type?: string;\n      source_handle?: string;\n      target_handle?: string;\n    }[];\n    // ui-only: graph orientation\n    layout_direction?: \"LR\" | \"TB\";\n    // ui-only, used to preserve seed block mode across imports/refresh\n    seed_source_type?: \"hf\" | \"local\" | \"unstructured\";\n    // ui-only, persisted aux node positions by llm name + aux key\n    aux_nodes?: Array<{\n      llm: string;\n      key: string;\n      x: number;\n      y: number;\n    }>;\n    // ui-only, seed metadata cached for refresh/import UX\n    seed_columns?: string[];\n    seed_drop_columns?: string[];\n    seed_preview_rows?: Record<string, unknown>[];\n    local_file_name?: string;\n    unstructured_file_name?: string;\n    unstructured_chunk_size?: string;\n    unstructured_chunk_overlap?: string;\n    // ui-only: per-node advanced accordion state\n    advanced_open_by_node?: Record<string, boolean>;\n  };\n};\n\nexport type RecipePayloadResult = {\n  errors: string[];\n  payload: RecipePayload;\n};\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/payload/validate.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type {\n  ModelConfig,\n  ModelProviderConfig,\n  NodeConfig,\n  ValidatorCodeLang,\n  ValidatorConfig,\n} from \"../../types\";\nimport { VALIDATOR_OXC_CODE_LANGS } from \"../validators/code-lang\";\nimport { isOxcCodeShape } from \"../validators/oxc-code-shape\";\nimport { isOxcValidationMode } from \"../validators/oxc-mode\";\n\nexport function validateSubcategoryConfigs(\n  configs: Record<string, NodeConfig>,\n  nameToConfig: Map<string, NodeConfig>,\n  errors: string[],\n): void {\n  for (const config of Object.values(configs)) {\n    if (config.kind !== \"sampler\" || config.sampler_type !== \"subcategory\") {\n      continue;\n    }\n    const parentName = config.subcategory_parent;\n    if (!parentName) {\n      errors.push(`Subcategory ${config.name}: parent category required.`);\n      continue;\n    }\n    const parent = nameToConfig.get(parentName);\n    const parentValues =\n      parent && parent.kind === \"sampler\" && parent.sampler_type === \"category\"\n        ? (parent.values ?? [])\n        : [];\n    const mapping = config.subcategory_mapping ?? {};\n    for (const value of parentValues) {\n      const list = mapping[value];\n      if (!list || list.length === 0) {\n        errors.push(\n          `Subcategory ${config.name}: '${value}' needs at least 1 subcategory.`,\n        );\n      }\n    }\n  }\n}\n\nexport function validateTimedeltaConfigs(\n  configs: Record<string, NodeConfig>,\n  nameToConfig: Map<string, NodeConfig>,\n  errors: string[],\n): void {\n  for (const config of Object.values(configs)) {\n    if (config.kind !== \"sampler\" || config.sampler_type !== \"timedelta\") {\n      continue;\n    }\n    const reference = config.reference_column_name?.trim() ?? \"\";\n    if (!reference) {\n      errors.push(`Timedelta ${config.name}: reference datetime column required.`);\n      continue;\n    }\n    const parent = nameToConfig.get(reference);\n    if (\n      !parent ||\n      parent.kind !== \"sampler\" ||\n      parent.sampler_type !== \"datetime\"\n    ) {\n      errors.push(`Timedelta ${config.name}: reference '${reference}' must be datetime.`);\n    }\n  }\n}\n\nexport function validateModelAliasLinks(\n  modelAliases: Set<string>,\n  modelConfigConfigs: ModelConfig[],\n  errors: string[],\n): void {\n  for (const alias of modelAliases) {\n    if (!modelConfigConfigs.some((config) => config.name === alias)) {\n      errors.push(`LLM model_alias ${alias}: missing model config.`);\n    }\n  }\n}\n\nexport function validateModelConfigProviders(\n  modelConfigConfigs: ModelConfig[],\n  modelAliases: Set<string>,\n  modelProviderNames: Set<string>,\n  errors: string[],\n): void {\n  for (const config of modelConfigConfigs) {\n    const provider = config.provider.trim();\n    const alias = config.name;\n    if (modelAliases.has(alias) && !config.model.trim()) {\n      errors.push(`Model config ${alias}: model is required.`);\n    }\n    if (provider && !modelProviderNames.has(provider)) {\n      errors.push(`Model config ${alias}: provider ${provider} not found.`);\n    }\n  }\n}\n\nexport function validateUsedProviders(\n  modelProviderConfigs: ModelProviderConfig[],\n  modelConfigConfigs: ModelConfig[],\n  errors: string[],\n): void {\n  const usedProviders = new Set(\n    modelConfigConfigs.map((config) => config.provider.trim()).filter(Boolean),\n  );\n  for (const provider of modelProviderConfigs) {\n    if (!usedProviders.has(provider.name)) {\n      continue;\n    }\n    if (!provider.endpoint.trim()) {\n      errors.push(`Model provider ${provider.name}: endpoint is required.`);\n    }\n    if (!provider.provider_type.trim()) {\n      errors.push(`Model provider ${provider.name}: provider_type is required.`);\n    }\n  }\n}\n\nexport function validateValidatorConfigs(\n  configs: Record<string, NodeConfig>,\n  nameToConfig: Map<string, NodeConfig>,\n  errors: string[],\n): void {\n  for (const config of Object.values(configs)) {\n    if (config.kind !== \"validator\") {\n      continue;\n    }\n    const target = (config as ValidatorConfig).target_columns[0]?.trim();\n    if (!target) {\n      continue;\n    }\n    const targetConfig = nameToConfig.get(target);\n    if (!targetConfig) {\n      errors.push(`Validator ${config.name}: target '${target}' not found.`);\n      continue;\n    }\n    if (targetConfig.kind !== \"llm\" || targetConfig.llm_type !== \"code\") {\n      errors.push(`Validator ${config.name}: target '${target}' must be LLM Code.`);\n      continue;\n    }\n    if (\n      config.validator_type === \"oxc\" &&\n      !VALIDATOR_OXC_CODE_LANGS.includes(\n        (targetConfig.code_lang ?? \"\").trim() as ValidatorCodeLang,\n      )\n    ) {\n      errors.push(\n        `Validator ${config.name}: target '${target}' must use javascript/typescript/jsx/tsx.`,\n      );\n      continue;\n    }\n    if (\n      config.validator_type === \"oxc\" &&\n      !isOxcValidationMode(config.oxc_validation_mode)\n    ) {\n      errors.push(\n        `Validator ${config.name}: oxc_validation_mode '${config.oxc_validation_mode}' is invalid.`,\n      );\n      continue;\n    }\n    if (\n      config.validator_type === \"oxc\" &&\n      !isOxcCodeShape(config.oxc_code_shape)\n    ) {\n      errors.push(\n        `Validator ${config.name}: oxc_code_shape '${config.oxc_code_shape}' is invalid.`,\n      );\n      continue;\n    }\n    if (\n      config.validator_type !== \"oxc\" &&\n      (targetConfig.code_lang ?? \"\").trim() !== config.code_lang.trim()\n    ) {\n      errors.push(\n        `Validator ${config.name}: code_lang '${config.code_lang}' must match target '${target}' (${targetConfig.code_lang ?? \"unknown\"}).`,\n      );\n    }\n  }\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/processors.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { RecipeProcessorConfig } from \"../types\";\n\nexport function buildDefaultSchemaTransform(): RecipeProcessorConfig {\n  return {\n    id: \"schema-transform-1\",\n    // biome-ignore lint/style/useNamingConvention: api schema\n    processor_type: \"schema_transform\",\n    name: \"schema_transform\",\n    template: '{\\n  \"text\": \"{{ column_name }}\"\\n}',\n  };\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/reactflow-changes.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type {\n  Edge,\n  EdgeChange,\n  Node,\n  NodeChange,\n  XYPosition,\n} from \"@xyflow/react\";\n\nexport function applyAuxNodeChanges<T extends Node>(\n  changes: NodeChange<T>[],\n  actions: {\n    setAuxNodePosition: (id: string, position: XYPosition) => void;\n  },\n): void {\n  for (const change of changes) {\n    if (!(\"id\" in change) || !change.id.startsWith(\"aux-\")) {\n      continue;\n    }\n    if (change.type !== \"position\") {\n      continue;\n    }\n    const nextPosition = change.position ?? change.positionAbsolute;\n    if (!nextPosition) {\n      continue;\n    }\n    actions.setAuxNodePosition(change.id, nextPosition);\n  }\n}\n\nexport function filterNodeChangesByIds<T extends Node>(\n  changes: NodeChange<T>[],\n  ids: Set<string>,\n): NodeChange<T>[] {\n  return changes.filter(\n    (change): change is NodeChange<T> => \"id\" in change && ids.has(change.id),\n  );\n}\n\nexport function filterEdgeChangesByIds(\n  changes: EdgeChange<Edge>[],\n  ids: Set<string>,\n): EdgeChange<Edge>[] {\n  return changes.filter(\n    (change): change is EdgeChange<Edge> => \"id\" in change && ids.has(change.id),\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/recipe-studio-view.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { NodeConfig, SamplerConfig } from \"../types\";\n\nexport type DialogOptions = {\n  categoryOptions: SamplerConfig[];\n  modelConfigAliases: string[];\n  modelProviderOptions: string[];\n  toolProfileAliases: string[];\n  datetimeOptions: string[];\n};\n\nexport function buildDialogOptions(configList: NodeConfig[]): DialogOptions {\n  const categoryOptions: SamplerConfig[] = [];\n  const modelConfigAliases: string[] = [];\n  const modelProviderOptions: string[] = [];\n  const toolProfileAliases: string[] = [];\n  const datetimeOptions: string[] = [];\n\n  for (const config of configList) {\n    if (config.kind === \"sampler\") {\n      if (config.sampler_type === \"category\") {\n        categoryOptions.push(config);\n      }\n      if (config.sampler_type === \"datetime\") {\n        datetimeOptions.push(config.name);\n      }\n      continue;\n    }\n    if (config.kind === \"model_config\") {\n      modelConfigAliases.push(config.name);\n      continue;\n    }\n    if (config.kind === \"model_provider\") {\n      modelProviderOptions.push(config.name);\n      continue;\n    }\n    if (config.kind === \"tool_config\") {\n      toolProfileAliases.push(config.name);\n    }\n  }\n\n  return {\n    categoryOptions,\n    modelConfigAliases,\n    modelProviderOptions,\n    toolProfileAliases,\n    datetimeOptions,\n  };\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/refs.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nconst JINJA_REF_RE = /{{\\s*([a-zA-Z_][a-zA-Z0-9_]*)\\s*}}/g;\nconst JINJA_EXPR_RE = /{{\\s*([^{}]+?)\\s*}}/g;\nconst SIMPLE_JINJA_EXPR_RE = /^[a-zA-Z_][a-zA-Z0-9_.]*$/;\nconst PLAIN_JINJA_EXPR_RE = /^[a-zA-Z0-9_.\\s-]+$/;\nconst NESTED_REFERENCE_ROOTS = new Set([\"user\"]);\n\nfunction isValidNestedReference(expr: string, validSet: Set<string>): boolean {\n  if (!expr.includes(\".\")) {\n    return false;\n  }\n  const parts = expr.split(\".\").map((part) => part.trim()).filter(Boolean);\n  if (parts.length < 2) {\n    return false;\n  }\n  const root = parts[0];\n  return validSet.has(root) && NESTED_REFERENCE_ROOTS.has(root);\n}\n\nfunction escapeRegExp(value: string): string {\n  return value.replace(/[.*+?^${}()|[\\]\\\\]/g, \"\\\\$&\");\n}\n\nexport function extractRefs(template: string): string[] {\n  if (!template) {\n    return [];\n  }\n  const refs = new Set<string>();\n  for (const match of template.matchAll(JINJA_REF_RE)) {\n    if (match[1]) {\n      refs.add(match[1]);\n    }\n  }\n  return Array.from(refs);\n}\n\nexport function findInvalidJinjaReferences(\n  template: string,\n  validReferences: string[],\n): string[] {\n  if (!template) {\n    return [];\n  }\n  const validSet = new Set(\n    validReferences.map((name) => name.trim()).filter(Boolean),\n  );\n  const invalid = new Set<string>();\n\n  for (const match of template.matchAll(JINJA_EXPR_RE)) {\n    const expr = (match[1] ?? \"\").trim();\n    if (!expr) {\n      continue;\n    }\n    if (SIMPLE_JINJA_EXPR_RE.test(expr)) {\n      if (!validSet.has(expr) && !isValidNestedReference(expr, validSet)) {\n        invalid.add(expr);\n      }\n      continue;\n    }\n    if (PLAIN_JINJA_EXPR_RE.test(expr)) {\n      invalid.add(expr);\n    }\n  }\n\n  return Array.from(invalid);\n}\n\nexport function replaceRef(\n  template: string,\n  from: string,\n  to: string,\n): string {\n  if (!template || from === to) {\n    return template;\n  }\n  const pattern = new RegExp(`{{\\\\s*${escapeRegExp(from)}\\\\s*}}`, \"g\");\n  return template.replace(pattern, `{{ ${to} }}`);\n}\n\nexport function removeRef(template: string, ref: string): string {\n  if (!template) {\n    return template;\n  }\n  const escaped = escapeRegExp(ref);\n  const fullLine = new RegExp(`^\\\\s*{{\\\\s*${escaped}\\\\s*}}\\\\s*$`);\n  const inline = new RegExp(`{{\\\\s*${escaped}\\\\s*}}`, \"g\");\n  const next = template\n    .split(\"\\n\")\n    .flatMap((line) => {\n      if (fullLine.test(line)) {\n        return [];\n      }\n      return [line.replace(inline, \"\").replace(/\\s+$/g, \"\")];\n    })\n    .join(\"\\n\");\n  return next;\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/rf-node-dimensions.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { Node } from \"@xyflow/react\";\n\nfunction parseDim(value: unknown): number | null {\n  if (typeof value === \"number\" && Number.isFinite(value)) {\n    return value;\n  }\n  if (typeof value === \"string\") {\n    const parsed = Number.parseFloat(value);\n    return Number.isFinite(parsed) ? parsed : null;\n  }\n  return null;\n}\n\nexport function readNodeWidth(node: Node): number | null {\n  return (\n    parseDim(node.width) ??\n    parseDim(node.style?.width) ??\n    parseDim(node.measured?.width) ??\n    null\n  );\n}\n\nexport function readNodeHeight(node: Node): number | null {\n  return (\n    parseDim(node.height) ??\n    parseDim(node.style?.height) ??\n    parseDim(node.measured?.height) ??\n    null\n  );\n}\n\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/ui-tones.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nexport const RECIPE_STUDIO_NODE_TONES = {\n  sampler:\n    \"bg-emerald-50 text-emerald-700 border-emerald-100 dark:bg-emerald-950/30 dark:text-emerald-300 dark:border-emerald-900/60\",\n  llm:\n    \"bg-sky-50 text-sky-700 border-sky-100 dark:bg-sky-950/30 dark:text-sky-300 dark:border-sky-900/60\",\n  validator:\n    \"bg-rose-50 text-rose-700 border-rose-100 dark:bg-rose-950/30 dark:text-rose-300 dark:border-rose-900/60\",\n  expression:\n    \"bg-indigo-50 text-indigo-700 border-indigo-100 dark:bg-indigo-950/30 dark:text-indigo-300 dark:border-indigo-900/60\",\n  note:\n    \"bg-violet-50 text-violet-700 border-violet-100 dark:bg-violet-950/30 dark:text-violet-300 dark:border-violet-900/60\",\n  seed:\n    \"bg-lime-50 text-lime-700 border-lime-100 dark:bg-lime-950/30 dark:text-lime-300 dark:border-lime-900/60\",\n  model_provider:\n    \"bg-amber-50 text-amber-700 border-amber-100 dark:bg-amber-950/30 dark:text-amber-300 dark:border-amber-900/60\",\n  model_config:\n    \"bg-orange-50 text-orange-700 border-orange-100 dark:bg-orange-950/30 dark:text-orange-300 dark:border-orange-900/60\",\n  tool_config:\n    \"bg-cyan-50 text-cyan-700 border-cyan-100 dark:bg-cyan-950/30 dark:text-cyan-300 dark:border-cyan-900/60\",\n} as const;\n\nexport const RECIPE_STUDIO_USER_NODE_TONE =\n  \"bg-amber-50 text-amber-700 border-amber-100 dark:bg-amber-950/30 dark:text-amber-300 dark:border-amber-900/60\";\n\nexport const RECIPE_STUDIO_REFERENCE_BADGE_TONES = {\n  user:\n    \"corner-squircle border-amber-500/25 bg-amber-500/10 font-mono text-[11px] text-amber-700 dark:text-amber-300\",\n  seed:\n    \"corner-squircle border-blue-500/25 bg-blue-500/10 font-mono text-[11px] text-blue-700 dark:text-blue-300\",\n  default: \"corner-squircle font-mono text-[11px]\",\n} as const;\n\nexport const RECIPE_STUDIO_WARNING_BADGE_TONE =\n  \"border-amber-500/40 bg-amber-500/10 text-amber-700 hover:bg-amber-500/20 dark:text-amber-300\";\n\nexport const RECIPE_STUDIO_WARNING_ICON_TONE =\n  \"text-amber-600 dark:text-amber-400\";\n\nexport const RECIPE_STUDIO_ONBOARDING_SURFACE_TONE =\n  \"border-primary/20 bg-primary/[0.045]\";\n\nexport const RECIPE_STUDIO_ONBOARDING_ICON_TONE =\n  \"bg-primary/10 text-primary\";\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/validation.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { NodeConfig } from \"../types\";\nimport { isValidSex, parseAgeRange, parseIntNumber, parseNumber } from \"./parse\";\nimport { VALIDATOR_OXC_CODE_LANGS, VALIDATOR_SQL_CODE_LANGS } from \"./validators/code-lang\";\nimport { isOxcCodeShape } from \"./validators/oxc-code-shape\";\nimport { isOxcValidationMode } from \"./validators/oxc-mode\";\n\nconst TRACE_MODES = new Set([\"none\", \"last_message\", \"all_messages\"]);\n\n// biome-ignore lint/complexity/noExcessiveCognitiveComplexity: validation rules\nexport function getConfigErrors(config: NodeConfig | null): string[] {\n  if (!config) {\n    return [];\n  }\n  const errors: string[] = [];\n  if (!config.name.trim()) {\n    errors.push(\"Name is required.\");\n  }\n  if (config.kind === \"sampler\") {\n    if (config.sampler_type === \"category\") {\n      const values = config.values ?? [];\n      if (values.length < 2) {\n        errors.push(\"Category needs at least 2 values.\");\n      }\n      const weights = config.weights ?? [];\n      const hasWeights = weights.some((weight) => weight !== null);\n      if (hasWeights && weights.some((weight) => weight === null)) {\n        errors.push(\"Weights must be set for all values.\");\n      }\n      for (const [condition, params] of Object.entries(\n        config.conditional_params ?? {},\n      )) {\n        if (!condition.trim()) {\n          errors.push(\"Category conditional rule needs condition text.\");\n          continue;\n        }\n        const conditionalValues = (params.values ?? [])\n          .map((value) => value.trim())\n          .filter(Boolean);\n        if (conditionalValues.length === 0) {\n          errors.push(`Category conditional '${condition}' needs values.`);\n          continue;\n        }\n        const conditionalWeights = params.weights ?? [];\n        const hasConditionalWeights = conditionalWeights.some(\n          (weight) => weight !== null,\n        );\n        if (\n          hasConditionalWeights &&\n          (conditionalWeights.length !== conditionalValues.length ||\n            conditionalWeights.some((weight) => weight === null))\n        ) {\n          errors.push(\n            `Category conditional '${condition}' weights must be set for all values.`,\n          );\n        }\n      }\n    }\n    if (config.sampler_type === \"uniform\") {\n      const low = parseNumber(config.low);\n      const high = parseNumber(config.high);\n      if (low === null || high === null) {\n        errors.push(\"Uniform low/high must be numbers.\");\n      } else if (low >= high) {\n        errors.push(\"Uniform low must be < high.\");\n      }\n    }\n    if (config.sampler_type === \"gaussian\") {\n      const mean = parseNumber(config.mean);\n      const std = parseNumber(config.std);\n      if (mean === null || std === null) {\n        errors.push(\"Gaussian mean/std must be numbers.\");\n      } else if (std <= 0) {\n        errors.push(\"Gaussian std must be > 0.\");\n      }\n    }\n    if (config.sampler_type === \"bernoulli\") {\n      const p = parseNumber(config.p);\n      if (p === null) {\n        errors.push(\"Bernoulli p must be a number.\");\n      } else if (p < 0 || p > 1) {\n        errors.push(\"Bernoulli p must be between 0 and 1.\");\n      }\n    }\n    if (config.sampler_type === \"datetime\") {\n      if (!config.datetime_unit) {\n        errors.push(\"Datetime unit required.\");\n      }\n      if (config.datetime_start && config.datetime_end) {\n        const start = new Date(config.datetime_start).getTime();\n        const end = new Date(config.datetime_end).getTime();\n        if (!(Number.isFinite(start) && Number.isFinite(end))) {\n          errors.push(\"Datetime start/end must be valid.\");\n        } else if (start >= end) {\n          errors.push(\"Datetime start must be before end.\");\n        }\n      }\n    }\n    if (config.sampler_type === \"timedelta\") {\n      const min = parseNumber(config.dt_min);\n      const max = parseNumber(config.dt_max);\n      if (min === null || max === null) {\n        errors.push(\"Timedelta dt_min/dt_max must be numbers.\");\n      } else if (min >= max) {\n        errors.push(\"Timedelta dt_min must be < dt_max.\");\n      }\n      if (!config.reference_column_name?.trim()) {\n        errors.push(\"Timedelta reference datetime column required.\");\n      }\n      if (!config.timedelta_unit) {\n        errors.push(\"Timedelta unit required.\");\n      }\n    }\n    if (config.sampler_type === \"subcategory\" && !config.subcategory_parent) {\n      errors.push(\"Subcategory needs a parent category column.\");\n    }\n    if (\n      config.sampler_type === \"person\" ||\n      config.sampler_type === \"person_from_faker\"\n    ) {\n      if (config.person_sex?.trim()) {\n        const normalized = config.person_sex.trim();\n        if (!isValidSex(normalized)) {\n          errors.push(\"Person sex must be Male or Female.\");\n        }\n      }\n      if (config.person_age_range?.trim()) {\n        const parsed = parseAgeRange(config.person_age_range);\n        if (!parsed) {\n          errors.push(\"Person age range must be like 18-70.\");\n        }\n      }\n    }\n  }\n  if (config.kind === \"llm\") {\n    if (!config.model_alias.trim()) {\n      errors.push(\"Choose a saved model.\");\n    }\n    if (!config.prompt.trim()) {\n      errors.push(\"Prompt is required.\");\n    }\n    if (config.llm_type === \"code\" && !config.code_lang) {\n      errors.push(\"Code language is required.\");\n    }\n    if (config.llm_type === \"structured\") {\n      if (!config.output_format?.trim()) {\n        errors.push(\"Output format is required.\");\n      } else {\n        try {\n          JSON.parse(config.output_format);\n        } catch {\n          errors.push(\"Output format must be valid JSON.\");\n        }\n      }\n    }\n    if (config.llm_type === \"judge\") {\n      const scores = config.scores ?? [];\n      if (scores.length === 0) {\n        errors.push(\"Add at least one scoring rule.\");\n      }\n      for (const score of scores) {\n        if (!score.name.trim()) {\n          errors.push(\"Each scoring rule needs a name.\");\n        }\n        if (!score.description.trim()) {\n          errors.push(\"Each scoring rule needs a description.\");\n        }\n        const options = score.options ?? [];\n        if (options.length === 0) {\n          errors.push(`Scoring rule ${score.name || \"Untitled\"} needs options.`);\n        }\n        for (const option of options) {\n          if (!option.value.trim() || !option.description.trim()) {\n            errors.push(\n              `Scoring rule ${score.name || \"Untitled\"} needs both a value and a description for each option.`,\n            );\n            break;\n          }\n        }\n      }\n    }\n    if (config.image_context?.enabled) {\n      if (!config.image_context.column_name.trim()) {\n        errors.push(\"Image context column is required.\");\n      }\n    }\n    if (\n      config.with_trace &&\n      !TRACE_MODES.has(config.with_trace)\n    ) {\n      errors.push(\"Trace mode must be none, last_message, or all_messages.\");\n    }\n  }\n  if (config.kind === \"expression\") {\n    if (!config.expr.trim()) {\n      errors.push(\"Expression is required.\");\n    }\n  }\n  if (config.kind === \"tool_config\") {\n    if (config.mcp_providers.length === 0) {\n      errors.push(\"Add at least one tool server.\");\n    }\n    const serverNames = new Set<string>();\n    for (const provider of config.mcp_providers) {\n      const name = provider.name.trim();\n      if (!name) {\n        errors.push(\"Each tool server needs a name.\");\n        continue;\n      }\n      if (serverNames.has(name)) {\n        errors.push(`Tool server names must be unique: ${name}.`);\n      }\n      serverNames.add(name);\n      if (provider.provider_type === \"stdio\") {\n        if (!provider.command?.trim()) {\n          errors.push(`Tool server ${name}: add a command.`);\n        }\n      } else if (!provider.endpoint?.trim()) {\n        errors.push(`Tool server ${name}: add an endpoint.`);\n      }\n    }\n    const maxTurnsRaw = config.max_tool_call_turns?.trim();\n    if (\n      maxTurnsRaw &&\n      (!Number.isFinite(Number(maxTurnsRaw)) || Number(maxTurnsRaw) < 1)\n    ) {\n      errors.push(\"Max tool-use turns must be 1 or more.\");\n    }\n    const timeoutRaw = config.timeout_sec?.trim();\n    if (\n      timeoutRaw &&\n      (!Number.isFinite(Number(timeoutRaw)) || Number(timeoutRaw) <= 0)\n    ) {\n      errors.push(\"Timeout must be > 0.\");\n    }\n  }\n  if (config.kind === \"validator\") {\n    const targets = (config.target_columns ?? [])\n      .map((value) => value.trim())\n      .filter(Boolean);\n    if (targets.length === 0) {\n      errors.push(\"Choose the code step to check.\");\n    }\n    const batch = parseIntNumber(config.batch_size);\n    if (batch === null || batch < 1) {\n      errors.push(\"Batch size must be an integer >= 1.\");\n    }\n    if (!config.code_lang.trim()) {\n      errors.push(\"Choose a code language for this check.\");\n    } else if (config.validator_type === \"oxc\") {\n      if (!VALIDATOR_OXC_CODE_LANGS.includes(config.code_lang)) {\n        errors.push(\"This JS/TS check only supports JavaScript or TypeScript.\");\n      }\n      if (!isOxcValidationMode(config.oxc_validation_mode)) {\n        errors.push(\"Choose whether to check syntax, lint rules, or both.\");\n      }\n      if (!isOxcCodeShape(config.oxc_code_shape)) {\n        errors.push(\"Choose whether this code is a full file or a snippet.\");\n      }\n    } else if (\n      config.code_lang !== \"python\" &&\n      !VALIDATOR_SQL_CODE_LANGS.includes(config.code_lang)\n    ) {\n      errors.push(\"This check supports Python or SQL.\");\n    }\n  }\n  if (config.kind === \"seed\") {\n    const seedSourceType = config.seed_source_type ?? \"hf\";\n    if (seedSourceType === \"hf\" && !config.hf_repo_id.trim()) {\n      errors.push(\"Choose a Hugging Face dataset.\");\n    }\n    if (!config.hf_path.trim()) {\n      errors.push(\"Load the source-data preview first.\");\n    }\n    if (\n      seedSourceType === \"hf\" &&\n      config.hf_endpoint?.trim() &&\n      !config.hf_endpoint.trim().startsWith(\"http\")\n    ) {\n      errors.push(\"HF endpoint must start with http.\");\n    }\n    if (seedSourceType === \"unstructured\") {\n      if (config.drop && (config.seed_columns?.length ?? 0) === 0) {\n        errors.push(\"Load the available fields before hiding any from the final dataset.\");\n      }\n      const chunkSizeRaw = Number(config.unstructured_chunk_size);\n      const chunkOverlapRaw = Number(config.unstructured_chunk_overlap);\n      if (!Number.isFinite(chunkSizeRaw) || Math.floor(chunkSizeRaw) < 1) {\n        errors.push(\"Chunk size must be an integer >= 1.\");\n      }\n      if (!Number.isFinite(chunkOverlapRaw) || Math.floor(chunkOverlapRaw) < 0) {\n        errors.push(\"Chunk overlap must be an integer >= 0.\");\n      }\n      if (\n        Number.isFinite(chunkSizeRaw) &&\n        Number.isFinite(chunkOverlapRaw) &&\n        Math.floor(chunkOverlapRaw) >= Math.floor(chunkSizeRaw)\n      ) {\n        errors.push(\"Chunk overlap must be less than chunk size.\");\n      }\n    } else {\n      const selectedDropColumns = (config.seed_drop_columns ?? [])\n        .map((value) => value.trim())\n        .filter(Boolean);\n      if (selectedDropColumns.length > 0 && (config.seed_columns?.length ?? 0) === 0) {\n        errors.push(\"Load the available fields before hiding any from the final dataset.\");\n      }\n    }\n\n    if (config.selection_type === \"index_range\") {\n      const start = parseIntNumber(config.selection_start);\n      const end = parseIntNumber(config.selection_end);\n      if (start === null || end === null) {\n        errors.push(\"Index range start/end must be integers.\");\n      } else {\n        if (start < 0 || end < 0) {\n          errors.push(\"Index range start/end must be >= 0.\");\n        }\n        if (end < start) {\n          errors.push(\"Index range end must be >= start.\");\n        }\n      }\n    }\n    if (config.selection_type === \"partition_block\") {\n      const index = parseIntNumber(config.selection_index);\n      const parts = parseIntNumber(config.selection_num_partitions);\n      if (index === null || parts === null) {\n        errors.push(\"Partition index/num_partitions must be integers.\");\n      } else {\n        if (index < 0) errors.push(\"Partition index must be >= 0.\");\n        if (parts < 1) errors.push(\"Partition num_partitions must be >= 1.\");\n        if (parts >= 1 && index >= parts) {\n          errors.push(\"Partition index must be < num_partitions.\");\n        }\n      }\n    }\n  }\n  return errors;\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/validators/code-lang.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { ValidatorCodeLang } from \"../../types\";\n\nexport const VALIDATOR_OXC_CODE_LANGS: ValidatorCodeLang[] = [\n  \"javascript\",\n  \"typescript\",\n  \"jsx\",\n  \"tsx\",\n];\n\nexport const VALIDATOR_SQL_CODE_LANGS: ValidatorCodeLang[] = [\n  \"sql:sqlite\",\n  \"sql:postgres\",\n  \"sql:mysql\",\n  \"sql:tsql\",\n  \"sql:bigquery\",\n  \"sql:ansi\",\n];\n\nconst VALIDATOR_CODE_LANG_SET = new Set<ValidatorCodeLang>([\n  ...VALIDATOR_OXC_CODE_LANGS,\n  \"python\",\n  ...VALIDATOR_SQL_CODE_LANGS,\n]);\n\nexport function isValidatorCodeLang(value: string): value is ValidatorCodeLang {\n  return VALIDATOR_CODE_LANG_SET.has(value as ValidatorCodeLang);\n}\n\nexport function normalizeValidatorCodeLang(\n  value: unknown,\n): ValidatorCodeLang {\n  const raw = typeof value === \"string\" ? value.trim() : \"\";\n  if (!raw) {\n    return \"python\";\n  }\n  if (VALIDATOR_OXC_CODE_LANGS.includes(raw as ValidatorCodeLang)) {\n    return raw as ValidatorCodeLang;\n  }\n  if (raw === \"python\") {\n    return \"python\";\n  }\n  if (raw.startsWith(\"sql:\")) {\n    if (VALIDATOR_SQL_CODE_LANGS.includes(raw as ValidatorCodeLang)) {\n      return raw as ValidatorCodeLang;\n    }\n    return \"sql:sqlite\";\n  }\n  return \"python\";\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/validators/oxc-code-shape.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { OxcCodeShape } from \"../../types\";\n\nexport const OXC_CODE_SHAPES: OxcCodeShape[] = [\n  \"auto\",\n  \"module\",\n  \"snippet\",\n];\n\nexport function isOxcCodeShape(value: string): value is OxcCodeShape {\n  return OXC_CODE_SHAPES.includes(value as OxcCodeShape);\n}\n\nexport function normalizeOxcCodeShape(value: unknown): OxcCodeShape {\n  if (typeof value !== \"string\") {\n    return \"auto\";\n  }\n  const normalized = value.trim().toLowerCase();\n  if (isOxcCodeShape(normalized)) {\n    return normalized;\n  }\n  return \"auto\";\n}\n\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/validators/oxc-mode.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { OxcValidationMode } from \"../../types\";\n\nexport const OXC_VALIDATION_MODES: OxcValidationMode[] = [\n  \"syntax\",\n  \"lint\",\n  \"syntax+lint\",\n];\n\nexport function isOxcValidationMode(value: string): value is OxcValidationMode {\n  return OXC_VALIDATION_MODES.includes(value as OxcValidationMode);\n}\n\nexport function normalizeOxcValidationMode(value: unknown): OxcValidationMode {\n  if (typeof value !== \"string\") {\n    return \"syntax\";\n  }\n  const normalized = value.trim().toLowerCase();\n  if (isOxcValidationMode(normalized)) {\n    return normalized;\n  }\n  return \"syntax\";\n}\n"
  },
  {
    "path": "studio/frontend/src/features/recipe-studio/utils/variables.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { NodeConfig } from \"../types\";\n\nexport type AvailableVariableSource = \"column\" | \"seed\";\n\nexport type AvailableVariableEntry = {\n  name: string;\n  source: AvailableVariableSource;\n};\n\nfunction getStructuredRefs(llmName: string, outputFormat: string): string[] {\n  try {\n    const schema = JSON.parse(outputFormat);\n    if (!(schema?.properties && typeof schema.properties === \"object\")) {\n      return [];\n    }\n    return Object.keys(schema.properties).map((key) => `${llmName}.${key}`);\n  } catch {\n    return [];\n  }\n}\n\nexport function getAvailableVariableEntries(\n  configs: Record<string, NodeConfig>,\n  currentId: string,\n): AvailableVariableEntry[] {\n  const vars: AvailableVariableEntry[] = [];\n\n  for (const config of Object.values(configs)) {\n    if (config.id === currentId) {\n      continue;\n    }\n    if (\n      config.kind === \"model_provider\" ||\n      config.kind === \"model_config\" ||\n      config.kind === \"tool_config\"\n    ) {\n      continue;\n    }\n\n    if (config.kind === \"sampler\") {\n      vars.push({ name: config.name, source: \"column\" });\n      continue;\n    }\n\n    if (config.kind === \"expression\") {\n      vars.push({ name: config.name, source: \"column\" });\n      continue;\n    }\n\n    if (config.kind === \"validator\") {\n      vars.push({ name: config.name, source: \"column\" });\n      continue;\n    }\n\n    if (config.kind === \"seed\") {\n      for (const col of config.seed_columns ?? []) {\n        const name = col.trim();\n        if (!name) continue;\n        vars.push({ name, source: \"seed\" });\n      }\n      continue;\n    }\n\n    if (config.kind !== \"llm\") {\n      continue;\n    }\n\n    vars.push({ name: config.name, source: \"column\" });\n    if (config.llm_type !== \"structured\" || !config.output_format) {\n      continue;\n    }\n    vars.push(\n      ...getStructuredRefs(config.name, config.output_format).map((name) => ({\n        name,\n        source: \"column\" as const,\n      })),\n    );\n  }\n\n  return vars;\n}\n\nexport function getAvailableVariables(\n  configs: Record<string, NodeConfig>,\n  currentId: string,\n): string[] {\n  return getAvailableVariableEntries(configs, currentId).map((entry) => entry.name);\n}\n"
  },
  {
    "path": "studio/frontend/src/features/studio/index.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nexport { StudioPage } from \"./studio-page\";\n"
  },
  {
    "path": "studio/frontend/src/features/studio/sections/charts/chart-preferences-store.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { create } from \"zustand\";\nimport type { OutlierMode, ScaleMode } from \"./types\";\nimport { DEFAULT_VISIBLE_POINTS, clamp } from \"./utils\";\n\nconst DEFAULT_WINDOW_SIZE = Math.max(\n  24,\n  Math.floor(DEFAULT_VISIBLE_POINTS / 2),\n);\n\ntype ChartPreferencesState = {\n  availableSteps: number;\n  windowSize: number | null;\n  smoothing: number;\n  showRaw: boolean;\n  showSmoothed: boolean;\n  showAvgLine: boolean;\n  lossScale: ScaleMode;\n  lrScale: ScaleMode;\n  gradScale: ScaleMode;\n  lossOutlierMode: OutlierMode;\n  gradOutlierMode: OutlierMode;\n  lrOutlierMode: OutlierMode;\n  setAvailableSteps: (value: number) => void;\n  setWindowSize: (value: number | null) => void;\n  setSmoothing: (value: number) => void;\n  setShowRaw: (value: boolean) => void;\n  setShowSmoothed: (value: boolean) => void;\n  setShowAvgLine: (value: boolean) => void;\n  setLossScale: (value: ScaleMode) => void;\n  setLrScale: (value: ScaleMode) => void;\n  setGradScale: (value: ScaleMode) => void;\n  setLossOutlierMode: (value: OutlierMode) => void;\n  setGradOutlierMode: (value: OutlierMode) => void;\n  setLrOutlierMode: (value: OutlierMode) => void;\n  resetPreferences: () => void;\n};\n\nconst defaultPreferences = {\n  windowSize: DEFAULT_WINDOW_SIZE as number | null,\n  smoothing: 0.6,\n  showRaw: true,\n  showSmoothed: true,\n  showAvgLine: true,\n  lossScale: \"linear\" as ScaleMode,\n  lrScale: \"linear\" as ScaleMode,\n  gradScale: \"linear\" as ScaleMode,\n  lossOutlierMode: \"none\" as OutlierMode,\n  gradOutlierMode: \"none\" as OutlierMode,\n  lrOutlierMode: \"none\" as OutlierMode,\n};\n\nexport const useChartPreferencesStore = create<ChartPreferencesState>(\n  (set) => ({\n    availableSteps: 0,\n    ...defaultPreferences,\n    setAvailableSteps: (value) =>\n      set((state) => {\n        const availableSteps = Math.max(0, Math.round(value));\n        if (state.windowSize == null || availableSteps <= 0) {\n          return { availableSteps };\n        }\n\n        if (state.windowSize >= availableSteps) {\n          return { availableSteps, windowSize: null };\n        }\n\n        return {\n          availableSteps,\n          windowSize: clamp(Math.round(state.windowSize), 1, availableSteps),\n        };\n      }),\n    setWindowSize: (value) =>\n      set((state) => {\n        if (value == null || state.availableSteps <= 0) {\n          return { windowSize: null };\n        }\n\n        const next = clamp(Math.round(value), 1, state.availableSteps);\n        return { windowSize: next >= state.availableSteps ? null : next };\n      }),\n    setSmoothing: (value) => set({ smoothing: clamp(value, 0, 0.9) }),\n    setShowRaw: (value) => set({ showRaw: value }),\n    setShowSmoothed: (value) => set({ showSmoothed: value }),\n    setShowAvgLine: (value) => set({ showAvgLine: value }),\n    setLossScale: (value) => set({ lossScale: value }),\n    setLrScale: (value) => set({ lrScale: value }),\n    setGradScale: (value) => set({ gradScale: value }),\n    setLossOutlierMode: (value) => set({ lossOutlierMode: value }),\n    setGradOutlierMode: (value) => set({ gradOutlierMode: value }),\n    setLrOutlierMode: (value) => set({ lrOutlierMode: value }),\n    resetPreferences: () => set({ ...defaultPreferences }),\n  }),\n);\n"
  },
  {
    "path": "studio/frontend/src/features/studio/sections/charts/chart-settings-sheet.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Button } from \"@/components/ui/button\";\nimport { Label } from \"@/components/ui/label\";\nimport { Separator } from \"@/components/ui/separator\";\nimport {\n  Sheet,\n  SheetContent,\n  SheetDescription,\n  SheetFooter,\n  SheetHeader,\n  SheetTitle,\n} from \"@/components/ui/sheet\";\nimport { Slider } from \"@/components/ui/slider\";\nimport { Switch } from \"@/components/ui/switch\";\nimport { Settings02Icon } from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport { type ReactElement, useState } from \"react\";\nimport { useShallow } from \"zustand/react/shallow\";\nimport { useChartPreferencesStore } from \"./chart-preferences-store\";\nimport type { OutlierMode, ScaleMode } from \"./types\";\n\nfunction ChoiceButtons<T extends string>({\n  options,\n  value,\n  onChange,\n}: {\n  options: { label: string; value: T }[];\n  value: T;\n  onChange: (value: T) => void;\n}): ReactElement {\n  return (\n    <div className=\"flex flex-wrap gap-2\">\n      {options.map((option) => (\n        <Button\n          key={option.value}\n          type=\"button\"\n          size=\"xs\"\n          variant={value === option.value ? \"secondary\" : \"outline\"}\n          onClick={() => onChange(option.value)}\n        >\n          {option.label}\n        </Button>\n      ))}\n    </div>\n  );\n}\n\nfunction SettingRow({\n  label,\n  description,\n  control,\n}: {\n  label: string;\n  description?: string;\n  control: ReactElement;\n}): ReactElement {\n  return (\n    <div className=\"flex items-start justify-between gap-4\">\n      <div className=\"min-w-0\">\n        <Label className=\"text-sm\">{label}</Label>\n        {description ? (\n          <p className=\"mt-1 text-xs text-muted-foreground\">{description}</p>\n        ) : null}\n      </div>\n      <div className=\"shrink-0\">{control}</div>\n    </div>\n  );\n}\n\nfunction ScaleSection({\n  title,\n  scale,\n  setScale,\n  outlierMode,\n  setOutlierMode,\n}: {\n  title: string;\n  scale: ScaleMode;\n  setScale: (value: ScaleMode) => void;\n  outlierMode: OutlierMode;\n  setOutlierMode: (value: OutlierMode) => void;\n}): ReactElement {\n  return (\n    <div className=\"space-y-3\">\n      <div>\n        <p className=\"text-sm font-medium\">{title}</p>\n        <p className=\"text-xs text-muted-foreground\">Scale and cleanup</p>\n      </div>\n      <ChoiceButtons\n        options={[\n          { label: \"Linear\", value: \"linear\" },\n          { label: \"Log\", value: \"log\" },\n        ]}\n        value={scale}\n        onChange={setScale}\n      />\n      <ChoiceButtons\n        options={[\n          { label: \"No clip\", value: \"none\" },\n          { label: \"Clip p99\", value: \"p99\" },\n          { label: \"Clip p95\", value: \"p95\" },\n        ]}\n        value={outlierMode}\n        onChange={setOutlierMode}\n      />\n    </div>\n  );\n}\n\nexport function ChartSettingsSheet(): ReactElement {\n  const [open, setOpen] = useState(false);\n  const {\n    availableSteps,\n    windowSize,\n    smoothing,\n    showRaw,\n    showSmoothed,\n    showAvgLine,\n    lossScale,\n    lrScale,\n    gradScale,\n    lossOutlierMode,\n    gradOutlierMode,\n    lrOutlierMode,\n    setWindowSize,\n    setSmoothing,\n    setShowRaw,\n    setShowSmoothed,\n    setShowAvgLine,\n    setLossScale,\n    setLrScale,\n    setGradScale,\n    setLossOutlierMode,\n    setGradOutlierMode,\n    setLrOutlierMode,\n    resetPreferences,\n  } = useChartPreferencesStore(\n    useShallow((state) => ({\n      availableSteps: state.availableSteps,\n      windowSize: state.windowSize,\n      smoothing: state.smoothing,\n      showRaw: state.showRaw,\n      showSmoothed: state.showSmoothed,\n      showAvgLine: state.showAvgLine,\n      lossScale: state.lossScale,\n      lrScale: state.lrScale,\n      gradScale: state.gradScale,\n      lossOutlierMode: state.lossOutlierMode,\n      gradOutlierMode: state.gradOutlierMode,\n      lrOutlierMode: state.lrOutlierMode,\n      setWindowSize: state.setWindowSize,\n      setSmoothing: state.setSmoothing,\n      setShowRaw: state.setShowRaw,\n      setShowSmoothed: state.setShowSmoothed,\n      setShowAvgLine: state.setShowAvgLine,\n      setLossScale: state.setLossScale,\n      setLrScale: state.setLrScale,\n      setGradScale: state.setGradScale,\n      setLossOutlierMode: state.setLossOutlierMode,\n      setGradOutlierMode: state.setGradOutlierMode,\n      setLrOutlierMode: state.setLrOutlierMode,\n      resetPreferences: state.resetPreferences,\n    })),\n  );\n\n  const minWindow = Math.min(10, Math.max(1, availableSteps));\n  const effectiveWindowSize =\n    windowSize == null ? Math.max(availableSteps, 1) : windowSize;\n  const showingAll =\n    availableSteps > 0 &&\n    (windowSize == null || effectiveWindowSize >= availableSteps);\n  const sliderMax = Math.max(minWindow, availableSteps || 1);\n\n  return (\n    <>\n      <Button\n        type=\"button\"\n        variant=\"ghost\"\n        size=\"icon-sm\"\n        className=\"rounded-full text-muted-foreground hover:bg-muted hover:text-foreground\"\n        onClick={() => setOpen(true)}\n        aria-label=\"Open chart settings\"\n      >\n        <HugeiconsIcon icon={Settings02Icon} className=\"size-4\" />\n      </Button>\n      <Sheet open={open} onOpenChange={setOpen}>\n        <SheetContent\n          className=\"w-full sm:max-w-md\"\n          overlayClassName=\"bg-transparent backdrop-blur-0\"\n        >\n          <SheetHeader className=\"pb-4\">\n            <SheetTitle>Chart Settings</SheetTitle>\n            <SheetDescription>\n              Tune chart presentation while training keeps running.\n            </SheetDescription>\n          </SheetHeader>\n          <div className=\"flex-1 space-y-6 overflow-y-auto px-6 pb-6\">\n            <div className=\"space-y-3\">\n              <div>\n                <p className=\"text-sm font-medium\">View window</p>\n                <p className=\"text-xs text-muted-foreground\">\n                  Show latest steps only or the full history.\n                </p>\n              </div>\n              <div className=\"space-y-2\">\n                <div className=\"flex items-center justify-between text-xs text-muted-foreground\">\n                  <span>Window</span>\n                  <span className=\"tabular-nums\">\n                    {showingAll ? \"All\" : effectiveWindowSize}\n                  </span>\n                </div>\n                <Slider\n                  value={[effectiveWindowSize]}\n                  onValueChange={([value]) => setWindowSize(value)}\n                  min={minWindow}\n                  max={sliderMax}\n                  step={1}\n                  disabled={availableSteps <= 1}\n                />\n              </div>\n            </div>\n            <Separator />\n            <div className=\"space-y-4\">\n              <div>\n                <p className=\"text-sm font-medium\">Training loss</p>\n                <p className=\"text-xs text-muted-foreground\">\n                  Control overlays and EMA smoothing.\n                </p>\n              </div>\n              <div className=\"space-y-2\">\n                <div className=\"flex items-center justify-between text-xs text-muted-foreground\">\n                  <span>Smoothing</span>\n                  <span className=\"tabular-nums\">{smoothing.toFixed(2)}</span>\n                </div>\n                <Slider\n                  value={[smoothing]}\n                  onValueChange={([value]) => setSmoothing(value)}\n                  min={0}\n                  max={0.9}\n                  step={0.01}\n                />\n                <p className=\"text-[11px] text-muted-foreground\">\n                  Move right for more smoothing. `0` = raw.\n                </p>\n              </div>\n              <SettingRow\n                label=\"Show raw loss\"\n                control={\n                  <Switch checked={showRaw} onCheckedChange={setShowRaw} />\n                }\n              />\n              <SettingRow\n                label=\"Show smoothed loss\"\n                control={\n                  <Switch\n                    checked={showSmoothed}\n                    onCheckedChange={setShowSmoothed}\n                  />\n                }\n              />\n              <SettingRow\n                label=\"Show average line\"\n                control={\n                  <Switch\n                    checked={showAvgLine}\n                    onCheckedChange={setShowAvgLine}\n                  />\n                }\n              />\n            </div>\n            <Separator />\n            <ScaleSection\n              title=\"Loss axis\"\n              scale={lossScale}\n              setScale={setLossScale}\n              outlierMode={lossOutlierMode}\n              setOutlierMode={setLossOutlierMode}\n            />\n            <Separator />\n            <ScaleSection\n              title=\"Gradient norm axis\"\n              scale={gradScale}\n              setScale={setGradScale}\n              outlierMode={gradOutlierMode}\n              setOutlierMode={setGradOutlierMode}\n            />\n            <Separator />\n            <ScaleSection\n              title=\"Learning rate axis\"\n              scale={lrScale}\n              setScale={setLrScale}\n              outlierMode={lrOutlierMode}\n              setOutlierMode={setLrOutlierMode}\n            />\n          </div>\n          <SheetFooter className=\"mt-0 border-t border-border/60 bg-background/70 sm:flex-row sm:justify-between\">\n            <Button\n              type=\"button\"\n              variant=\"outline\"\n              size=\"sm\"\n              onClick={resetPreferences}\n            >\n              Reset defaults\n            </Button>\n            <Button type=\"button\" size=\"sm\" onClick={() => setOpen(false)}>\n              Done\n            </Button>\n          </SheetFooter>\n        </SheetContent>\n      </Sheet>\n    </>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/studio/sections/charts/eval-loss-chart-card.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Card, CardContent, CardHeader, CardTitle } from \"@/components/ui/card\";\nimport {\n  ChartContainer,\n  ChartLegend,\n  ChartLegendContent,\n  ChartTooltip,\n  ChartTooltipContent,\n} from \"@/components/ui/chart\";\nimport type { ChartConfig } from \"@/components/ui/chart\";\nimport { ChartAverageIcon } from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport type { ReactElement } from \"react\";\nimport { CartesianGrid, Line, LineChart, XAxis, YAxis } from \"recharts\";\nimport {\n  CHART_CONTAINER_CLASS,\n  DEFAULT_CHART_MARGIN,\n  DEFAULT_Y_AXIS_WIDTH,\n  formatAxisMetric,\n  formatMetric,\n  formatStepTick,\n  placeholderEvalData,\n} from \"./utils\";\n\nconst evalLossConfig = {\n  loss: { label: \"Eval Loss\", color: \"#ef4444\" },\n} satisfies ChartConfig;\n\nexport function EvalLossChartCard({\n  data,\n  domain,\n  ticks,\n  isTraining,\n  evalEnabled,\n}: {\n  data: { step: number; loss: number }[];\n  domain: [number, number];\n  ticks?: number[];\n  isTraining: boolean;\n  evalEnabled: boolean;\n}): ReactElement {\n  return (\n    <Card data-tour=\"studio-eval-loss\" size=\"sm\">\n      <CardHeader>\n        <CardTitle className={`text-sm${data.length > 0 ? \"\" : \" text-muted-foreground\"}`}>\n          Eval Loss\n        </CardTitle>\n      </CardHeader>\n      <CardContent>\n        {data.length > 0 ? (\n          <ChartContainer config={evalLossConfig} className={CHART_CONTAINER_CLASS}>\n            <LineChart\n              data={data}\n              accessibilityLayer={true}\n              margin={DEFAULT_CHART_MARGIN}\n            >\n              <CartesianGrid vertical={false} strokeDasharray=\"3 3\" />\n              <XAxis\n                dataKey=\"step\"\n                type=\"number\"\n                domain={[\"dataMin\", \"dataMax\"]}\n                ticks={ticks}\n                allowDataOverflow={true}\n                allowDecimals={false}\n                minTickGap={28}\n                tickLine={false}\n                axisLine={false}\n                tickMargin={8}\n                fontSize={10}\n                tickFormatter={(value) => formatStepTick(Number(value))}\n                interval=\"preserveStartEnd\"\n              />\n              <YAxis\n                domain={domain}\n                allowDataOverflow={true}\n                tickLine={false}\n                axisLine={false}\n                tickMargin={8}\n                tickCount={5}\n                fontSize={10}\n                width={DEFAULT_Y_AXIS_WIDTH}\n                tickFormatter={(value) => formatAxisMetric(Number(value))}\n              />\n              <ChartTooltip\n                content={\n                  <ChartTooltipContent\n                    labelFormatter={(_value, payload) =>\n                      `Step ${payload?.[0]?.payload?.step ?? \"\"}`\n                    }\n                    formatter={(_value, _name, item) => [\n                      formatMetric(Number(item?.payload?.loss)),\n                      \"Eval Loss\",\n                    ]}\n                  />\n                }\n              />\n              <Line\n                type=\"monotone\"\n                dataKey=\"loss\"\n                stroke=\"var(--color-loss)\"\n                strokeWidth={2}\n                dot={{ r: 3, strokeWidth: 0, fill: \"#ef4444\" }}\n                activeDot={{ r: 4, strokeWidth: 0 }}\n                connectNulls={true}\n                isAnimationActive={false}\n              />\n              <ChartLegend content={<ChartLegendContent />} />\n            </LineChart>\n          </ChartContainer>\n        ) : (\n          <div className=\"relative\">\n            <ChartContainer\n              config={evalLossConfig}\n              className={`${CHART_CONTAINER_CLASS} blur`}\n            >\n              <LineChart\n                data={placeholderEvalData}\n                accessibilityLayer={true}\n                margin={DEFAULT_CHART_MARGIN}\n              >\n                <CartesianGrid vertical={false} strokeDasharray=\"3 3\" />\n                <XAxis\n                  dataKey=\"step\"\n                  type=\"number\"\n                  domain={[\"dataMin\", \"dataMax\"]}\n                  tickLine={false}\n                  axisLine={false}\n                  tickMargin={8}\n                  fontSize={10}\n                  interval=\"preserveStartEnd\"\n                />\n                <YAxis\n                  tickLine={false}\n                  axisLine={false}\n                  tickMargin={8}\n                  tickCount={5}\n                  fontSize={10}\n                  width={DEFAULT_Y_AXIS_WIDTH}\n                />\n                <Line\n                  type=\"monotone\"\n                  dataKey=\"loss\"\n                  stroke=\"var(--color-loss)\"\n                  strokeWidth={2}\n                  dot={false}\n                  isAnimationActive={false}\n                />\n              </LineChart>\n            </ChartContainer>\n            <div className=\"absolute inset-0 flex flex-col items-center justify-center gap-1\">\n              <HugeiconsIcon\n                icon={ChartAverageIcon}\n                className=\"size-5 text-muted-foreground/50\"\n              />\n              <p className=\"text-sm font-medium text-muted-foreground\">\n                {isTraining && evalEnabled\n                  ? \"Waiting for first evaluation step…\"\n                  : \"Evaluation not configured\"}\n              </p>\n              <p className=\"text-xs text-muted-foreground/60\">\n                {isTraining && evalEnabled\n                  ? \"Chart will appear once eval_steps is reached\"\n                  : \"Set eval dataset & eval_steps to track eval loss\"}\n              </p>\n            </div>\n          </div>\n        )}\n      </CardContent>\n    </Card>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/studio/sections/charts/grad-norm-chart-card.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Card, CardContent, CardHeader, CardTitle } from \"@/components/ui/card\";\nimport {\n  ChartContainer,\n  ChartLegend,\n  ChartLegendContent,\n  ChartTooltip,\n  ChartTooltipContent,\n} from \"@/components/ui/chart\";\nimport type { ChartConfig } from \"@/components/ui/chart\";\nimport type { ReactElement } from \"react\";\nimport { CartesianGrid, Line, LineChart, XAxis, YAxis } from \"recharts\";\nimport type { ScaleMode } from \"./types\";\nimport {\n  CHART_SYNC_ID,\n  CHART_CONTAINER_CLASS,\n  DEFAULT_CHART_MARGIN,\n  DEFAULT_Y_AXIS_WIDTH,\n  formatAxisMetric,\n  formatMetric,\n  formatStepTick,\n  fromLog1p,\n} from \"./utils\";\n\nconst gradNormConfig = {\n  displayGradNorm: { label: \"Grad Norm\", color: \"#f97316\" },\n} satisfies ChartConfig;\n\ninterface GradNormPoint {\n  step: number;\n  gradNorm: number;\n  displayGradNorm: number;\n}\n\nexport function GradNormChartCard({\n  data,\n  domain,\n  visibleStepDomain,\n  xAxisTicks,\n  scale,\n}: {\n  data: GradNormPoint[];\n  domain: [number, number];\n  visibleStepDomain: [number, number];\n  xAxisTicks: number[];\n  scale: ScaleMode;\n}): ReactElement {\n  const showPoint = data.length <= 1 ? { r: 3, strokeWidth: 0 } : false;\n\n  return (\n    <Card size=\"sm\">\n      <CardHeader>\n        <CardTitle className=\"text-sm\">Gradient Norm</CardTitle>\n      </CardHeader>\n      <CardContent>\n        <ChartContainer config={gradNormConfig} className={CHART_CONTAINER_CLASS}>\n          <LineChart\n            data={data}\n            syncId={CHART_SYNC_ID}\n            syncMethod=\"value\"\n            accessibilityLayer={true}\n            margin={DEFAULT_CHART_MARGIN}\n          >\n            <CartesianGrid vertical={false} strokeDasharray=\"3 3\" />\n            <XAxis\n              dataKey=\"step\"\n              type=\"number\"\n              domain={visibleStepDomain}\n              ticks={xAxisTicks}\n              allowDataOverflow={true}\n              allowDecimals={false}\n              minTickGap={28}\n              tickLine={false}\n              axisLine={false}\n              tickMargin={8}\n              fontSize={10}\n              tickFormatter={(value) => formatStepTick(Number(value))}\n              interval=\"preserveStartEnd\"\n            />\n            <YAxis\n              domain={domain}\n              allowDataOverflow={true}\n              tickLine={false}\n              axisLine={false}\n              tickMargin={8}\n              tickCount={5}\n              fontSize={10}\n              width={DEFAULT_Y_AXIS_WIDTH}\n              tickFormatter={(value) => {\n                const num = Number(value);\n                if (!Number.isFinite(num)) {\n                  return \"0\";\n                }\n                const shown = scale === \"log\" ? fromLog1p(num) : num;\n                return formatAxisMetric(shown);\n              }}\n            />\n            <ChartTooltip\n              content={\n                <ChartTooltipContent\n                  labelFormatter={(_value, payload) =>\n                    `Step ${payload?.[0]?.payload?.step ?? \"\"}`\n                  }\n                  formatter={(_value, _name, item) => {\n                    const raw = Number(item?.payload?.gradNorm);\n                    return [formatMetric(raw), \"Grad Norm\"];\n                  }}\n                />\n              }\n            />\n            <Line\n              type=\"linear\"\n              dataKey=\"displayGradNorm\"\n              stroke=\"var(--color-displayGradNorm)\"\n              strokeWidth={2}\n              dot={showPoint}\n              activeDot={{ r: 3, strokeWidth: 0 }}\n              connectNulls={true}\n              strokeLinecap=\"round\"\n              strokeLinejoin=\"round\"\n              isAnimationActive={false}\n            />\n            <ChartLegend content={<ChartLegendContent />} />\n          </LineChart>\n        </ChartContainer>\n      </CardContent>\n    </Card>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/studio/sections/charts/learning-rate-chart-card.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Card, CardContent, CardHeader, CardTitle } from \"@/components/ui/card\";\nimport {\n  ChartContainer,\n  ChartLegend,\n  ChartLegendContent,\n  ChartTooltip,\n  ChartTooltipContent,\n} from \"@/components/ui/chart\";\nimport type { ChartConfig } from \"@/components/ui/chart\";\nimport type { ReactElement } from \"react\";\nimport { CartesianGrid, Line, LineChart, XAxis, YAxis } from \"recharts\";\nimport type { ScaleMode } from \"./types\";\nimport {\n  CHART_CONTAINER_CLASS,\n  CHART_SYNC_ID,\n  DEFAULT_CHART_MARGIN,\n  DEFAULT_Y_AXIS_WIDTH,\n  formatStepTick,\n  fromLog1p,\n} from \"./utils\";\n\nconst lrConfig = {\n  displayLr: { label: \"LR\", color: \"#8b5cf6\" },\n} satisfies ChartConfig;\n\ninterface LearningRatePoint {\n  step: number;\n  lr: number;\n  displayLr: number;\n}\n\nexport function LearningRateChartCard({\n  data,\n  domain,\n  visibleStepDomain,\n  xAxisTicks,\n  scale,\n}: {\n  data: LearningRatePoint[];\n  domain: [number, number];\n  visibleStepDomain: [number, number];\n  xAxisTicks: number[];\n  scale: ScaleMode;\n}): ReactElement {\n  const showPoint = data.length <= 1 ? { r: 3, strokeWidth: 0 } : false;\n\n  return (\n    <Card size=\"sm\">\n      <CardHeader>\n        <CardTitle className=\"text-sm\">Learning Rate</CardTitle>\n      </CardHeader>\n      <CardContent>\n        <ChartContainer config={lrConfig} className={CHART_CONTAINER_CLASS}>\n          <LineChart\n            data={data}\n            syncId={CHART_SYNC_ID}\n            syncMethod=\"value\"\n            accessibilityLayer={true}\n            margin={DEFAULT_CHART_MARGIN}\n          >\n            <CartesianGrid vertical={false} strokeDasharray=\"3 3\" />\n            <XAxis\n              dataKey=\"step\"\n              type=\"number\"\n              domain={visibleStepDomain}\n              ticks={xAxisTicks}\n              allowDataOverflow={true}\n              allowDecimals={false}\n              minTickGap={28}\n              tickLine={false}\n              axisLine={false}\n              tickMargin={8}\n              fontSize={10}\n              tickFormatter={(value) => formatStepTick(Number(value))}\n              interval=\"preserveStartEnd\"\n            />\n            <YAxis\n              domain={domain}\n              allowDataOverflow={true}\n              tickLine={false}\n              axisLine={false}\n              tickMargin={8}\n              tickCount={5}\n              fontSize={10}\n              width={DEFAULT_Y_AXIS_WIDTH}\n              tickFormatter={(value) => {\n                const num = Number(value);\n                if (!Number.isFinite(num)) {\n                  return \"0e+0\";\n                }\n                const shown = scale === \"log\" ? fromLog1p(num) : num;\n                return shown.toExponential(0);\n              }}\n            />\n            <ChartTooltip\n              content={\n                <ChartTooltipContent\n                  labelFormatter={(_value, payload) =>\n                    `Step ${payload?.[0]?.payload?.step ?? \"\"}`\n                  }\n                  formatter={(_value, _name, item) => {\n                    const raw = Number(item?.payload?.lr);\n                    return [\n                      Number.isFinite(raw) ? raw.toExponential(3) : \"0e+0\",\n                      \"LR\",\n                    ];\n                  }}\n                />\n              }\n            />\n            <Line\n              type=\"linear\"\n              dataKey=\"displayLr\"\n              stroke=\"var(--color-displayLr)\"\n              strokeWidth={2}\n              dot={showPoint}\n              activeDot={{ r: 3, strokeWidth: 0 }}\n              connectNulls={true}\n              strokeLinecap=\"round\"\n              strokeLinejoin=\"round\"\n              isAnimationActive={false}\n            />\n            <ChartLegend content={<ChartLegendContent />} />\n          </LineChart>\n        </ChartContainer>\n      </CardContent>\n    </Card>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/studio/sections/charts/training-loss-chart-card.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Card, CardContent, CardHeader, CardTitle } from \"@/components/ui/card\";\nimport {\n  ChartContainer,\n  ChartLegend,\n  ChartLegendContent,\n  ChartTooltip,\n  ChartTooltipContent,\n} from \"@/components/ui/chart\";\nimport type { ChartConfig } from \"@/components/ui/chart\";\nimport type { ReactElement } from \"react\";\nimport {\n  CartesianGrid,\n  Line,\n  LineChart,\n  ReferenceLine,\n  XAxis,\n  YAxis,\n} from \"recharts\";\nimport type { ScaleMode } from \"./types\";\nimport {\n  CHART_SYNC_ID,\n  CHART_CONTAINER_CLASS,\n  DEFAULT_CHART_MARGIN,\n  DEFAULT_Y_AXIS_WIDTH,\n  formatAxisMetric,\n  formatMetric,\n  formatStepTick,\n  fromLog1p,\n} from \"./utils\";\n\nconst lossConfig = {\n  displayLoss: { label: \"Loss\", color: \"#3b82f6\" },\n  displaySmoothed: { label: \"Smoothed\", color: \"#f59e0b\" },\n} satisfies ChartConfig;\n\ninterface LossChartPoint {\n  step: number;\n  loss: number;\n  smoothed: number;\n  displayLoss: number;\n  displaySmoothed: number;\n}\n\nexport function TrainingLossChartCard({\n  data,\n  domain,\n  visibleStepDomain,\n  xAxisTicks,\n  avgRaw,\n  avgDisplay,\n  showRaw,\n  showSmoothed,\n  showAvgLine,\n  scale,\n}: {\n  data: LossChartPoint[];\n  domain: [number, number];\n  visibleStepDomain: [number, number];\n  xAxisTicks: number[];\n  avgRaw: number;\n  avgDisplay: number;\n  showRaw: boolean;\n  showSmoothed: boolean;\n  showAvgLine: boolean;\n  scale: ScaleMode;\n}): ReactElement {\n  const showPoint = data.length <= 1 ? { r: 3, strokeWidth: 0 } : false;\n\n  return (\n    <Card data-tour=\"studio-training-loss\" size=\"sm\">\n      <CardHeader>\n        <CardTitle className=\"text-sm\">Training Loss</CardTitle>\n      </CardHeader>\n      <CardContent>\n        <ChartContainer config={lossConfig} className={CHART_CONTAINER_CLASS}>\n          <LineChart\n            data={data}\n            syncId={CHART_SYNC_ID}\n            syncMethod=\"value\"\n            accessibilityLayer={true}\n            margin={DEFAULT_CHART_MARGIN}\n          >\n            <CartesianGrid vertical={false} strokeDasharray=\"3 3\" />\n            <XAxis\n              dataKey=\"step\"\n              type=\"number\"\n              domain={visibleStepDomain}\n              ticks={xAxisTicks}\n              allowDataOverflow={true}\n              allowDecimals={false}\n              minTickGap={28}\n              tickLine={false}\n              axisLine={false}\n              tickMargin={8}\n              fontSize={10}\n              tickFormatter={(value) => formatStepTick(Number(value))}\n              interval=\"preserveStartEnd\"\n            />\n            <YAxis\n              domain={domain}\n              allowDataOverflow={true}\n              tickLine={false}\n              axisLine={false}\n              tickMargin={8}\n              tickCount={5}\n              fontSize={10}\n              width={DEFAULT_Y_AXIS_WIDTH}\n              tickFormatter={(value) => {\n                const num = Number(value);\n                if (!Number.isFinite(num)) {\n                  return \"0\";\n                }\n                const shown = scale === \"log\" ? fromLog1p(num) : num;\n                return formatAxisMetric(shown);\n              }}\n            />\n            <ChartTooltip\n              content={\n                <ChartTooltipContent\n                  labelFormatter={(_value, payload) =>\n                    `Step ${payload?.[0]?.payload?.step ?? \"\"}`\n                  }\n                  formatter={(_value, name, item) => {\n                    if (name === \"displaySmoothed\") {\n                      return [\n                        formatMetric(Number(item?.payload?.smoothed)),\n                        \"Smoothed\",\n                      ];\n                    }\n                    return [formatMetric(Number(item?.payload?.loss)), \"Loss\"];\n                  }}\n                />\n              }\n            />\n            {showAvgLine && (\n              <ReferenceLine\n                y={avgDisplay}\n                stroke=\"#3b82f6\"\n                strokeDasharray=\"4 4\"\n                strokeOpacity={0.5}\n                label={{\n                  value: `avg ${formatMetric(avgRaw)}`,\n                  position: \"insideTopRight\",\n                  fontSize: 10,\n                  fill: \"#3b82f6\",\n                }}\n              />\n            )}\n            {showRaw && (\n              <Line\n                type=\"linear\"\n                dataKey=\"displayLoss\"\n                stroke=\"var(--color-displayLoss)\"\n                strokeWidth={1.2}\n                strokeOpacity={showSmoothed ? 0.35 : 1}\n                dot={showPoint}\n                activeDot={{ r: 3, strokeWidth: 0 }}\n                connectNulls={true}\n                strokeLinecap=\"round\"\n                strokeLinejoin=\"round\"\n                isAnimationActive={false}\n              />\n            )}\n            {showSmoothed && (\n              <Line\n                type=\"linear\"\n                dataKey=\"displaySmoothed\"\n                stroke=\"var(--color-displaySmoothed)\"\n                strokeWidth={2.2}\n                dot={showPoint}\n                activeDot={{ r: 3, strokeWidth: 0 }}\n                connectNulls={true}\n                strokeLinecap=\"round\"\n                strokeLinejoin=\"round\"\n                isAnimationActive={false}\n              />\n            )}\n            <ChartLegend content={<ChartLegendContent />} />\n          </LineChart>\n        </ChartContainer>\n      </CardContent>\n    </Card>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/studio/sections/charts/types.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nexport type ScaleMode = \"linear\" | \"log\";\nexport type OutlierMode = \"none\" | \"p99\" | \"p95\";\n\nexport type LossHistoryItem = { step: number; loss: number };\nexport type SmoothedLossItem = LossHistoryItem & { smoothed: number };\n\nexport interface TrainingChartSeries {\n  lossHistory: LossHistoryItem[];\n  lrHistory: { step: number; lr: number }[];\n  gradNormHistory: { step: number; gradNorm: number }[];\n  evalLossHistory: { step: number; loss: number }[];\n}\n"
  },
  {
    "path": "studio/frontend/src/features/studio/sections/charts/utils.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { LossHistoryItem, OutlierMode, SmoothedLossItem } from \"./types\";\n\nexport const CHART_SYNC_ID = \"train-metrics-sync\";\nexport const MAX_RENDER_POINTS = 800;\nexport const DEFAULT_VISIBLE_POINTS = 160;\nexport const CHART_CONTAINER_CLASS = \"h-[220px] w-full\";\nexport const DEFAULT_CHART_MARGIN = { top: 4, right: 8, bottom: 0, left: 4 };\nexport const DEFAULT_Y_AXIS_WIDTH = 45;\nconst TRAILING_ZEROES_RE = /\\.?0+$/;\nconst NEGATIVE_ZERO_RE = /^-0$/;\n\nexport const placeholderEvalData = [\n  { step: 0, loss: 2.8 },\n  { step: 50, loss: 2.4 },\n  { step: 100, loss: 2.0 },\n  { step: 150, loss: 1.7 },\n  { step: 200, loss: 1.5 },\n];\n\nexport function toLog1p(value: number): number {\n  const safe = Number.isFinite(value) ? Math.max(value, 0) : 0;\n  return Math.log10(safe + 1);\n}\n\nexport function fromLog1p(value: number): number {\n  return Math.max(0, 10 ** value - 1);\n}\n\nexport function formatMetric(value: number): string {\n  if (!Number.isFinite(value)) {\n    return \"0\";\n  }\n  const abs = Math.abs(value);\n  let decimals = 6;\n\n  if (abs >= 1000) {\n    decimals = 0;\n  } else if (abs >= 100) {\n    decimals = 2;\n  } else if (abs >= 1) {\n    decimals = 4;\n  } else if (abs >= 0.01) {\n    decimals = 5;\n  } else if (abs >= 0.0001) {\n    decimals = 6;\n  } else {\n    decimals = 8;\n  }\n\n  return value\n    .toFixed(decimals)\n    .replace(TRAILING_ZEROES_RE, \"\")\n    .replace(NEGATIVE_ZERO_RE, \"0\");\n}\n\nexport function formatAxisMetric(value: number): string {\n  if (!Number.isFinite(value)) {\n    return \"0\";\n  }\n\n  const abs = Math.abs(value);\n  let decimals = 4;\n\n  if (abs >= 1000) {\n    decimals = 0;\n  } else if (abs >= 100) {\n    decimals = 1;\n  } else if (abs >= 1) {\n    decimals = 3;\n  } else if (abs >= 0.01) {\n    decimals = 4;\n  } else {\n    decimals = 5;\n  }\n\n  return value\n    .toFixed(decimals)\n    .replace(TRAILING_ZEROES_RE, \"\")\n    .replace(NEGATIVE_ZERO_RE, \"0\");\n}\n\nexport function formatStepTick(value: number): string {\n  if (value >= 1_000_000) {\n    return `${(value / 1_000_000).toFixed(1)}M`;\n  }\n  if (value >= 1_000) {\n    return `${(value / 1_000).toFixed(1)}k`;\n  }\n  return String(Math.round(value));\n}\n\nexport function compressSeries<T>(data: T[], maxPoints: number): T[] {\n  if (data.length <= maxPoints) {\n    return data;\n  }\n\n  const stride = Math.ceil(data.length / maxPoints);\n  return data.filter(\n    (_item, index) => index % stride === 0 || index === data.length - 1,\n  );\n}\n\nexport function clamp(value: number, min: number, max: number): number {\n  return Math.min(max, Math.max(min, value));\n}\n\nexport function buildStepTicks(\n  min: number,\n  max: number,\n  targetCount = 6,\n): number[] {\n  if (!(Number.isFinite(min) && Number.isFinite(max))) {\n    return [0, 1];\n  }\n  if (max <= min) {\n    return [min, max];\n  }\n\n  const stepSize = Math.max(1, Math.ceil((max - min) / (targetCount - 1)));\n  const ticks: number[] = [];\n  let current = min;\n\n  while (current < max) {\n    ticks.push(current);\n    current += stepSize;\n  }\n\n  ticks.push(max);\n  return Array.from(new Set(ticks));\n}\n\nexport function buildYDomain(values: number[]): [number, number] {\n  const finiteValues = values.filter((value) => Number.isFinite(value));\n  if (finiteValues.length === 0) {\n    return [0, 1];\n  }\n\n  const min = Math.min(...finiteValues);\n  const max = Math.max(...finiteValues);\n\n  if (min === max) {\n    const base = Math.abs(min);\n    const pad = base > 0 ? base * 0.08 : 0.1;\n    return [min - pad, max + pad];\n  }\n\n  const pad = (max - min) * 0.12;\n  return [min - pad, max + pad];\n}\n\nfunction getUpperPercentile(\n  values: number[],\n  mode: OutlierMode,\n): number | null {\n  if (mode === \"none\") {\n    return null;\n  }\n  const finiteValues = values.filter((value) => Number.isFinite(value));\n  if (finiteValues.length < 3) {\n    return null;\n  }\n\n  const sorted = [...finiteValues].sort((a, b) => a - b);\n  const q = mode === \"p99\" ? 0.99 : 0.95;\n  const index = Math.max(\n    0,\n    Math.min(sorted.length - 1, Math.floor((sorted.length - 1) * q)),\n  );\n  return sorted[index] ?? null;\n}\n\nexport function applyOutlierCap(values: number[], mode: OutlierMode): number[] {\n  const cap = getUpperPercentile(values, mode);\n  if (cap == null) {\n    return values;\n  }\n  return values.map((value) => Math.min(value, cap));\n}\n\nexport function ema(\n  data: LossHistoryItem[],\n  alpha: number,\n): SmoothedLossItem[] {\n  if (data.length === 0) {\n    return [];\n  }\n\n  const values = data.map((point) => point.loss);\n  const isConstant = values.every((value) => value === values[0]);\n\n  let last = 0;\n  let count = 0;\n\n  return data.map((point) => {\n    const next = point.loss;\n    if (!Number.isFinite(next) || isConstant) {\n      return { ...point, smoothed: next };\n    }\n\n    last = last * alpha + (1 - alpha) * next;\n    count += 1;\n\n    const debias = alpha === 1 ? 1 : 1 - alpha ** count;\n    return { ...point, smoothed: last / debias };\n  });\n}\n"
  },
  {
    "path": "studio/frontend/src/features/studio/sections/charts-content.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { type ReactElement, useEffect, useMemo } from \"react\";\nimport { useShallow } from \"zustand/react/shallow\";\nimport { useChartPreferencesStore } from \"./charts/chart-preferences-store\";\nimport { EvalLossChartCard } from \"./charts/eval-loss-chart-card\";\nimport { GradNormChartCard } from \"./charts/grad-norm-chart-card\";\nimport { LearningRateChartCard } from \"./charts/learning-rate-chart-card\";\nimport { TrainingLossChartCard } from \"./charts/training-loss-chart-card\";\nimport type { TrainingChartSeries } from \"./charts/types\";\nimport {\n  MAX_RENDER_POINTS,\n  applyOutlierCap,\n  buildStepTicks,\n  buildYDomain,\n  clamp,\n  compressSeries,\n  ema,\n  toLog1p,\n} from \"./charts/utils\";\n\ntype LossDisplayPoint = {\n  step: number;\n  displayLoss: number;\n  displaySmoothed: number;\n};\n\nfunction isStepVisible(step: number, domain: [number, number]): boolean {\n  return step >= domain[0] && step <= domain[1];\n}\n\nfunction collectLossValues(\n  data: LossDisplayPoint[],\n  domain: [number, number],\n  options: { includeRaw: boolean; includeSmoothed: boolean },\n): number[] {\n  const values: number[] = [];\n\n  for (const point of data) {\n    if (!isStepVisible(point.step, domain)) {\n      continue;\n    }\n\n    if (options.includeRaw && Number.isFinite(point.displayLoss)) {\n      values.push(point.displayLoss);\n    }\n\n    if (options.includeSmoothed && Number.isFinite(point.displaySmoothed)) {\n      values.push(point.displaySmoothed);\n    }\n  }\n\n  return values;\n}\n\nexport function ChartsContent({\n  metrics,\n  isTraining,\n  evalEnabled,\n}: {\n  metrics: TrainingChartSeries;\n  isTraining: boolean;\n  evalEnabled: boolean;\n}): ReactElement {\n  const {\n    windowSize,\n    smoothing,\n    showRaw,\n    showSmoothed,\n    showAvgLine,\n    lossScale,\n    lrScale,\n    gradScale,\n    lossOutlierMode,\n    gradOutlierMode,\n    lrOutlierMode,\n    setAvailableSteps,\n  } = useChartPreferencesStore(\n    useShallow((state) => ({\n      windowSize: state.windowSize,\n      smoothing: state.smoothing,\n      showRaw: state.showRaw,\n      showSmoothed: state.showSmoothed,\n      showAvgLine: state.showAvgLine,\n      lossScale: state.lossScale,\n      lrScale: state.lrScale,\n      gradScale: state.gradScale,\n      lossOutlierMode: state.lossOutlierMode,\n      gradOutlierMode: state.gradOutlierMode,\n      lrOutlierMode: state.lrOutlierMode,\n      setAvailableSteps: state.setAvailableSteps,\n    })),\n  );\n\n  const smoothedData = useMemo(\n    () =>\n      metrics.lossHistory.length > 0 ? ema(metrics.lossHistory, smoothing) : [],\n    [metrics.lossHistory, smoothing],\n  );\n\n  const reducedLossData = useMemo(\n    () => compressSeries(smoothedData, MAX_RENDER_POINTS),\n    [smoothedData],\n  );\n  const reducedGradNormData = useMemo(\n    () => compressSeries(metrics.gradNormHistory, MAX_RENDER_POINTS),\n    [metrics.gradNormHistory],\n  );\n  const reducedLrData = useMemo(\n    () => compressSeries(metrics.lrHistory, MAX_RENDER_POINTS),\n    [metrics.lrHistory],\n  );\n  const reducedEvalLossData = useMemo(\n    () => compressSeries(metrics.evalLossHistory, MAX_RENDER_POINTS),\n    [metrics.evalLossHistory],\n  );\n\n  const allSteps = useMemo(() => {\n    const set = new Set<number>();\n    for (const point of metrics.lossHistory) {\n      set.add(point.step);\n    }\n    for (const point of metrics.gradNormHistory) {\n      set.add(point.step);\n    }\n    for (const point of metrics.lrHistory) {\n      set.add(point.step);\n    }\n    return Array.from(set).sort((a, b) => a - b);\n  }, [metrics.gradNormHistory, metrics.lossHistory, metrics.lrHistory]);\n\n  useEffect(() => {\n    setAvailableSteps(allSteps.length);\n  }, [allSteps.length, setAvailableSteps]);\n\n  const stepCount = Math.max(1, allSteps.length);\n  const effectiveWindowSize =\n    windowSize == null\n      ? stepCount\n      : clamp(Math.round(windowSize), 1, stepCount);\n\n  const visibleStepDomain = useMemo<[number, number]>(() => {\n    if (allSteps.length === 0) {\n      return [0, 1];\n    }\n\n    const endIndex = allSteps.length - 1;\n    const startIndex = Math.max(0, endIndex - effectiveWindowSize + 1);\n    const minStep = allSteps[0] ?? 0;\n    const startStep = allSteps[startIndex] ?? minStep;\n    const endStep = allSteps[endIndex] ?? startStep;\n\n    if (startStep === endStep) {\n      return [startStep, startStep + 4];\n    }\n    if (endStep - startStep < 6) {\n      return [Math.max(minStep, endStep - 6), endStep];\n    }\n    return [startStep, endStep];\n  }, [allSteps, effectiveWindowSize]);\n\n  const xAxisTicks = useMemo(\n    () => buildStepTicks(visibleStepDomain[0], visibleStepDomain[1]),\n    [visibleStepDomain],\n  );\n\n  const displayLossData = useMemo(\n    () =>\n      reducedLossData.map((point) => ({\n        ...point,\n        displayLoss: lossScale === \"log\" ? toLog1p(point.loss) : point.loss,\n        displaySmoothed:\n          lossScale === \"log\" ? toLog1p(point.smoothed) : point.smoothed,\n      })),\n    [lossScale, reducedLossData],\n  );\n\n  const displayGradData = useMemo(\n    () =>\n      reducedGradNormData.map((point) => ({\n        ...point,\n        displayGradNorm:\n          gradScale === \"log\" ? toLog1p(point.gradNorm) : point.gradNorm,\n      })),\n    [gradScale, reducedGradNormData],\n  );\n\n  const displayLrData = useMemo(\n    () =>\n      reducedLrData.map((point) => ({\n        ...point,\n        displayLr: lrScale === \"log\" ? toLog1p(point.lr) : point.lr,\n      })),\n    [lrScale, reducedLrData],\n  );\n\n  const visibleLossDisplayValues = useMemo(() => {\n    const visibleValues = collectLossValues(\n      displayLossData,\n      visibleStepDomain,\n      {\n        includeRaw: showRaw,\n        includeSmoothed: showSmoothed,\n      },\n    );\n\n    if (visibleValues.length > 0) {\n      return visibleValues;\n    }\n\n    return collectLossValues(displayLossData, visibleStepDomain, {\n      includeRaw: true,\n      includeSmoothed: true,\n    });\n  }, [displayLossData, showRaw, showSmoothed, visibleStepDomain]);\n\n  const visibleGradDisplayValues = useMemo(\n    () =>\n      displayGradData\n        .filter(\n          (point) =>\n            point.step >= visibleStepDomain[0] &&\n            point.step <= visibleStepDomain[1],\n        )\n        .map((point) => point.displayGradNorm)\n        .filter((value) => Number.isFinite(value)),\n    [displayGradData, visibleStepDomain],\n  );\n\n  const visibleLrDisplayValues = useMemo(\n    () =>\n      displayLrData\n        .filter(\n          (point) =>\n            point.step >= visibleStepDomain[0] &&\n            point.step <= visibleStepDomain[1],\n        )\n        .map((point) => point.displayLr)\n        .filter((value) => Number.isFinite(value)),\n    [displayLrData, visibleStepDomain],\n  );\n\n  const lossDomain = useMemo(\n    () =>\n      buildYDomain(applyOutlierCap(visibleLossDisplayValues, lossOutlierMode)),\n    [lossOutlierMode, visibleLossDisplayValues],\n  );\n  const gradDomain = useMemo(\n    () =>\n      buildYDomain(applyOutlierCap(visibleGradDisplayValues, gradOutlierMode)),\n    [gradOutlierMode, visibleGradDisplayValues],\n  );\n  const lrDomain = useMemo(\n    () => buildYDomain(applyOutlierCap(visibleLrDisplayValues, lrOutlierMode)),\n    [lrOutlierMode, visibleLrDisplayValues],\n  );\n\n  const evalLossDomain = useMemo(() => {\n    const vals = reducedEvalLossData.map((point) => point.loss);\n    return buildYDomain(vals);\n  }, [reducedEvalLossData]);\n\n  const evalLossStepTicks = useMemo(() => {\n    if (reducedEvalLossData.length < 2) {\n      return undefined;\n    }\n    const min = reducedEvalLossData[0].step;\n    const max = reducedEvalLossData[reducedEvalLossData.length - 1].step;\n    return buildStepTicks(min, max);\n  }, [reducedEvalLossData]);\n\n  const avgRaw =\n    metrics.lossHistory.length > 0\n      ? +(\n          metrics.lossHistory.reduce((sum, point) => sum + point.loss, 0) /\n          metrics.lossHistory.length\n        ).toFixed(4)\n      : 0;\n  const avgDisplay = lossScale === \"log\" ? toLog1p(avgRaw) : avgRaw;\n\n  return (\n    <div className=\"grid grid-cols-1 gap-6 lg:grid-cols-2\">\n      <TrainingLossChartCard\n        data={displayLossData}\n        domain={lossDomain}\n        visibleStepDomain={visibleStepDomain}\n        xAxisTicks={xAxisTicks}\n        avgRaw={avgRaw}\n        avgDisplay={avgDisplay}\n        showRaw={showRaw}\n        showSmoothed={showSmoothed}\n        showAvgLine={showAvgLine}\n        scale={lossScale}\n      />\n      <GradNormChartCard\n        data={displayGradData}\n        domain={gradDomain}\n        visibleStepDomain={visibleStepDomain}\n        xAxisTicks={xAxisTicks}\n        scale={gradScale}\n      />\n      <LearningRateChartCard\n        data={displayLrData}\n        domain={lrDomain}\n        visibleStepDomain={visibleStepDomain}\n        xAxisTicks={xAxisTicks}\n        scale={lrScale}\n      />\n      <EvalLossChartCard\n        data={reducedEvalLossData}\n        domain={evalLossDomain}\n        ticks={evalLossStepTicks}\n        isTraining={isTraining}\n        evalEnabled={evalEnabled}\n      />\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/studio/sections/charts-section.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { useTrainingRuntimeStore } from \"@/features/training\";\nimport { type ReactElement, Suspense, lazy, useMemo } from \"react\";\n\nconst ChartsContent = lazy(() =>\n  import(\"./charts-content\").then((module) => ({\n    default: module.ChartsContent,\n  })),\n);\nconst SKELETON_KEYS = [\n  \"chart-skeleton-1\",\n  \"chart-skeleton-2\",\n  \"chart-skeleton-3\",\n  \"chart-skeleton-4\",\n];\n\nexport function ChartsSection(): ReactElement | null {\n  const currentStep = useTrainingRuntimeStore((state) => state.currentStep);\n  const totalSteps = useTrainingRuntimeStore((state) => state.totalSteps);\n  const isTraining = useTrainingRuntimeStore((state) => state.isTrainingRunning);\n  const evalEnabled = useTrainingRuntimeStore((state) => state.evalEnabled);\n  const lossHistoryRaw = useTrainingRuntimeStore((state) => state.lossHistory);\n  const lrHistoryRaw = useTrainingRuntimeStore((state) => state.lrHistory);\n  const gradNormHistoryRaw = useTrainingRuntimeStore(\n    (state) => state.gradNormHistory,\n  );\n  const evalLossHistoryRaw = useTrainingRuntimeStore(\n    (state) => state.evalLossHistory,\n  );\n\n  const series = useMemo(\n    () => ({\n      currentStep,\n      totalSteps,\n      lossHistory: lossHistoryRaw.map((point) => ({\n        step: point.step,\n        loss: point.value,\n      })),\n      lrHistory: lrHistoryRaw.map((point) => ({\n        step: point.step,\n        lr: point.value,\n      })),\n      gradNormHistory: gradNormHistoryRaw.map((point) => ({\n        step: point.step,\n        gradNorm: point.value,\n      })),\n      evalLossHistory: evalLossHistoryRaw.map((point) => ({\n        step: point.step,\n        loss: point.value,\n      })),\n    }),\n    [currentStep, evalLossHistoryRaw, gradNormHistoryRaw, lossHistoryRaw, lrHistoryRaw, totalSteps],\n  );\n\n  if (\n    series.lossHistory.length === 0 &&\n    series.lrHistory.length === 0 &&\n    series.gradNormHistory.length === 0\n  ) {\n    return null;\n  }\n\n  return (\n    <Suspense\n      fallback={\n        <div className=\"grid grid-cols-1 lg:grid-cols-2 gap-6\">\n          {SKELETON_KEYS.map((key) => (\n            <div\n              key={key}\n              className=\"h-[280px] rounded-xl border bg-muted/30 animate-pulse\"\n            />\n          ))}\n        </div>\n      }\n    >\n      <ChartsContent metrics={series} isTraining={isTraining} evalEnabled={evalEnabled} />\n    </Suspense>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/studio/sections/dataset-preview-dialog-mapping.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Badge } from \"@/components/ui/badge\";\nimport { Button } from \"@/components/ui/button\";\nimport {\n  Select,\n  SelectContent,\n  SelectItem,\n  SelectTrigger,\n  SelectValue,\n} from \"@/components/ui/select\";\nimport type { CheckFormatResponse } from \"@/features/training/types/datasets\";\nimport { cn } from \"@/lib/utils\";\nimport { AlertCircleIcon, CheckmarkCircle02Icon } from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport { Loader2, Sparkles } from \"lucide-react\";\n\nconst CHATML_ROLES = [\"system\", \"user\", \"assistant\"] as const;\nconst ALPACA_ROLES = [\"instruction\", \"input\", \"output\"] as const;\nconst SHAREGPT_ROLES = [\"system\", \"human\", \"gpt\"] as const;\nconst VLM_ROLES = [\"image\", \"text\"] as const;\nconst AUDIO_ROLES = [\"audio\", \"text\", \"speaker_id\"] as const;\n\nconst ROLE_LABELS: Record<string, string> = {\n  system: \"System\",\n  user: \"User\",\n  assistant: \"Assistant\",\n  human: \"Human\",\n  gpt: \"GPT\",\n  instruction: \"Instruction\",\n  input: \"Input\",\n  output: \"Output\",\n  image: \"Image\",\n  text: \"Text\",\n  audio: \"Audio\",\n  speaker_id: \"Speaker ID\",\n};\n\nexport function getAvailableRoles(isVlm: boolean, format?: string, isAudio?: boolean): readonly string[] {\n  if (isAudio) return AUDIO_ROLES;\n  if (isVlm) return VLM_ROLES;\n  if (format === \"alpaca\") return ALPACA_ROLES;\n  if (format === \"sharegpt\") return SHAREGPT_ROLES;\n  return CHATML_ROLES;\n}\n\nexport function isMappingComplete(\n  mapping: Record<string, string>,\n  isVlm: boolean,\n  format?: string,\n  isAudio?: boolean,\n): boolean {\n  const roles = new Set(Object.values(mapping));\n  if (isAudio) return roles.has(\"audio\") && roles.has(\"text\");\n  if (isVlm) return roles.has(\"image\") && roles.has(\"text\");\n  if (format === \"alpaca\") return roles.has(\"instruction\") && roles.has(\"output\");\n  if (format === \"sharegpt\") return roles.has(\"human\") && roles.has(\"gpt\");\n  return roles.has(\"user\") && roles.has(\"assistant\");\n}\n\nexport function HeaderRolePicker({\n  currentRole,\n  onRoleChange,\n  availableRoles,\n}: {\n  currentRole: string | undefined;\n  onRoleChange: (role: string | undefined) => void;\n  availableRoles: readonly string[];\n}) {\n  return (\n    <Select\n      value={currentRole ?? \"_none\"}\n      onValueChange={(v) => onRoleChange(v === \"_none\" ? undefined : v)}\n    >\n      <SelectTrigger className=\"h-6 w-[90px] text-[10px] px-2 py-0 border-dashed cursor-pointer\">\n        <SelectValue placeholder=\"Role...\" />\n      </SelectTrigger>\n      <SelectContent>\n        <SelectItem value=\"_none\" className=\"text-[11px]\">\n          None\n        </SelectItem>\n        {availableRoles.map((role) => (\n          <SelectItem key={role} value={role} className=\"text-[11px]\">\n            {ROLE_LABELS[role] ?? role}\n          </SelectItem>\n        ))}\n      </SelectContent>\n    </Select>\n  );\n}\n\nexport function DatasetMappingCard({\n  mapping,\n  mappingOk,\n  autoDetected = false,\n  isVlm = false,\n  isAudio = false,\n  format,\n  onAiAssist,\n  isAiLoading = false,\n  aiError,\n  advisorNotification,\n  advisorSystemPrompt,\n}: {\n  mapping: Record<string, string>;\n  mappingOk: boolean;\n  autoDetected?: boolean;\n  isVlm?: boolean;\n  isAudio?: boolean;\n  format?: string;\n  onAiAssist?: () => void;\n  isAiLoading?: boolean;\n  aiError?: string | null;\n  advisorNotification?: string | null;\n  advisorSystemPrompt?: string;\n}) {\n  const entries = Object.entries(mapping);\n  const requiredLabel = isAudio\n    ? \"audio and text\"\n    : isVlm\n      ? \"image and text\"\n      : format === \"alpaca\"\n        ? \"instruction and output\"\n        : format === \"sharegpt\"\n          ? \"human and gpt\"\n          : \"user and assistant\";\n\n  return (\n    <div\n      className={cn(\n        \"rounded-xl corner-squircle ring-1 px-5 py-4 mb-4\",\n        mappingOk\n          ? \"ring-emerald-200/70 bg-emerald-50/70 text-emerald-950 dark:ring-emerald-900/50 dark:bg-emerald-950/30 dark:text-emerald-50\"\n          : \"ring-amber-200/70 bg-amber-50/70 text-amber-950 dark:ring-amber-900/50 dark:bg-amber-950/30 dark:text-amber-50\",\n      )}\n    >\n      <div className=\"flex items-start gap-3\">\n        <div\n          className={cn(\n            \"rounded-xl corner-squircle p-2 shrink-0\",\n            mappingOk ? \"bg-emerald-500/15\" : \"bg-amber-500/15\",\n          )}\n        >\n          <HugeiconsIcon\n            icon={mappingOk ? CheckmarkCircle02Icon : AlertCircleIcon}\n            className={cn(\n              \"size-4\",\n              mappingOk\n                ? \"text-emerald-700 dark:text-emerald-300\"\n                : \"text-amber-700 dark:text-amber-300\",\n            )}\n          />\n        </div>\n        <div className=\"min-w-0\">\n          <p className=\"text-sm font-semibold tracking-tight\">\n            {mappingOk\n              ? autoDetected ? \"Heuristic-detected mapping\" : \"Mapping ready\"\n              : \"Map dataset columns\"}\n          </p>\n          <p\n            className={cn(\n              \"text-xs mt-0.5\",\n              mappingOk\n                ? \"text-emerald-800/80 dark:text-emerald-200/80\"\n                : \"text-amber-800/80 dark:text-amber-200/80\",\n            )}\n          >\n            {mappingOk\n              ? autoDetected\n                ? \"We auto-detected the column mapping below using heuristics. Please review and adjust using the dropdowns in the column headers, or use AI Assist for a smarter mapping.\"\n                : \"Looks good. We'll convert this dataset automatically.\"\n              : `Assign roles to columns using the dropdowns in the headers. At minimum, assign ${requiredLabel}.`}\n          </p>\n          {entries.length > 0 && (\n            <div className=\"mt-3 flex flex-wrap items-center gap-2\">\n              {entries.map(([col, role]) => (\n                <Badge\n                  key={col}\n                  variant=\"outline\"\n                  className=\"h-6 text-[11px] bg-white/60 dark:bg-transparent\"\n                >\n                  <span className=\"font-mono\">{col}</span>\n                  <span className=\"mx-1 text-muted-foreground/60\">&rarr;</span>\n                  <span>{ROLE_LABELS[role] ?? role}</span>\n                </Badge>\n              ))}\n            </div>\n          )}\n          {!mappingOk && entries.length === 0 && (\n            <p className=\"mt-2 text-xs text-amber-800/80 dark:text-amber-200/80\">\n              Use the dropdowns in the column headers to assign roles.\n            </p>\n          )}\n          {onAiAssist && (\n            <div className=\"mt-3 flex items-center gap-2\">\n              <Button\n                variant=\"outline\"\n                size=\"sm\"\n                onClick={onAiAssist}\n                disabled={isAiLoading}\n                className=\"cursor-pointer bg-white/60 dark:bg-transparent\"\n              >\n                {isAiLoading ? (\n                  <>\n                    <Loader2 className=\"mr-1.5 h-3.5 w-3.5 animate-spin\" />\n                    Analyzing dataset...\n                  </>\n                ) : (\n                  <>\n                    <Sparkles className=\"mr-1.5 h-3.5 w-3.5\" />\n                    AI Assist\n                    <Badge variant=\"outline\" className=\"ml-1.5 text-[9px] px-1 py-0 h-4 font-medium\">Beta</Badge>\n                  </>\n                )}\n              </Button>\n              {aiError && (\n                <p className=\"text-xs text-amber-700 dark:text-amber-300\">{aiError}</p>\n              )}\n            </div>\n          )}\n          {advisorNotification && (\n            <div className=\"mt-3 rounded-lg border border-indigo-200 bg-indigo-50 px-3 py-2.5 text-xs text-indigo-700 dark:border-indigo-800 dark:bg-indigo-950 dark:text-indigo-300 space-y-2\">\n              <div className=\"flex items-start gap-2\">\n                <Sparkles className=\"size-3.5 shrink-0 mt-0.5\" />\n                <span>{advisorNotification}</span>\n              </div>\n              {advisorSystemPrompt && (\n                <div className=\"pl-5.5 text-[11px] font-mono text-indigo-600/80 dark:text-indigo-400/80\">\n                  <span className=\"font-sans font-medium text-indigo-500 dark:text-indigo-400\">System:</span>{\" \"}\n                  <span className=\"break-words\">{advisorSystemPrompt}</span>\n                </div>\n              )}\n            </div>\n          )}\n        </div>\n      </div>\n    </div>\n  );\n}\n\nexport function DatasetMappingFooter({\n  mappingOk,\n  isStarting,\n  startError,\n  onCancel,\n  onStartTraining,\n}: {\n  mappingOk: boolean;\n  isStarting: boolean;\n  startError: string | null;\n  onCancel: () => void;\n  onStartTraining: () => Promise<void>;\n}) {\n  return (\n    <div className=\"mt-3 flex flex-col gap-2\">\n      <div className=\"flex items-center justify-between gap-3\">\n        <p className=\"text-[11px] text-muted-foreground/70 leading-relaxed\">\n          Tip: use the role dropdowns in the column headers to assign roles.\n        </p>\n        <div className=\"flex items-center gap-2\">\n          <Button\n            variant=\"outline\"\n            size=\"sm\"\n            className=\"cursor-pointer\"\n            onClick={onCancel}\n          >\n            Cancel\n          </Button>\n          <Button\n            size=\"sm\"\n            className=\"cursor-pointer\"\n            disabled={!mappingOk || isStarting}\n            onClick={() => void onStartTraining()}\n          >\n            {isStarting ? \"Starting...\" : \"Continue\"}\n          </Button>\n        </div>\n      </div>\n\n      {startError && (\n        <p className=\"text-xs text-red-500 leading-relaxed text-center\">\n          {startError}\n        </p>\n      )}\n    </div>\n  );\n}\n\n/** Canonical chatml role for any format-specific role name. */\nconst TO_CANONICAL: Record<string, string> = {\n  user: \"user\", assistant: \"assistant\", system: \"system\",\n  instruction: \"user\", input: \"system\", output: \"assistant\",\n  human: \"user\", gpt: \"assistant\",\n  image: \"image\", text: \"text\",\n  audio: \"audio\", speaker_id: \"speaker_id\",\n};\n\n/** Chatml → format-specific role names (only for formats that differ). */\nconst FROM_CANONICAL: Record<string, Record<string, string>> = {\n  alpaca: { user: \"instruction\", system: \"input\", assistant: \"output\" },\n  sharegpt: { user: \"human\", assistant: \"gpt\", system: \"system\" },\n};\n\n/**\n * Remap a column→role mapping between formats.\n * Normalises every role to canonical chatml first, then maps to the target format.\n */\nexport function remapRolesForFormat(\n  mapping: Record<string, string>,\n  format?: string,\n): Record<string, string> {\n  const table = format ? FROM_CANONICAL[format] : undefined;\n  const out: Record<string, string> = {};\n  for (const [col, role] of Object.entries(mapping)) {\n    const canonical = TO_CANONICAL[role] ?? role;\n    out[col] = table ? (table[canonical] ?? canonical) : canonical;\n  }\n  return out;\n}\n\nexport function deriveDefaultMapping(\n  data: CheckFormatResponse,\n  isVlm: boolean,\n  format?: string,\n  isAudio?: boolean,\n): Record<string, string> {\n  if (data.suggested_mapping) {\n    return remapRolesForFormat({ ...data.suggested_mapping }, format);\n  }\n  if (isAudio) {\n    const result: Record<string, string> = {};\n    if (data.detected_audio_column) result[data.detected_audio_column] = \"audio\";\n    if (data.detected_text_column) result[data.detected_text_column] = \"text\";\n    if (data.detected_speaker_column) result[data.detected_speaker_column] = \"speaker_id\";\n    return result;\n  }\n  if (isVlm) {\n    const result: Record<string, string> = {};\n    if (data.detected_image_column) result[data.detected_image_column] = \"image\";\n    if (data.detected_text_column) result[data.detected_text_column] = \"text\";\n    return result;\n  }\n  return {};\n}\n"
  },
  {
    "path": "studio/frontend/src/features/studio/sections/dataset-preview-dialog-utils.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\ntype PreviewImagePayload = {\n  type: \"image\";\n  mime?: string;\n  width?: number;\n  height?: number;\n  data?: string;\n};\n\nexport function formatCell(value: unknown): string {\n  if (value == null) return \"\";\n  if (typeof value === \"string\") return value;\n  if (typeof value === \"number\" || typeof value === \"boolean\") return String(value);\n  if (Array.isArray(value) || typeof value === \"object\")\n    return JSON.stringify(value).slice(0, 500);\n  return String(value);\n}\n\nfunction isPreviewImagePayload(value: unknown): value is PreviewImagePayload {\n  if (!value || typeof value !== \"object\") return false;\n  const record = value as Record<string, unknown>;\n  return (\n    record.type === \"image\" &&\n    typeof record.data === \"string\" &&\n    record.data.length > 0\n  );\n}\n\nexport function collectPreviewImages(value: unknown): PreviewImagePayload[] {\n  const images: PreviewImagePayload[] = [];\n  const stack: unknown[] = [value];\n  let steps = 0;\n\n  while (stack.length > 0 && steps < 200) {\n    steps += 1;\n    const current = stack.pop();\n    if (isPreviewImagePayload(current)) {\n      images.push(current);\n      continue;\n    }\n\n    if (Array.isArray(current)) {\n      for (const item of current) stack.push(item);\n      continue;\n    }\n\n    if (current && typeof current === \"object\") {\n      for (const nested of Object.values(current as Record<string, unknown>)) {\n        stack.push(nested);\n      }\n    }\n  }\n\n  return images;\n}\n\n"
  },
  {
    "path": "studio/frontend/src/features/studio/sections/dataset-preview-dialog.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { type ReactNode, useCallback, useEffect, useMemo, useRef, useState } from \"react\";\nimport { aiAssistMapping } from \"@/features/training/api/datasets-api\";\nimport type { ColumnDef } from \"@tanstack/react-table\";\nimport {\n  Dialog,\n  DialogContent,\n  DialogHeader,\n  DialogTitle,\n} from \"@/components/ui/dialog\";\nimport { DataTable } from \"@/components/ui/data-table\";\nimport { Badge } from \"@/components/ui/badge\";\nimport { Spinner } from \"@/components/ui/spinner\";\nimport { useTrainingActions, useTrainingConfigStore } from \"@/features/training\";\nimport { checkDatasetFormat } from \"@/features/training/api/datasets-api\";\nimport type { CheckFormatResponse } from \"@/features/training/types/datasets\";\nimport { Database02Icon, AlertCircleIcon } from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport { useShallow } from \"zustand/react/shallow\";\nimport { collectPreviewImages, formatCell } from \"./dataset-preview-dialog-utils\";\nimport {\n  DatasetMappingCard,\n  DatasetMappingFooter,\n  HeaderRolePicker,\n  deriveDefaultMapping,\n  getAvailableRoles,\n  isMappingComplete,\n  remapRolesForFormat,\n} from \"./dataset-preview-dialog-mapping\";\n\n/** Chatml → format-specific role remap (only for formats that differ from chatml). */\nconst ROLE_REMAP: Record<string, Record<string, string>> = {\n  alpaca: { user: \"instruction\", system: \"input\", assistant: \"output\" },\n  sharegpt: { user: \"human\", assistant: \"gpt\", system: \"system\" },\n};\n\ntype DatasetPreviewDialogProps = {\n  open: boolean;\n  onOpenChange: (open: boolean) => void;\n  datasetName: string | null;\n  datasetSource?: \"huggingface\" | \"upload\";\n  hfToken: string | null;\n  datasetSubset?: string | null;\n  datasetSplit?: string | null;\n  mode?: \"preview\" | \"mapping\";\n  initialData?: CheckFormatResponse | null;\n  isVlm?: boolean;\n};\n\nexport function DatasetPreviewDialog({\n  open,\n  onOpenChange,\n  datasetName,\n  datasetSource,\n  hfToken,\n  datasetSubset,\n  datasetSplit,\n  mode = \"preview\",\n  initialData,\n  isVlm = false,\n}: DatasetPreviewDialogProps) {\n  const [data, setData] = useState<CheckFormatResponse | null>(null);\n  const [loading, setLoading] = useState(false);\n  const [error, setError] = useState<string | null>(null);\n\n  const {\n    manualMapping, setManualMapping, datasetFormat,\n    setDatasetAdvisorFields, datasetAdvisorNotification,\n    datasetSystemPrompt,\n    selectedModel,\n    modelType,\n  } = useTrainingConfigStore(\n    useShallow((s) => ({\n      manualMapping: s.datasetManualMapping,\n      setManualMapping: s.setDatasetManualMapping,\n      datasetFormat: s.datasetFormat,\n      setDatasetAdvisorFields: s.setDatasetAdvisorFields,\n      datasetAdvisorNotification: s.datasetAdvisorNotification,\n      datasetSystemPrompt: s.datasetSystemPrompt,\n      selectedModel: s.selectedModel,\n      modelType: s.modelType,\n    })),\n  );\n  const { isStarting, startError, startTrainingRun } = useTrainingActions();\n\n  // If the backend reports image data, treat as VLM even if the prop\n  // hasn't caught up yet (isDatasetImage may still be null in the store).\n  const effectiveIsAudio = !!data?.is_audio;\n  const effectiveIsVlm = isVlm || !!data?.is_image;\n\n  const hasHeuristicMapping = !data?.requires_manual_mapping && !!data?.suggested_mapping;\n  const mappingEnabled = !!data?.requires_manual_mapping || hasHeuristicMapping;\n  const showMappingFooter = mode === \"mapping\" && mappingEnabled;\n  const mappingOk = isMappingComplete(manualMapping, effectiveIsVlm, datasetFormat, effectiveIsAudio);\n  const availableRoles = getAvailableRoles(effectiveIsVlm, datasetFormat, effectiveIsAudio);\n  const isHfDataset = datasetSource === \"huggingface\";\n\n  // ── AI Assist ──────────────────────────────────────────────────────\n  const [isAiLoading, setIsAiLoading] = useState(false);\n  const [aiError, setAiError] = useState<string | null>(null);\n\n  const handleAiAssist = useCallback(async () => {\n    if (!data?.columns || !data?.preview_samples) return;\n    setIsAiLoading(true);\n    setAiError(null);\n\n    try {\n      const result = await aiAssistMapping({\n        columns: data.columns,\n        samples: data.preview_samples,\n        datasetName: datasetName,\n        hfToken: hfToken,\n        modelName: selectedModel,\n        modelType: modelType,\n      });\n\n      if (result.success && result.suggested_mapping) {\n        // Remap from chatml roles (user/assistant/system) to format-specific roles\n        const table = ROLE_REMAP[datasetFormat];\n        const mapped: Record<string, string> = {};\n        for (const [col, role] of Object.entries(result.suggested_mapping)) {\n          mapped[col] = table ? (table[role] ?? role) : role;\n        }\n        setManualMapping(mapped);\n\n        // Store conversion advisor fields (system prompt, label mapping, notification)\n        if (result.system_prompt || result.label_mapping || result.user_notification) {\n          setDatasetAdvisorFields({\n            systemPrompt: result.system_prompt ?? undefined,\n            labelMapping: result.label_mapping ?? undefined,\n            notification: result.user_notification ?? null,\n          });\n        }\n      } else {\n        setAiError(result.warning || \"AI could not determine column roles.\");\n      }\n    } catch (err) {\n      setAiError(err instanceof Error ? err.message : \"AI assist failed.\");\n    } finally {\n      setIsAiLoading(false);\n    }\n  }, [data, datasetFormat, datasetName, hfToken, setManualMapping, setDatasetAdvisorFields, selectedModel, modelType]);\n\n  // When format changes, remap existing mapping roles to the new format's role names\n  const prevFormatRef = useRef(datasetFormat);\n  useEffect(() => {\n    const prev = prevFormatRef.current;\n    prevFormatRef.current = datasetFormat;\n    if (prev === datasetFormat) return;\n    if (Object.keys(manualMapping).length === 0) return;\n    setManualMapping(remapRolesForFormat(manualMapping, datasetFormat));\n  }, [datasetFormat]); // eslint-disable-line react-hooks/exhaustive-deps\n\n  // Handle role change for a column\n  const handleRoleChange = useCallback(\n    (colName: string, role: string | undefined) => {\n      const next = { ...manualMapping };\n      // Remove this column's previous role\n      delete next[colName];\n      if (role) {\n        // Remove any other column that had this role (each role can only be assigned once)\n        for (const [col, r] of Object.entries(next)) {\n          if (r === role) delete next[col];\n        }\n        next[colName] = role;\n      }\n      setManualMapping(next);\n    },\n    [manualMapping, setManualMapping],\n  );\n\n  useEffect(() => {\n    if (!open || !datasetName) {\n      setData(null);\n      setError(null);\n      return;\n    }\n\n    if (initialData) {\n      setData(initialData);\n      setError(null);\n      setLoading(false);\n      return;\n    }\n\n    let cancelled = false;\n    setLoading(true);\n    setError(null);\n\n    checkDatasetFormat({\n      datasetName,\n      hfToken,\n      subset: datasetSubset,\n      split: datasetSplit,\n      isVlm,\n    })\n      .then((res) => {\n        if (!cancelled) {\n          setData(res);\n          setError(null);\n        }\n      })\n      .catch((err) => {\n        if (!cancelled) setError(err.message || \"Failed to load preview\");\n      })\n      .finally(() => {\n        if (!cancelled) setLoading(false);\n      });\n\n    return () => {\n      cancelled = true;\n    };\n  }, [open, datasetName, hfToken, datasetSubset, datasetSplit, isVlm, initialData]);\n\n  // Pre-fill mapping from suggested_mapping when data arrives\n  useEffect(() => {\n    if (!open || !datasetName) return;\n    if (!data?.requires_manual_mapping && !data?.suggested_mapping) return;\n    // Don't overwrite if mapping already has entries\n    if (Object.keys(manualMapping).length > 0) return;\n    const derived = deriveDefaultMapping(data, effectiveIsVlm, datasetFormat, effectiveIsAudio);\n    if (Object.keys(derived).length === 0) return;\n    setManualMapping(derived);\n  }, [open, datasetName, data, effectiveIsVlm, datasetFormat, effectiveIsAudio, manualMapping, setManualMapping]);\n\n  const rows = data?.preview_samples ?? [];\n  const columns = data?.columns ?? [];\n\n  // Determine source label\n  const sourceLabel = useMemo(() => {\n    if (!datasetName) return \"\";\n    if (datasetSource === \"huggingface\") {\n      let label = `Hugging Face (${datasetName}`;\n      if (datasetSubset) label += ` / ${datasetSubset}`;\n      if (datasetSplit) label += ` / ${datasetSplit}`;\n      label += \")\";\n      return label;\n    }\n    return `Local Files (${datasetName})`;\n  }, [datasetName, datasetSource, datasetSubset, datasetSplit]);\n\n  // Build TanStack Table columns from the column names\n  const tableColumns = useMemo<ColumnDef<Record<string, unknown>>[]>(() => {\n    if (!columns.length) return [];\n\n    const dataCols: ColumnDef<Record<string, unknown>>[] = columns.map((colName) => ({\n      accessorKey: colName,\n      header: () => (\n        <div className=\"flex flex-col gap-2\">\n          <span className=\"font-heading text-[13px] font-semibold tracking-tight text-foreground\">\n            {colName}\n          </span>\n          {mappingEnabled && (\n            <HeaderRolePicker\n              currentRole={manualMapping[colName]}\n              onRoleChange={(role) => handleRoleChange(colName, role)}\n              availableRoles={availableRoles}\n            />\n          )}\n        </div>\n      ),\n      cell: ({ getValue }: { getValue: () => unknown }) => {\n        const value = getValue();\n        const images = collectPreviewImages(value);\n        if (images.length > 0) {\n          return (\n            <div className=\"flex flex-wrap gap-2\">\n              {images.slice(0, 4).map((image, index) => {\n                const mime = image.mime || \"image/jpeg\";\n                const src = image.data ? `data:${mime};base64,${image.data}` : \"\";\n                const width = image.width ?? 128;\n                const height = image.height ?? 128;\n                return (\n                  <img\n                    key={`${colName}-img-${index}`}\n                    src={src}\n                    alt={`preview-${index}`}\n                    className=\"h-16 w-auto max-w-40 rounded-md border object-contain bg-muted\"\n                    width={width}\n                    height={height}\n                    loading=\"lazy\"\n                  />\n                );\n              })}\n              {images.length > 4 && (\n                <span className=\"text-xs text-muted-foreground self-end\">\n                  +{images.length - 4} more\n                </span>\n              )}\n            </div>\n          );\n        }\n\n        const text = formatCell(value);\n        if (!text) {\n          return (\n            <span className=\"text-muted-foreground/40 italic text-[13px]\">\n              --\n            </span>\n          );\n        }\n        const full = typeof value === \"string\" ? value : JSON.stringify(value);\n        return (\n          <p\n            className=\"text-[13px] leading-relaxed line-clamp-6\"\n            title={full}\n          >\n            {text}\n          </p>\n        );\n      },\n    }));\n\n    // Prepend generated system prompt column when advisor is active\n    if (datasetSystemPrompt) {\n      dataCols.unshift({\n        id: \"__system_generated\",\n        header: () => (\n          <div className=\"flex flex-col gap-2\">\n            <span className=\"font-heading text-[13px] font-semibold tracking-tight text-foreground\">\n              System <span className=\"text-muted-foreground font-normal\">(generated)</span>\n            </span>\n            {mappingEnabled && (\n              <Badge variant=\"outline\" className=\"h-6 w-fit text-[10px] px-2 py-0 border-dashed text-muted-foreground\">\n                System\n              </Badge>\n            )}\n          </div>\n        ),\n        cell: () => (\n          <p\n            className=\"text-[13px] leading-relaxed line-clamp-6 text-muted-foreground italic\"\n            title={datasetSystemPrompt}\n          >\n            {datasetSystemPrompt}\n          </p>\n        ),\n      });\n    }\n\n    return dataCols;\n  }, [\n    columns,\n    manualMapping,\n    handleRoleChange,\n    mappingEnabled,\n    availableRoles,\n    datasetSystemPrompt,\n  ]);\n\n  return (\n    <Dialog open={open} onOpenChange={onOpenChange}>\n      <DialogContent\n        className=\"sm:max-w-5xl w-[90vw] max-h-[88vh] flex flex-col gap-0 p-0 overflow-hidden rounded-3xl corner-squircle\"\n        showCloseButton={true}\n      >\n        {/* Header */}\n        <DialogHeader className=\"px-6 pt-5 pb-4 shrink-0\">\n          <div className=\"flex items-center gap-3 pr-10\">\n            <div className=\"rounded-xl corner-squircle p-2 ring-1 ring-indigo-200 bg-indigo-50 text-indigo-600 dark:ring-indigo-800 dark:bg-indigo-950 dark:text-indigo-400 shrink-0\">\n              <HugeiconsIcon icon={Database02Icon} className=\"size-4\" />\n            </div>\n            <DialogTitle className=\"font-heading text-lg font-semibold tracking-tight\">\n              Dataset Preview\n            </DialogTitle>\n          </div>\n        </DialogHeader>\n\n        {/* Body */}\n        <div className=\"flex flex-col min-h-0 flex-1 overflow-auto px-6 pb-6\">\n          {/* Loading */}\n          {loading && (\n            <div className=\"py-24 flex flex-col items-center justify-center gap-3\">\n              <div className=\"rounded-2xl corner-squircle bg-primary/5 p-4\">\n                <Spinner className=\"size-5 text-primary\" />\n              </div>\n              <p className=\"text-sm text-muted-foreground font-medium\">\n                {isHfDataset ? \"Fetching dataset preview from Hugging Face...\" : \"Loading preview...\"}\n              </p>\n              {isHfDataset && (\n                <p className=\"text-xs text-muted-foreground/60\">\n                  This may take a moment for large datasets\n                </p>\n              )}\n            </div>\n          )}\n\n          {/* Error */}\n          {error && (\n            <div className=\"py-20 flex flex-col items-center justify-center gap-3\">\n              <div className=\"rounded-2xl corner-squircle bg-destructive/10 p-3\">\n                <HugeiconsIcon\n                  icon={AlertCircleIcon}\n                  className=\"size-5 text-destructive\"\n                />\n              </div>\n              <div className=\"text-center space-y-1\">\n                <p className=\"text-sm font-medium text-destructive\">{error}</p>\n                <p className=\"text-xs text-muted-foreground\">\n                  Make sure the backend is running and reachable.\n                </p>\n              </div>\n            </div>\n          )}\n\n          {/* Content */}\n          {!loading && !error && data && (\n            <>\n              {/* Metadata card */}\n              <div className=\"rounded-xl corner-squircle ring-1 ring-border/60 bg-muted/30 px-5 py-4 mb-4 space-y-2\">\n                <MetaRow label=\"Source\" value={sourceLabel} />\n                <MetaRow\n                  label=\"Format\"\n                  value={data.detected_format || \"--\"}\n                />\n                <MetaRow\n                  label=\"Total Rows\"\n                  value={\n                    data.total_rows != null\n                      ? data.total_rows.toLocaleString()\n                      : \"--\"\n                  }\n                />\n                <MetaRow\n                  label=\"Columns\"\n                  value={\n                    <span className=\"flex items-center gap-1.5 flex-wrap\">\n                      {columns.map((col) => (\n                        <Badge\n                          key={col}\n                          variant=\"outline\"\n                          className=\"text-[11px] font-mono h-5\"\n                        >\n                          {col}\n                        </Badge>\n                      ))}\n                    </span>\n                  }\n                />\n              </div>\n\n              {data.warning && (\n                <div className=\"rounded-lg border border-amber-200 bg-amber-50 px-4 py-3 text-xs text-amber-700 dark:border-amber-800 dark:bg-amber-950 dark:text-amber-400 mb-4 flex items-start gap-2.5\">\n                  <HugeiconsIcon icon={AlertCircleIcon} className=\"size-4 shrink-0 mt-0.5\" />\n                  <span>{data.warning}</span>\n                </div>\n              )}\n\n              {mappingEnabled && (\n                <DatasetMappingCard\n                  mapping={manualMapping}\n                  mappingOk={mappingOk}\n                  autoDetected={hasHeuristicMapping}\n                  isVlm={effectiveIsVlm}\n                  isAudio={effectiveIsAudio}\n                  format={datasetFormat}\n                  onAiAssist={handleAiAssist}\n                  isAiLoading={isAiLoading}\n                  aiError={aiError}\n                  advisorNotification={datasetAdvisorNotification}\n                  advisorSystemPrompt={datasetSystemPrompt || undefined}\n                />\n              )}\n\n              {/* Data table */}\n              <div className=\"flex-1 min-h-[250px] rounded-xl corner-squircle ring-1 ring-border/60 overflow-auto\">\n                <DataTable columns={tableColumns} data={rows} />\n              </div>\n\n              {/* Footer */}\n              <div className=\"mt-3\">\n                <p className=\"text-[11px] text-muted-foreground/60 text-center tabular-nums\">\n                  Showing {rows.length}\n                  {data.total_rows != null &&\n                    ` of ${data.total_rows.toLocaleString()}`}{\" \"}\n                  rows\n                </p>\n\n                {mode === \"preview\" && mappingEnabled && (\n                  <p className=\"mt-2 text-[11px] text-muted-foreground/70 text-center\">\n                    Mapping is saved automatically. You can start training anytime.\n                  </p>\n                )}\n\n                {showMappingFooter && (\n                  <DatasetMappingFooter\n                    mappingOk={mappingOk}\n                    isStarting={isStarting}\n                    startError={startError}\n                    onCancel={() => onOpenChange(false)}\n                    onStartTraining={async () => {\n                      const ok = await startTrainingRun();\n                      if (ok) onOpenChange(false);\n                    }}\n                  />\n                )}\n              </div>\n            </>\n          )}\n        </div>\n      </DialogContent>\n    </Dialog>\n  );\n}\n\n// ---------------------------------------------------------------------------\n// Metadata row\n// ---------------------------------------------------------------------------\n\nfunction MetaRow({\n  label,\n  value,\n}: {\n  label: string;\n  value: ReactNode;\n}) {\n  return (\n    <div className=\"flex items-baseline gap-3 text-sm\">\n      <span className=\"text-muted-foreground font-medium text-xs w-24 shrink-0\">\n        {label}:\n      </span>\n      <span className=\"text-foreground text-[13px] min-w-0\">{value}</span>\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/studio/sections/dataset-section.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { SectionCard } from \"@/components/section-card\";\nimport { Button } from \"@/components/ui/button\";\nimport {\n  Collapsible,\n  CollapsibleContent,\n  CollapsibleTrigger,\n} from \"@/components/ui/collapsible\";\nimport {\n  Combobox,\n  ComboboxContent,\n  ComboboxEmpty,\n  ComboboxInput,\n  ComboboxItem,\n  ComboboxList,\n} from \"@/components/ui/combobox\";\nimport { Input } from \"@/components/ui/input\";\nimport { InputGroupAddon } from \"@/components/ui/input-group\";\nimport {\n  Select,\n  SelectContent,\n  SelectItem,\n  SelectTrigger,\n  SelectValue,\n} from \"@/components/ui/select\";\nimport { Spinner } from \"@/components/ui/spinner\";\nimport { Tabs, TabsContent, TabsList, TabsTrigger } from \"@/components/ui/tabs\";\nimport {\n  Tooltip,\n  TooltipContent,\n  TooltipTrigger,\n} from \"@/components/ui/tooltip\";\nimport {\n  useDebouncedValue,\n  useHfDatasetSearch,\n  useHfTokenValidation,\n  useInfiniteScroll,\n} from \"@/hooks\";\nimport {\n  HfDatasetSubsetSplitSelectors,\n  uploadTrainingDataset,\n  useDatasetPreviewDialogStore,\n  useTrainingConfigStore,\n} from \"@/features/training\";\nimport { listLocalDatasets } from \"@/features/training/api/datasets-api\";\nimport type { LocalDatasetInfo } from \"@/features/training/types/datasets\";\nimport { useNavigate } from \"@tanstack/react-router\";\nimport {\n  ArrowDown01Icon,\n  Cancel01Icon,\n  CloudUploadIcon,\n  Database02Icon,\n  FileAttachmentIcon,\n  InformationCircleIcon,\n  Search01Icon,\n  ViewIcon,\n} from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport { type ChangeEvent, useCallback, useEffect, useMemo, useRef, useState } from \"react\";\nimport { toast } from \"sonner\";\nimport { useShallow } from \"zustand/react/shallow\";\nimport { DocumentUploadRedirectDialog } from \"./document-upload-redirect-dialog\";\n\nconst DOCUMENT_REDIRECT_EXTENSIONS = new Set([\".pdf\", \".docx\", \".txt\"]);\n\nconst SEARCH_INPUT_REASONS = new Set([\"input-change\", \"input-paste\", \"input-clear\"]);\nconst OPEN_LEARNING_RECIPES_ON_ARRIVAL_KEY =\n  \"data-recipes:open-learning-recipes\";\n\nfunction isLikelyLocalDatasetRef(value: string) {\n  return (\n    value.startsWith(\"/\") ||\n    value.startsWith(\"./\") ||\n    value.startsWith(\"../\") ||\n    value.includes(\"\\\\\") ||\n    /\\.(jsonl|json|csv|parquet)$/i.test(value)\n  );\n}\n\nfunction deriveLocalDatasetName(path: string): string {\n  const normalized = path.replaceAll(\"\\\\\", \"/\");\n  const parts = normalized.split(\"/\").filter(Boolean);\n  const parquetIndex = parts.lastIndexOf(\"parquet-files\");\n  if (parquetIndex > 0) return parts[parquetIndex - 1];\n  const basename = parts[parts.length - 1] ?? path;\n  // Strip UUID prefix from uploaded files (format: {32hex}_{original})\n  const uuidPrefixMatch = basename.match(/^[a-f0-9]{32}_(.+)$/);\n  if (uuidPrefixMatch) return uuidPrefixMatch[1];\n  return basename;\n}\n\nfunction formatUpdatedDate(timestamp: number | null): string {\n  if (typeof timestamp !== \"number\") return \"--\";\n  return new Date(timestamp * 1000).toLocaleDateString();\n}\n\nfunction normalizeSliceInput(value: string): string | null {\n  const trimmed = value.trim();\n  if (!trimmed) return null;\n  if (!/^\\d+$/.test(trimmed)) return null;\n  return trimmed;\n}\n\nexport function DatasetSection() {\n  const navigate = useNavigate();\n  const {\n    dataset,\n    datasetSource,\n    selectHfDataset,\n    selectLocalDataset,\n    datasetFormat,\n    setDatasetFormat,\n    datasetSubset,\n    setDatasetSubset,\n    datasetSplit,\n    setDatasetSplit,\n    datasetEvalSplit,\n    setDatasetEvalSplit,\n    uploadedFile,\n    uploadedEvalFile,\n    setUploadedEvalFile,\n    hfToken,\n    modelType,\n    datasetSliceStart,\n    setDatasetSliceStart,\n    datasetSliceEnd,\n    setDatasetSliceEnd,\n  } = useTrainingConfigStore(\n    useShallow((s) => ({\n      dataset: s.dataset,\n      datasetSource: s.datasetSource,\n      selectHfDataset: s.selectHfDataset,\n      selectLocalDataset: s.selectLocalDataset,\n      datasetFormat: s.datasetFormat,\n      setDatasetFormat: s.setDatasetFormat,\n      datasetSubset: s.datasetSubset,\n      setDatasetSubset: s.setDatasetSubset,\n      datasetSplit: s.datasetSplit,\n      setDatasetSplit: s.setDatasetSplit,\n      datasetEvalSplit: s.datasetEvalSplit,\n      setDatasetEvalSplit: s.setDatasetEvalSplit,\n      uploadedFile: s.uploadedFile,\n      uploadedEvalFile: s.uploadedEvalFile,\n      setUploadedEvalFile: s.setUploadedEvalFile,\n      hfToken: s.hfToken,\n      modelType: s.modelType,\n      datasetSliceStart: s.datasetSliceStart,\n      setDatasetSliceStart: s.setDatasetSliceStart,\n      datasetSliceEnd: s.datasetSliceEnd,\n      setDatasetSliceEnd: s.setDatasetSliceEnd,\n    })),\n  );\n\n  const [searchQuery, setSearchQuery] = useState(\"\");\n  const [advancedOpen, setAdvancedOpen] = useState(false);\n  const [pickerTab, setPickerTab] = useState<\"huggingface\" | \"local\">(\n    datasetSource === \"upload\" ? \"local\" : \"huggingface\",\n  );\n  const [localDatasets, setLocalDatasets] = useState<LocalDatasetInfo[]>([]);\n  const [hasLoadedLocalDatasets, setHasLoadedLocalDatasets] = useState(false);\n  const [localLoading, setLocalLoading] = useState(false);\n  const [localError, setLocalError] = useState<string | null>(null);\n  const openPreview = useDatasetPreviewDialogStore((s) => s.openPreview);\n  const selectingRef = useRef(false);\n  const pendingSourceTabRef = useRef<\"huggingface\" | \"local\" | null>(null);\n  const debouncedQuery = useDebouncedValue(searchQuery);\n\n  useEffect(() => {\n    setPickerTab(datasetSource === \"upload\" ? \"local\" : \"huggingface\");\n  }, [datasetSource]);\n\n  const refreshLocalDatasets = useCallback(async () => {\n    setLocalLoading(true);\n    setLocalError(null);\n    try {\n      const response = await listLocalDatasets();\n      setLocalDatasets(response.datasets ?? []);\n    } catch (error) {\n      setLocalError(\n        error instanceof Error ? error.message : \"Failed to load local datasets.\",\n      );\n    } finally {\n      setHasLoadedLocalDatasets(true);\n      setLocalLoading(false);\n    }\n  }, []);\n\n  useEffect(() => {\n    if (pickerTab !== \"local\") return;\n    void refreshLocalDatasets();\n  }, [pickerTab, refreshLocalDatasets]);\n\n  useEffect(() => {\n    const handleRefresh = () => {\n      if (document.hidden) return;\n      if (pickerTab !== \"local\" && datasetSource !== \"upload\") return;\n      void refreshLocalDatasets();\n    };\n\n    window.addEventListener(\"focus\", handleRefresh);\n    document.addEventListener(\"visibilitychange\", handleRefresh);\n    return () => {\n      window.removeEventListener(\"focus\", handleRefresh);\n      document.removeEventListener(\"visibilitychange\", handleRefresh);\n    };\n  }, [datasetSource, pickerTab, refreshLocalDatasets]);\n\n  function handleDatasetSelect(id: string | null) {\n    selectingRef.current = true;\n    pendingSourceTabRef.current = \"huggingface\";\n    selectHfDataset(id);\n  }\n\n  function handleLocalDatasetSelect(path: string) {\n    selectingRef.current = true;\n    pendingSourceTabRef.current = \"local\";\n    selectLocalDataset(path);\n  }\n\n  function clearSelectionForTab(tab: \"huggingface\" | \"local\") {\n    pendingSourceTabRef.current = tab;\n    if (tab === \"huggingface\") {\n      handleDatasetSelect(null);\n      return;\n    }\n    selectingRef.current = true;\n    selectLocalDataset(null);\n  }\n\n  function handleInputChange(\n    val: string,\n    eventDetails?: {\n      reason?: string;\n    },\n  ) {\n    if (selectingRef.current) {\n      selectingRef.current = false;\n      return;\n    }\n    if (!SEARCH_INPUT_REASONS.has(eventDetails?.reason ?? \"\")) {\n      return;\n    }\n    setSearchQuery(val);\n  }\n\n  const effectiveModelType = modelType ?? \"text\";\n\n  const {\n    results: hfResults,\n    isLoading,\n    isLoadingMore,\n    fetchMore,\n    error: hfSearchError,\n  } = useHfDatasetSearch(pickerTab === \"huggingface\" ? debouncedQuery : \"\", {\n    modelType: effectiveModelType,\n    accessToken: hfToken || undefined,\n    enabled: pickerTab === \"huggingface\",\n  });\n\n  const { error: tokenValidationError, isChecking: isCheckingToken } =\n    useHfTokenValidation(hfToken);\n\n  const hfResultIds = useMemo(() => {\n    const ids = hfResults.map((r) => r.id);\n    if (dataset && !ids.includes(dataset)) {\n      ids.push(dataset);\n    }\n    return ids;\n  }, [hfResults, dataset]);\n\n  const localFilteredDatasets = useMemo(() => {\n    const query = searchQuery.trim().toLowerCase();\n    if (!query) return localDatasets;\n    return localDatasets.filter(\n      (item) =>\n        item.label.toLowerCase().includes(query) ||\n        item.path.toLowerCase().includes(query),\n    );\n  }, [localDatasets, searchQuery]);\n\n  const localPathById = useMemo(() => {\n    return new Map(localDatasets.map((item) => [item.id, item.path]));\n  }, [localDatasets]);\n\n  const localLabelById = useMemo(() => {\n    return new Map(localDatasets.map((item) => [item.id, item.label]));\n  }, [localDatasets]);\n\n  const selectedLocalDataset = useMemo(() => {\n    if (!uploadedFile) return null;\n    return localDatasets.find((item) => item.path === uploadedFile) ?? null;\n  }, [localDatasets, uploadedFile]);\n\n  const selectedLocalId = selectedLocalDataset?.id ?? null;\n\n  const localResultIds = useMemo(() => {\n    const ids = localFilteredDatasets.map((item) => item.id);\n    if (selectedLocalDataset && selectedLocalId && !ids.includes(selectedLocalId)) {\n      ids.push(selectedLocalId);\n    }\n    return ids;\n  }, [localFilteredDatasets, selectedLocalDataset, selectedLocalId]);\n\n  useEffect(() => {\n    if (!hasLoadedLocalDatasets) return;\n    if (localLoading) return;\n    if (localError) return;\n    if (datasetSource !== \"upload\") return;\n    if (!uploadedFile) return;\n    if (selectedLocalDataset) return;\n    // Don't clear if this is a direct file upload (e.g. user uploaded a .jsonl/.csv)\n    if (/\\.(jsonl|json|csv|parquet|arrow)$/i.test(uploadedFile)) return;\n    selectLocalDataset(null);\n  }, [\n    datasetSource,\n    hasLoadedLocalDatasets,\n    localError,\n    localLoading,\n    uploadedFile,\n    selectedLocalDataset,\n    selectLocalDataset,\n  ]);\n\n  const activeSourceTab = datasetSource === \"upload\" ? \"local\" : \"huggingface\";\n  const comboboxItems = pickerTab === \"huggingface\" ? hfResultIds : localResultIds;\n  const comboboxValue =\n    pickerTab === \"huggingface\"\n      ? datasetSource === \"huggingface\"\n        ? dataset\n        : null\n      : datasetSource === \"upload\"\n        ? selectedLocalId\n        : null;\n  const isHfDatasetSelected =\n    datasetSource === \"huggingface\" &&\n    !!dataset &&\n    !isLikelyLocalDatasetRef(dataset);\n\n  const selectedDatasetName = datasetSource === \"upload\" ? uploadedFile : dataset;\n  const selectedLocalMetadata = selectedLocalDataset?.metadata ?? null;\n  const selectedLocalColumns = selectedLocalMetadata?.columns ?? [];\n  const selectedLocalRows =\n    selectedLocalDataset?.rows ?? selectedLocalMetadata?.actual_num_records ?? null;\n  const selectedLocalUpdatedAt = selectedLocalDataset?.updated_at ?? null;\n\n  const comboboxAnchorRef = useRef<HTMLDivElement>(null);\n  const fileInputRef = useRef<HTMLInputElement>(null);\n  const evalFileInputRef = useRef<HTMLInputElement>(null);\n  const { scrollRef, sentinelRef } = useInfiniteScroll(\n    fetchMore,\n    hfResults.length,\n  );\n\n  const [isUploading, setIsUploading] = useState(false);\n  const [documentRedirectOpen, setDocumentRedirectOpen] = useState(false);\n  const [redirectFileName, setRedirectFileName] = useState<string | null>(null);\n\n  const handleUploadButtonClick = () => {\n    fileInputRef.current?.click();\n  };\n\n  const handleFileUpload = async (\n    file: File,\n    onSuccess: (storedPath: string) => void,\n    successMessage: string,\n  ) => {\n    setIsUploading(true);\n    try {\n      const uploaded = await uploadTrainingDataset(file);\n      onSuccess(uploaded.stored_path);\n      toast.success(successMessage, { description: uploaded.filename });\n    } catch (error) {\n      toast.error(\"Upload failed\", {\n        description: error instanceof Error ? error.message : \"Unknown error\",\n      });\n    } finally {\n      setIsUploading(false);\n    }\n  };\n\n  const handleDatasetFileChange = async (event: ChangeEvent<HTMLInputElement>) => {\n    const file = event.target.files?.[0];\n    event.target.value = \"\";\n    if (!file) return;\n\n    const extension = file.name.slice(file.name.lastIndexOf(\".\")).toLowerCase();\n    if (DOCUMENT_REDIRECT_EXTENSIONS.has(extension)) {\n      setRedirectFileName(file.name);\n      setDocumentRedirectOpen(true);\n      return;\n    }\n\n    await handleFileUpload(file, selectLocalDataset, \"Dataset uploaded\");\n  };\n\n  const handleEvalFileChange = async (event: ChangeEvent<HTMLInputElement>) => {\n    const file = event.target.files?.[0];\n    event.target.value = \"\";\n    if (!file) return;\n\n    await handleFileUpload(file, setUploadedEvalFile, \"Eval dataset uploaded\");\n  };\n\n  const handleOpenLearningRecipes = useCallback(() => {\n    sessionStorage.setItem(OPEN_LEARNING_RECIPES_ON_ARRIVAL_KEY, \"1\");\n    setDocumentRedirectOpen(false);\n    void navigate({ to: \"/data-recipes\" });\n  }, [navigate]);\n\n  return (\n    <div data-tour=\"studio-dataset\" className=\"col-span-1 xl:col-span-4\">\n      <SectionCard\n        icon={<HugeiconsIcon icon={Database02Icon} className=\"size-5\" />}\n        title=\"Dataset\"\n        description=\"Select or upload training data\"\n        accent=\"indigo\"\n        className=\"dark:shadow-border\"\n      >\n        <div className=\"flex flex-col gap-4\">\n          <div className=\"flex flex-col gap-2\">\n            <span className=\"flex items-center gap-1.5 text-xs font-medium text-muted-foreground\">\n              Choose dataset\n              <span className=\"rounded-full border border-border/70 bg-muted/40 px-2 py-0.5 text-[10px] font-medium text-foreground/80\">\n                {datasetSource === \"upload\" ? \"Local\" : \"Hugging Face\"}\n              </span>\n              <Tooltip>\n                <TooltipTrigger asChild={true}>\n                  <button\n                    type=\"button\"\n                    className=\"text-foreground/70 hover:text-foreground\"\n                  >\n                    <HugeiconsIcon\n                      icon={InformationCircleIcon}\n                      className=\"size-3\"\n                    />\n                  </button>\n                </TooltipTrigger>\n                <TooltipContent>\n                  Use the popup tabs to switch between Hugging Face and local\n                  recipe outputs.{\" \"}\n                  <a\n                    href=\"https://unsloth.ai/docs/get-started/fine-tuning-llms-guide/datasets-guide\"\n                    target=\"_blank\"\n                    rel=\"noopener noreferrer\"\n                    className=\"text-primary underline\"\n                  >\n                    Read more\n                  </a>\n                </TooltipContent>\n              </Tooltip>\n            </span>\n            <div\n              ref={comboboxAnchorRef}\n              onKeyDown={(event) => {\n                if (event.key !== \"Enter\") return;\n                if (!(event.target instanceof HTMLInputElement)) return;\n                event.preventDefault();\n                if (pickerTab === \"huggingface\") {\n                  if (hfResults.length > 0) {\n                    handleDatasetSelect(hfResults[0].id);\n                  } else {\n                    const text = event.target.value.trim();\n                    if (text) handleDatasetSelect(text);\n                  }\n                  return;\n                }\n\n                if (localResultIds.length > 0) {\n                  const selectedId = localResultIds[0];\n                  const path = localPathById.get(selectedId);\n                  if (path) {\n                    handleLocalDatasetSelect(path);\n                  }\n                }\n              }}\n            >\n              <Combobox\n                items={comboboxItems}\n                filteredItems={comboboxItems}\n                filter={null}\n                value={comboboxValue}\n                onOpenChange={(open) => {\n                  setSearchQuery(\"\");\n                  if (open && (pickerTab === \"local\" || activeSourceTab === \"local\")) {\n                    void refreshLocalDatasets();\n                  }\n                  if (!open) {\n                    setPickerTab(pendingSourceTabRef.current ?? activeSourceTab);\n                    pendingSourceTabRef.current = null;\n                  }\n                }}\n                onValueChange={(value) => {\n                  if (!value) {\n                    clearSelectionForTab(pickerTab);\n                    return;\n                  }\n                  if (pickerTab === \"huggingface\") {\n                    handleDatasetSelect(value);\n                    return;\n                  }\n                  const path = localPathById.get(value);\n                  if (path) {\n                    handleLocalDatasetSelect(path);\n                  }\n                }}\n                onInputValueChange={(value, eventDetails) =>\n                  handleInputChange(value, eventDetails)\n                }\n                itemToStringValue={(id) =>\n                  pickerTab === \"local\"\n                    ? localLabelById.get(id) ?? id\n                    : id\n                }\n                autoHighlight={true}\n              >\n                <ComboboxInput\n                  placeholder={\n                    pickerTab === \"huggingface\"\n                      ? \"Search Hugging Face datasets...\"\n                      : \"Search local datasets...\"\n                  }\n                  className=\"w-full\"\n                  showClear={true}\n                >\n                  <InputGroupAddon>\n                    <HugeiconsIcon icon={Search01Icon} className=\"size-4\" />\n                  </InputGroupAddon>\n                </ComboboxInput>\n                <ComboboxContent anchor={comboboxAnchorRef}>\n                  <div className=\"px-2 pt-2 pb-2\">\n                    <Tabs\n                      value={pickerTab}\n                      onValueChange={(value) => {\n                        setPickerTab(value as \"huggingface\" | \"local\");\n                        setSearchQuery(\"\");\n                      }}\n                      className=\"w-full\"\n                    >\n                      <TabsList className=\" w-full\">\n                        <TabsTrigger value=\"huggingface\">Hugging Face</TabsTrigger>\n                        <TabsTrigger value=\"local\">Local</TabsTrigger>\n                      </TabsList>\n\n                      <TabsContent value=\"huggingface\" className=\"m-0\">\n                        {isLoading ? (\n                          <div className=\"flex items-center justify-center py-4 gap-2 text-xs text-muted-foreground\">\n                            <Spinner className=\"size-4\" /> Searching...\n                          </div>\n                        ) : (\n                          <ComboboxEmpty>No datasets found</ComboboxEmpty>\n                        )}\n                        <div\n                          ref={scrollRef}\n                          className=\"max-h-64 overflow-y-auto overscroll-contain [scrollbar-width:thin]\"\n                        >\n                          <ComboboxList className=\"p-1 !max-h-none !overflow-visible\">\n                            {(id: string) => {\n                              return (\n                                <ComboboxItem key={id} value={id} className=\"gap-2\">\n                                  <Tooltip>\n                                    <TooltipTrigger asChild={true}>\n                                      <span className=\"block min-w-0 flex-1 truncate\">\n                                        {id}\n                                      </span>\n                                    </TooltipTrigger>\n                                    <TooltipContent\n                                      side=\"left\"\n                                      className=\"max-w-xs break-all\"\n                                    >\n                                      {id}\n                                    </TooltipContent>\n                                  </Tooltip>\n                                </ComboboxItem>\n                              );\n                            }}\n                          </ComboboxList>\n                          <div ref={sentinelRef} className=\"h-px\" />\n                          {isLoadingMore && (\n                            <div className=\"flex items-center justify-center py-2\">\n                              <Spinner className=\"size-3.5 text-muted-foreground\" />\n                            </div>\n                          )}\n                        </div>\n                      </TabsContent>\n\n                      <TabsContent value=\"local\" className=\"m-0\">\n                        {localLoading ? (\n                          <div className=\"flex items-center justify-center py-4 gap-2 text-xs text-muted-foreground\">\n                            <Spinner className=\"size-4\" /> Loading local datasets...\n                          </div>\n                        ) : (\n                          <>\n                            {localError ? (\n                              <p className=\"px-2 py-2 text-xs text-destructive\">{localError}</p>\n                            ) : (\n                              <ComboboxEmpty className=\"px-2 py-3\">\n                                <div className=\"flex w-full flex-col items-center gap-2 text-center\">\n                                  <p className=\"text-xs text-muted-foreground\">\n                                    {localDatasets.length === 0\n                                      ? \"No local datasets yet.\"\n                                      : \"No local datasets match search.\"}\n                                  </p>\n                                  {localDatasets.length === 0 ? (\n                                    <Button asChild={true} size=\"sm\" variant=\"outline\">\n                                      <a href=\"/data-recipes\">Open Data Recipes</a>\n                                    </Button>\n                                  ) : null}\n                                </div>\n                              </ComboboxEmpty>\n                            )}\n                            <div className=\"max-h-64 overflow-y-auto overscroll-contain [scrollbar-width:thin]\">\n                              <ComboboxList className=\"p-1 !max-h-none !overflow-visible\">\n                                {(id: string) => {\n                                  const label = localLabelById.get(id) ?? id;\n                                  return (\n                                    <ComboboxItem key={id} value={id} className=\"gap-2\">\n                                      <Tooltip>\n                                        <TooltipTrigger asChild={true}>\n                                          <span className=\"block min-w-0 flex-1 truncate\">\n                                            {label}\n                                          </span>\n                                        </TooltipTrigger>\n                                        <TooltipContent\n                                          side=\"left\"\n                                          className=\"max-w-xs break-all\"\n                                        >\n                                          {label}\n                                        </TooltipContent>\n                                      </Tooltip>\n                                    </ComboboxItem>\n                                  );\n                                }}\n                              </ComboboxList>\n                            </div>\n                          </>\n                        )}\n                      </TabsContent>\n                    </Tabs>\n                  </div>\n                </ComboboxContent>\n              </Combobox>\n            </div>\n            {(tokenValidationError ?? hfSearchError) && (\n              <p className=\"text-xs text-destructive\">\n                {tokenValidationError ?? hfSearchError}\n                {\" — \"}\n                <a\n                  href=\"https://huggingface.co/settings/tokens\"\n                  target=\"_blank\"\n                  rel=\"noopener noreferrer\"\n                  className=\"underline\"\n                >\n                  Get or update token\n                </a>\n              </p>\n            )}\n            {isCheckingToken && (\n              <p className=\"text-xs text-muted-foreground\">Checking token…</p>\n            )}\n            {pickerTab !== activeSourceTab && (\n              <p className=\"text-[11px] text-muted-foreground\">\n                Browsing {pickerTab === \"local\" ? \"Local datasets\" : \"Hugging Face\"}.\n                Current selection stays {datasetSource === \"upload\" ? \"Local\" : \"Hugging Face\"}.\n              </p>\n            )}\n          </div>\n\n          {isHfDatasetSelected ? (\n            <HfDatasetSubsetSplitSelectors\n              variant=\"studio\"\n              enabled={true}\n              datasetName={dataset}\n              accessToken={hfToken || undefined}\n              datasetSubset={datasetSubset}\n              setDatasetSubset={setDatasetSubset}\n              datasetSplit={datasetSplit}\n              setDatasetSplit={setDatasetSplit}\n              datasetEvalSplit={datasetEvalSplit}\n              setDatasetEvalSplit={setDatasetEvalSplit}\n            />\n          ) : !selectedDatasetName ? (\n            <HfDatasetSubsetSplitSelectors\n              variant=\"studio\"\n              enabled={false}\n              datasetName={null}\n              accessToken={hfToken || undefined}\n              datasetSubset={datasetSubset}\n              setDatasetSubset={setDatasetSubset}\n              datasetSplit={datasetSplit}\n              setDatasetSplit={setDatasetSplit}\n              datasetEvalSplit={datasetEvalSplit}\n              setDatasetEvalSplit={setDatasetEvalSplit}\n            />\n          ) : datasetSource === \"upload\" && selectedLocalDataset ? (\n            <div className=\"rounded-lg border bg-muted/20 px-3.5 py-3\">\n              <div className=\"mb-3 flex items-center justify-between gap-3\">\n                <div>\n                  <p className=\"text-xs font-medium text-muted-foreground\">\n                    Local dataset metadata\n                  </p>\n                  <p className=\"text-[10px] text-muted-foreground/80\">\n                    Data Recipe output.\n                  </p>\n                </div>\n              </div>\n\n              <div className=\"flex flex-col gap-3\">\n                <div className=\"grid grid-cols-2 gap-x-4 gap-y-2 text-xs\">\n                  <MetadataRow\n                    label=\"Rows\"\n                    value={\n                      typeof selectedLocalRows === \"number\"\n                        ? selectedLocalRows.toLocaleString()\n                        : \"--\"\n                    }\n                  />\n                  <MetadataRow\n                    label=\"Columns\"\n                    value={\n                      selectedLocalColumns.length > 0\n                        ? String(selectedLocalColumns.length)\n                        : \"--\"\n                    }\n                  />\n                  <MetadataRow\n                    label=\"Batches\"\n                    value={\n                      typeof selectedLocalMetadata?.num_completed_batches === \"number\" &&\n                      typeof selectedLocalMetadata?.total_num_batches === \"number\"\n                        ? `${selectedLocalMetadata.num_completed_batches}/${selectedLocalMetadata.total_num_batches}`\n                        : \"--\"\n                    }\n                  />\n                  <MetadataRow\n                    label=\"Updated\"\n                    value={formatUpdatedDate(selectedLocalUpdatedAt)}\n                  />\n                </div>\n              </div>\n            </div>\n          ) : null}\n\n          {datasetSource === \"upload\" && uploadedFile && (\n            <div className=\"rounded-lg border bg-muted/20 px-3.5 py-3\">\n              <p className=\"mb-2 text-xs font-medium text-muted-foreground\">\n                Eval dataset\n              </p>\n              {uploadedEvalFile ? (\n                <div className=\"flex items-center justify-between gap-2\">\n                  <div className=\"flex items-center gap-1.5 overflow-hidden\">\n                    <HugeiconsIcon icon={FileAttachmentIcon} className=\"size-3.5 shrink-0 text-muted-foreground\" />\n                    <span className=\"truncate text-xs\">\n                      {deriveLocalDatasetName(uploadedEvalFile)}\n                    </span>\n                  </div>\n                  <Button\n                    variant=\"ghost\"\n                    size=\"sm\"\n                    className=\"h-6 w-6 shrink-0 cursor-pointer p-0\"\n                    onClick={() => setUploadedEvalFile(null)}\n                  >\n                    <HugeiconsIcon icon={Cancel01Icon} className=\"size-3.5\" />\n                  </Button>\n                </div>\n              ) : (\n                <div className=\"flex flex-col gap-1.5\">\n                  <Button\n                    variant=\"outline\"\n                    size=\"sm\"\n                    className=\"w-full cursor-pointer gap-1.5\"\n                    disabled={isUploading}\n                    onClick={() => evalFileInputRef.current?.click()}\n                  >\n                    {isUploading ? (\n                      <Spinner className=\"size-3.5\" />\n                    ) : (\n                      <HugeiconsIcon icon={CloudUploadIcon} className=\"size-3.5\" />\n                    )}\n                    {isUploading ? \"Uploading...\" : \"Upload eval file\"}\n                  </Button>\n                  <p className=\"text-[10px] text-muted-foreground/80\">\n                    Optional. If not provided, a small portion will be split from the training data.\n                  </p>\n                </div>\n              )}\n            </div>\n          )}\n\n          <Collapsible open={advancedOpen} onOpenChange={setAdvancedOpen}>\n            <CollapsibleTrigger className=\"flex w-full cursor-pointer items-center gap-1.5 text-xs text-muted-foreground\">\n              <HugeiconsIcon\n                icon={ArrowDown01Icon}\n                className={`size-3.5 transition-transform ${advancedOpen ? \"rotate-180\" : \"\"}`}\n              />\n              Advanced\n            </CollapsibleTrigger>\n            <CollapsibleContent className=\"mt-3 data-[state=open]:overflow-visible\">\n              <div className=\"flex flex-col gap-4\">\n                <div className=\"flex flex-col gap-2\">\n                  <span className=\"flex items-center gap-1.5 text-xs font-medium text-muted-foreground\">\n                    Target Format\n                    <Tooltip>\n                      <TooltipTrigger asChild={true}>\n                        <button\n                          type=\"button\"\n                          className=\"text-foreground/70 hover:text-foreground\"\n                        >\n                          <HugeiconsIcon\n                            icon={InformationCircleIcon}\n                            className=\"size-3\"\n                          />\n                        </button>\n                      </TooltipTrigger>\n                      <TooltipContent>\n                        Format of your training data. Auto-detect works for most\n                        datasets.{\" \"}\n                        <a\n                          href=\"https://unsloth.ai/docs/get-started/fine-tuning-llms-guide/datasets-guide\"\n                          target=\"_blank\"\n                          rel=\"noopener noreferrer\"\n                          className=\"text-primary underline\"\n                        >\n                          Read more\n                        </a>\n                      </TooltipContent>\n                    </Tooltip>\n                  </span>\n                  <Select\n                    value={datasetFormat}\n                    onValueChange={(v) =>\n                      setDatasetFormat(v as typeof datasetFormat)\n                    }\n                  >\n                    <SelectTrigger className=\"w-full\">\n                      <SelectValue />\n                    </SelectTrigger>\n                    <SelectContent>\n                      <SelectItem value=\"auto\">Auto</SelectItem>\n                      <SelectItem value=\"alpaca\">Alpaca</SelectItem>\n                      <SelectItem value=\"chatml\">ChatML</SelectItem>\n                      <SelectItem value=\"sharegpt\">ShareGPT</SelectItem>\n                    </SelectContent>\n                  </Select>\n                </div>\n                <div className=\"grid grid-cols-2 gap-3\">\n                  <div className=\"flex flex-col gap-1.5\">\n                    <span className=\"flex items-center gap-1.5 text-xs font-medium text-muted-foreground\">\n                      Train Split Start\n                      <Tooltip>\n                        <TooltipTrigger asChild={true}>\n                          <button\n                            type=\"button\"\n                            className=\"text-foreground/70 hover:text-foreground\"\n                          >\n                            <HugeiconsIcon\n                              icon={InformationCircleIcon}\n                              className=\"size-3\"\n                            />\n                          </button>\n                        </TooltipTrigger>\n                        <TooltipContent>\n                          Only train on a subset of your training split by\n                          specifying a start row index (inclusive, 0-based).\n                          Leave empty to start from the first row.\n                        </TooltipContent>\n                      </Tooltip>\n                    </span>\n                    <Input\n                      type=\"number\"\n                      inputMode=\"numeric\"\n                      min={0}\n                      step={1}\n                      placeholder=\"0\"\n                      value={datasetSliceStart ?? \"\"}\n                      onChange={(e) =>\n                        setDatasetSliceStart(normalizeSliceInput(e.target.value))\n                      }\n                    />\n                  </div>\n                  <div className=\"flex flex-col gap-1.5\">\n                    <span className=\"flex items-center gap-1.5 text-xs font-medium text-muted-foreground\">\n                      Train Split End\n                      <Tooltip>\n                        <TooltipTrigger asChild={true}>\n                          <button\n                            type=\"button\"\n                            className=\"text-foreground/70 hover:text-foreground\"\n                          >\n                            <HugeiconsIcon\n                              icon={InformationCircleIcon}\n                              className=\"size-3\"\n                            />\n                          </button>\n                        </TooltipTrigger>\n                        <TooltipContent>\n                          Last row index to include from the training split\n                          (inclusive, 0-based). For example, set Start to 0 and\n                          End to 99 to train on the first 100 rows. Leave empty\n                          to use all remaining rows.\n                        </TooltipContent>\n                      </Tooltip>\n                    </span>\n                    <Input\n                      type=\"number\"\n                      inputMode=\"numeric\"\n                      min={0}\n                      step={1}\n                      placeholder=\"End\"\n                      value={datasetSliceEnd ?? \"\"}\n                      onChange={(e) =>\n                        setDatasetSliceEnd(normalizeSliceInput(e.target.value))\n                      }\n                    />\n                  </div>\n                </div>\n              </div>\n            </CollapsibleContent>\n          </Collapsible>\n\n          <div className=\"flex flex-col gap-4 pt-1\">\n            {selectedDatasetName ? (\n              <div className=\"flex items-center gap-3 rounded-lg border bg-muted/40 px-3.5 py-3\">\n                <div className=\"rounded-md bg-indigo-500/10 p-1.5\">\n                  <HugeiconsIcon\n                    icon={FileAttachmentIcon}\n                    className=\"size-4 text-indigo-500\"\n                  />\n                </div>\n                <div className=\"flex-1 min-w-0\">\n                  <p className=\"font-mono text-sm font-medium truncate\">\n                    {datasetSource === \"upload\"\n                      ? selectedLocalDataset?.label ??\n                        deriveLocalDatasetName(selectedDatasetName)\n                      : selectedDatasetName}\n                  </p>\n                  <p className=\"text-[10px] text-muted-foreground\">\n                    {datasetSource === \"upload\" ? (\n                      uploadedFile ? (\n                        <>\n                          Local dataset\n                          {selectedLocalRows != null\n                            ? ` / ${selectedLocalRows.toLocaleString()} rows`\n                            : \"\"}\n                        </>\n                      ) : (\n                        \"Local dataset\"\n                      )\n                    ) : (\n                      <>\n                        Hugging Face Dataset\n                        {datasetSubset && ` / ${datasetSubset}`}\n                        {datasetSplit && ` / ${datasetSplit}`}\n                      </>\n                    )}\n                  </p>\n                </div>\n                <Button\n                  variant=\"ghost\"\n                  size=\"sm\"\n                  className=\"shrink-0 text-xs\"\n                  onClick={() => clearSelectionForTab(activeSourceTab)}\n                >\n                  Clear\n                </Button>\n              </div>\n            ) : (\n              <div className=\"flex items-center gap-3 rounded-lg border border-dashed bg-muted/20 px-3.5 py-3\">\n                <HugeiconsIcon\n                  icon={Database02Icon}\n                  className=\"size-4 text-muted-foreground/40\"\n                />\n                <span className=\"text-xs text-muted-foreground\">\n                  No dataset selected\n                </span>\n              </div>\n            )}\n\n            <div className=\"grid grid-cols-2 gap-2\">\n              <Button\n                variant=\"outline\"\n                size=\"sm\"\n                className=\"cursor-pointer gap-1.5\"\n                disabled={isUploading}\n                onClick={handleUploadButtonClick}\n              >\n                {isUploading ? (\n                  <Spinner className=\"size-3.5\" />\n                ) : (\n                  <HugeiconsIcon icon={CloudUploadIcon} className=\"size-3.5\" />\n                )}\n                {isUploading ? \"Uploading...\" : \"Upload\"}\n              </Button>\n              <Button\n                variant=\"outline\"\n                size=\"sm\"\n                className=\"cursor-pointer gap-1.5\"\n                disabled={!selectedDatasetName}\n                onClick={() => openPreview()}\n              >\n                <HugeiconsIcon icon={ViewIcon} className=\"size-3.5\" />\n                View dataset\n              </Button>\n            </div>\n          </div>\n          <input\n            ref={fileInputRef}\n            type=\"file\"\n            accept=\".json,.jsonl,.csv,.parquet,.pdf,.docx,.txt\"\n            className=\"hidden\"\n            onChange={(event) => {\n              void handleDatasetFileChange(event);\n            }}\n          />\n          <input\n            ref={evalFileInputRef}\n            type=\"file\"\n            accept=\".json,.jsonl,.csv,.parquet\"\n            className=\"hidden\"\n            onChange={(event) => {\n              void handleEvalFileChange(event);\n            }}\n          />\n          <DocumentUploadRedirectDialog\n            open={documentRedirectOpen}\n            onOpenChange={setDocumentRedirectOpen}\n            fileName={redirectFileName}\n            onOpenLearningRecipes={handleOpenLearningRecipes}\n          />\n      </div>\n      </SectionCard>\n    </div>\n  );\n}\n\nfunction MetadataRow({ label, value }: { label: string; value: string }) {\n  return (\n    <div className=\"flex items-center justify-between gap-2 rounded-md bg-background/60 px-2 py-1.5\">\n      <span className=\"text-muted-foreground\">{label}</span>\n      <span className=\"font-medium text-foreground\">{value}</span>\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/studio/sections/document-upload-redirect-dialog.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Button } from \"@/components/ui/button\";\nimport {\n  Dialog,\n  DialogContent,\n  DialogDescription,\n  DialogFooter,\n  DialogHeader,\n  DialogTitle,\n} from \"@/components/ui/dialog\";\nimport { Badge } from \"@/components/ui/badge\";\nimport {\n  ArrowRight01Icon,\n  DocumentAttachmentIcon,\n} from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport type { ReactElement } from \"react\";\n\ntype DocumentUploadRedirectDialogProps = {\n  open: boolean;\n  onOpenChange: (open: boolean) => void;\n  fileName: string | null;\n  onOpenLearningRecipes: () => void;\n};\n\nexport function DocumentUploadRedirectDialog({\n  open,\n  onOpenChange,\n  fileName,\n  onOpenLearningRecipes,\n}: DocumentUploadRedirectDialogProps): ReactElement {\n  return (\n    <Dialog open={open} onOpenChange={onOpenChange}>\n      <DialogContent\n        className=\"sm:max-w-lg\"\n        overlayClassName=\"bg-background/45 supports-backdrop-filter:backdrop-blur-[1px]\"\n      >\n        <DialogHeader className=\"gap-3\">\n          <div className=\"flex items-center gap-2\">\n            <div className=\"flex size-10 items-center justify-center rounded-2xl border border-border/70 bg-muted/30\">\n              <HugeiconsIcon\n                icon={DocumentAttachmentIcon}\n                className=\"size-5 text-foreground/90\"\n              />\n            </div>\n            <Badge variant=\"outline\">Recipe Studio</Badge>\n          </div>\n          <div className=\"space-y-1\">\n            <DialogTitle>This file needs conversion first</DialogTitle>\n            <DialogDescription>\n              {fileName ? (\n                <>\n                  <span className=\"font-medium text-foreground\">{fileName}</span>{\" \"}\n                  is source material, not a ready-to-train dataset.\n                </>\n              ) : (\n                \"This file is source material, not a ready-to-train dataset.\"\n              )}{\" \"}\n              Use Data Recipes to turn documents into a dataset, then bring the\n              result back here for fine-tuning.\n            </DialogDescription>\n          </div>\n        </DialogHeader>\n\n        <div className=\"corner-squircle rounded-2xl border border-border/70 bg-muted/20 p-4\">\n          <p className=\"text-sm font-medium text-foreground\">\n            Best next step\n          </p>\n          <p className=\"mt-1 text-sm text-muted-foreground\">\n            Open Learning Recipes and start from a document-based recipe like PDF\n            grounded QA.\n          </p>\n        </div>\n\n        <DialogFooter className=\"sm:justify-between\">\n          <Button\n            type=\"button\"\n            variant=\"outline\"\n            onClick={() => onOpenChange(false)}\n          >\n            Cancel\n          </Button>\n          <Button type=\"button\" onClick={onOpenLearningRecipes}>\n            Open Learning Recipes\n            <HugeiconsIcon icon={ArrowRight01Icon} className=\"size-4\" />\n          </Button>\n        </DialogFooter>\n      </DialogContent>\n    </Dialog>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/studio/sections/model-section.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { SectionCard } from \"@/components/section-card\";\nimport {\n  Combobox,\n  ComboboxContent,\n  ComboboxEmpty,\n  ComboboxInput,\n  ComboboxItem,\n  ComboboxList,\n} from \"@/components/ui/combobox\";\nimport {\n  InputGroup,\n  InputGroupAddon,\n  InputGroupInput,\n} from \"@/components/ui/input-group\";\nimport {\n  Select,\n  SelectContent,\n  SelectItem,\n  SelectTrigger,\n  SelectValue,\n} from \"@/components/ui/select\";\nimport { Spinner } from \"@/components/ui/spinner\";\nimport {\n  Tooltip,\n  TooltipContent,\n  TooltipTrigger,\n} from \"@/components/ui/tooltip\";\nimport { MODEL_TYPE_TO_HF_TASK, PRIORITY_TRAINING_MODELS, applyPriorityOrdering } from \"@/config/training\";\nimport {\n  useDebouncedValue,\n  useGpuInfo,\n  useHfModelSearch,\n  useHfTokenValidation,\n  useInfiniteScroll,\n} from \"@/hooks\";\nimport { formatCompact } from \"@/lib/utils\";\nimport {\n  type TrainingMethod as VramTrainingMethod,\n  type VramFitStatus,\n  buildModelVramMap,\n} from \"@/lib/vram\";\nimport {\n  listLocalModels,\n  type LocalModelInfo,\n  useTrainingConfigStore,\n} from \"@/features/training\";\nimport type { TrainingMethod } from \"@/types/training\";\nimport {\n  ChipIcon,\n  FolderSearchIcon,\n  InformationCircleIcon,\n  Key01Icon,\n  Search01Icon,\n} from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport { useEffect, useMemo, useRef, useState } from \"react\";\nimport { useShallow } from \"zustand/react/shallow\";\n\nconst METHOD_DOTS: Record<string, string> = {\n  qlora: \"bg-emerald-400\",\n  lora: \"bg-blue-400\",\n  full: \"bg-amber-400\",\n};\n\nconst DARK_TRIGGER =\n  \"w-full bg-foreground text-background hover:bg-foreground/90 dark:bg-foreground dark:text-background dark:hover:bg-foreground [&_svg]:text-background/50\";\nconst DARK_CONTENT =\n  \"bg-foreground text-background shadow-xl border-background/10 [--accent:rgba(255,255,255,0.1)] [--accent-foreground:white] dark:[--accent:rgba(2,6,23,0.08)] dark:[--accent-foreground:rgb(2,6,23)] [&_[data-slot=select-item]]:text-white/80 dark:[&_[data-slot=select-item]]:text-slate-900 [&_[data-slot=select-scroll-up-button]]:bg-foreground [&_[data-slot=select-scroll-down-button]]:bg-foreground\";\nconst DARK_COMBOBOX_CONTENT =\n  \"bg-foreground text-background shadow-xl border-background/10 dark:[--accent:rgba(2,6,23,0.08)] dark:[--accent-foreground:rgb(2,6,23)] dark:[&_[data-slot=combobox-item]]:text-slate-900 dark:[&_.text-muted-foreground]:text-slate-500\";\n\nexport function ModelSection() {\n  const gpu = useGpuInfo();\n\n  const {\n    modelType,\n    selectedModel,\n    setSelectedModel,\n    trainingMethod,\n    setTrainingMethod,\n    hfToken,\n    setHfToken,\n  } = useTrainingConfigStore(\n    useShallow(\n      ({\n        modelType,\n        selectedModel,\n        setSelectedModel,\n        trainingMethod,\n        setTrainingMethod,\n        hfToken,\n        setHfToken,\n      }) => ({\n        modelType,\n        selectedModel,\n        setSelectedModel,\n        trainingMethod,\n        setTrainingMethod,\n        hfToken,\n        setHfToken,\n      }),\n    ),\n  );\n\n  const [inputValue, setInputValue] = useState(\"\");\n  const [localModelInput, setLocalModelInput] = useState(\"\");\n  const [localModels, setLocalModels] = useState<LocalModelInfo[]>([]);\n  const [isLoadingLocalModels, setIsLoadingLocalModels] = useState(true);\n  const [localModelsError, setLocalModelsError] = useState<string | null>(null);\n  const selectingRef = useRef(false);\n  const debouncedQuery = useDebouncedValue(inputValue);\n\n  function handleModelSelect(id: string | null) {\n    selectingRef.current = true;\n    setSelectedModel(id);\n  }\n\n  function handleInputChange(val: string) {\n    if (selectingRef.current) {\n      selectingRef.current = false;\n      return;\n    }\n    setInputValue(val);\n  }\n\n  function applyLocalModel(value: string) {\n    const next = value.trim();\n    if (!next) return;\n    setSelectedModel(next);\n  }\n\n  useEffect(() => {\n    const controller = new AbortController();\n    void listLocalModels(controller.signal)\n      .then((models) => {\n        if (controller.signal.aborted) return;\n        setLocalModels(models);\n      })\n      .catch((error) => {\n        if (controller.signal.aborted) return;\n        setLocalModelsError(\n          error instanceof Error ? error.message : \"Failed to load local models\",\n        );\n      })\n      .finally(() => {\n        if (controller.signal.aborted) return;\n        setIsLoadingLocalModels(false);\n      });\n    return () => controller.abort();\n  }, []);\n  const task = modelType ? MODEL_TYPE_TO_HF_TASK[modelType] : undefined;\n  const {\n    results: hfResults,\n    isLoading,\n    isLoadingMore,\n    fetchMore,\n    error: hfSearchError,\n  } = useHfModelSearch(debouncedQuery, {\n    task,\n    accessToken: hfToken || undefined,\n    excludeGguf: true,\n    priorityIds: PRIORITY_TRAINING_MODELS,\n  });\n\n  const { error: tokenValidationError, isChecking: isCheckingToken } =\n    useHfTokenValidation(hfToken);\n\n  const resultIds = useMemo(() => {\n    const ids = hfResults.map((r) => r.id);\n    if (selectedModel && !ids.includes(selectedModel)) {\n      ids.push(selectedModel);\n    }\n\n    return applyPriorityOrdering(ids);\n  }, [hfResults, selectedModel]);\n\n  // Filter out GGUF models — they can't be used for training\n  const trainableLocalModels = useMemo(\n    () =>\n      localModels.filter((m) => {\n        if (m.path.endsWith(\".gguf\")) return false;\n        if (m.id.toLowerCase().includes(\"-gguf\")) return false;\n        return true;\n      }),\n    [localModels],\n  );\n\n  const localMetaById = useMemo(() => {\n    const map = new Map<string, LocalModelInfo>();\n    for (const model of trainableLocalModels) map.set(model.id, model);\n    return map;\n  }, [trainableLocalModels]);\n\n  const localResultIds = useMemo(() => {\n    const ids = trainableLocalModels.map((model) => model.id);\n    const manual = localModelInput.trim();\n    if (manual && !ids.includes(manual)) {\n      ids.unshift(manual);\n    }\n    return ids;\n  }, [localModelInput, localModels]);\n\n  const localFilteredIds = useMemo(() => {\n    const q = localModelInput.trim().toLowerCase();\n    if (!q) return localResultIds;\n    return localResultIds.filter((id) => {\n      const meta = localMetaById.get(id);\n      if (id.toLowerCase().includes(q)) return true;\n      if (meta?.display_name.toLowerCase().includes(q)) return true;\n      if (meta?.path.toLowerCase().includes(q)) return true;\n      return false;\n    });\n  }, [localMetaById, localModelInput, localResultIds]);\n\n  // Pre-compute VRAM fit status for every model in the current result set.\n  // Keyed by model id so the render callback is a simple O(1) lookup.\n  //\n  // Pre-compute VRAM fit status for every model in the current result set.\n  // Keyed by model id so the render callback is a simple O(1) lookup.\n  // Re-computes when the training method changes (QLoRA=4-bit vs LoRA/Full=fp16).\n  const vramMap = useMemo(() => {\n    const fitMap = buildModelVramMap(\n      hfResults,\n      trainingMethod as VramTrainingMethod,\n      gpu,\n    );\n    const map = new Map<\n      string,\n      { est: number; status: VramFitStatus | null; detail: string | null }\n    >();\n    for (const r of hfResults) {\n      const detail = r.totalParams ? formatCompact(r.totalParams) : null;\n      const fit = fitMap.get(r.id);\n      map.set(r.id, {\n        est: fit?.est ?? 0,\n        status: fit?.status ?? null,\n        detail,\n      });\n    }\n    return map;\n  }, [hfResults, gpu, trainingMethod]);\n\n  const comboboxAnchorRef = useRef<HTMLDivElement>(null);\n  const localComboboxAnchorRef = useRef<HTMLDivElement>(null);\n  const { scrollRef, sentinelRef } = useInfiniteScroll(\n    fetchMore,\n    hfResults.length,\n  );\n\n  return (\n    <div data-tour=\"studio-model\" className=\"col-span-1 md:col-span-2 xl:col-span-12\">\n      <SectionCard\n        icon={<HugeiconsIcon icon={ChipIcon} className=\"size-5\" />}\n        title=\"Model\"\n        description=\"Select base model and training method\"\n        accent=\"emerald\"\n        featured={true}\n        badge=\"2x Faster Training\"\n        className=\"shadow-border ring-1 ring-border\"\n      >\n        <div className=\"grid gap-4 md:grid-cols-2 xl:grid-cols-4\">\n          <div data-tour=\"studio-local-model\" className=\"flex flex-col gap-2\">\n            <span className=\"flex items-center gap-1.5 text-xs font-medium text-muted-foreground\">\n              Local Model\n            <Tooltip>\n              <TooltipTrigger asChild={true}>\n                <button\n                  type=\"button\"\n                  className=\"text-foreground/70 hover:text-foreground\"\n                >\n                  <HugeiconsIcon\n                    icon={InformationCircleIcon}\n                    className=\"size-3\"\n                  />\n                </button>\n              </TooltipTrigger>\n              <TooltipContent>\n                Path to a locally downloaded model or a custom HF repo.\n              </TooltipContent>\n            </Tooltip>\n          </span>\n          <div ref={localComboboxAnchorRef}>\n            <Combobox\n              items={localResultIds}\n              filteredItems={localFilteredIds}\n              filter={null}\n              value={localModelInput || null}\n              onValueChange={(id) => {\n                const next = id ?? \"\";\n                setLocalModelInput(next);\n                if (next) setSelectedModel(next);\n              }}\n              onInputValueChange={setLocalModelInput}\n              itemToStringValue={(id) => id}\n              autoHighlight={true}\n            >\n              <ComboboxInput\n                placeholder={\n                  isLoadingLocalModels\n                    ? \"Scanning local and cached models...\"\n                    : \"./models/my-model\"\n                }\n                className=\"w-full bg-foreground text-background [&_input]:text-background [&_input]:placeholder:text-background/40 [&_svg]:text-background/50 hover:bg-foreground/90\"\n                onBlur={() => applyLocalModel(localModelInput)}\n                onKeyDown={(event) => {\n                  if (event.key !== \"Enter\") return;\n                  event.preventDefault();\n                  applyLocalModel(localModelInput);\n                }}\n              >\n                <InputGroupAddon>\n                  <HugeiconsIcon icon={FolderSearchIcon} className=\"size-4\" />\n                </InputGroupAddon>\n              </ComboboxInput>\n              <ComboboxContent\n                anchor={localComboboxAnchorRef}\n                className={DARK_COMBOBOX_CONTENT}\n              >\n                {isLoadingLocalModels ? (\n                  <div className=\"flex items-center justify-center gap-2 py-4 text-xs text-muted-foreground\">\n                    <Spinner className=\"size-4\" /> Scanning...\n                  </div>\n                ) : localModelsError ? (\n                  <div className=\"px-3 py-2 text-xs text-red-500\">\n                    {localModelsError}\n                  </div>\n                ) : (\n                  <ComboboxEmpty>No local models found</ComboboxEmpty>\n                )}\n                <ComboboxList className=\"p-1\">\n                  {(id: string) => {\n                    const model = localMetaById.get(id);\n                    const source =\n                      model?.source === \"hf_cache\" ? \"HF cache\" : \"Local dir\";\n                    return (\n                      <ComboboxItem key={id} value={id} className=\"gap-2\">\n                        <Tooltip>\n                          <TooltipTrigger asChild={true}>\n                            <span className=\"block min-w-0 flex-1 truncate\">\n                              {model?.display_name ?? id}\n                            </span>\n                          </TooltipTrigger>\n                          <TooltipContent side=\"left\" className=\"max-w-xs break-all\">\n                            {model?.path ?? id}\n                          </TooltipContent>\n                        </Tooltip>\n                        <span className=\"ml-auto shrink-0 text-[10px] text-muted-foreground\">\n                          {source}\n                        </span>\n                      </ComboboxItem>\n                    );\n                  }}\n                </ComboboxList>\n              </ComboboxContent>\n            </Combobox>\n          </div>\n          {isLoadingLocalModels ? (\n            <p className=\"text-[10px] text-muted-foreground\">Scanning local models...</p>\n          ) : localModelsError ? (\n            <p className=\"text-[10px] text-red-500\">{localModelsError}</p>\n          ) : (\n            <p className=\"text-[10px] text-muted-foreground\">\n              {trainableLocalModels.length > 0\n                ? `${trainableLocalModels.length} local/cached models found`\n                : \"No local models found. Enter path manually.\"}\n            </p>\n          )}\n        </div>\n\n          <div data-tour=\"studio-base-model\" className=\"flex flex-col gap-2\">\n          <span className=\"flex items-center gap-1.5 text-xs font-medium text-muted-foreground\">\n            Hugging Face Model\n            <Tooltip>\n              <TooltipTrigger asChild={true}>\n                <button\n                  type=\"button\"\n                  className=\"text-foreground/70 hover:text-foreground\"\n                >\n                  <HugeiconsIcon\n                    icon={InformationCircleIcon}\n                    className=\"size-3\"\n                  />\n                </button>\n              </TooltipTrigger>\n              <TooltipContent>\n                Search Hugging Face models or pick from our recommended list.{\" \"}\n                <a\n                  href=\"https://unsloth.ai/docs/get-started/fine-tuning-llms-guide/what-model-should-i-use\"\n                  target=\"_blank\"\n                  rel=\"noopener noreferrer\"\n                  className=\"text-primary underline\"\n                >\n                  Read more\n                </a>\n              </TooltipContent>\n            </Tooltip>\n          </span>\n          <div\n            ref={comboboxAnchorRef}\n            onKeyDown={(event) => {\n              if (event.key !== \"Enter\") return;\n              if (!(event.target instanceof HTMLInputElement)) return;\n              event.preventDefault();\n              if (hfResults.length > 0) {\n                handleModelSelect(hfResults[0].id);\n              } else {\n                const text = event.target.value.trim();\n                if (text) handleModelSelect(text);\n              }\n            }}\n          >\n            <Combobox\n              items={resultIds}\n              filteredItems={resultIds}\n              filter={null}\n              value={selectedModel}\n              onValueChange={handleModelSelect}\n              onInputValueChange={handleInputChange}\n              itemToStringValue={(id) => id}\n              autoHighlight={true}\n            >\n              <ComboboxInput placeholder=\"Search models...\" className=\"w-full\">\n                <InputGroupAddon>\n                  <HugeiconsIcon icon={Search01Icon} className=\"size-4\" />\n                </InputGroupAddon>\n              </ComboboxInput>\n              <ComboboxContent anchor={comboboxAnchorRef}>\n                {isLoading ? (\n                  <div className=\"flex items-center justify-center py-4 gap-2 text-xs text-muted-foreground\">\n                    <Spinner className=\"size-4\" /> Searching…\n                  </div>\n                ) : (\n                  <ComboboxEmpty>No models found</ComboboxEmpty>\n                )}\n                <div\n                  ref={scrollRef}\n                  className=\"max-h-64 overflow-y-auto overscroll-contain [scrollbar-width:thin]\"\n                >\n                  <ComboboxList className=\"p-1 !max-h-none !overflow-visible\">\n                    {(id: string) => {\n                      const entry = vramMap.get(id);\n                      const detail = entry?.detail ?? null;\n                      const fitStatus = entry?.status ?? null;\n                      const vramEst = entry?.est ?? null;\n                      const exceeds = fitStatus === \"exceeds\";\n\n                      return (\n                        <ComboboxItem\n                          key={id}\n                          value={id}\n                          className={`gap-2 ${exceeds ? \"opacity-50\" : \"\"}`}\n                        >\n                          <Tooltip>\n                            <TooltipTrigger asChild={true}>\n                              <span className={`block min-w-0 flex-1 truncate ${exceeds ? \"line-through decoration-muted-foreground/50\" : \"\"}`}>\n                                {id}\n                              </span>\n                            </TooltipTrigger>\n                            <TooltipContent\n                              side=\"left\"\n                              className=\"max-w-xs break-all\"\n                            >\n                              {id}\n                              {vramEst != null && vramEst > 0 && gpu.available && (\n                                <span className=\"block text-[10px] mt-1\">\n                                  {exceeds\n                                    ? `Needs ~${vramEst}GB VRAM (GPU: ${gpu.memoryTotalGb}GB)`\n                                    : fitStatus === \"tight\"\n                                      ? `~${vramEst}GB VRAM (tight fit on ${gpu.memoryTotalGb}GB)`\n                                      : `~${vramEst}GB VRAM`}\n                                </span>\n                              )}\n                            </TooltipContent>\n                          </Tooltip>\n                          <span className=\"ml-auto flex items-center gap-1.5 shrink-0\">\n                            {fitStatus === \"exceeds\" && (\n                              <span className=\"text-[9px] font-medium text-red-400\">\n                                OOM\n                              </span>\n                            )}\n                            {fitStatus === \"tight\" && (\n                              <span className=\"text-[9px] font-medium text-amber-400\">\n                                TIGHT\n                              </span>\n                            )}\n                            {detail && (\n                              <span className=\"text-[10px] text-muted-foreground\">\n                                {detail}\n                              </span>\n                            )}\n                          </span>\n                        </ComboboxItem>\n                      );\n                    }}\n                  </ComboboxList>\n                  <div ref={sentinelRef} className=\"h-px\" />\n                  {isLoadingMore && (\n                    <div className=\"flex items-center justify-center py-2\">\n                      <Spinner className=\"size-3.5 text-muted-foreground\" />\n                    </div>\n                  )}\n                </div>\n              </ComboboxContent>\n            </Combobox>\n          </div>\n        </div>\n\n          <div data-tour=\"studio-method\" className=\"flex flex-col gap-2\">\n          <span className=\"flex items-center gap-1.5 text-xs font-medium text-muted-foreground\">\n            Method\n            <Tooltip>\n              <TooltipTrigger asChild={true}>\n                <button\n                  type=\"button\"\n                  className=\"text-foreground/70 hover:text-foreground\"\n                >\n                  <HugeiconsIcon\n                    icon={InformationCircleIcon}\n                    className=\"size-3\"\n                  />\n                </button>\n              </TooltipTrigger>\n              <TooltipContent className=\"max-w-xs\">\n                QLoRA uses 4-bit quantization for lowest VRAM. LoRA uses 16-bit.\n                Full updates all weights.{\" \"}\n                <a\n                  href=\"https://unsloth.ai/docs/get-started/fine-tuning-llms-guide/lora-hyperparameters-guide\"\n                  target=\"_blank\"\n                  rel=\"noopener noreferrer\"\n                  className=\"text-primary underline\"\n                >\n                  Read more\n                </a>\n              </TooltipContent>\n            </Tooltip>\n          </span>\n          <Select\n            value={trainingMethod}\n            onValueChange={(v) => setTrainingMethod(v as TrainingMethod)}\n          >\n            <SelectTrigger className={DARK_TRIGGER}>\n              <SelectValue />\n            </SelectTrigger>\n            <SelectContent\n              position=\"popper\"\n              className={`${DARK_CONTENT} w-[var(--radix-select-trigger-width)]`}\n            >\n              <SelectItem value=\"qlora\">\n                <span className=\"flex items-center gap-2\">\n                  <span\n                    className={`size-2 shrink-0 rounded-full ${METHOD_DOTS.qlora}`}\n                  />\n                  QLoRA (4-bit)\n                </span>\n              </SelectItem>\n              <SelectItem value=\"lora\">\n                <span className=\"flex items-center gap-2\">\n                  <span\n                    className={`size-2 shrink-0 rounded-full ${METHOD_DOTS.lora}`}\n                  />\n                  LoRA (16-bit)\n                </span>\n              </SelectItem>\n              <SelectItem value=\"full\">\n                <span className=\"flex items-center gap-2\">\n                  <span\n                    className={`size-2 shrink-0 rounded-full ${METHOD_DOTS.full}`}\n                  />\n                  Full Fine-tune\n                </span>\n              </SelectItem>\n            </SelectContent>\n          </Select>\n        </div>\n\n        <div className=\"flex flex-col gap-2\">\n          <span className=\"text-xs font-medium text-muted-foreground\">\n            Hugging Face Token (Optional)\n          </span>\n          <InputGroup>\n            <InputGroupAddon>\n              <HugeiconsIcon icon={Key01Icon} className=\"size-4\" />\n            </InputGroupAddon>\n            <InputGroupInput\n              type=\"password\"\n              autoComplete=\"new-password\"\n              name=\"hf-token\"\n              placeholder=\"hf_...\"\n              value={hfToken}\n              onChange={(e) => setHfToken(e.target.value)}\n            />\n          </InputGroup>\n          {(tokenValidationError ?? hfSearchError) && (\n            <p className=\"text-xs text-destructive\">\n              {tokenValidationError ?? hfSearchError}\n              {\" — \"}\n              <a\n                href=\"https://huggingface.co/settings/tokens\"\n                target=\"_blank\"\n                rel=\"noopener noreferrer\"\n                className=\"underline\"\n              >\n                Get or update token\n              </a>\n            </p>\n          )}\n          {isCheckingToken && (\n            <p className=\"text-xs text-muted-foreground\">Checking token…</p>\n          )}\n        </div>\n        </div>\n      </SectionCard>\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/studio/sections/params-section.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { SectionCard } from \"@/components/section-card\";\nimport { Checkbox } from \"@/components/ui/checkbox\";\nimport {\n  Collapsible,\n  CollapsibleContent,\n  CollapsibleTrigger,\n} from \"@/components/ui/collapsible\";\nimport { Input } from \"@/components/ui/input\";\nimport {\n  Combobox,\n  ComboboxContent,\n  ComboboxEmpty,\n  ComboboxInput,\n  ComboboxItem,\n  ComboboxList,\n} from \"@/components/ui/combobox\";\nimport {\n  Select,\n  SelectContent,\n  SelectItem,\n  SelectTrigger,\n  SelectValue,\n} from \"@/components/ui/select\";\nimport { Slider } from \"@/components/ui/slider\";\nimport { Tabs, TabsContent, TabsList, TabsTrigger } from \"@/components/ui/tabs\";\nimport {\n  Tooltip,\n  TooltipContent,\n  TooltipTrigger,\n} from \"@/components/ui/tooltip\";\nimport {\n  CONTEXT_LENGTHS,\n  LR_SCHEDULER_OPTIONS,\n  OPTIMIZER_OPTIONS,\n  TARGET_MODULES,\n} from \"@/config/training\";\nimport { useMaxStepsEpochsToggle, useTrainingConfigStore } from \"@/features/training\";\nimport type { GradientCheckpointing } from \"@/types/training\";\nimport {\n  ArrowDown01Icon,\n  InformationCircleIcon,\n  Settings04Icon,\n} from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport { type ReactElement, type ReactNode, useEffect, useRef, useState } from \"react\";\n\nfunction Row({\n  label,\n  tooltip,\n  children,\n}: { label: string; tooltip?: ReactNode; children: ReactNode }): ReactElement {\n  return (\n    <div className=\"flex items-center justify-between\">\n      <span className=\"flex items-center gap-1.5 text-xs font-medium text-muted-foreground\">\n        {label}\n        {tooltip && (\n          <Tooltip>\n            <TooltipTrigger asChild={true}>\n              <button\n                type=\"button\"\n                className=\"text-foreground/70 hover:text-foreground\"\n              >\n                <HugeiconsIcon\n                  icon={InformationCircleIcon}\n                  className=\"size-3\"\n                />\n              </button>\n            </TooltipTrigger>\n            <TooltipContent>{tooltip}</TooltipContent>\n          </Tooltip>\n        )}\n      </span>\n      {children}\n    </div>\n  );\n}\n\nfunction SliderRow({\n  label,\n  tooltip,\n  value,\n  onChange,\n  min,\n  max,\n  step,\n  format,\n}: {\n  label: string;\n  tooltip?: ReactNode;\n  value: number;\n  onChange: (v: number) => void;\n  min: number;\n  max: number;\n  step: number;\n  format?: (v: number) => string;\n}): ReactElement {\n  return (\n    <Row label={label} tooltip={tooltip}>\n      <div className=\"flex items-center gap-3\">\n        <Slider\n          value={[value]}\n          onValueChange={([v]) => onChange(v)}\n          min={min}\n          max={max}\n          step={step}\n          className=\"w-32\"\n        />\n        <input\n          type=\"number\"\n          value={format ? format(value) : value}\n          onChange={(e) => onChange(Number(e.target.value))}\n          min={min}\n          max={max}\n          step={step}\n          className=\"w-12 text-right font-mono text-xs font-medium bg-muted/50 border border-border rounded-lg px-1.5 py-0.5 focus:outline-none focus:ring-1 focus:ring-primary/30 [&::-webkit-inner-spin-button]:appearance-none\"\n        />\n      </div>\n    </Row>\n  );\n}\n\nexport function ParamsSection(): ReactElement {\n  const store = useTrainingConfigStore();\n  const isLora = store.trainingMethod !== \"full\";\n  const showVisionLora = store.isVisionModel && store.isDatasetImage === true;\n  const [loraOpen, setLoraOpen] = useState(false);\n  const [hyperOpen, setHyperOpen] = useState(false);\n  const [ctxInput, setCtxInput] = useState(String(store.contextLength));\n  const ctxAnchorRef = useRef<HTMLDivElement>(null);\n  const ctxItems = CONTEXT_LENGTHS.map(String);\n\n  // Keep input in sync when the store value changes externally\n  // (e.g. model defaults being applied after model selection).\n  useEffect(() => {\n    setCtxInput(String(store.contextLength));\n  }, [store.contextLength]);\n\n  const trySetContextLength = (input: string): number | null => {\n    const n = Number(input);\n    if (Number.isInteger(n) && n > 0) {\n      store.setContextLength(n);\n      return n;\n    }\n    return null;\n  };\n\n  const { useEpochs, toggleUseEpochs } = useMaxStepsEpochsToggle({\n    maxSteps: store.maxSteps,\n    epochs: store.epochs,\n    saveSteps: store.saveSteps,\n    setMaxSteps: store.setMaxSteps,\n    setEpochs: store.setEpochs,\n    setSaveSteps: store.setSaveSteps,\n  });\n\n  const maxStepsSliderMax = Math.max(500, store.maxSteps, 30);\n  const epochsSliderMax = Math.max(20, store.epochs, 1);\n\n  return (\n    <div data-tour=\"studio-params\" className=\"col-span-1 xl:col-span-4\">\n      <SectionCard\n        icon={<HugeiconsIcon icon={Settings04Icon} className=\"size-5\" />}\n        title=\"Parameters\"\n        description=\"Configure training hyperparameters\"\n        accent=\"orange\"\n        className=\"md:min-h-[470px]\"\n      >\n        <div className=\"flex flex-col gap-4\">\n          {/* Max Steps / Epochs */}\n          <div className=\"flex flex-col gap-2\">\n            <div\n              key={useEpochs ? \"epochs\" : \"steps\"}\n              className=\"flex flex-col gap-2 animate-in fade-in-0 slide-in-from-bottom-1 duration-200\"\n            >\n              <div className=\"flex items-center justify-between\">\n                <span className=\"flex items-center gap-1.5 text-xs font-medium text-muted-foreground\">\n                  {useEpochs ? \"Epochs\" : \"Max Steps\"}\n                  <Tooltip>\n                    <TooltipTrigger asChild={true}>\n                      <button\n                        type=\"button\"\n                        className=\"text-foreground/70 hover:text-foreground\"\n                      >\n                        <HugeiconsIcon\n                          icon={InformationCircleIcon}\n                          className=\"size-3\"\n                        />\n                      </button>\n                    </TooltipTrigger>\n                    <TooltipContent>\n                      {useEpochs\n                        ? \"Number of full passes over the dataset.\"\n                        : \"Override total optimizer steps.\"}{\" \"}\n                      <a\n                        href=\"https://unsloth.ai/docs/get-started/fine-tuning-llms-guide/lora-hyperparameters-guide\"\n                        target=\"_blank\"\n                        rel=\"noopener noreferrer\"\n                        className=\"text-primary underline\"\n                      >\n                        Read more\n                      </a>\n                    </TooltipContent>\n                  </Tooltip>\n                </span>\n                <div className=\"flex items-center gap-3\">\n                  <button\n                    type=\"button\"\n                    onClick={toggleUseEpochs}\n                    className=\"text-xs text-primary underline cursor-pointer\"\n                  >\n                    {useEpochs ? \"Use Max Steps\" : \"Use Epochs\"}\n                  </button>\n                  <input\n                    type=\"number\"\n                    value={useEpochs ? store.epochs : store.maxSteps}\n                    onChange={(e) => {\n                      const raw = e.target.value;\n                      if (raw === \"\") return;\n\n                      const value = Number(raw);\n                      if (!Number.isFinite(value) || value < 1) return;\n\n                      if (useEpochs) {\n                        store.setEpochs(value);\n                      } else {\n                        store.setMaxSteps(value);\n                      }\n                    }}\n                    min={1}\n                    max={useEpochs ? epochsSliderMax : maxStepsSliderMax}\n                    step={1}\n                    className=\"w-16 text-right font-mono text-xs font-medium bg-muted/50 border border-border rounded-lg px-1.5 py-0.5 focus:outline-none focus:ring-1 focus:ring-primary/30 [&::-webkit-inner-spin-button]:appearance-none\"\n                  />\n                </div>\n              </div>\n              <Slider\n                value={[\n                  useEpochs\n                    ? Math.min(epochsSliderMax, Math.max(1, store.epochs))\n                    : Math.min(maxStepsSliderMax, Math.max(1, store.maxSteps)),\n                ]}\n                onValueChange={([v]) =>\n                  useEpochs ? store.setEpochs(v) : store.setMaxSteps(v)\n                }\n                min={1}\n                max={useEpochs ? epochsSliderMax : maxStepsSliderMax}\n                step={1}\n              />\n              <p className=\"text-[10px] text-muted-foreground\">\n                {useEpochs\n                  ? \"Each epoch is one full pass over your dataset.\"\n                  : \"Limits training to a fixed number of optimizer steps.\"}\n              </p>\n            </div>\n          </div>\n\n          {/* Context length */}\n          <div className=\"flex flex-col gap-2\">\n            <span className=\"flex items-center gap-1.5 text-xs font-medium text-muted-foreground\">\n              Context Length\n              <Tooltip>\n                <TooltipTrigger asChild={true}>\n                  <button\n                    type=\"button\"\n                    className=\"text-foreground/70 hover:text-foreground\"\n                  >\n                    <HugeiconsIcon\n                      icon={InformationCircleIcon}\n                      className=\"size-3\"\n                    />\n                  </button>\n                </TooltipTrigger>\n                <TooltipContent>\n                  Maximum number of tokens per training sample.{\" \"}\n                  <a\n                    href=\"https://unsloth.ai/docs/get-started/fine-tuning-llms-guide/lora-hyperparameters-guide\"\n                    target=\"_blank\"\n                    rel=\"noopener noreferrer\"\n                    className=\"text-primary underline\"\n                  >\n                    Read more\n                  </a>\n                </TooltipContent>\n              </Tooltip>\n            </span>\n            <div ref={ctxAnchorRef}>\n              <Combobox\n                items={ctxItems}\n                filteredItems={ctxItems}\n                filter={null}\n                value={String(store.contextLength)}\n                onValueChange={(v) => {\n                  if (v && trySetContextLength(v)) {\n                    setCtxInput(v);\n                  }\n                }}\n                onInputValueChange={setCtxInput}\n                itemToStringValue={(id) => Number(id).toLocaleString()}\n                autoHighlight={false}\n              >\n                <ComboboxInput\n                  placeholder={String(store.contextLength)}\n                  className=\"w-full font-mono\"\n                  onBlur={() => {\n                    trySetContextLength(ctxInput);\n                    setCtxInput(String(store.contextLength));\n                  }}\n                  onKeyDown={(e) => {\n                    if (e.key !== \"Enter\") { return; }\n                    const n = trySetContextLength(ctxInput);\n                    if (n === null) { return; }\n                    if (!ctxItems.includes(ctxInput.trim())) {\n                      e.stopPropagation();\n                      e.preventDefault();\n                    }\n                    setCtxInput(String(n));\n                  }}\n                />\n                <ComboboxContent anchor={ctxAnchorRef}>\n                  <ComboboxEmpty>Enter a custom value</ComboboxEmpty>\n                  <ComboboxList className=\"p-1\">\n                    {(id: string) => (\n                      <ComboboxItem key={id} value={id} className=\"font-mono\">\n                        {Number(id).toLocaleString()}\n                      </ComboboxItem>\n                    )}\n                  </ComboboxList>\n                </ComboboxContent>\n              </Combobox>\n            </div>\n            <p className=\"text-[10px] text-muted-foreground\">\n              Max sequence length for training samples\n            </p>\n          </div>\n\n          {/* Learning Rate */}\n          <div className=\"flex flex-col gap-2\">\n            <span className=\"flex items-center gap-1.5 text-xs font-medium text-muted-foreground\">\n              Learning Rate\n              <Tooltip>\n                <TooltipTrigger asChild={true}>\n                  <button\n                    type=\"button\"\n                    className=\"text-foreground/70 hover:text-foreground\"\n                  >\n                    <HugeiconsIcon\n                      icon={InformationCircleIcon}\n                      className=\"size-3\"\n                    />\n                  </button>\n                </TooltipTrigger>\n                <TooltipContent>\n                  Step size for weight updates. Lower values train slower but more\n                  stably.{\" \"}\n                  <a\n                    href=\"https://unsloth.ai/docs/get-started/fine-tuning-llms-guide/lora-hyperparameters-guide\"\n                    target=\"_blank\"\n                    rel=\"noopener noreferrer\"\n                    className=\"text-primary underline\"\n                  >\n                    Read more\n                  </a>\n                </TooltipContent>\n              </Tooltip>\n            </span>\n            <Input\n              type=\"number\"\n              step=\"0.00001\"\n              value={store.learningRate}\n              onChange={(e) => store.setLearningRate(Number(e.target.value))}\n              className=\"w-full font-mono\"\n            />\n            <p className=\"text-[10px] text-muted-foreground\">\n              Recommended: 2e-4 for LoRA, 2e-5 for full fine-tune\n            </p>\n          </div>\n\n          {/* LoRA Settings */}\n          {isLora && (\n            <div>\n              <button\n                type=\"button\"\n                onClick={() => setLoraOpen(!loraOpen)}\n                className=\"flex w-full cursor-pointer items-center gap-1.5 text-xs text-muted-foreground\"\n              >\n                <HugeiconsIcon\n                  icon={ArrowDown01Icon}\n                  className={`size-3.5 transition-transform ${loraOpen ? \"rotate-180\" : \"\"}`}\n                />\n                LoRA Settings\n              </button>\n              <div\n                className={`${loraOpen ? \"\" : \"hidden\"} pt-1.5 mt-4 flex flex-col gap-4`}\n              >\n                <SliderRow\n                  label=\"Rank\"\n                  tooltip={\n                    <>\n                      Dimension of the low-rank matrices. Higher = more capacity.{\" \"}\n                      <a\n                        href=\"https://unsloth.ai/docs/get-started/fine-tuning-llms-guide/lora-hyperparameters-guide\"\n                        target=\"_blank\"\n                        rel=\"noopener noreferrer\"\n                        className=\"text-primary underline\"\n                      >\n                        Read more\n                      </a>\n                    </>\n                  }\n                  value={store.loraRank}\n                  onChange={store.setLoraRank}\n                  min={4}\n                  max={128}\n                  step={4}\n                />\n                <SliderRow\n                  label=\"Alpha\"\n                  tooltip={\n                    <>\n                      Scaling factor for LoRA updates. Usually 2x rank.{\" \"}\n                      <a\n                        href=\"https://unsloth.ai/docs/get-started/fine-tuning-llms-guide/lora-hyperparameters-guide\"\n                        target=\"_blank\"\n                        rel=\"noopener noreferrer\"\n                        className=\"text-primary underline\"\n                      >\n                        Read more\n                      </a>\n                    </>\n                  }\n                  value={store.loraAlpha}\n                  onChange={store.setLoraAlpha}\n                  min={4}\n                  max={256}\n                  step={4}\n                />\n                <SliderRow\n                  label=\"Dropout\"\n                  tooltip={\n                    <>\n                      Dropout probability for LoRA layers to reduce overfitting.{\" \"}\n                      <a\n                        href=\"https://unsloth.ai/docs/get-started/fine-tuning-llms-guide/lora-hyperparameters-guide\"\n                        target=\"_blank\"\n                        rel=\"noopener noreferrer\"\n                        className=\"text-primary underline\"\n                      >\n                        Read more\n                      </a>\n                    </>\n                  }\n                  value={store.loraDropout}\n                  onChange={store.setLoraDropout}\n                  min={0}\n                  max={0.5}\n                  step={0.01}\n                  format={(v) => v.toFixed(2)}\n                />\n\n                {/* Vision checkboxes */}\n                {showVisionLora && (\n                  <div className=\"flex flex-col gap-2 pt-1\">\n                    {(\n                      [\n                        [\n                          \"finetuneVisionLayers\",\n                          \"Vision layers\",\n                          store.finetuneVisionLayers,\n                          store.setFinetuneVisionLayers,\n                        ],\n                        [\n                          \"finetuneLanguageLayers\",\n                          \"Language layers\",\n                          store.finetuneLanguageLayers,\n                          store.setFinetuneLanguageLayers,\n                        ],\n                        [\n                          \"finetuneAttentionModules\",\n                          \"Attention modules\",\n                          store.finetuneAttentionModules,\n                          store.setFinetuneAttentionModules,\n                        ],\n                        [\n                          \"finetuneMLPModules\",\n                          \"MLP modules\",\n                          store.finetuneMLPModules,\n                          store.setFinetuneMLPModules,\n                        ],\n                      ] as const\n                    ).map(([key, label, value, setter]) => (\n                      <div key={key} className=\"flex items-center gap-2\">\n                        <Checkbox\n                          id={key}\n                          checked={value as boolean}\n                          onCheckedChange={(v) =>\n                            (setter as (v: boolean) => void)(!!v)\n                          }\n                        />\n                        <label\n                          htmlFor={key}\n                          className=\"text-xs cursor-pointer text-muted-foreground\"\n                        >\n                          {label}\n                        </label>\n                      </div>\n                    ))}\n                  </div>\n                )}\n\n                {/* Text target modules */}\n                {!showVisionLora && (\n                  <div className=\"flex flex-col gap-2 pt-1\">\n                    <span className=\"text-xs font-medium text-muted-foreground\">\n                      Target Modules\n                    </span>\n                    <div className=\"flex flex-wrap gap-1.5\">\n                      {TARGET_MODULES.map((mod) => {\n                        const active = store.targetModules.includes(mod);\n                        return (\n                          <button\n                            key={mod}\n                            type=\"button\"\n                            onClick={() => {\n                              store.setTargetModules(\n                                active\n                                  ? store.targetModules.filter((m) => m !== mod)\n                                  : [...store.targetModules, mod],\n                              );\n                            }}\n                            className={`cursor-pointer rounded-full border px-2.5 py-0.5 text-[11px] font-mono transition-colors ${active\n                                ? \"border-orange-300 bg-orange-50 text-orange-700 dark:border-orange-700 dark:bg-orange-950 dark:text-orange-300\"\n                                : \"text-muted-foreground hover:bg-muted/50\"\n                              }`}\n                          >\n                            {mod}\n                          </button>\n                        );\n                      })}\n                    </div>\n                  </div>\n                )}\n\n                {/* LoRA variant */}\n                <div className=\"flex gap-2\">\n                  {(\n                    [\n                      {\n                        value: \"lora\",\n                        label: \"Enable LoRA\",\n                        desc: \"Train with LoRA\",\n                      },\n                      { value: \"rslora\", label: \"RS-LoRA\", desc: \"Stable Rank\" },\n                      {\n                        value: \"loftq\",\n                        label: \"LoftQ\",\n                        desc: \"Memory Efficient\",\n                      },\n                    ] as const\n                  ).map((opt) => (\n                    <button\n                      key={opt.value}\n                      type=\"button\"\n                      onClick={() => store.setLoraVariant(opt.value)}\n                      className={`flex-1 corner-squircle rounded-xl border px-3 py-2 text-left transition-colors cursor-pointer ${store.loraVariant === opt.value\n                          ? \"border-primary/50 bg-primary/5 ring-1 ring-primary/20\"\n                          : \"border-border hover:border-foreground/20\"\n                        }`}\n                    >\n                      <p className=\"text-xs font-medium\">{opt.label}</p>\n                      <p className=\"text-[10px] text-muted-foreground\">\n                        {opt.desc}\n                      </p>\n                    </button>\n                  ))}\n                </div>\n              </div>\n            </div>\n          )}\n\n          {/* Training Hyperparams */}\n          <Collapsible open={hyperOpen} onOpenChange={setHyperOpen}>\n            <CollapsibleTrigger className=\"flex w-full cursor-pointer items-center gap-1.5 text-xs text-muted-foreground\">\n              <HugeiconsIcon\n                icon={ArrowDown01Icon}\n                className={`size-3.5 transition-transform ${hyperOpen ? \"rotate-180\" : \"\"}`}\n              />\n              Training Hyperparameters\n            </CollapsibleTrigger>\n            <CollapsibleContent className=\"mt-3 data-[state=open]:overflow-visible\">\n              <Tabs defaultValue=\"optimization\" className=\"w-full\">\n                <TabsList className=\"w-full\">\n                  <TabsTrigger\n                    value=\"optimization\"\n                    className=\"flex-1 !corner-squircle text-xs cursor-pointer\"\n                  >\n                    Optimization\n                  </TabsTrigger>\n                  <TabsTrigger\n                    value=\"schedule\"\n                    className=\"flex-1 text-xs cursor-pointer\"\n                  >\n                    Schedule\n                  </TabsTrigger>\n                  <TabsTrigger\n                    value=\"memory\"\n                    className=\"flex-1 text-xs cursor-pointer\"\n                  >\n                    Memory\n                  </TabsTrigger>\n                </TabsList>\n\n                <TabsContent\n                  value=\"optimization\"\n                  className=\"mt-3 flex flex-col gap-3\"\n                >\n                  <Row\n                    label=\"Optimizer\"\n                    tooltip={\n                      <>\n                        Optimization algorithm. 8-bit variants reduce memory usage.\n                        Fused is recommended for vision models.{\" \"}\n                        <a\n                          href=\"https://unsloth.ai/docs/get-started/fine-tuning-llms-guide/lora-hyperparameters-guide\"\n                          target=\"_blank\"\n                          rel=\"noopener noreferrer\"\n                          className=\"text-primary underline\"\n                        >\n                          Read more\n                        </a>\n                      </>\n                    }\n                  >\n                    <Select\n                      value={store.optimizerType}\n                      onValueChange={(v) => store.setOptimizerType(v)}\n                    >\n                      <SelectTrigger className=\"w-48\">\n                        <SelectValue />\n                      </SelectTrigger>\n                      <SelectContent>\n                        {OPTIMIZER_OPTIONS.map((opt) => (\n                          <SelectItem\n                            key={opt.value}\n                            value={opt.value}\n                          >\n                            {opt.label}\n                          </SelectItem>\n                        ))}\n                      </SelectContent>\n                    </Select>\n                  </Row>\n                  <Row\n                    label=\"LR scheduler\"\n                    tooltip={\n                      <>\n                        How the learning rate changes over training. Linear decays\n                        steadily; cosine decays in a curve.{\" \"}\n                        <a\n                          href=\"https://unsloth.ai/docs/get-started/fine-tuning-llms-guide/lora-hyperparameters-guide\"\n                          target=\"_blank\"\n                          rel=\"noopener noreferrer\"\n                          className=\"text-primary underline\"\n                        >\n                          Read more\n                        </a>\n                      </>\n                    }\n                  >\n                    <Select\n                      value={store.lrSchedulerType}\n                      onValueChange={(v) => store.setLrSchedulerType(v)}\n                    >\n                      <SelectTrigger className=\"w-48\">\n                        <SelectValue />\n                      </SelectTrigger>\n                      <SelectContent>\n                        {LR_SCHEDULER_OPTIONS.map((opt) => (\n                          <SelectItem\n                            key={opt.value}\n                            value={opt.value}\n                          >\n                            {opt.label}\n                          </SelectItem>\n                        ))}\n                      </SelectContent>\n                    </Select>\n                  </Row>\n                  <SliderRow\n                    label=\"Batch Size\"\n                    tooltip={\n                      <>\n                        Samples processed per step. Higher uses more VRAM.{\" \"}\n                        <a\n                          href=\"https://unsloth.ai/docs/get-started/fine-tuning-llms-guide/lora-hyperparameters-guide\"\n                          target=\"_blank\"\n                          rel=\"noopener noreferrer\"\n                          className=\"text-primary underline\"\n                        >\n                          Read more\n                        </a>\n                      </>\n                    }\n                    value={store.batchSize}\n                    onChange={store.setBatchSize}\n                    min={1}\n                    max={32}\n                    step={1}\n                  />\n                  <SliderRow\n                    label=\"Grad Accum\"\n                    tooltip={\n                      <>\n                        Simulates larger batch sizes without extra VRAM.{\" \"}\n                        <a\n                          href=\"https://unsloth.ai/docs/get-started/fine-tuning-llms-guide/lora-hyperparameters-guide\"\n                          target=\"_blank\"\n                          rel=\"noopener noreferrer\"\n                          className=\"text-primary underline\"\n                        >\n                          Read more\n                        </a>\n                      </>\n                    }\n                    value={store.gradientAccumulation}\n                    onChange={store.setGradientAccumulation}\n                    min={1}\n                    max={64}\n                    step={1}\n                  />\n                  <Row\n                    label=\"Weight Decay\"\n                    tooltip={\n                      <>\n                        L2 regularization to prevent overfitting.{\" \"}\n                        <a\n                          href=\"https://unsloth.ai/docs/get-started/fine-tuning-llms-guide/lora-hyperparameters-guide\"\n                          target=\"_blank\"\n                          rel=\"noopener noreferrer\"\n                          className=\"text-primary underline\"\n                        >\n                          Read more\n                        </a>\n                      </>\n                    }\n                  >\n                    <Input\n                      type=\"number\"\n                      step=\"0.001\"\n                      value={store.weightDecay}\n                      onChange={(e) =>\n                        store.setWeightDecay(Number(e.target.value))\n                      }\n                      className=\"w-28 font-mono\"\n                    />\n                  </Row>\n                </TabsContent>\n\n                <TabsContent\n                  value=\"schedule\"\n                  className=\"mt-3 flex flex-col gap-3\"\n                >\n                  <SliderRow\n                    label=\"Warmup Steps\"\n                    tooltip={\n                      <>\n                        Gradually increase LR at training start for stability.{\" \"}\n                        <a\n                          href=\"https://unsloth.ai/docs/get-started/fine-tuning-llms-guide/lora-hyperparameters-guide\"\n                          target=\"_blank\"\n                          rel=\"noopener noreferrer\"\n                          className=\"text-primary underline\"\n                        >\n                          Read more\n                        </a>\n                      </>\n                    }\n                    value={store.warmupSteps}\n                    onChange={store.setWarmupSteps}\n                    min={0}\n                    max={100}\n                    step={1}\n                  />\n                  {!useEpochs && (\n                    <SliderRow\n                      label=\"Epochs\"\n                      tooltip={\n                        <>\n                          Number of full passes over the dataset. Set 0 to run by\n                          max steps.{\" \"}\n                          <a\n                            href=\"https://unsloth.ai/docs/get-started/fine-tuning-llms-guide/lora-hyperparameters-guide\"\n                            target=\"_blank\"\n                            rel=\"noopener noreferrer\"\n                            className=\"text-primary underline\"\n                          >\n                            Read more\n                          </a>\n                        </>\n                      }\n                      value={store.epochs}\n                      onChange={store.setEpochs}\n                      min={0}\n                      max={epochsSliderMax}\n                      step={1}\n                    />\n                  )}\n                  <Row\n                    label=\"Save Steps\"\n                    tooltip={\n                      <>\n                        Save a checkpoint every N steps. 0 to disable.{\" \"}\n                        <a\n                          href=\"https://unsloth.ai/docs/get-started/fine-tuning-llms-guide/lora-hyperparameters-guide\"\n                          target=\"_blank\"\n                          rel=\"noopener noreferrer\"\n                          className=\"text-primary underline\"\n                        >\n                          Read more\n                        </a>\n                      </>\n                    }\n                  >\n                    <Input\n                      type=\"number\"\n                      value={store.saveSteps}\n                      onChange={(e) => store.setSaveSteps(Number(e.target.value))}\n                      className=\"w-28 font-mono\"\n                    />\n                  </Row>\n                  <Row\n                    label=\"Eval Steps\"\n                    tooltip=\"Fraction of total training steps between evaluations (0-1). Set to 0 to disable evaluation. E.g. 0.01 = evaluate every 1% of steps.\"\n                  >\n                    <Input\n                      type=\"number\"\n                      step=\"0.01\"\n                      min=\"0.0\"\n                      max=\"1.0\"\n                      value={store.evalSteps}\n                      onChange={(e) => store.setEvalSteps(Number(e.target.value))}\n                      className=\"w-28 font-mono\"\n                    />\n                  </Row>\n                  <Row label=\"Seed\" tooltip=\"Random seed for reproducibility.\">\n                    <Input\n                      type=\"number\"\n                      value={store.randomSeed}\n                      onChange={(e) =>\n                        store.setRandomSeed(Number(e.target.value))\n                      }\n                      className=\"w-28 font-mono\"\n                    />\n                  </Row>\n                </TabsContent>\n\n                <TabsContent value=\"memory\" className=\"mt-3 flex flex-col gap-3\">\n                  <Row\n                    label=\"Grad Checkpoint\"\n                    tooltip={\n                      <>\n                        Trade compute for memory by recomputing activations.{\" \"}\n                        <a\n                          href=\"https://unsloth.ai/docs/get-started/fine-tuning-llms-guide/lora-hyperparameters-guide\"\n                          target=\"_blank\"\n                          rel=\"noopener noreferrer\"\n                          className=\"text-primary underline\"\n                        >\n                          Read more\n                        </a>\n                      </>\n                    }\n                  >\n                    <Select\n                      value={store.gradientCheckpointing}\n                      onValueChange={(v) =>\n                        store.setGradientCheckpointing(v as GradientCheckpointing)\n                      }\n                    >\n                      <SelectTrigger className=\"w-32\">\n                        <SelectValue />\n                      </SelectTrigger>\n                      <SelectContent>\n                        <SelectItem value=\"none\">None</SelectItem>\n                        <SelectItem value=\"true\">Standard</SelectItem>\n                        <SelectItem value=\"unsloth\">Unsloth</SelectItem>\n                      </SelectContent>\n                    </Select>\n                  </Row>\n                  {!showVisionLora && !store.isEmbeddingModel && (\n                    <div className=\"flex items-center gap-2\">\n                      <Checkbox\n                        id=\"packing\"\n                        checked={store.packing}\n                        onCheckedChange={(v) => store.setPacking(!!v)}\n                      />\n                      <label\n                        htmlFor=\"packing\"\n                        className=\"text-xs cursor-pointer text-muted-foreground\"\n                      >\n                        Enable packing\n                      </label>\n                    </div>\n                  )}\n                  {!store.isEmbeddingModel && (\n                    <div className=\"flex items-center gap-2\">\n                      <Checkbox\n                        id=\"trainOnCompletions\"\n                        checked={store.trainOnCompletions}\n                        onCheckedChange={(v) => store.setTrainOnCompletions(!!v)}\n                      />\n                      <label\n                        htmlFor=\"trainOnCompletions\"\n                        className=\"text-xs cursor-pointer text-muted-foreground\"\n                      >\n                        Assistant completions only\n                      </label>\n                    </div>\n                  )}\n                </TabsContent>\n              </Tabs>\n            </CollapsibleContent>\n          </Collapsible>\n        </div>\n      </SectionCard>\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/studio/sections/progress-section-lib.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { TrainingPhase } from \"@/features/training\";\n\nexport const phaseLabel: Record<TrainingPhase, string> = {\n  idle: \"Idle\",\n  downloading_model: \"Downloading model\",\n  downloading_dataset: \"Downloading dataset\",\n  loading_model: \"Loading model\",\n  loading_dataset: \"Loading dataset\",\n  configuring: \"Configuring\",\n  training: \"Training\",\n  completed: \"Completed\",\n  error: \"Error\",\n  stopped: \"Stopped\",\n};\n\nexport const phaseColors: Record<TrainingPhase, string> = {\n  idle: \"bg-muted text-muted-foreground\",\n  downloading_model:\n    \"bg-sky-100 text-sky-700 dark:bg-sky-900 dark:text-sky-300\",\n  downloading_dataset:\n    \"bg-sky-100 text-sky-700 dark:bg-sky-900 dark:text-sky-300\",\n  loading_model:\n    \"bg-amber-100 text-amber-700 dark:bg-amber-900 dark:text-amber-300\",\n  loading_dataset:\n    \"bg-amber-100 text-amber-700 dark:bg-amber-900 dark:text-amber-300\",\n  configuring: \"bg-blue-100 text-blue-700 dark:bg-blue-900 dark:text-blue-300\",\n  training:\n    \"bg-emerald-100 text-emerald-700 dark:bg-emerald-900 dark:text-emerald-300\",\n  completed:\n    \"bg-emerald-100 text-emerald-700 dark:bg-emerald-900 dark:text-emerald-300\",\n  error: \"bg-red-100 text-red-700 dark:bg-red-900 dark:text-red-300\",\n  stopped: \"bg-muted text-muted-foreground\",\n};\n\nexport function formatDuration(seconds: number | null): string {\n  if (seconds == null || seconds < 0) return \"--\";\n  const total = Math.floor(seconds);\n  const min = Math.floor(total / 60);\n  const sec = total % 60;\n  return `${min}m ${sec}s`;\n}\n\nexport function formatNumber(value: number | null | undefined, digits: number): string {\n  if (value == null || !Number.isFinite(value)) return \"--\";\n  return value.toFixed(digits);\n}\n\n"
  },
  {
    "path": "studio/frontend/src/features/studio/sections/progress-section.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { SectionCard } from \"@/components/section-card\";\nimport {\n  AlertDialog,\n  AlertDialogAction,\n  AlertDialogCancel,\n  AlertDialogContent,\n  AlertDialogDescription,\n  AlertDialogFooter,\n  AlertDialogHeader,\n  AlertDialogTitle,\n} from \"@/components/ui/alert-dialog\";\nimport { Button } from \"@/components/ui/button\";\nimport {\n  Popover,\n  PopoverContent,\n  PopoverTrigger,\n} from \"@/components/ui/popover\";\nimport { Progress } from \"@/components/ui/progress\";\nimport { OPTIMIZER_OPTIONS } from \"@/config/training\";\nimport { setTrainingCompareHandoff } from \"@/features/chat\";\nimport {\n  useTrainingActions,\n  useTrainingConfigStore,\n  useTrainingRuntimeStore,\n} from \"@/features/training\";\nimport { useGpuUtilization } from \"@/hooks\";\nimport { cn } from \"@/lib/utils\";\nimport {\n  ChartAverageIcon,\n  DashboardSpeed01Icon,\n  Notebook01Icon,\n  RamMemoryIcon,\n  StopIcon,\n  TemperatureIcon,\n  ZapIcon,\n} from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport { Link, useNavigate } from \"@tanstack/react-router\";\nimport { type ReactElement, type ReactNode, useEffect, useState } from \"react\";\nimport { useShallow } from \"zustand/react/shallow\";\nimport { ChartSettingsSheet } from \"./charts/chart-settings-sheet\";\nimport {\n  formatDuration,\n  formatNumber,\n  phaseColors,\n  phaseLabel,\n} from \"./progress-section-lib\";\n\ntype ConfigGroup = {\n  section: string;\n  rows: [string, string | number | null | undefined][];\n};\n\nfunction configRow(\n  label: string,\n  value: string | number | null | undefined,\n): [string, string | number | null | undefined] {\n  return [label, value];\n}\n\nexport function ProgressSection(): ReactElement {\n  const navigate = useNavigate();\n  const runtime = useTrainingRuntimeStore(\n    useShallow((state) => ({\n      phase: state.phase,\n      message: state.message,\n      error: state.error,\n      currentStep: state.currentStep,\n      totalSteps: state.totalSteps,\n      currentEpoch: state.currentEpoch,\n      currentLoss: state.currentLoss,\n      currentLearningRate: state.currentLearningRate,\n      currentGradNorm: state.currentGradNorm,\n      progressPercent: state.progressPercent,\n      elapsedSeconds: state.elapsedSeconds,\n      etaSeconds: state.etaSeconds,\n      currentNumTokens: state.currentNumTokens,\n      isTrainingRunning: state.isTrainingRunning,\n      lossHistory: state.lossHistory,\n      lrHistory: state.lrHistory,\n      gradNormHistory: state.gradNormHistory,\n    })),\n  );\n\n  const config = useTrainingConfigStore(\n    useShallow((state) => ({\n      selectedModel: state.selectedModel,\n      trainingMethod: state.trainingMethod,\n      epochs: state.epochs,\n      batchSize: state.batchSize,\n      learningRate: state.learningRate,\n      maxSteps: state.maxSteps,\n      contextLength: state.contextLength,\n      warmupSteps: state.warmupSteps,\n      optimizerType: state.optimizerType,\n      loraRank: state.loraRank,\n      loraAlpha: state.loraAlpha,\n      loraDropout: state.loraDropout,\n      loraVariant: state.loraVariant,\n    })),\n  );\n\n  const { stopTrainingRun } = useTrainingActions();\n  const gpu = useGpuUtilization(runtime.isTrainingRunning);\n  const [stopDialogOpen, setStopDialogOpen] = useState(false);\n  const [stopRequested, setStopRequested] = useState(false);\n\n  useEffect(() => {\n    if (!runtime.isTrainingRunning) {\n      setStopRequested(false);\n    }\n  }, [runtime.isTrainingRunning]);\n\n  const pct =\n    runtime.totalSteps > 0\n      ? Math.min(\n          100,\n          Math.max(\n            0,\n            Math.round((runtime.currentStep / runtime.totalSteps) * 100),\n          ),\n        )\n      : Math.round(runtime.progressPercent);\n\n  const elapsed = runtime.elapsedSeconds;\n  const derivedEta =\n    elapsed != null && pct > 0\n      ? Math.round((elapsed * (100 - pct)) / Math.max(pct, 1))\n      : null;\n  const eta = runtime.etaSeconds ?? derivedEta;\n\n  const stepsPerSecond =\n    elapsed != null && elapsed > 0 ? runtime.currentStep / elapsed : null;\n  const showHalfwayHint =\n    runtime.phase === \"training\" && pct >= 50 && pct < 100;\n  const showCompletedHint = runtime.phase === \"completed\";\n  const handleCompareInChat = async () => {\n    setTrainingCompareHandoff(config.selectedModel);\n    await navigate({ to: \"/chat\" });\n  };\n  const requestStop = async (saveCheckpoint: boolean) => {\n    setStopRequested(true);\n    setStopDialogOpen(false);\n    useTrainingRuntimeStore.getState().setStopRequested(true);\n    try {\n      const ok = await stopTrainingRun(saveCheckpoint);\n      if (!ok) {\n        setStopRequested(false);\n      }\n    } catch {\n      setStopRequested(false);\n    }\n  };\n\n  const stoppedLoss = getDisplayMetric(\n    runtime.isTrainingRunning,\n    runtime.currentLoss,\n    runtime.lossHistory,\n  );\n  const stoppedLr = getDisplayMetric(\n    runtime.isTrainingRunning,\n    runtime.currentLearningRate,\n    runtime.lrHistory,\n  );\n  const stoppedGradNorm = runtime.isTrainingRunning\n    ? runtime.currentGradNorm\n    : (lastNonZeroValue(runtime.gradNormHistory) ?? runtime.currentGradNorm);\n\n  const optimizerLabel =\n    OPTIMIZER_OPTIONS.find((o) => o.value === config.optimizerType)?.label ??\n    config.optimizerType;\n\n  const configItems: ConfigGroup[] = [\n    {\n      section: \"Hyperparams\",\n      rows: [\n        configRow(\"Epochs\", config.epochs),\n        configRow(\"Batch size\", config.batchSize),\n        configRow(\"Learning rate\", config.learningRate),\n        configRow(\"Optimizer\", optimizerLabel),\n        configRow(\"Max steps\", config.maxSteps),\n        configRow(\"Context length\", config.contextLength),\n        configRow(\"Warmup steps\", config.warmupSteps),\n      ],\n    },\n    ...(config.trainingMethod !== \"full\"\n      ? [\n          {\n            section: \"LoRA\",\n            rows: [\n              configRow(\"Rank\", config.loraRank),\n              configRow(\"Alpha\", config.loraAlpha),\n              configRow(\"Dropout\", config.loraDropout),\n              configRow(\"Variant\", config.loraVariant),\n            ],\n          },\n        ]\n      : []),\n  ];\n\n  return (\n    <SectionCard\n      icon={<HugeiconsIcon icon={ChartAverageIcon} className=\"size-5\" />}\n      title=\"Training Progress\"\n      description={runtime.message || \"Live training metrics\"}\n      accent=\"emerald\"\n      className=\"shadow-border border border-border/60 bg-card/90 ring-0 backdrop-blur-sm\"\n      headerAction={\n        <TrainingHeaderActions\n          configItems={configItems}\n          isTrainingRunning={runtime.isTrainingRunning}\n          onOpenStopDialog={setStopDialogOpen}\n          onRequestStop={requestStop}\n          stopDialogOpen={stopDialogOpen}\n          stopRequested={stopRequested}\n        />\n      }\n    >\n      <div className=\"grid grid-cols-1 gap-5 lg:grid-cols-[minmax(0,1.2fr)_minmax(18rem,0.8fr)]\">\n        <div className=\"flex flex-col gap-4\">\n          <div className=\"flex flex-wrap items-center gap-2\">\n            <span\n              className={`rounded-full px-2.5 py-1 text-[10px] font-semibold ${phaseColors[runtime.phase]}`}\n            >\n              {phaseLabel[runtime.phase]}\n            </span>\n            <span className=\"text-[10px] tabular-nums text-muted-foreground\">\n              Epoch {runtime.currentEpoch.toFixed(2)}\n            </span>\n            <span className=\"rounded-full border border-border/60 px-2.5 py-1 text-[10px] font-medium tabular-nums text-muted-foreground\">\n              {pct}% complete\n            </span>\n          </div>\n\n          <div className=\"flex flex-col gap-2\">\n            <div className=\"flex justify-between text-xs text-muted-foreground\">\n              <span>\n                Step {runtime.currentStep} / {runtime.totalSteps || \"--\"}\n              </span>\n              <span>{pct}%</span>\n            </div>\n            <Progress value={pct} className=\"h-2 bg-foreground/[0.05]\" />\n          </div>\n\n          <MilestoneCallout\n            showCompletedHint={showCompletedHint}\n            showHalfwayHint={showHalfwayHint}\n            onCompareInChat={handleCompareInChat}\n          />\n\n          {runtime.error && (\n            <p className=\"rounded-2xl border border-destructive/30 bg-destructive/5 px-3 py-2 text-xs text-red-500 leading-relaxed\">\n              {runtime.error}\n            </p>\n          )}\n\n          <div className=\"grid gap-x-4 gap-y-3 pt-1 sm:grid-cols-2 xl:grid-cols-5\">\n            <MetricStat\n              label=\"Loss\"\n              valueClassName=\"text-2xl font-bold tracking-tight\"\n            >\n              {stoppedLoss.toFixed(4)}\n            </MetricStat>\n            <MetricStat label=\"LR\">{stoppedLr.toExponential(2)}</MetricStat>\n            <MetricStat label=\"Grad Norm\">\n              {formatNumber(stoppedGradNorm, 3)}\n            </MetricStat>\n            <MetricStat label=\"Model\" valueClassName=\"truncate\">\n              {config.selectedModel ?? \"--\"}\n            </MetricStat>\n            <MetricStat label=\"Method\">\n              {config.trainingMethod === \"qlora\" ? \"QLoRA\" : config.trainingMethod === \"lora\" ? \"LoRA\" : \"Full\"}\n            </MetricStat>\n          </div>\n\n          <div className=\"flex flex-wrap gap-x-4 gap-y-1 text-xs text-muted-foreground\">\n            <span>Elapsed: {formatDuration(elapsed)}</span>\n            <span>ETA: {formatDuration(eta)}</span>\n            <span>\n              {stepsPerSecond == null\n                ? \"-- steps/s\"\n                : `${stepsPerSecond.toFixed(2)} steps/s`}\n            </span>\n            {runtime.currentNumTokens != null && (\n              <span>Tokens: {runtime.currentNumTokens}</span>\n            )}\n          </div>\n        </div>\n\n        <div className=\"flex flex-col gap-3\">\n          <div className=\"flex items-center justify-between\">\n            <p className=\"text-xs font-medium text-muted-foreground\">\n              GPU Monitor\n            </p>\n            <span className=\"text-[11px] text-muted-foreground\">Live</span>\n          </div>\n          <div className=\"grid grid-cols-2 gap-2.5\">\n            <GpuStat\n              label=\"Utilization\"\n              icon={\n                <HugeiconsIcon\n                  icon={DashboardSpeed01Icon}\n                  className=\"size-3.5\"\n                />\n              }\n              value={\n                gpu.gpu_utilization_pct != null\n                  ? `${gpu.gpu_utilization_pct}%`\n                  : \"--\"\n              }\n              pct={gpu.gpu_utilization_pct ?? 0}\n            />\n            <GpuStat\n              label=\"Temperature\"\n              icon={\n                <HugeiconsIcon icon={TemperatureIcon} className=\"size-3.5\" />\n              }\n              value={\n                gpu.temperature_c != null ? `${gpu.temperature_c}°C` : \"--\"\n              }\n              pct={gpu.temperature_c ?? 0}\n              max={100}\n            />\n            <GpuStat\n              label=\"VRAM\"\n              icon={<HugeiconsIcon icon={RamMemoryIcon} className=\"size-3.5\" />}\n              value={\n                gpu.vram_used_gb != null && gpu.vram_total_gb != null\n                  ? `${gpu.vram_used_gb} / ${gpu.vram_total_gb} GB`\n                  : \"--\"\n              }\n              pct={gpu.vram_utilization_pct ?? 0}\n            />\n            <GpuStat\n              label=\"Power\"\n              icon={<HugeiconsIcon icon={ZapIcon} className=\"size-3.5\" />}\n              value={\n                gpu.power_draw_w != null\n                  ? gpu.power_limit_w != null\n                    ? `${gpu.power_draw_w} / ${gpu.power_limit_w} W`\n                    : `${gpu.power_draw_w} W`\n                  : \"--\"\n              }\n              pct={gpu.power_utilization_pct ?? 0}\n            />\n          </div>\n        </div>\n      </div>\n    </SectionCard>\n  );\n}\n\nfunction TrainingHeaderActions({\n  configItems,\n  isTrainingRunning,\n  onOpenStopDialog,\n  onRequestStop,\n  stopDialogOpen,\n  stopRequested,\n}: {\n  configItems: ConfigGroup[];\n  isTrainingRunning: boolean;\n  onOpenStopDialog: (open: boolean) => void;\n  onRequestStop: (saveCheckpoint: boolean) => Promise<void>;\n  stopDialogOpen: boolean;\n  stopRequested: boolean;\n}): ReactElement {\n  return (\n    <div className=\"flex items-center gap-2\">\n      <Popover>\n        <PopoverTrigger asChild={true}>\n          <Button\n            type=\"button\"\n            variant=\"ghost\"\n            size=\"icon-sm\"\n            className=\"rounded-full text-muted-foreground hover:bg-muted hover:text-foreground\"\n            aria-label=\"Open training config\"\n          >\n            <HugeiconsIcon icon={Notebook01Icon} className=\"size-4\" />\n          </Button>\n        </PopoverTrigger>\n        <PopoverContent className=\"w-72\" align=\"end\">\n          <div className=\"flex flex-col gap-3\">\n            <p className=\"text-xs font-semibold\">Training Config</p>\n            {configItems.map((group) => (\n              <div key={group.section} className=\"flex flex-col gap-1\">\n                <p className=\"text-[10px] font-semibold uppercase tracking-wider text-muted-foreground\">\n                  {group.section}\n                </p>\n                {group.rows.map(([label, value]) => (\n                  <div key={label} className=\"flex justify-between text-xs\">\n                    <span className=\"text-muted-foreground\">{label}</span>\n                    <span className=\"font-medium tabular-nums\">\n                      {String(value)}\n                    </span>\n                  </div>\n                ))}\n              </div>\n            ))}\n          </div>\n        </PopoverContent>\n      </Popover>\n      <ChartSettingsSheet />\n      <AlertDialog open={stopDialogOpen} onOpenChange={onOpenStopDialog}>\n        <Button\n          data-tour=\"studio-training-stop\"\n          variant=\"destructive\"\n          size=\"sm\"\n          className={cn(\n            \"h-8 rounded-full px-3.5 text-xs shadow-sm\",\n            stopRequested ? \"cursor-not-allowed opacity-60\" : \"cursor-pointer\",\n          )}\n          onClick={() => onOpenStopDialog(true)}\n          disabled={!isTrainingRunning || stopRequested}\n        >\n          <HugeiconsIcon icon={StopIcon} className=\"size-3\" />\n          {stopRequested ? \"Stopping…\" : \"Stop\"}\n        </Button>\n        <AlertDialogContent overlayClassName=\"bg-background/40 supports-backdrop-filter:backdrop-blur-[1px]\">\n          <AlertDialogHeader>\n            <AlertDialogTitle>Stop Training</AlertDialogTitle>\n            <AlertDialogDescription>\n              Choose how you want to stop the current training run.\n            </AlertDialogDescription>\n          </AlertDialogHeader>\n          <AlertDialogFooter>\n            <AlertDialogCancel>Continue Training</AlertDialogCancel>\n            <AlertDialogAction\n              variant=\"destructive\"\n              onClick={() => onRequestStop(false)}\n            >\n              Cancel Training\n            </AlertDialogAction>\n            <AlertDialogAction onClick={() => onRequestStop(true)}>\n              Stop and Save\n            </AlertDialogAction>\n          </AlertDialogFooter>\n        </AlertDialogContent>\n      </AlertDialog>\n    </div>\n  );\n}\n\nfunction MilestoneCallout({\n  showCompletedHint,\n  showHalfwayHint,\n  onCompareInChat,\n}: {\n  showCompletedHint: boolean;\n  showHalfwayHint: boolean;\n  onCompareInChat: () => Promise<void>;\n}): ReactElement | null {\n  if (!(showHalfwayHint || showCompletedHint)) {\n    return null;\n  }\n\n  return (\n    <div className=\"corner-squircle rounded-2xl border border-border/60 bg-muted/30 px-3 py-2.5\">\n      <div className=\"flex items-start justify-between gap-3\">\n        <div className=\"min-w-0\">\n          {!showCompletedHint && (\n            <p className=\"text-[10px] font-medium uppercase tracking-[0.12em] text-muted-foreground\">\n              Milestone\n            </p>\n          )}\n          <p\n            className={cn(\n              \"text-xs text-foreground/85\",\n              !showCompletedHint && \"mt-1\",\n            )}\n          >\n            {showCompletedHint\n              ? \"Training done. Next step: compare base vs fine-tuned outputs.\"\n              : \"Halfway done. Training is past 50%.\"}\n          </p>\n        </div>\n        {!showCompletedHint && (\n          <span className=\"rounded-full border border-border/60 bg-background/80 px-2 py-0.5 text-[10px] font-medium text-muted-foreground\">\n            50%+\n          </span>\n        )}\n      </div>\n      {showCompletedHint && (\n        <div className=\"mt-2 flex flex-wrap gap-2\">\n          <Button size=\"xs\" onClick={onCompareInChat}>\n            Compare in Chat\n          </Button>\n          <Button asChild={true} size=\"xs\" variant=\"outline\">\n            <Link to=\"/export\">Export Model</Link>\n          </Button>\n        </div>\n      )}\n    </div>\n  );\n}\n\nfunction MetricStat({\n  label,\n  children,\n  valueClassName,\n}: {\n  label: string;\n  children: ReactNode;\n  valueClassName?: string;\n}): ReactElement {\n  return (\n    <div className=\"min-w-0\">\n      <p className=\"text-[11px] text-muted-foreground\">{label}</p>\n      <p\n        className={`mt-1 text-base font-semibold tabular-nums ${valueClassName ?? \"\"}`}\n      >\n        {children}\n      </p>\n    </div>\n  );\n}\n\nfunction lastNonZeroValue(points: { value: number }[]): number | null {\n  for (let i = points.length - 1; i >= 0; i -= 1) {\n    const value = points[i]?.value;\n    if (Number.isFinite(value) && value !== 0) {\n      return value;\n    }\n  }\n  return null;\n}\n\nfunction getDisplayMetric(\n  isTrainingRunning: boolean,\n  currentValue: number,\n  history: { value: number }[],\n): number {\n  if (isTrainingRunning) {\n    return currentValue;\n  }\n  return lastNonZeroValue(history) ?? currentValue;\n}\n\nfunction GpuStat({\n  label,\n  icon,\n  value,\n  pct,\n  max,\n}: {\n  label: string;\n  icon: ReactNode;\n  value: string;\n  pct: number;\n  max?: number;\n}): ReactElement {\n  const clamped = Math.max(0, Math.min(pct, max ?? 100));\n  let barColor = \"bg-red-500\";\n  if (clamped < 60) {\n    barColor = \"bg-emerald-500\";\n  } else if (clamped < 95) {\n    barColor = \"bg-amber-500\";\n  }\n\n  return (\n    <div className=\"corner-squircle flex flex-col gap-2 rounded-2xl border border-border/50 bg-background/60 p-3\">\n      <div className=\"flex items-center justify-between text-xs\">\n        <span className=\"flex items-center gap-1.5 text-muted-foreground\">\n          {icon}\n          {label}\n        </span>\n        <span className=\"font-medium tabular-nums\">{value}</span>\n      </div>\n      <div className=\"h-2 w-full overflow-hidden rounded-full bg-muted/80\">\n        <div\n          className={`h-full rounded-full ${barColor} transition-all duration-300`}\n          style={{ width: `${clamped}%` }}\n        />\n      </div>\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/studio/sections/training-section.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { SectionCard } from \"@/components/section-card\";\nimport { Button } from \"@/components/ui/button\";\nimport { ChartContainer } from \"@/components/ui/chart\";\nimport type { ChartConfig } from \"@/components/ui/chart\";\nimport {\n  Tooltip,\n  TooltipContent,\n  TooltipTrigger,\n} from \"@/components/ui/tooltip\";\nimport {\n  parseYamlConfig,\n  serializeConfigToYaml,\n  useTrainingActions,\n  useTrainingConfigStore,\n  validateTrainingConfig,\n} from \"@/features/training\";\nimport {\n  Archive04Icon,\n  ChartAverageIcon,\n  CleanIcon,\n  CloudUploadIcon,\n  Rocket01Icon,\n} from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport { useRef } from \"react\";\nimport { toast } from \"sonner\";\nimport { CartesianGrid, Line, LineChart, XAxis, YAxis } from \"recharts\";\n\nconst chartConfig = {\n  loss: { label: \"Loss\", color: \"#3b82f6\" },\n} satisfies ChartConfig;\n\nconst placeholderData = [\n  { step: 0, loss: 2.5 },\n  { step: 10, loss: 2.1 },\n  { step: 20, loss: 1.7 },\n  { step: 30, loss: 1.3 },\n  { step: 40, loss: 1.0 },\n  { step: 50, loss: 0.8 },\n];\n\nexport function TrainingSection() {\n  const store = useTrainingConfigStore();\n  const { isStarting, startError, startTrainingRun } = useTrainingActions();\n  const isIncompatible =\n    (!store.isVisionModel && store.isDatasetImage === true) ||\n    (!store.isAudioModel && store.isDatasetAudio === true);\n  const configValidation = validateTrainingConfig(store);\n  const fileInputRef = useRef<HTMLInputElement>(null);\n\n  const handleFileUpload = (e: React.ChangeEvent<HTMLInputElement>) => {\n    const file = e.target.files?.[0];\n    if (!file) return;\n    e.target.value = \"\";\n\n    const reader = new FileReader();\n    reader.onload = () => {\n      try {\n        const config = parseYamlConfig(reader.result as string);\n        store.applyConfigPatch(config);\n        toast.success(\"Config loaded\", { description: file.name });\n      } catch (err) {\n        toast.error(\"Failed to load config\", {\n          description:\n            err instanceof Error ? err.message : \"Invalid YAML file\",\n        });\n      }\n    };\n    reader.onerror = () => {\n      toast.error(\"Failed to read file\");\n    };\n    reader.readAsText(file);\n  };\n\n  const handleSaveConfig = () => {\n    const yamlStr = serializeConfigToYaml(store, store.isVisionModel);\n    const blob = new Blob([yamlStr], { type: \"text/yaml\" });\n    const url = URL.createObjectURL(blob);\n    const a = document.createElement(\"a\");\n    a.href = url;\n\n    const model = (store.selectedModel ?? \"model\").split(\"/\").pop();\n    const method = store.trainingMethod ?? \"qlora\";\n    const dataset = (store.dataset ?? \"dataset\").split(\"/\").pop();\n    const timestamp = new Date().toISOString().replace(/[:T]/g, \"-\").slice(0, 19);\n    a.download = `${model}_${method}_${dataset}_${timestamp}.yaml`;\n\n    a.click();\n    URL.revokeObjectURL(url);\n  };\n\n  const handleResetConfig = () => {\n    store.resetToModelDefaults();\n    toast.success(\"Parameters reset to model defaults\");\n  };\n\n  return (\n    <div data-tour=\"studio-training\" className=\"col-span-1 xl:col-span-4\">\n      <SectionCard\n        icon={<HugeiconsIcon icon={ChartAverageIcon} className=\"size-5\" />}\n        title=\"Training\"\n        description=\"Monitor and control training\"\n        accent=\"blue\"\n        className=\"md:min-h-[470px]\"\n      >\n        <div className=\"flex flex-col gap-4\">\n        {/* Loss chart */}\n        <div className=\"relative  \">\n          <ChartContainer\n            config={chartConfig}\n            className=\"h-[180px] w-full relative right-8 blur\"\n          >\n            <LineChart data={placeholderData} accessibilityLayer={true}>\n              <CartesianGrid vertical={false} strokeDasharray=\"3 3\" />\n              <XAxis\n                dataKey=\"step\"\n                tickLine={false}\n                axisLine={false}\n                tickMargin={8}\n                fontSize={10}\n              />\n              <YAxis\n                tickLine={false}\n                axisLine={false}\n                tickMargin={8}\n                fontSize={10}\n              />\n              <Line\n                type=\"monotone\"\n                dataKey=\"loss\"\n                stroke=\"var(--color-loss)\"\n                strokeWidth={2}\n                dot={false}\n              />\n            </LineChart>\n          </ChartContainer>\n          <div className=\"absolute inset-0 flex flex-col items-center justify-center gap-1\">\n            <HugeiconsIcon\n              icon={ChartAverageIcon}\n              className=\"size-5 text-muted-foreground/50\"\n            />\n            <p className=\"text-sm font-medium text-muted-foreground\">\n              No training data yet\n            </p>\n            <p className=\"text-xs text-muted-foreground/60\">\n              Start training to see loss progress\n            </p>\n          </div>\n        </div>\n\n        {/* Start/Stop */}\n        <Button\n          data-tour=\"studio-start\"\n          className=\"w-full cursor-pointer bg-gradient-to-r from-emerald-500 to-teal-500 text-white hover:from-emerald-600 hover:to-teal-600\"\n          onClick={() => void startTrainingRun()}\n          disabled={isStarting || isIncompatible || store.isCheckingDataset || !configValidation.ok}\n        >\n          <HugeiconsIcon icon={Rocket01Icon} className=\"size-4\" />\n          {isStarting ? \"Starting...\" : store.isCheckingDataset ? \"Checking dataset...\" : \"Start Training\"}\n        </Button>\n        {startError && (\n          <p className=\"text-xs text-red-500 leading-relaxed\">{startError}</p>\n        )}\n        {isIncompatible && (\n          <p className=\"text-xs text-red-500 leading-relaxed\">\n            Text model is not compatible with a multimodal dataset. Switch to a vision model or choose a text-only dataset.\n          </p>\n        )}\n        {!configValidation.ok && configValidation.message && !isIncompatible && (\n          <p className=\"text-xs text-red-500 leading-relaxed\">{configValidation.message}</p>\n        )}\n\n        {/* Upload / Save / Reset */}\n        <p className=\"text-xs text-muted-foreground\">Training Config</p>\n        <div className=\"grid grid-cols-3 gap-2\">\n          <Tooltip>\n            <TooltipTrigger asChild>\n              <Button\n                variant=\"outline\"\n                size=\"sm\"\n                className=\"cursor-pointer\"\n                onClick={() => fileInputRef.current?.click()}\n              >\n                <HugeiconsIcon icon={CloudUploadIcon} className=\"size-3.5\" />\n                Upload\n              </Button>\n            </TooltipTrigger>\n            <TooltipContent>Load a saved YAML config</TooltipContent>\n          </Tooltip>\n          <Tooltip>\n            <TooltipTrigger asChild>\n              <Button\n                data-tour=\"studio-save\"\n                variant=\"outline\"\n                size=\"sm\"\n                className=\"cursor-pointer\"\n                onClick={handleSaveConfig}\n              >\n                <HugeiconsIcon icon={Archive04Icon} className=\"size-3.5\" />\n                Save\n              </Button>\n            </TooltipTrigger>\n            <TooltipContent>Download current config as YAML</TooltipContent>\n          </Tooltip>\n          <Tooltip>\n            <TooltipTrigger asChild>\n              <Button\n                variant=\"outline\"\n                size=\"sm\"\n                className=\"cursor-pointer\"\n                onClick={handleResetConfig}\n                disabled={!store.selectedModel}\n              >\n                <HugeiconsIcon icon={CleanIcon} className=\"size-3.5\" />\n                Reset\n              </Button>\n            </TooltipTrigger>\n            <TooltipContent>Reset to model defaults</TooltipContent>\n          </Tooltip>\n        </div>\n        <input\n          ref={fileInputRef}\n          type=\"file\"\n          accept=\".yaml,.yml\"\n          className=\"hidden\"\n          onChange={handleFileUpload}\n        />\n        </div>\n      </SectionCard>\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/studio/studio-page.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Button } from \"@/components/ui/button\";\nimport {\n  shouldShowTrainingView,\n  useDatasetPreviewDialogStore,\n  useTrainingActions,\n  useTrainingConfigStore,\n  useTrainingRuntimeLifecycle,\n  useTrainingRuntimeStore,\n} from \"@/features/training\";\nimport { GuidedTour, useGuidedTourController } from \"@/features/tour\";\nimport { studioTourSteps, studioTrainingTourSteps } from \"./tour\";\nimport { ArrowLeft01Icon } from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport { type ReactElement, useEffect } from \"react\";\nimport { DatasetPreviewDialog } from \"./sections/dataset-preview-dialog\";\nimport { DatasetSection } from \"./sections/dataset-section\";\nimport { ModelSection } from \"./sections/model-section\";\nimport { ParamsSection } from \"./sections/params-section\";\nimport { TrainingSection } from \"./sections/training-section\";\nimport { TrainingView } from \"./training-view\";\n\nconst STUDIO_TOUR_KEY = \"tour:studio:v1\";\n\nexport function StudioPage(): ReactElement {\n  useTrainingRuntimeLifecycle();\n  const showTrainingView = useTrainingRuntimeStore(shouldShowTrainingView);\n  const isTrainingRunning = useTrainingRuntimeStore((state) => state.isTrainingRunning);\n  const runtimeMessage = useTrainingRuntimeStore((state) => state.message);\n  const runtimePhase = useTrainingRuntimeStore((state) => state.phase);\n  const isHydratingRuntime = useTrainingRuntimeStore((state) => state.isHydrating);\n  const hasHydratedRuntime = useTrainingRuntimeStore((state) => state.hasHydrated);\n  const { dismissTrainingRun } = useTrainingActions();\n\n  const config = useTrainingConfigStore();\n  const selectedModel = useTrainingConfigStore((s) => s.selectedModel);\n  const ensureModelDefaultsLoaded = useTrainingConfigStore(\n    (s) => s.ensureModelDefaultsLoaded,\n  );\n  const ensureDatasetChecked = useTrainingConfigStore(\n    (s) => s.ensureDatasetChecked,\n  );\n  const dialogOpen = useDatasetPreviewDialogStore((s) => s.open);\n  const dialogMode = useDatasetPreviewDialogStore((s) => s.mode);\n  const dialogInitial = useDatasetPreviewDialogStore((s) => s.initialData);\n  const closeDialog = useDatasetPreviewDialogStore((s) => s.close);\n\n  const stopRequested = useTrainingRuntimeStore((state) => state.stopRequested);\n  const canGoBack =\n    showTrainingView &&\n    !isHydratingRuntime &&\n    (stopRequested ||\n      (!isTrainingRunning &&\n        (runtimePhase === \"stopped\" ||\n          runtimePhase === \"error\" ||\n          runtimePhase === \"completed\" ||\n          runtimePhase === \"idle\")));\n  const tourEnabled = hasHydratedRuntime && !isHydratingRuntime;\n  const isConfigTour = !showTrainingView;\n  const tourSteps = showTrainingView ? studioTrainingTourSteps : studioTourSteps;\n  const tour = useGuidedTourController({\n    id: \"studio\",\n    steps: tourSteps,\n    enabled: tourEnabled,\n    autoKey: isConfigTour ? STUDIO_TOUR_KEY : undefined,\n    autoWhen: isConfigTour,\n  });\n\n  const setTourOpen = tour.setOpen;\n  useEffect(() => {\n    setTourOpen(false);\n  }, [showTrainingView, setTourOpen]);\n\n  useEffect(() => {\n    ensureModelDefaultsLoaded();\n    ensureDatasetChecked();\n  }, [selectedModel, ensureModelDefaultsLoaded, ensureDatasetChecked]);\n\n  return (\n    <div className=\"relative min-h-screen overflow-hidden bg-background\">\n      <main className=\"relative z-10 mx-auto max-w-7xl px-4 py-4 sm:px-6\">\n        <GuidedTour {...tour.tourProps} celebrate={isConfigTour} />\n\n        <DatasetPreviewDialog\n          open={dialogOpen}\n          onOpenChange={(open) => {\n            if (!open) closeDialog();\n          }}\n          datasetSource={config.datasetSource}\n          datasetName={\n            config.datasetSource === \"huggingface\" ? config.dataset : config.uploadedFile\n          }\n          hfToken={config.hfToken.trim() || null}\n          datasetSubset={config.datasetSubset}\n          datasetSplit={config.datasetSplit}\n          mode={dialogMode}\n          initialData={dialogInitial}\n          isVlm={config.isVisionModel && config.isDatasetImage === true}\n        />\n\n        {canGoBack && (\n          <Button\n            variant=\"ghost\"\n            size=\"sm\"\n            className=\"mb-2 cursor-pointer gap-1.5 text-muted-foreground\"\n            onClick={() => void dismissTrainingRun()}\n          >\n            <HugeiconsIcon icon={ArrowLeft01Icon} className=\"size-4\" />\n            Back to configuration\n          </Button>\n        )}\n\n        <div className=\"mb-6 flex flex-col gap-0.5 sm:mb-8\">\n          <h1 className=\"text-2xl font-semibold tracking-tight\">\n            Fine-tuning Studio\n          </h1>\n          <p className=\"text-sm text-muted-foreground\">\n            {showTrainingView\n              ? runtimeMessage || \"Training in progress\"\n              : \"Configure and start training\"}\n          </p>\n        </div>\n\n        {!hasHydratedRuntime && isHydratingRuntime ? (\n          <div className=\"rounded-xl border bg-card p-8 text-sm text-muted-foreground\">\n            Loading training runtime...\n          </div>\n        ) : showTrainingView ? (\n          <TrainingView />\n        ) : (\n          <div className=\"grid grid-cols-1 items-start gap-4 md:grid-cols-2 md:gap-6 xl:grid-cols-12\">\n            <ModelSection />\n            <DatasetSection />\n            <ParamsSection />\n            <TrainingSection />\n          </div>\n        )}\n      </main>\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/studio/tour/index.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nexport { studioTourSteps } from \"./steps\";\nexport { studioTrainingTourSteps } from \"./training\";\n"
  },
  {
    "path": "studio/frontend/src/features/studio/tour/steps/base-model.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { ReadMore, type TourStep } from \"@/features/tour\";\n\nexport const studioBaseModelStep: TourStep = {\n  id: \"base-model\",\n  target: \"studio-base-model\",\n  title: \"Hugging Face Model\",\n  body: (\n    <>\n      Paste <span className=\"font-mono\">org/model</span> or search. Pick a base\n      model close to your task (chat/instruct vs base). Smaller models iterate\n      faster; scale up once prompts + data look good.{\" \"}\n      <ReadMore href=\"https://unsloth.ai/docs/get-started/fine-tuning-llms-guide/what-model-should-i-use\" />\n    </>\n  ),\n};\n"
  },
  {
    "path": "studio/frontend/src/features/studio/tour/steps/dataset.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { ReadMore, type TourStep } from \"@/features/tour\";\n\nexport const studioDatasetStep: TourStep = {\n  id: \"dataset\",\n  target: \"studio-dataset\",\n  title: \"Dataset\",\n  body: (\n    <>\n      Search Hub or paste <span className=\"font-mono\">user/dataset</span>. Preview\n      a few rows: formatting matters more than size. We’ll try to auto-convert\n      your dataset into a supported training format. If we can’t infer it\n      cleanly, we’ll prompt you to map the fields manually. If outputs look off\n      in Chat later, dataset formatting/template is the first thing to check.{\" \"}\n      <ReadMore href=\"https://unsloth.ai/docs/get-started/fine-tuning-llms-guide/datasets-guide\" />\n    </>\n  ),\n};\n"
  },
  {
    "path": "studio/frontend/src/features/studio/tour/steps/index.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { TourStep } from \"@/features/tour\";\nimport { studioBaseModelStep } from \"./base-model\";\nimport { studioDatasetStep } from \"./dataset\";\nimport { studioLocalModelStep } from \"./local-model\";\nimport { studioMethodStep } from \"./method\";\nimport { studioNavStep } from \"./nav\";\nimport { studioParamsStep } from \"./params\";\nimport { studioSaveStep } from \"./save\";\nimport { studioStartStep } from \"./start\";\n\nexport const studioTourSteps: TourStep[] = [\n  studioNavStep,\n  studioLocalModelStep,\n  studioBaseModelStep,\n  studioMethodStep,\n  studioDatasetStep,\n  studioParamsStep,\n  studioStartStep,\n  studioSaveStep,\n];\n\n"
  },
  {
    "path": "studio/frontend/src/features/studio/tour/steps/local-model.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { ReadMore, type TourStep } from \"@/features/tour\";\n\nexport const studioLocalModelStep: TourStep = {\n  id: \"local-model\",\n  target: \"studio-local-model\",\n  title: \"Local model path\",\n  body: (\n    <>\n      Use this if you already downloaded weights locally (eg{\" \"}\n      <span className=\"font-mono\">./models/...</span>) to avoid re-downloading.\n      Folder should look like a Hugging Face model (config + tokenizer + weights).{\" \"}\n      <ReadMore href=\"https://unsloth.ai/docs/basics/fine-tuning-llms-guide\" />\n    </>\n  ),\n};\n"
  },
  {
    "path": "studio/frontend/src/features/studio/tour/steps/method.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { ReadMore, type TourStep } from \"@/features/tour\";\n\nexport const studioMethodStep: TourStep = {\n  id: \"method\",\n  target: \"studio-method\",\n  title: \"Method: QLoRA vs LoRA vs Full\",\n  body: (\n    <>\n      LoRA: trains small adapter weights (fast, common default). QLoRA: LoRA on\n      4-bit base weights (much lower VRAM). Full: updates all weights (highest\n      cost, usually needs more data to be worth it).{\" \"}\n      <ReadMore href=\"https://unsloth.ai/docs/basics/lora-hyperparameters-guide\" />\n    </>\n  ),\n};\n"
  },
  {
    "path": "studio/frontend/src/features/studio/tour/steps/nav.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { ReadMore, type TourStep } from \"@/features/tour\";\n\nexport const studioNavStep: TourStep = {\n  id: \"nav\",\n  target: \"navbar\",\n  title: \"Quick orientation\",\n  body: (\n    <>\n      Studio: pick base model, dataset, hyperparams, then start training. After\n      you start, you’ll see a Training view with live loss/metrics. Chat is for\n      testing base vs LoRA adapters. Export packages checkpoints for deployment.{\" \"}\n      <ReadMore href=\"https://unsloth.ai/docs/get-started/fine-tuning-for-beginners\" />\n    </>\n  ),\n};\n"
  },
  {
    "path": "studio/frontend/src/features/studio/tour/steps/params.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { ReadMore, type TourStep } from \"@/features/tour\";\n\nexport const studioParamsStep: TourStep = {\n  id: \"params\",\n  target: \"studio-params\",\n  title: \"Dial hyperparams\",\n  body: (\n    <>\n      Start boring, then iterate. We usually recommend starting with 1-3 epochs\n      (higher can overfit fast). If you’re unsure, change 1 knob at a time, and\n      watch train vs eval loss.{\" \"}\n      <ReadMore href=\"https://unsloth.ai/docs/basics/lora-hyperparameters-guide\" />\n    </>\n  ),\n};\n"
  },
  {
    "path": "studio/frontend/src/features/studio/tour/steps/save.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { TourStep } from \"@/features/tour\";\n\nexport const studioSaveStep: TourStep = {\n  id: \"save\",\n  target: \"studio-save\",\n  title: \"Save config\",\n  body: (\n    <>\n      Save your training config as a YAML file. Re-running the same baseline\n      makes it obvious if a change helped (or if you just got lucky).\n    </>\n  ),\n};\n"
  },
  {
    "path": "studio/frontend/src/features/studio/tour/steps/start.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { TourStep } from \"@/features/tour\";\n\nexport const studioStartStep: TourStep = {\n  id: \"start\",\n  target: \"studio-start\",\n  title: \"Start training\",\n  body: (\n    <>\n      Kick off training. If it errors immediately, check HF token / local paths\n      / dataset access first. Start with a small run to sanity-check loss + sample\n      outputs before burning hours.\n    </>\n  ),\n};\n"
  },
  {
    "path": "studio/frontend/src/features/studio/tour/training/index.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nexport { studioTrainingTourSteps } from \"./steps\";\n\n"
  },
  {
    "path": "studio/frontend/src/features/studio/tour/training/steps.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { TourStep } from \"@/features/tour\";\n\nexport const studioTrainingTourSteps: TourStep[] = [\n  {\n    id: \"nav\",\n    target: \"navbar\",\n    title: \"Training view\",\n    body: (\n      <>\n        This view updates live as training runs. Watch loss, speed, and ETA, and\n        use Stop if you need to bail out or save.\n      </>\n    ),\n  },\n  {\n    id: \"progress\",\n    target: \"studio-training-progress\",\n    title: \"Progress + ETA\",\n    body: (\n      <>\n        Phase shows what we’re doing (loading model/dataset, configuring,\n        training). ETA is rough early on; it stabilizes after a few steps.\n      </>\n    ),\n  },\n  {\n    id: \"train-loss\",\n    target: \"studio-training-loss\",\n    title: \"Training loss\",\n    body: (\n      <>\n        Training loss should generally trend down. Absolute values vary by\n        dataset + tokenizer, so use it for direction more than “a magic number”.\n        If loss goes very low (eg below ~0.2), that can be a sign you’re\n        overfitting. If loss plateaus high, you likely need better data\n        formatting, more data, or different hyperparams.\n      </>\n    ),\n  },\n  {\n    id: \"eval-loss\",\n    target: \"studio-eval-loss\",\n    title: \"Eval loss (validation)\",\n    body: (\n      <>\n        Eval loss is your sanity check. If training loss keeps dropping but eval\n        loss goes up, you’re likely overfitting. To track it, set an eval dataset\n        and `eval_steps` (setting `eval_steps=1` can be very slow).\n      </>\n    ),\n  },\n  {\n    id: \"stop\",\n    target: \"studio-training-stop\",\n    title: \"Stop / save\",\n    body: (\n      <>\n        Stop training any time. “Stop and Save” keeps the checkpoint/adapters so\n        you can export or compare later.\n      </>\n    ),\n  },\n];\n"
  },
  {
    "path": "studio/frontend/src/features/studio/training-start-overlay.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport {\n  AlertDialog,\n  AlertDialogAction,\n  AlertDialogCancel,\n  AlertDialogContent,\n  AlertDialogDescription,\n  AlertDialogFooter,\n  AlertDialogHeader,\n  AlertDialogTitle,\n} from \"@/components/ui/alert-dialog\";\nimport { Button } from \"@/components/ui/button\";\nimport {\n  AnimatedSpan,\n  Terminal,\n  TypingAnimation,\n} from \"@/components/ui/terminal\";\nimport { useTrainingActions, useTrainingRuntimeStore } from \"@/features/training\";\nimport { Cancel01Icon } from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport { useEffect, useState, type ReactElement } from \"react\";\n\ntype TrainingStartOverlayProps = {\n  message: string\n  currentStep: number\n}\n\nexport function TrainingStartOverlay({\n  message,\n  currentStep,\n}: TrainingStartOverlayProps): ReactElement {\n  const { stopTrainingRun, dismissTrainingRun } = useTrainingActions();\n  const isStarting = useTrainingRuntimeStore((s) => s.isStarting);\n  const [cancelDialogOpen, setCancelDialogOpen] = useState(false);\n  const [cancelRequested, setCancelRequested] = useState(false);\n\n  useEffect(() => {\n    if (!isStarting) {\n      setCancelRequested(false);\n    }\n  }, [isStarting]);\n\n  return (\n    <div className=\"pointer-events-none absolute inset-0 z-30 flex items-center justify-center rounded-2xl bg-background/45 backdrop-blur-[1px]\">\n      <div className=\"pointer-events-auto relative flex w-[860px] max-w-[calc(100%-2rem)] flex-col items-center gap-4\">\n        <img\n          src=\"/unsloth-gem.png\"\n          alt=\"Unsloth mascot\"\n          className=\"size-24 object-contain\"\n        />\n        <div className=\"relative w-full\">\n          <AlertDialog open={cancelDialogOpen} onOpenChange={setCancelDialogOpen}>\n            <Button\n              variant=\"ghost\"\n              size=\"icon\"\n              className=\"absolute right-3 top-3 z-10 size-7 cursor-pointer rounded-full text-muted-foreground/60 hover:bg-destructive/10 hover:text-destructive\"\n              onClick={() => setCancelDialogOpen(true)}\n              disabled={cancelRequested}\n            >\n              <HugeiconsIcon icon={Cancel01Icon} className=\"size-3.5\" />\n            </Button>\n            <AlertDialogContent overlayClassName=\"bg-background/40 supports-backdrop-filter:backdrop-blur-[1px]\">\n              <AlertDialogHeader>\n                <AlertDialogTitle>Cancel Training</AlertDialogTitle>\n                <AlertDialogDescription>\n                  Do you want to cancel the current training run?\n                </AlertDialogDescription>\n              </AlertDialogHeader>\n              <AlertDialogFooter>\n                <AlertDialogCancel>Continue Training</AlertDialogCancel>\n                <AlertDialogAction\n                  variant=\"destructive\"\n                  onClick={() => {\n                    setCancelRequested(true);\n                    setCancelDialogOpen(false);\n                    useTrainingRuntimeStore.getState().setStopRequested(true);\n                    void stopTrainingRun(false).then((ok) => {\n                      if (ok) {\n                        void dismissTrainingRun();\n                      } else {\n                        setCancelRequested(false);\n                      }\n                    });\n                  }}\n                >\n                  Cancel Training\n                </AlertDialogAction>\n              </AlertDialogFooter>\n            </AlertDialogContent>\n          </AlertDialog>\n          <Terminal\n            className=\"w-full min-h-[390px] rounded-2xl px-7 py-6 text-left\"\n            startOnView={false}\n          >\n          <TypingAnimation\n            duration={36}\n            className=\"bg-gradient-to-r from-emerald-300 via-lime-300 to-teal-300 bg-clip-text font-semibold text-transparent\"\n          >\n            {\"> unsloth training starts...\"}\n          </TypingAnimation>\n          <AnimatedSpan className=\"my-2\">\n            <pre className=\"whitespace-pre text-muted-foreground inline-block\">{`==((====))==\\n   \\\\\\\\   /|\\nO^O/ \\\\_/ \\\\\\n\\\\        /\\n \"-____-\"`}</pre>\n          </AnimatedSpan>\n          <TypingAnimation duration={44}>\n            {\"> Preparing model and dataset...\"}\n          </TypingAnimation>\n          <TypingAnimation duration={44}>\n            {\"> We are getting everything ready for your run...\"}\n          </TypingAnimation>\n          <AnimatedSpan className=\"mt-2 text-muted-foreground\">\n            {`> ${message || \"starting training...\"} | waiting for first step... (${currentStep})`}\n          </AnimatedSpan>\n          </Terminal>\n        </div>\n      </div>\n    </div>\n  )\n}\n"
  },
  {
    "path": "studio/frontend/src/features/studio/training-view.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { cn } from \"@/lib/utils\";\nimport { useTrainingRuntimeStore } from \"@/features/training\";\nimport type { ReactElement } from \"react\";\nimport { useShallow } from \"zustand/react/shallow\";\nimport { ChartsSection } from \"./sections/charts-section\";\nimport { ProgressSection } from \"./sections/progress-section\";\nimport { TrainingStartOverlay } from \"./training-start-overlay\";\n\nexport function TrainingView(): ReactElement {\n  const runtime = useTrainingRuntimeStore(\n    useShallow((state) => ({\n      phase: state.phase,\n      message: state.message,\n      currentStep: state.currentStep,\n      firstStepReceived: state.firstStepReceived,\n      isStarting: state.isStarting,\n    })),\n  );\n\n  const isPreparingPhase =\n    runtime.phase === \"downloading_model\" ||\n    runtime.phase === \"downloading_dataset\" ||\n    runtime.phase === \"loading_model\" ||\n    runtime.phase === \"loading_dataset\" ||\n    runtime.phase === \"configuring\";\n  const isWaitingForFirstStep =\n    runtime.phase === \"training\" && !runtime.firstStepReceived;\n  const showOverlay =\n    runtime.isStarting ||\n    isPreparingPhase ||\n    (isWaitingForFirstStep && runtime.currentStep <= 0);\n\n  return (\n    <div className={cn(\"relative\", showOverlay && \"min-h-[72vh]\")}>\n      <div\n        className={cn(\n          \"relative z-10 flex flex-col gap-6 transition-[filter]\",\n          showOverlay && \"blur\",\n        )}\n      >\n        <div data-tour=\"studio-training-progress\">\n          <ProgressSection />\n        </div>\n        <ChartsSection />\n      </div>\n      {showOverlay ? (\n        <TrainingStartOverlay\n          message={runtime.message}\n          currentStep={runtime.currentStep}\n        />\n      ) : null}\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/tour/components/guided-tour.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { Button } from \"@/components/ui/button\";\nimport { cn } from \"@/lib/utils\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport { ArrowLeft01Icon, ArrowRight01Icon, Cancel01Icon, CheckmarkCircle01Icon } from \"@hugeicons/core-free-icons\";\nimport { Dialog as DialogPrimitive } from \"radix-ui\";\nimport { AnimatePresence, motion } from \"motion/react\";\nimport { useEffect, useId, useLayoutEffect, useMemo, useRef, useState } from \"react\";\nimport { cssEscape, toRect } from \"../lib/dom\";\nimport { fireConfettiFireworks } from \"../lib/confetti-fireworks\";\nimport { computeCardPos, padded, pickPlacement } from \"../lib/layout\";\nimport { SpotlightOverlay } from \"./spotlight-overlay\";\nimport type { Placement, Rect, TourStep } from \"../types\";\n\ntype GuidedTourProps = { open: boolean; onOpenChange: (open: boolean) => void; steps: TourStep[]; onSkip: () => void; onComplete: () => void; celebrate?: boolean }; // confetti on complete only\n\nexport function GuidedTour({\n  open,\n  onOpenChange,\n  steps,\n  onSkip,\n  onComplete,\n  celebrate = false,\n}: GuidedTourProps) {\n  const maskId = `${useId()}-tour-mask`;\n  const [idx, setIdx] = useState(0);\n  const [vw, setVw] = useState(0);\n  const [vh, setVh] = useState(0);\n  const [targetRect, setTargetRect] = useState<Rect | null>(null);\n  const [placement, setPlacement] = useState<Placement>(\"right\");\n  const [cardPos, setCardPos] = useState<{ left: number; top: number }>({\n    left: 12,\n    top: 12,\n  });\n  const cardRef = useRef<HTMLDivElement>(null);\n  const closeLockRef = useRef(false);\n  const rafRef = useRef<number | null>(null);\n  const lastRectRef = useRef<Rect | null>(null);\n  const activeStepRef = useRef<TourStep | null>(null);\n\n  const step = steps[idx] ?? null;\n  const total = steps.length;\n  const isLast = idx === total - 1;\n\n  const spotlightRect = useMemo(() => {\n    if (!targetRect || !vw || !vh) return null;\n    const pad = step?.target === \"navbar\" ? 4 : 14;\n    return padded(targetRect, pad, vw, vh);\n  }, [step?.target, targetRect, vw, vh]);\n\n  useEffect(() => {\n    if (!open) return;\n    const prev = activeStepRef.current;\n    if (prev && prev.id !== step?.id) {\n      void prev.onExit?.();\n    }\n    activeStepRef.current = step;\n    if (step) {\n      void step.onEnter?.();\n    }\n  }, [open, step?.id]); // run before target lookup effect below\n\n  useEffect(() => {\n    if (open) return;\n    const prev = activeStepRef.current;\n    activeStepRef.current = null;\n    if (prev) {\n      void prev.onExit?.();\n    }\n  }, [open]);\n\n  useEffect(() => {\n    if (!open) return;\n    setIdx(0);\n    setTargetRect(null);\n    closeLockRef.current = false;\n    lastRectRef.current = null;\n  }, [open]);\n\n  useEffect(() => {\n    if (!open) return;\n    function onResize() {\n      setVw(window.innerWidth);\n      setVh(window.innerHeight);\n    }\n    onResize();\n    window.addEventListener(\"resize\", onResize);\n    return () => window.removeEventListener(\"resize\", onResize);\n  }, [open]);\n\n  useEffect(() => {\n    if (!open || !step) return;\n\n    const sel = `[data-tour=\"${cssEscape(step.target)}\"]`;\n    let el: HTMLElement | null = null;\n    let ro: ResizeObserver | null = null;\n    let retryTimer = 0;\n    let retries = 0;\n\n    let raf = 0;\n    let t = 0;\n\n    function findTarget(): HTMLElement | null {\n      const found = document.querySelector(sel);\n      if (!(found instanceof HTMLElement)) return null;\n      return found;\n    }\n\n    function isUsableTarget(candidate: HTMLElement): boolean {\n      const r = candidate.getBoundingClientRect();\n      return r.width >= 6 && r.height >= 6;\n    }\n\n    function rectChanged(a: Rect | null, b: Rect): boolean {\n      if (!a) return true;\n      return (\n        Math.abs(a.x - b.x) > 0.5 ||\n        Math.abs(a.y - b.y) > 0.5 ||\n        Math.abs(a.w - b.w) > 0.5 ||\n        Math.abs(a.h - b.h) > 0.5\n      );\n    }\n\n    function read(candidate: HTMLElement) {\n      const r = candidate.getBoundingClientRect();\n      const next = toRect(r);\n      const prev = lastRectRef.current;\n      if (rectChanged(prev, next)) {\n        lastRectRef.current = next;\n        setTargetRect(next);\n      }\n    }\n\n    function schedule() {\n      if (rafRef.current != null) return;\n      rafRef.current = window.requestAnimationFrame(() => {\n        rafRef.current = null;\n        if (el) read(el);\n      });\n    }\n\n    function attach(candidate: HTMLElement) {\n      el = candidate;\n\n      if (step.target !== \"navbar\") {\n        el.scrollIntoView({\n          block: \"center\",\n          inline: \"center\",\n          behavior: \"smooth\",\n        });\n      }\n\n      raf = window.requestAnimationFrame(() => read(el!));\n      t = window.setTimeout(schedule, 240);\n\n      ro = new ResizeObserver(() => schedule());\n      ro.observe(el);\n      window.addEventListener(\"scroll\", schedule, { capture: true, passive: true });\n      window.addEventListener(\"resize\", schedule, { passive: true });\n    }\n\n    function tryAttach(): boolean {\n      const candidate = findTarget();\n      if (!candidate) return false;\n      if (!isUsableTarget(candidate)) return false;\n      attach(candidate);\n      return true;\n    }\n\n    if (!tryAttach()) {\n      setTargetRect(null);\n      retryTimer = window.setInterval(() => {\n        retries += 1;\n        if (tryAttach() || retries > 40) {\n          window.clearInterval(retryTimer);\n        }\n      }, 50);\n    }\n\n    return () => {\n      window.cancelAnimationFrame(raf);\n      window.clearTimeout(t);\n      if (retryTimer) window.clearInterval(retryTimer);\n      ro?.disconnect();\n      window.removeEventListener(\"scroll\", schedule, true);\n      window.removeEventListener(\"resize\", schedule);\n      if (rafRef.current != null) {\n        window.cancelAnimationFrame(rafRef.current);\n        rafRef.current = null;\n      }\n    };\n  }, [open, step?.id]);\n\n  useLayoutEffect(() => {\n    if (!open || !spotlightRect || !vw || !vh) return;\n    const card = cardRef.current?.getBoundingClientRect();\n    if (!card) return;\n\n    const gap = 14;\n    const picked = pickPlacement(spotlightRect, { w: card.width, h: card.height }, vw, vh, gap);\n    setPlacement(picked);\n    setCardPos(\n      computeCardPos(\n        picked,\n        spotlightRect,\n        { w: card.width, h: card.height },\n        vw,\n        vh,\n        gap,\n      ),\n    );\n  }, [open, spotlightRect, vw, vh, idx]);\n\n  function requestClose(reason: \"skip\" | \"complete\") {\n    if (closeLockRef.current) return;\n    closeLockRef.current = true;\n    if (reason === \"skip\") {\n      onSkip();\n    } else {\n      if (celebrate) void fireConfettiFireworks();\n      onComplete();\n    }\n    onOpenChange(false);\n  }\n\n  return (\n    <DialogPrimitive.Root\n      open={open}\n      onOpenChange={(v) => {\n        if (v) onOpenChange(true);\n        else requestClose(\"skip\");\n      }}\n      modal={true}\n    >\n      <DialogPrimitive.Portal>\n        <AnimatePresence>\n          {open && (\n            <>\n              <DialogPrimitive.Overlay asChild>\n                <motion.div\n                  className=\"fixed inset-0 z-50\"\n                  initial={{ opacity: 0 }}\n                  animate={{ opacity: 1 }}\n                  exit={{ opacity: 0 }}\n                  transition={{ duration: 0.18 }}\n                >\n                  <SpotlightOverlay rect={spotlightRect} vw={vw} vh={vh} maskId={maskId} />\n                  {spotlightRect && (\n                    <motion.div\n                      className=\"fixed z-[51] pointer-events-none rounded-[22px] ring-1 ring-white/10\"\n                      initial={false}\n                      animate={{\n                        left: spotlightRect.x,\n                        top: spotlightRect.y,\n                        width: spotlightRect.w,\n                        height: spotlightRect.h,\n                        boxShadow:\n                          \"0 0 0 1px rgba(34, 211, 238, 0.12), 0 0 0 6px rgba(16, 185, 129, 0.08), 0 18px 90px rgba(0,0,0,0.55)\",\n                      }}\n                      transition={{ type: \"spring\", stiffness: 260, damping: 30 }}\n                    />\n                  )}\n                </motion.div>\n              </DialogPrimitive.Overlay>\n\n              <DialogPrimitive.Content\n                onPointerDownOutside={(e) => e.preventDefault()}\n                onInteractOutside={(e) => e.preventDefault()}\n                className={cn(\n                  \"fixed z-[52] outline-none\",\n                  \"w-[min(420px,calc(100vw-1.5rem))]\",\n                )}\n                style={{\n                  left: cardPos.left,\n                  top: cardPos.top,\n                }}\n              >\n                <motion.div\n                  ref={cardRef}\n                  initial={{ opacity: 0, scale: 0.985, y: 8 }}\n                  animate={{ opacity: 1, scale: 1, y: 0 }}\n                  exit={{ opacity: 0, scale: 0.99, y: 10 }}\n                  transition={{ duration: 0.22, ease: [0.165, 0.84, 0.44, 1] }}\n                  className={cn(\n                    \"relative overflow-hidden rounded-[28px] corner-squircle\",\n                    \"bg-white/95 text-foreground ring-1 ring-black/10 dark:bg-zinc-900/96 dark:text-zinc-100 dark:ring-white/12\",\n                    \"shadow-[0_30px_120px_rgba(0,0,0,0.35)]\",\n                  )}\n                  style={{\n                    fontFamily: \"'Figtree Variable', ui-sans-serif, sans-serif\",\n                  }}\n                >\n                  <div\n                    className={cn(\n                      \"absolute z-10 size-3 rotate-45 rounded-[3px] bg-white/95 ring-1 ring-black/10 dark:bg-zinc-900/96 dark:ring-white/12\",\n                      placement === \"right\" &&\n                        \"-left-1 top-1/2 -translate-y-1/2\",\n                      placement === \"left\" &&\n                        \"-right-1 top-1/2 -translate-y-1/2\",\n                      placement === \"bottom\" &&\n                        \"left-1/2 -top-1 -translate-x-1/2\",\n                      placement === \"top\" &&\n                        \"left-1/2 -bottom-1 -translate-x-1/2\",\n                    )}\n                    aria-hidden={true}\n                  />\n                  <div className=\"absolute inset-x-0 top-0 h-20 bg-gradient-to-b from-emerald-400/18 via-cyan-300/6 to-transparent dark:from-emerald-400/24 dark:via-cyan-300/12\" />\n                  <div className=\"absolute -left-14 -top-16 size-44 rounded-full bg-emerald-400/20 blur-2xl dark:bg-emerald-400/26\" />\n                  <div className=\"absolute -right-14 -bottom-16 size-44 rounded-full bg-cyan-300/18 blur-2xl dark:bg-cyan-300/24\" />\n\n                  <div className=\"relative p-5\">\n                    <div className=\"flex items-start justify-between gap-3\">\n                      <div className=\"min-w-0\">\n                        <div className=\"inline-flex items-center gap-2 rounded-full bg-black/[0.04] px-2.5 py-1 text-[10px] font-mono text-foreground/60 ring-1 ring-black/10 dark:bg-white/[0.04] dark:text-zinc-200/75 dark:ring-white/14\">\n                          {idx + 1}/{total}\n                          <span className=\"size-1 rounded-full bg-emerald-500/70\" />\n                          guided tour\n                        </div>\n                        <DialogPrimitive.Title\n                          className=\"mt-2 text-[18px] leading-tight\"\n                          style={{ fontFamily: \"var(--font-serif)\" }}\n                        >\n                          {step?.title ?? \"Quick tour\"}\n                        </DialogPrimitive.Title>\n                        <DialogPrimitive.Description className=\"mt-1.5 text-sm leading-relaxed text-foreground/70 dark:text-zinc-200/75\">\n                          {step?.body ?? \"Let’s get you oriented.\"}\n                        </DialogPrimitive.Description>\n                      </div>\n\n                      <Button\n                        variant=\"ghost\"\n                        size=\"icon-sm\"\n                        className=\"text-foreground/60 hover:text-foreground hover:bg-black/[0.05] dark:text-zinc-300/70 dark:hover:text-zinc-100 dark:hover:bg-white/[0.08]\"\n                        onClick={() => requestClose(\"skip\")}\n                        aria-label=\"Skip tour\"\n                      >\n                        <HugeiconsIcon icon={Cancel01Icon} className=\"size-4\" />\n                      </Button>\n                    </div>\n\n                    <div className=\"mt-5 flex items-center justify-between gap-3\">\n                      <Button\n                        variant=\"ghost\"\n                        className=\"text-foreground/60 hover:text-foreground hover:bg-black/[0.05] dark:text-zinc-300/70 dark:hover:text-zinc-100 dark:hover:bg-white/[0.08]\"\n                        onClick={() => requestClose(\"skip\")}\n                      >\n                        Skip\n                      </Button>\n\n                      <div className=\"flex items-center gap-2\">\n                        <Button\n                          variant=\"outline\"\n                          className=\"border-black/10 bg-white/70 text-foreground hover:bg-white hover:text-foreground dark:border-white/15 dark:bg-white/[0.07] dark:text-zinc-100 dark:hover:bg-white/[0.12]\"\n                          disabled={idx === 0}\n                          onClick={() => setIdx((i) => Math.max(0, i - 1))}\n                        >\n                          <HugeiconsIcon icon={ArrowLeft01Icon} className=\"size-4\" />\n                          Back\n                        </Button>\n                        {isLast ? (\n                          <Button\n                            variant=\"dark\"\n                            className=\"bg-gradient-to-r from-emerald-500 to-cyan-400 text-white hover:from-emerald-600 hover:to-cyan-500\"\n                            onClick={() => requestClose(\"complete\")}\n                          >\n                            <HugeiconsIcon icon={CheckmarkCircle01Icon} className=\"size-4\" />\n                            Done\n                          </Button>\n                        ) : (\n                          <Button\n                            variant=\"dark\"\n                            className=\"bg-gradient-to-r from-emerald-500 to-cyan-400 text-white hover:from-emerald-600 hover:to-cyan-500\"\n                            onClick={() => setIdx((i) => Math.min(total - 1, i + 1))}\n                          >\n                            Next\n                            <HugeiconsIcon icon={ArrowRight01Icon} className=\"size-4\" />\n                          </Button>\n                        )}\n                      </div>\n                    </div>\n                  </div>\n\n                  <div className=\"h-px bg-gradient-to-r from-transparent via-black/10 to-transparent dark:via-white/14\" />\n                  <div className=\"px-5 py-3 text-[11px] text-foreground/55 dark:text-zinc-300/65\">\n                    Tip: `Esc` skips. Tour blocks clicks so you can read.\n                  </div>\n                </motion.div>\n              </DialogPrimitive.Content>\n            </>\n          )}\n        </AnimatePresence>\n      </DialogPrimitive.Portal>\n    </DialogPrimitive.Root>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/tour/components/read-more.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nexport function ReadMore({ href = \"#\" }: { href?: string }) {\n  return (\n    <a\n      href={href}\n      onClick={(e) => {\n        if (href === \"#\") e.preventDefault();\n      }}\n      className=\"text-emerald-600 underline underline-offset-2 hover:text-emerald-700\"\n    >\n      Read more\n    </a>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/tour/components/spotlight-overlay.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { motion } from \"motion/react\";\nimport type { Rect } from \"../types\";\n\ntype SpotlightOverlayProps = {\n  rect: Rect | null;\n  vw: number;\n  vh: number;\n  maskId: string;\n};\n\nexport function SpotlightOverlay({ rect, vw, vh, maskId }: SpotlightOverlayProps) {\n  const hole = rect ?? { x: vw / 2 - 140, y: vh / 2 - 90, w: 280, h: 180 };\n  const r = 22;\n\n  return (\n    <svg\n      className=\"absolute inset-0 size-full\"\n      viewBox={`0 0 ${vw} ${vh}`}\n      preserveAspectRatio=\"none\"\n      aria-hidden={true}\n    >\n      <defs>\n        <radialGradient id={`${maskId}-v`} cx=\"50%\" cy=\"45%\" r=\"80%\">\n          <stop offset=\"0%\" stopColor=\"rgba(6, 9, 15, 0.35)\" />\n          <stop offset=\"55%\" stopColor=\"rgba(6, 9, 15, 0.65)\" />\n          <stop offset=\"100%\" stopColor=\"rgba(6, 9, 15, 0.88)\" />\n        </radialGradient>\n        <mask id={maskId}>\n          <rect x=\"0\" y=\"0\" width={vw} height={vh} fill=\"white\" />\n          <motion.rect\n            x={hole.x}\n            y={hole.y}\n            width={hole.w}\n            height={hole.h}\n            rx={r}\n            fill=\"black\"\n            transition={{ type: \"spring\", stiffness: 260, damping: 30 }}\n          />\n        </mask>\n      </defs>\n      <rect\n        x=\"0\"\n        y=\"0\"\n        width={vw}\n        height={vh}\n        fill={`url(#${maskId}-v)`}\n        mask={`url(#${maskId})`}\n      />\n    </svg>\n  );\n}\n\n"
  },
  {
    "path": "studio/frontend/src/features/tour/hooks/use-guided-tour-controller.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { useCallback, useEffect, useMemo, useState } from \"react\";\nimport type { TourStep } from \"../types\";\n\nexport const TOUR_OPEN_EVENT = \"omx:tour:open\";\n\nexport type TourOpenDetail = {\n  id?: string;\n};\n\nexport function useGuidedTourController({\n  id,\n  steps,\n  enabled = true,\n  autoKey,\n  autoWhen = false,\n}: {\n  id: string;\n  steps: TourStep[];\n  enabled?: boolean;\n  autoKey?: string;\n  autoWhen?: boolean;\n}) {\n  const [open, setOpen] = useState(false);\n  const [hasRuntime, setHasRuntime] = useState(false);\n\n  useEffect(() => setHasRuntime(true), []);\n\n  useEffect(() => {\n    if (!hasRuntime || !enabled) return;\n    if (!autoKey || !autoWhen) return;\n    if (steps.length === 0) return;\n    if (localStorage.getItem(autoKey)) return;\n    setOpen(true);\n  }, [autoKey, autoWhen, enabled, hasRuntime, steps.length]);\n\n  useEffect(() => {\n    if (!hasRuntime || !enabled) return;\n    function onOpen(e: Event) {\n      const ce = e as CustomEvent<TourOpenDetail>;\n      if (ce.detail?.id && ce.detail.id !== id) return;\n      if (steps.length === 0) return;\n      setOpen(true);\n    }\n    window.addEventListener(TOUR_OPEN_EVENT, onOpen);\n    return () => window.removeEventListener(TOUR_OPEN_EVENT, onOpen);\n  }, [enabled, hasRuntime, id, steps.length]);\n\n  const onSkip = useCallback(() => {\n    if (!autoKey) return;\n    localStorage.setItem(autoKey, \"skipped\");\n  }, [autoKey]);\n\n  const onComplete = useCallback(() => {\n    if (!autoKey) return;\n    localStorage.setItem(autoKey, \"done\");\n  }, [autoKey]);\n\n  const tourProps = useMemo(\n    () => ({\n      open,\n      onOpenChange: setOpen,\n      steps,\n      onSkip,\n      onComplete,\n    }),\n    [onComplete, onSkip, open, steps],\n  );\n\n  return { open, setOpen, onSkip, onComplete, tourProps };\n}\n\n"
  },
  {
    "path": "studio/frontend/src/features/tour/index.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nexport { GuidedTour } from \"./components/guided-tour\";\nexport { ReadMore } from \"./components/read-more\";\nexport { TOUR_OPEN_EVENT, useGuidedTourController } from \"./hooks/use-guided-tour-controller\";\nexport type { TourStep } from \"./types\";\n"
  },
  {
    "path": "studio/frontend/src/features/tour/types.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { ReactNode } from \"react\";\n\nexport type TourStep = {\n  id: string;\n  target: string; // data-tour=\"<target>\"\n  title: string;\n  body: ReactNode;\n  onEnter?: () => void | Promise<void>;\n  onExit?: () => void | Promise<void>;\n};\n\nexport type Rect = { x: number; y: number; w: number; h: number };\n\nexport type Placement = \"right\" | \"left\" | \"top\" | \"bottom\";\n"
  },
  {
    "path": "studio/frontend/src/features/training/api/datasets-api.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type {\n  CheckFormatResponse,\n  LocalDatasetsResponse,\n  UploadDatasetResponse,\n} from \"../types/datasets\";\nimport { authFetch } from \"@/features/auth\";\n\ntype CheckDatasetFormatArgs = {\n  datasetName: string;\n  hfToken: string | null;\n  subset?: string | null;\n  split?: string | null;\n  isVlm?: boolean;\n};\n\nexport async function checkDatasetFormat({\n  datasetName,\n  hfToken,\n  subset,\n  split,\n  isVlm,\n}: CheckDatasetFormatArgs): Promise<CheckFormatResponse> {\n  const res = await authFetch(\"/api/datasets/check-format\", {\n    method: \"POST\",\n    headers: { \"Content-Type\": \"application/json\" },\n    body: JSON.stringify({\n      dataset_name: datasetName,\n      hf_token: hfToken || undefined,\n      subset: subset || undefined,\n      split: split || \"train\",\n      is_vlm: !!isVlm,\n    }),\n  });\n\n  if (!res.ok) {\n    const body = await res.json().catch(() => null);\n    throw new Error(body?.detail || `Request failed (${res.status})`);\n  }\n\n  return res.json();\n}\n\nexport async function uploadTrainingDataset(\n  file: File,\n): Promise<UploadDatasetResponse> {\n  const form = new FormData();\n  form.append(\"file\", file);\n\n  const res = await authFetch(\"/api/datasets/upload\", {\n    method: \"POST\",\n    body: form,\n  });\n\n  if (!res.ok) {\n    const body = await res.json().catch(() => null);\n    throw new Error(body?.detail || `Upload failed (${res.status})`);\n  }\n\n  return res.json();\n}\n\n// ── AI Assist ────────────────────────────────────────────────────────\n\ntype AiAssistMappingArgs = {\n  columns: string[];\n  samples: Record<string, unknown>[];\n  datasetName?: string | null;\n  hfToken?: string | null;\n  modelName?: string | null;\n  modelType?: \"text\" | \"vision\" | \"audio\" | \"embeddings\" | null;\n};\n\nexport type AiAssistMappingResponse = {\n  success: boolean;\n  suggested_mapping?: Record<string, string> | null;\n  warning?: string | null;\n  // Conversion advisor fields\n  system_prompt?: string | null;\n  label_mapping?: Record<string, Record<string, string>> | null;\n  dataset_type?: string | null;\n  is_conversational?: boolean | null;\n  user_notification?: string | null;\n};\n\nexport async function aiAssistMapping({\n  columns,\n  samples,\n  datasetName,\n  hfToken,\n  modelName,\n  modelType,\n}: AiAssistMappingArgs): Promise<AiAssistMappingResponse> {\n  const res = await authFetch(\"/api/datasets/ai-assist-mapping\", {\n    method: \"POST\",\n    headers: { \"Content-Type\": \"application/json\" },\n    body: JSON.stringify({\n      columns,\n      samples: samples.slice(0, 5),\n      dataset_name: datasetName || undefined,\n      hf_token: hfToken || undefined,\n      model_name: modelName || undefined,\n      model_type: modelType || undefined,\n    }),\n  });\n\n  if (!res.ok) {\n    const body = await res.json().catch(() => null);\n    throw new Error(body?.detail || `AI assist failed (${res.status})`);\n  }\n\n  return res.json();\n}\n\nexport async function listLocalDatasets(): Promise<LocalDatasetsResponse> {\n  const res = await authFetch(\"/api/datasets/local\");\n  if (!res.ok) {\n    const body = await res.json().catch(() => null);\n    throw new Error(body?.detail || `Request failed (${res.status})`);\n  }\n  return res.json();\n}\n"
  },
  {
    "path": "studio/frontend/src/features/training/api/mappers.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { TrainingConfigState } from \"../types/config\";\nimport type { TrainingStartRequest } from \"../types/api\";\n\nconst BACKEND_LORA_TYPE = \"LoRA/QLoRA\";\nconst BACKEND_FULL_TYPE = \"Full Finetuning\";\n\nfunction parseSliceValue(value: string | null): number | null {\n  if (value == null) return null;\n  const trimmed = value.trim();\n  if (!trimmed) return null;\n  const num = Number(trimmed);\n  if (!Number.isFinite(num) || !Number.isInteger(num) || num < 0) return null;\n  return num;\n}\n\nexport function toBackendTrainingType(trainingMethod: string): string {\n  return trainingMethod === \"full\" ? BACKEND_FULL_TYPE : BACKEND_LORA_TYPE;\n}\n\nexport function buildTrainingStartPayload(\n  config: TrainingConfigState,\n): TrainingStartRequest {\n  const adapterMethod = config.trainingMethod !== \"full\";\n  const isQloraMethod = config.trainingMethod === \"qlora\";\n  const isEmbedding = config.isEmbeddingModel;\n  const hfDataset = config.datasetSource === \"huggingface\" ? config.dataset : null;\n  const localDatasets =\n    config.datasetSource === \"upload\" && config.uploadedFile\n      ? [config.uploadedFile]\n      : [];\n  let customFormatMapping: Record<string, unknown> | undefined =\n    Object.keys(config.datasetManualMapping).length > 0\n      ? { ...config.datasetManualMapping }\n      : undefined;\n\n  // Inject conversion advisor metadata into the mapping (__ prefix keys)\n  const hasAdvisorMeta =\n    config.datasetSystemPrompt ||\n    Object.keys(config.datasetLabelMapping).length > 0;\n  if (customFormatMapping && hasAdvisorMeta) {\n    if (config.datasetSystemPrompt) {\n      customFormatMapping.__system_prompt = config.datasetSystemPrompt;\n    }\n    if (Object.keys(config.datasetLabelMapping).length > 0) {\n      customFormatMapping.__label_mapping = config.datasetLabelMapping;\n    }\n  }\n\n  return {\n    model_name: config.selectedModel ?? \"\",\n    training_type: toBackendTrainingType(config.trainingMethod),\n    hf_token: config.hfToken.trim() || null,\n    load_in_4bit: adapterMethod ? isQloraMethod : false,\n    max_seq_length: config.contextLength,\n    trust_remote_code: config.trustRemoteCode ?? false,\n    hf_dataset: hfDataset,\n    subset: hfDataset ? config.datasetSubset : null,\n    train_split: hfDataset ? config.datasetSplit : null,\n    eval_split: hfDataset ? config.datasetEvalSplit : null,\n    dataset_slice_start: parseSliceValue(config.datasetSliceStart),\n    dataset_slice_end: parseSliceValue(config.datasetSliceEnd),\n    local_datasets: localDatasets,\n    local_eval_datasets:\n      config.datasetSource === \"upload\" && config.uploadedEvalFile\n        ? [config.uploadedEvalFile]\n        : [],\n    format_type: config.datasetFormat,\n    custom_format_mapping: customFormatMapping,\n    num_epochs: config.epochs,\n    learning_rate: String(config.learningRate),\n    batch_size: config.batchSize,\n    gradient_accumulation_steps: config.gradientAccumulation,\n    warmup_steps: isEmbedding ? null : config.warmupSteps,\n    warmup_ratio: isEmbedding ? 0.03 : null,\n    max_steps: config.maxSteps,\n    save_steps: config.saveSteps,\n    eval_steps: config.evalSteps,\n    weight_decay: config.weightDecay,\n    random_seed: config.randomSeed,\n    packing: isEmbedding ? false : config.packing,\n    optim: config.optimizerType,\n    lr_scheduler_type: config.lrSchedulerType,\n    use_lora: adapterMethod,\n    lora_r: config.loraRank,\n    lora_alpha: config.loraAlpha,\n    lora_dropout: config.loraDropout,\n    target_modules: adapterMethod ? config.targetModules : [],\n    gradient_checkpointing: config.gradientCheckpointing,\n    use_rslora: config.loraVariant === \"rslora\",\n    use_loftq: config.loraVariant === \"loftq\",\n    train_on_completions: isEmbedding ? false : config.trainOnCompletions,\n    finetune_vision_layers: config.finetuneVisionLayers,\n    finetune_language_layers: config.finetuneLanguageLayers,\n    finetune_attention_modules: config.finetuneAttentionModules,\n    finetune_mlp_modules: config.finetuneMLPModules,\n    is_dataset_image: isEmbedding ? false : !!config.isDatasetImage,\n    is_dataset_audio: isEmbedding ? false : config.isDatasetAudio,\n    is_embedding: isEmbedding,\n    enable_wandb: config.enableWandb,\n    wandb_token: config.enableWandb ? config.wandbToken.trim() || null : null,\n    wandb_project: config.enableWandb\n      ? config.wandbProject.trim() || null\n      : null,\n    enable_tensorboard: config.enableTensorboard,\n    tensorboard_dir: config.enableTensorboard\n      ? config.tensorboardDir.trim() || null\n      : null,\n  };\n}\n"
  },
  {
    "path": "studio/frontend/src/features/training/api/models-api.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { authFetch } from \"@/features/auth\";\n\ninterface VisionCheckResponse {\n  model_name: string;\n  is_vision: boolean;\n}\n\ninterface EmbeddingCheckResponse {\n  model_name: string;\n  is_embedding: boolean;\n}\n\ninterface BackendTrainingDefaults {\n  max_seq_length?: number;\n  num_epochs?: number;\n  learning_rate?: number | string;\n  optim?: string;\n  lr_scheduler_type?: string;\n  batch_size?: number;\n  gradient_accumulation_steps?: number;\n  warmup_steps?: number;\n  max_steps?: number;\n  save_steps?: number;\n  eval_steps?: number;\n  weight_decay?: number;\n  random_seed?: number;\n  packing?: boolean;\n  train_on_completions?: boolean;\n  gradient_checkpointing?: \"none\" | \"true\" | \"unsloth\";\n  trust_remote_code?: boolean;\n}\n\ninterface BackendLoraDefaults {\n  lora_r?: number;\n  lora_alpha?: number;\n  lora_dropout?: number;\n  target_modules?: string[];\n  use_rslora?: boolean;\n  use_loftq?: boolean;\n  finetune_vision_layers?: boolean;\n  finetune_language_layers?: boolean;\n  finetune_attention_modules?: boolean;\n  finetune_mlp_modules?: boolean;\n}\n\ninterface BackendLoggingDefaults {\n  enable_wandb?: boolean;\n  wandb_project?: string;\n  enable_tensorboard?: boolean;\n  tensorboard_dir?: string;\n  log_frequency?: number;\n}\n\nexport interface BackendModelConfig {\n  audio_type?: string | null;\n  training?: BackendTrainingDefaults;\n  lora?: BackendLoraDefaults;\n  logging?: BackendLoggingDefaults;\n}\n\nexport interface ModelConfigResponse {\n  id: string;\n  model_name?: string | null;\n  config?: BackendModelConfig | null;\n  is_vision: boolean;\n  is_embedding?: boolean;\n  is_audio: boolean;\n  is_lora: boolean;\n  base_model?: string | null;\n  model_type?: \"text\" | \"vision\" | \"audio\" | \"embeddings\" | null;\n  max_position_embeddings?: number | null;\n  model_size_bytes?: number | null;\n}\n\nexport interface LocalModelInfo {\n  id: string;\n  display_name: string;\n  path: string;\n  source: \"models_dir\" | \"hf_cache\";\n  model_id?: string | null;\n  updated_at?: number | null;\n}\n\ninterface LocalModelListResponse {\n  models_dir: string;\n  hf_cache_dir?: string | null;\n  models: LocalModelInfo[];\n}\n\n/**\n * Check whether a model is a vision model by asking the backend.\n * Calls GET /api/models/check-vision/{model_name}.\n */\nexport async function checkVisionModel(modelName: string): Promise<boolean> {\n  const encoded = encodeURIComponent(modelName);\n  const response = await authFetch(`/api/models/check-vision/${encoded}`);\n  if (!response.ok) {\n    // If the check fails (e.g. network error), default to non-vision\n    return false;\n  }\n  const data = (await response.json()) as VisionCheckResponse;\n  return data.is_vision;\n}\n\n/**\n * Check whether a model is an embedding model by asking the backend.\n * Calls GET /api/models/check-embedding/{model_name}.\n */\nexport async function checkEmbeddingModel(\n  modelName: string,\n): Promise<boolean> {\n  const encoded = encodeURIComponent(modelName);\n  const response = await authFetch(`/api/models/check-embedding/${encoded}`);\n  if (!response.ok) {\n    // If the check fails (e.g. network error), default to non-embedding\n    return false;\n  }\n  const data = (await response.json()) as EmbeddingCheckResponse;\n  return data.is_embedding;\n}\n\nexport async function getModelConfig(\n  modelName: string,\n  signal?: AbortSignal,\n  hfToken?: string,\n): Promise<ModelConfigResponse> {\n  const encoded = encodeURIComponent(modelName);\n  const params = hfToken ? `?hf_token=${encodeURIComponent(hfToken)}` : \"\";\n  const response = await authFetch(`/api/models/config/${encoded}${params}`, { signal });\n  if (!response.ok) {\n    throw new Error(`Failed to fetch model config (${response.status})`);\n  }\n  return (await response.json()) as ModelConfigResponse;\n}\n\nexport async function listLocalModels(\n  signal?: AbortSignal,\n): Promise<LocalModelInfo[]> {\n  const response = await authFetch(\"/api/models/local\", { signal });\n  if (!response.ok) {\n    throw new Error(`Failed to fetch local models (${response.status})`);\n  }\n  const data = (await response.json()) as LocalModelListResponse;\n  return data.models;\n}\n"
  },
  {
    "path": "studio/frontend/src/features/training/api/train-api.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { authFetch } from \"@/features/auth\";\nimport type {\n  TrainingStartRequest,\n  TrainingStartResponse,\n  TrainingStopResponse,\n} from \"../types/api\";\nimport type {\n  TrainingMetricsResponse,\n  TrainingProgressPayload,\n  TrainingStatusResponse,\n} from \"../types/runtime\";\n\nfunction isAbortError(error: unknown): boolean {\n  return error instanceof DOMException && error.name === \"AbortError\";\n}\n\nasync function readError(response: Response): Promise<string> {\n  try {\n    const payload = (await response.json()) as { detail?: string; message?: string };\n    return payload.detail || payload.message || `Request failed (${response.status})`;\n  } catch {\n    return `Request failed (${response.status})`;\n  }\n}\n\nasync function parseJson<T>(response: Response): Promise<T> {\n  if (!response.ok) {\n    throw new Error(await readError(response));\n  }\n  return (await response.json()) as T;\n}\n\nexport async function startTraining(\n  payload: TrainingStartRequest,\n): Promise<TrainingStartResponse> {\n  const response = await authFetch(\"/api/train/start\", {\n    method: \"POST\",\n    headers: { \"Content-Type\": \"application/json\" },\n    body: JSON.stringify(payload),\n  });\n  return parseJson<TrainingStartResponse>(response);\n}\n\nexport async function stopTraining(save = true): Promise<TrainingStopResponse> {\n  const response = await authFetch(\"/api/train/stop\", {\n    method: \"POST\",\n    headers: { \"Content-Type\": \"application/json\" },\n    body: JSON.stringify({ save }),\n  });\n  return parseJson<TrainingStopResponse>(response);\n}\n\nexport async function resetTraining(): Promise<void> {\n  const response = await authFetch(\"/api/train/reset\", { method: \"POST\" });\n  if (!response.ok) {\n    throw new Error(await readError(response));\n  }\n}\n\nexport async function getTrainingStatus(): Promise<TrainingStatusResponse> {\n  const response = await authFetch(\"/api/train/status\");\n  return parseJson<TrainingStatusResponse>(response);\n}\n\nexport async function getTrainingMetrics(): Promise<TrainingMetricsResponse> {\n  const response = await authFetch(\"/api/train/metrics\");\n  return parseJson<TrainingMetricsResponse>(response);\n}\n\ntype ProgressEventName = \"progress\" | \"heartbeat\" | \"complete\" | \"error\";\n\ninterface ParsedSseEvent {\n  event: ProgressEventName;\n  payload: TrainingProgressPayload;\n  id: number | null;\n}\n\nfunction parseSseEvent(rawEvent: string): ParsedSseEvent | null {\n  const lines = rawEvent.split(/\\r?\\n/);\n  let eventName: ProgressEventName = \"progress\";\n  let id: number | null = null;\n  const dataLines: string[] = [];\n\n  for (const line of lines) {\n    if (!line) {\n      continue;\n    }\n    if (line.startsWith(\"event:\")) {\n      const value = line.slice(6).trim();\n      if (\n        value === \"progress\" ||\n        value === \"heartbeat\" ||\n        value === \"complete\" ||\n        value === \"error\"\n      ) {\n        eventName = value;\n      }\n      continue;\n    }\n    if (line.startsWith(\"id:\")) {\n      const value = Number(line.slice(3).trim());\n      id = Number.isFinite(value) ? value : null;\n      continue;\n    }\n    if (line.startsWith(\"data:\")) {\n      dataLines.push(line.slice(5).trimStart());\n    }\n  }\n\n  if (dataLines.length === 0) {\n    return null;\n  }\n\n  const parsed = JSON.parse(dataLines.join(\"\\n\")) as TrainingProgressPayload;\n  return { event: eventName, payload: parsed, id };\n}\n\nexport async function streamTrainingProgress(options: {\n  signal: AbortSignal;\n  lastEventId?: number | null;\n  onOpen?: () => void;\n  onEvent: (event: ParsedSseEvent) => void;\n}): Promise<void> {\n  const headers = new Headers();\n  if (typeof options.lastEventId === \"number\") {\n    headers.set(\"Last-Event-ID\", String(options.lastEventId));\n  }\n\n  const response = await authFetch(\"/api/train/progress\", {\n    method: \"GET\",\n    headers,\n    signal: options.signal,\n  });\n\n  if (!response.ok) {\n    throw new Error(await readError(response));\n  }\n\n  if (!response.body) {\n    throw new Error(\"Progress stream unavailable\");\n  }\n\n  options.onOpen?.();\n\n  const reader = response.body.getReader();\n  const decoder = new TextDecoder();\n  let buffer = \"\";\n\n  while (true) {\n    const { value, done } = await reader.read();\n    if (done) {\n      break;\n    }\n\n    buffer += decoder.decode(value, { stream: true });\n\n    let separatorIndex = buffer.search(/\\r?\\n\\r?\\n/);\n    while (separatorIndex >= 0) {\n      const rawEvent = buffer.slice(0, separatorIndex);\n      const separatorLength = buffer[separatorIndex] === \"\\r\" ? 4 : 2;\n      buffer = buffer.slice(separatorIndex + separatorLength);\n\n      if (rawEvent.startsWith(\"retry:\")) {\n        separatorIndex = buffer.search(/\\r?\\n\\r?\\n/);\n        continue;\n      }\n\n      try {\n        const event = parseSseEvent(rawEvent);\n        if (event) {\n          options.onEvent(event);\n        }\n      } catch (error) {\n        if (!isAbortError(error)) {\n          throw error;\n        }\n      }\n\n      separatorIndex = buffer.search(/\\r?\\n\\r?\\n/);\n    }\n  }\n}\n\nexport { isAbortError };\n"
  },
  {
    "path": "studio/frontend/src/features/training/components/hf-dataset-subset-split-selectors.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport {\n  Select,\n  SelectContent,\n  SelectItem,\n  SelectTrigger,\n  SelectValue,\n} from \"@/components/ui/select\";\nimport { Spinner } from \"@/components/ui/spinner\";\nimport {\n  Tooltip,\n  TooltipContent,\n  TooltipTrigger,\n} from \"@/components/ui/tooltip\";\nimport {\n  Field,\n  FieldLabel,\n} from \"@/components/ui/field\";\nimport { useHfDatasetSplits } from \"@/hooks\";\nimport { InformationCircleIcon } from \"@hugeicons/core-free-icons\";\nimport { HugeiconsIcon } from \"@hugeicons/react\";\nimport { useEffect } from \"react\";\n\ntype Props = {\n  variant: \"wizard\" | \"studio\";\n  enabled: boolean;\n  datasetName: string | null;\n  accessToken?: string;\n  datasetSubset: string | null;\n  setDatasetSubset: (v: string | null) => void;\n  datasetSplit: string | null;\n  setDatasetSplit: (v: string | null) => void;\n  datasetEvalSplit: string | null;\n  setDatasetEvalSplit: (v: string | null) => void;\n};\n\nexport function HfDatasetSubsetSplitSelectors({\n  variant,\n  enabled,\n  datasetName,\n  accessToken,\n  datasetSubset,\n  setDatasetSubset,\n  datasetSplit,\n  setDatasetSplit,\n  datasetEvalSplit,\n  setDatasetEvalSplit,\n}: Props) {\n  const {\n    subsets: hfSubsets,\n    splits: hfSplits,\n    isLoading,\n    error,\n  } = useHfDatasetSplits(enabled ? datasetName : null, datasetSubset, {\n    accessToken,\n  });\n  const showPlaceholderDropdowns =\n    variant === \"studio\" && !enabled && !datasetName;\n\n  // Auto-select subset and split in one pass to avoid racing effects\n  useEffect(() => {\n    if (hfSubsets.length === 0) return;\n\n    // --- subset ---\n    if (!datasetSubset || !hfSubsets.includes(datasetSubset)) {\n      const pick = hfSubsets.includes(\"default\") ? \"default\" : hfSubsets[0];\n      setDatasetSubset(pick);\n      return;\n    }\n\n    // --- split (only once subset is settled) ---\n    if (hfSplits.length === 0) return;\n    if (!datasetSplit || !hfSplits.includes(datasetSplit)) {\n      const pick = hfSplits.includes(\"train\") ? \"train\" : hfSplits[0];\n      setDatasetSplit(pick);\n    }\n  }, [\n    hfSubsets,\n    hfSplits,\n    datasetSubset,\n    setDatasetSubset,\n    datasetSplit,\n    setDatasetSplit,\n  ]);\n\n  const showDropdowns = !isLoading && !error && hfSubsets.length > 0;\n\n  return (\n    <>\n      {showPlaceholderDropdowns && (\n        <>\n          <div className=\"grid gap-3 sm:grid-cols-2\">\n            <SelectorDropdown\n              variant={variant}\n              label=\"Subset\"\n              tooltip=\"Select which subset (config) of the dataset to use.\"\n              value={null}\n              onChange={setDatasetSubset}\n              options={[]}\n              placeholder=\"Select a subset...\"\n              disabled={true}\n            />\n            <SelectorDropdown\n              variant={variant}\n              label=\"Train Split\"\n              tooltip=\"Select which split to use for training.\"\n              value={null}\n              onChange={setDatasetSplit}\n              options={[]}\n              placeholder=\"Select a split...\"\n              disabled={true}\n            />\n          </div>\n          <SelectorDropdown\n            variant={variant}\n            label=\"Evaluation Split\"\n            tooltip=\"Select which split to use for evaluation. None means no evaluation during training.\"\n            value={null}\n            onChange={setDatasetEvalSplit}\n            options={[]}\n            placeholder=\"None\"\n            allowNone\n            disabled={true}\n          />\n        </>\n      )}\n\n      {isLoading && (\n        <div\n          className={\n            variant === \"wizard\"\n              ? \"flex items-center gap-2 text-xs text-muted-foreground py-1\"\n              : \"flex items-center gap-2 rounded-lg border bg-muted/20 px-3.5 py-3 text-xs text-muted-foreground\"\n          }\n        >\n          <Spinner className=\"size-3.5\" />\n          Loading dataset configs and splits...\n        </div>\n      )}\n\n      {error && (\n        <div\n          className={\n            variant === \"wizard\"\n              ? \"rounded-lg border border-amber-200 bg-amber-50 px-3 py-2 text-xs text-amber-700 dark:border-amber-800 dark:bg-amber-950 dark:text-amber-400\"\n              : \"rounded-lg border border-amber-200 bg-amber-50 px-3.5 py-2.5 text-xs text-amber-700 dark:border-amber-800 dark:bg-amber-950 dark:text-amber-400\"\n          }\n        >\n          {error}\n        </div>\n      )}\n\n      {showDropdowns && (\n        <>\n          {variant === \"studio\" ? (\n            <div className=\"grid gap-3 sm:grid-cols-2\">\n              <SelectorDropdown\n                variant={variant}\n                label=\"Subset\"\n                tooltip=\"Select which subset (config) of the dataset to use.\"\n                value={datasetSubset}\n                onChange={setDatasetSubset}\n                options={hfSubsets}\n                placeholder=\"Select a subset...\"\n              />\n              <SelectorDropdown\n                variant={variant}\n                label=\"Train Split\"\n                tooltip=\"Select which split to use for training.\"\n                value={datasetSplit}\n                onChange={setDatasetSplit}\n                options={hfSplits}\n                placeholder=\"Select a split...\"\n              />\n            </div>\n          ) : (\n            <>\n              <SelectorDropdown\n                variant={variant}\n                label=\"Subset\"\n                tooltip=\"Select which subset (config) of the dataset to use.\"\n                value={datasetSubset}\n                onChange={setDatasetSubset}\n                options={hfSubsets}\n                placeholder=\"Select a subset...\"\n              />\n              <SelectorDropdown\n                variant={variant}\n                label=\"Train Split\"\n                tooltip=\"Select which split to use for training.\"\n                value={datasetSplit}\n                onChange={setDatasetSplit}\n                options={hfSplits}\n                placeholder=\"Select a split...\"\n              />\n            </>\n          )}\n          <SelectorDropdown\n            variant={variant}\n            label=\"Evaluation Split\"\n            tooltip=\"Select which split to use for evaluation. None means no evaluation during training.\"\n            value={datasetEvalSplit}\n            onChange={setDatasetEvalSplit}\n            options={hfSplits}\n            placeholder=\"None\"\n            allowNone\n          />\n        </>\n      )}\n    </>\n  );\n}\n\nfunction SelectorDropdown({\n  variant,\n  label,\n  tooltip,\n  value,\n  onChange,\n  options,\n  placeholder,\n  allowNone = false,\n  disabled = false,\n}: {\n  variant: \"wizard\" | \"studio\";\n  label: string;\n  tooltip: string;\n  value: string | null;\n  onChange: (v: string | null) => void;\n  options: string[];\n  placeholder: string;\n  allowNone?: boolean;\n  disabled?: boolean;\n}) {\n  const selectValue =\n    value ?? (allowNone && !disabled ? \"_none\" : undefined);\n\n  if (variant === \"wizard\") {\n    return (\n      <Field>\n        <FieldLabel className=\"flex items-center gap-1.5\">\n          {label}\n          <Tooltip>\n            <TooltipTrigger asChild={true}>\n              <button\n                type=\"button\"\n                className=\"text-muted-foreground/50 hover:text-muted-foreground\"\n              >\n                <HugeiconsIcon\n                  icon={InformationCircleIcon}\n                  className=\"size-3.5\"\n                />\n              </button>\n            </TooltipTrigger>\n            <TooltipContent className=\"max-w-xs\">\n              {tooltip}\n            </TooltipContent>\n          </Tooltip>\n        </FieldLabel>\n        <Select\n          value={selectValue}\n          onValueChange={(v) => onChange(v === \"_none\" ? null : v)}\n          disabled={disabled}\n        >\n          <SelectTrigger className=\"w-full\">\n            <SelectValue placeholder={placeholder} />\n          </SelectTrigger>\n          <SelectContent>\n            {allowNone && (\n              <SelectItem value=\"_none\">None</SelectItem>\n            )}\n            {options.map((opt) => (\n              <SelectItem key={opt} value={opt}>\n                {opt}\n              </SelectItem>\n            ))}\n          </SelectContent>\n        </Select>\n      </Field>\n    );\n  }\n\n  return (\n    <div className=\"flex flex-col gap-1.5\">\n      <span className=\"flex items-center gap-1.5 text-xs font-medium text-muted-foreground\">\n        {label}\n        <Tooltip>\n          <TooltipTrigger asChild={true}>\n            <button\n              type=\"button\"\n              className=\"text-foreground/70 hover:text-foreground\"\n            >\n              <HugeiconsIcon\n                icon={InformationCircleIcon}\n                className=\"size-3\"\n              />\n            </button>\n          </TooltipTrigger>\n          <TooltipContent>\n            {tooltip}\n          </TooltipContent>\n        </Tooltip>\n      </span>\n      <Select\n        value={selectValue}\n        onValueChange={(v) => onChange(v === \"_none\" ? null : v)}\n        disabled={disabled}\n      >\n        <SelectTrigger className=\"w-full\">\n          <SelectValue placeholder={placeholder} />\n        </SelectTrigger>\n        <SelectContent>\n          {allowNone && (\n            <SelectItem value=\"_none\">None</SelectItem>\n          )}\n          {options.map((opt) => (\n            <SelectItem key={opt} value={opt}>\n              {opt}\n            </SelectItem>\n          ))}\n        </SelectContent>\n      </Select>\n    </div>\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/training/hooks/use-max-steps-epochs-toggle.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { useCallback, useEffect, useState } from \"react\";\n\nconst PREV_MAX_STEPS_KEY = \"unsloth_prev_max_steps\";\nconst PREV_SAVE_STEPS_KEY = \"unsloth_prev_save_steps\";\nconst DEFAULT_MAX_STEPS = 60;\nconst DEFAULT_EPOCHS = 3;\n\nfunction readStoredNumber(key: string, fallback: number): number {\n  if (typeof window === \"undefined\") return fallback;\n  try {\n    const value = window.localStorage.getItem(key);\n    if (value === null) return fallback;\n    const parsed = Number(value);\n    return Number.isFinite(parsed) ? parsed : fallback;\n  } catch {\n    return fallback;\n  }\n}\n\nfunction writeStoredNumber(key: string, value: number): void {\n  if (typeof window === \"undefined\") return;\n  try {\n    window.localStorage.setItem(key, String(value));\n  } catch {\n    // Best effort only; ignore storage errors in restricted environments.\n  }\n}\n\nfunction normalizePrevMaxSteps(value: number): number {\n  return Number.isFinite(value) && value > 0 ? value : DEFAULT_MAX_STEPS;\n}\n\nfunction normalizePrevSaveSteps(value: number): number {\n  return Number.isFinite(value) && value >= 0 ? value : 0;\n}\n\ntype UseMaxStepsEpochsToggleParams = {\n  maxSteps: number;\n  epochs: number;\n  saveSteps: number;\n  setMaxSteps: (value: number) => void;\n  setEpochs: (value: number) => void;\n  setSaveSteps: (value: number) => void;\n  defaultEpochs?: number;\n};\n\ntype UseMaxStepsEpochsToggleResult = {\n  useEpochs: boolean;\n  toggleUseEpochs: () => void;\n};\n\nexport function useMaxStepsEpochsToggle({\n  maxSteps,\n  epochs,\n  saveSteps,\n  setMaxSteps,\n  setEpochs,\n  setSaveSteps,\n  defaultEpochs = DEFAULT_EPOCHS,\n}: UseMaxStepsEpochsToggleParams): UseMaxStepsEpochsToggleResult {\n  const useEpochs = maxSteps === 0;\n  const [prevMaxSteps, setPrevMaxSteps] = useState(() =>\n    normalizePrevMaxSteps(readStoredNumber(PREV_MAX_STEPS_KEY, DEFAULT_MAX_STEPS)),\n  );\n  const [prevSaveSteps, setPrevSaveSteps] = useState(() => {\n    if (maxSteps === 0 && saveSteps > 0) {\n      return normalizePrevSaveSteps(saveSteps);\n    }\n    return normalizePrevSaveSteps(readStoredNumber(PREV_SAVE_STEPS_KEY, 0));\n  });\n\n  useEffect(() => {\n    if (maxSteps > 0) {\n      const normalized = normalizePrevMaxSteps(maxSteps);\n      setPrevMaxSteps(normalized);\n      writeStoredNumber(PREV_MAX_STEPS_KEY, normalized);\n    }\n  }, [maxSteps]);\n\n  useEffect(() => {\n    if (!useEpochs) {\n      const normalized = normalizePrevSaveSteps(saveSteps);\n      setPrevSaveSteps(normalized);\n      writeStoredNumber(PREV_SAVE_STEPS_KEY, normalized);\n    }\n  }, [saveSteps, useEpochs]);\n\n  const toggleUseEpochs = useCallback(() => {\n    if (useEpochs) {\n      setMaxSteps(normalizePrevMaxSteps(prevMaxSteps));\n      setSaveSteps(normalizePrevSaveSteps(prevSaveSteps));\n      return;\n    }\n\n    setMaxSteps(0);\n    setEpochs(epochs || defaultEpochs);\n  }, [\n    defaultEpochs,\n    epochs,\n    prevMaxSteps,\n    prevSaveSteps,\n    setEpochs,\n    setMaxSteps,\n    setSaveSteps,\n    useEpochs,\n  ]);\n\n  return { useEpochs, toggleUseEpochs };\n}\n"
  },
  {
    "path": "studio/frontend/src/features/training/hooks/use-training-actions.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { useCallback } from \"react\";\nimport { checkDatasetFormat } from \"../api/datasets-api\";\nimport { buildTrainingStartPayload } from \"../api/mappers\";\nimport { startTraining, stopTraining, resetTraining } from \"../api/train-api\";\nimport { syncTrainingRuntimeFromBackend } from \"../lib/sync-runtime\";\nimport { validateTrainingConfig } from \"../lib/validation\";\nimport { useDatasetPreviewDialogStore } from \"../stores/dataset-preview-dialog-store\";\nimport { useTrainingConfigStore } from \"../stores/training-config-store\";\nimport { useTrainingRuntimeStore } from \"../stores/training-runtime-store\";\nimport type { TrainingConfigState } from \"../types/config\";\nimport { toast } from \"sonner\";\n\n/** Chatml → format-specific role remap (only for formats that differ from chatml). */\nconst ROLE_REMAP: Record<string, Record<string, string>> = {\n  alpaca: { user: \"instruction\", system: \"input\", assistant: \"output\" },\n  sharegpt: { user: \"human\", assistant: \"gpt\", system: \"system\" },\n};\n\nfunction normalizeTrainingStartError(message: string): string {\n  const normalized = message.toLowerCase();\n  const isLegacyDatasetScriptError =\n    normalized.includes(\"failed to check dataset format\") &&\n    normalized.includes(\"dataset scripts are no longer supported\");\n\n  if (isLegacyDatasetScriptError) {\n    return \"This Hub dataset relies on a legacy custom script and isn’t supported in this training flow.\";\n  }\n\n  return message;\n}\n\nexport function useTrainingActions() {\n  const isStarting = useTrainingRuntimeStore((state) => state.isStarting);\n  const startError = useTrainingRuntimeStore((state) => state.startError);\n\n  const startTrainingRun = useCallback(async (): Promise<boolean> => {\n    const config = useTrainingConfigStore.getState();\n    const runtimeStore = useTrainingRuntimeStore.getState();\n    const dialogStore = useDatasetPreviewDialogStore.getState();\n\n    runtimeStore.setStartError(null);\n    const validation = validateTrainingConfig(config);\n    if (!validation.ok) {\n      runtimeStore.setStartError(validation.message);\n      return false;\n    }\n\n    runtimeStore.setStarting(true);\n\n    try {\n      const datasetName = getDatasetName(config);\n      let isVlm = config.isVisionModel && config.isDatasetImage === true;\n\n      if (datasetName) {\n        const check = await checkDatasetFormat({\n          datasetName,\n          hfToken: config.hfToken.trim() || null,\n          subset: config.datasetSubset,\n          split: config.datasetSplit,\n          isVlm,\n        });\n\n        // Backend auto-detects image/audio from dataset content.\n        // Sync these flags into the store so buildTrainingStartPayload picks them up.\n        const isAudio = !!check.is_audio;\n        const isImage = !!check.is_image;\n\n        if (isImage && config.isVisionModel) {\n          isVlm = true;\n        }\n        if (isImage !== config.isDatasetImage || isAudio !== config.isDatasetAudio) {\n          useTrainingConfigStore.setState({\n            isDatasetImage: isImage,\n            isDatasetAudio: isAudio,\n          });\n        }\n\n        const needsReview = check.requires_manual_mapping || check.detected_format === \"custom_heuristic\";\n        if (needsReview && !hasManualMapping(config, isVlm, isAudio)) {\n          // Pre-fill from suggested_mapping or VLM detected columns\n          const hint: Record<string, string> = {};\n          if (check.suggested_mapping) {\n            const table = ROLE_REMAP[config.datasetFormat];\n            for (const [col, role] of Object.entries(check.suggested_mapping)) {\n              hint[col] = table ? (table[role] ?? role) : role;\n            }\n          } else if (isAudio) {\n            if (check.detected_audio_column) hint[check.detected_audio_column] = \"audio\";\n            if (check.detected_text_column) hint[check.detected_text_column] = \"text\";\n            if (check.detected_speaker_column) hint[check.detected_speaker_column] = \"speaker_id\";\n          } else if (isVlm) {\n            if (check.detected_image_column) hint[check.detected_image_column] = \"image\";\n            if (check.detected_text_column) hint[check.detected_text_column] = \"text\";\n          }\n\n          if (Object.keys(hint).length > 0) {\n            useTrainingConfigStore.getState().setDatasetManualMapping(hint);\n          }\n\n          runtimeStore.setStarting(false);\n          dialogStore.openMapping(check);\n          return false;\n        }\n      }\n\n      // Abort if cancel was requested during dataset check\n      if (useTrainingRuntimeStore.getState().stopRequested) {\n        runtimeStore.setStarting(false);\n        return false;\n      }\n\n      // Re-read config after potential store updates from dataset check\n      const payload = buildTrainingStartPayload(useTrainingConfigStore.getState());\n      const response = await startTraining(payload);\n\n      if (response.status === \"error\") {\n        const rawMessage = response.error || response.message;\n        const safeMessage = normalizeTrainingStartError(rawMessage);\n        runtimeStore.setStartError(safeMessage);\n        runtimeStore.setStarting(false);\n        return false;\n      }\n\n      runtimeStore.setStartQueued(response.job_id, response.message);\n      await syncTrainingRuntimeFromBackend();\n      return true;\n    } catch (error) {\n      const rawMessage =\n        error instanceof Error ? error.message : \"Failed to start training\";\n      const safeMessage = normalizeTrainingStartError(rawMessage);\n      runtimeStore.setStartError(safeMessage);\n      runtimeStore.setStarting(false);\n      return false;\n    }\n  }, []);\n\n  const stopTrainingRun = useCallback(async (save = true): Promise<boolean> => {\n    const runtimeStore = useTrainingRuntimeStore.getState();\n    runtimeStore.setStartError(null);\n\n    try {\n      await stopTraining(save);\n      await syncTrainingRuntimeFromBackend();\n      return true;\n    } catch (error) {\n      const message =\n        error instanceof Error ? error.message : \"Failed to stop training\";\n      runtimeStore.setRuntimeError(message);\n      return false;\n    }\n  }, []);\n\n  const dismissTrainingRun = useCallback(async (): Promise<void> => {\n    try {\n      await resetTraining();\n      useTrainingRuntimeStore.getState().resetRuntime();\n    } catch (error) {\n      const message =\n        error instanceof Error\n          ? error.message\n          : \"Stop training first, then return to configuration.\";\n      toast.error(\"Training still active\", {\n        description: message,\n      });\n      await syncTrainingRuntimeFromBackend();\n    }\n  }, []);\n\n  return {\n    isStarting,\n    startError,\n    startTrainingRun,\n    stopTrainingRun,\n    dismissTrainingRun,\n  };\n}\n\nfunction getDatasetName(config: TrainingConfigState): string | null {\n  return config.datasetSource === \"huggingface\"\n    ? config.dataset\n    : config.uploadedFile;\n}\n\nfunction hasManualMapping(config: TrainingConfigState, isVlm = false, isAudio = false): boolean {\n  const mapping = config.datasetManualMapping;\n  const roles = new Set(Object.values(mapping));\n  if (isAudio) return roles.has(\"audio\") && roles.has(\"text\");\n  if (isVlm) return roles.has(\"image\") && roles.has(\"text\");\n  const fmt = config.datasetFormat;\n  if (fmt === \"alpaca\") return roles.has(\"instruction\") && roles.has(\"output\");\n  if (fmt === \"sharegpt\") return roles.has(\"human\") && roles.has(\"gpt\");\n  return roles.has(\"user\") && roles.has(\"assistant\");\n}\n"
  },
  {
    "path": "studio/frontend/src/features/training/hooks/use-training-runtime-lifecycle.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { hasAuthToken } from \"@/features/auth\";\nimport { useEffect } from \"react\";\nimport {\n  getTrainingMetrics,\n  getTrainingStatus,\n  isAbortError,\n  streamTrainingProgress,\n} from \"../api/train-api\";\nimport { useTrainingRuntimeStore } from \"../stores/training-runtime-store\";\nimport type { TrainingRuntimeStore } from \"../types/runtime\";\n\nconst STATUS_POLL_INTERVAL_MS = 3000;\nconst METRICS_POLL_INTERVAL_MS = 5000;\nconst STREAM_RECONNECT_DELAY_MS = 1500;\n\nfunction shouldUseLiveSync(state: TrainingRuntimeStore): boolean {\n  return state.isTrainingRunning || state.phase === \"training\";\n}\n\nexport function useTrainingRuntimeLifecycle(): void {\n  useEffect(() => {\n    let disposed = false;\n    let openingStream = false;\n    let streamController: AbortController | null = null;\n    let reconnectTimer: ReturnType<typeof setTimeout> | null = null;\n\n    const runtimeStore = useTrainingRuntimeStore;\n\n    const clearReconnect = () => {\n      if (reconnectTimer) {\n        clearTimeout(reconnectTimer);\n        reconnectTimer = null;\n      }\n    };\n\n    const stopStream = () => {\n      clearReconnect();\n      if (streamController) {\n        streamController.abort();\n        streamController = null;\n      }\n      runtimeStore.getState().setSseConnected(false);\n    };\n\n    const pollMetrics = async () => {\n      if (!hasAuthToken()) return;\n      const gen = runtimeStore.getState().resetGeneration;\n      try {\n        const metrics = await getTrainingMetrics();\n        if (disposed || runtimeStore.getState().resetGeneration !== gen) {\n          return;\n        }\n        runtimeStore.getState().applyMetrics(metrics);\n      } catch (error) {\n        if (!isAbortError(error) && !disposed && hasAuthToken()) {\n          runtimeStore.getState().setSseConnected(false);\n        }\n      }\n    };\n\n    const pollStatus = async () => {\n      if (!hasAuthToken()) return;\n      const gen = runtimeStore.getState().resetGeneration;\n      try {\n        const status = await getTrainingStatus();\n        if (disposed || runtimeStore.getState().resetGeneration !== gen) {\n          return;\n        }\n\n        runtimeStore.getState().applyStatus(status);\n\n        const nextState = runtimeStore.getState();\n        if (shouldUseLiveSync(nextState)) {\n          void ensureStream();\n        } else {\n          stopStream();\n        }\n      } catch (error) {\n        if (!isAbortError(error) && !disposed && hasAuthToken()) {\n          runtimeStore.getState().setSseConnected(false);\n        }\n      }\n    };\n\n    const ensureStream = async () => {\n      const state = runtimeStore.getState();\n      if (\n        disposed ||\n        openingStream ||\n        streamController ||\n        !shouldUseLiveSync(state)\n      ) {\n        return;\n      }\n\n      clearReconnect();\n      openingStream = true;\n      const controller = new AbortController();\n      streamController = controller;\n\n      try {\n        await streamTrainingProgress({\n          signal: controller.signal,\n          lastEventId: state.lastEventId,\n          onOpen: () => {\n            runtimeStore.getState().setSseConnected(true);\n          },\n          onEvent: (event) => {\n            const liveStore = runtimeStore.getState();\n            if (typeof event.id === \"number\") {\n              liveStore.setLastEventId(event.id);\n            }\n\n            liveStore.applyProgress(event.payload, event.id ?? undefined);\n\n            if (event.event === \"complete\") {\n              void pollStatus();\n              void pollMetrics();\n              stopStream();\n            }\n\n            if (event.event === \"error\") {\n              liveStore.setRuntimeError(\"Training stream error\");\n              stopStream();\n            }\n          },\n        });\n      } catch (error) {\n        if (!disposed && !controller.signal.aborted && !isAbortError(error)) {\n          runtimeStore.getState().setSseConnected(false);\n        }\n      } finally {\n        openingStream = false;\n        if (streamController === controller) {\n          streamController = null;\n        }\n        runtimeStore.getState().setSseConnected(false);\n\n        if (!disposed && !controller.signal.aborted) {\n          const liveState = runtimeStore.getState();\n          if (shouldUseLiveSync(liveState)) {\n            reconnectTimer = setTimeout(() => {\n              void ensureStream();\n            }, STREAM_RECONNECT_DELAY_MS);\n          }\n        }\n      }\n    };\n\n    const hydrate = async () => {\n      runtimeStore.getState().setHydrating(true);\n      try {\n        await Promise.all([pollStatus(), pollMetrics()]);\n      } finally {\n        if (!disposed) {\n          runtimeStore.getState().setHydrating(false);\n          runtimeStore.getState().setHasHydrated(true);\n        }\n      }\n    };\n\n    void hydrate();\n\n    const statusTimer = setInterval(() => {\n      void pollStatus();\n    }, STATUS_POLL_INTERVAL_MS);\n\n    const metricsTimer = setInterval(() => {\n      const state = runtimeStore.getState();\n      if (shouldUseLiveSync(state) || state.currentStep > 0) {\n        void pollMetrics();\n      }\n    }, METRICS_POLL_INTERVAL_MS);\n\n    return () => {\n      disposed = true;\n      clearInterval(statusTimer);\n      clearInterval(metricsTimer);\n      stopStream();\n    };\n  }, []);\n}\n"
  },
  {
    "path": "studio/frontend/src/features/training/index.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nexport { useTrainingConfigStore } from \"./stores/training-config-store\";\nexport {\n  shouldShowTrainingView,\n  useTrainingRuntimeStore,\n} from \"./stores/training-runtime-store\";\nexport { useTrainingActions } from \"./hooks/use-training-actions\";\nexport { useTrainingRuntimeLifecycle } from \"./hooks/use-training-runtime-lifecycle\";\nexport { useMaxStepsEpochsToggle } from \"./hooks/use-max-steps-epochs-toggle\";\nexport { HfDatasetSubsetSplitSelectors } from \"./components/hf-dataset-subset-split-selectors\";\nexport { useDatasetPreviewDialogStore } from \"./stores/dataset-preview-dialog-store\";\nexport { uploadTrainingDataset } from \"./api/datasets-api\";\nexport { listLocalModels } from \"./api/models-api\";\nexport type { LocalModelInfo } from \"./api/models-api\";\nexport type { TrainingPhase } from \"./types/runtime\";\nexport { parseYamlConfig, serializeConfigToYaml } from \"./lib/yaml-config\";\nexport { validateTrainingConfig } from \"./lib/validation\";\n"
  },
  {
    "path": "studio/frontend/src/features/training/stores/dataset-preview-dialog-store.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { create } from \"zustand\";\nimport type { CheckFormatResponse } from \"../types/datasets\";\n\nexport type DatasetPreviewDialogMode = \"preview\" | \"mapping\";\n\ntype DatasetPreviewDialogState = {\n  open: boolean;\n  mode: DatasetPreviewDialogMode;\n  initialData: CheckFormatResponse | null;\n};\n\ntype DatasetPreviewDialogActions = {\n  openPreview: () => void;\n  openMapping: (data: CheckFormatResponse) => void;\n  close: () => void;\n};\n\nconst initialState: DatasetPreviewDialogState = {\n  open: false,\n  mode: \"preview\",\n  initialData: null,\n};\n\nexport const useDatasetPreviewDialogStore = create<\n  DatasetPreviewDialogState & DatasetPreviewDialogActions\n>()((set) => ({\n  ...initialState,\n\n  openPreview: () => set({ open: true, mode: \"preview\", initialData: null }),\n  openMapping: (data) => set({ open: true, mode: \"mapping\", initialData: data }),\n  close: () => set({ open: false, initialData: null, mode: \"preview\" }),\n}));\n\n"
  },
  {
    "path": "studio/frontend/src/features/training/stores/training-config-store.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { DEFAULT_HYPERPARAMS, STEPS } from \"@/config/training\";\nimport { authFetch } from \"@/features/auth\";\nimport type { ModelType, StepNumber, TrainingMethod } from \"@/types/training\";\nimport { create } from \"zustand\";\nimport { persist } from \"zustand/middleware\";\nimport { checkDatasetFormat } from \"../api/datasets-api\";\nimport { checkVisionModel, getModelConfig } from \"../api/models-api\";\nimport { mapBackendModelConfigToTrainingPatch } from \"../lib/model-defaults\";\nimport type { BackendModelConfig } from \"../api/models-api\";\nimport type { TrainingConfigState, TrainingConfigStore } from \"../types/config\";\n\nconst MIN_STEP: StepNumber = 1;\nconst MAX_STEP: StepNumber = STEPS.length as StepNumber;\n\n/**\n * Auto-select LoRA (16-bit) vs QLoRA (4-bit) based on model size and GPU memory.\n *\n * Rule: if model_size_gb * 1.5 * context_scale fits in free VRAM, use \"lora\" (16-bit).\n * Otherwise use \"qlora\" (4-bit).\n *\n * Context scale: <=8192 = 1.0, >8192 = 1.7, >=16384 = 2.0, >=32768 = 4.0\n */\nasync function autoSelectTrainingMethod(\n  modelSizeBytes: number,\n  contextLength: number,\n): Promise<TrainingMethod | null> {\n  try {\n    const res = await authFetch(\"/api/system/hardware\");\n    if (!res.ok) return null;\n    const data = await res.json();\n    const freeGb: number | null = data?.gpu?.vram_free_gb ?? null;\n    if (freeGb == null) return null;\n\n    const modelSizeGb = modelSizeBytes / (1024 ** 3);\n\n    let contextScale = 1.0;\n    if (contextLength >= 32768) contextScale = 4.0;\n    else if (contextLength >= 16384) contextScale = 2.0;\n    else if (contextLength > 8192) contextScale = 1.7;\n\n    const estimatedUsage = modelSizeGb * 1.5 * contextScale;\n    return estimatedUsage <= freeGb ? \"lora\" : \"qlora\";\n  } catch {\n    return null;\n  }\n}\n\nfunction emptyManualMapping(): TrainingConfigState[\"datasetManualMapping\"] {\n  return {};\n}\n\nconst initialState: TrainingConfigState = {\n  currentStep: MIN_STEP,\n  modelType: null,\n  selectedModel: null,\n  trainingMethod: \"qlora\",\n  hfToken: \"\",\n  datasetSource: \"huggingface\",\n  datasetFormat: \"auto\",\n  dataset: null,\n  datasetSubset: null,\n  datasetSplit: null,\n  datasetEvalSplit: null,\n  datasetManualMapping: emptyManualMapping(),\n  datasetSystemPrompt: \"\",\n  datasetUserTemplate: \"\",\n  datasetAssistantTemplate: \"\",\n  datasetLabelMapping: {},\n  datasetAdvisorNotification: null,\n  datasetSliceStart: null,\n  datasetSliceEnd: null,\n  uploadedFile: null,\n  uploadedEvalFile: null,\n  isCheckingVision: false,\n  isVisionModel: false,\n  isEmbeddingModel: false,\n  isAudioModel: false,\n  isLoadingModelDefaults: false,\n  modelDefaultsError: null,\n  modelDefaultsAppliedFor: null,\n  isCheckingDataset: false,\n  isDatasetImage: null,\n  isDatasetAudio: false,\n  maxPositionEmbeddings: null,\n  ...DEFAULT_HYPERPARAMS,\n};\n\n// AbortController for in-flight dataset multimodal checks.\nlet _datasetCheckController: AbortController | null = null;\n\n// AbortController for in-flight model default loads.\nlet _modelConfigController: AbortController | null = null;\n\n// Track whether the user has manually toggled trainOnCompletions\n// since the last auto-set (model load or dataset change).\nlet _trainOnCompletionsManuallySet = false;\n\nconst NON_PERSISTED_STATE_KEYS: ReadonlySet<keyof TrainingConfigState> = new Set([\n  \"modelType\",\n  \"isCheckingVision\",\n  \"isEmbeddingModel\",\n  \"isAudioModel\",\n  \"isLoadingModelDefaults\",\n  \"modelDefaultsError\",\n  \"modelDefaultsAppliedFor\",\n  \"isCheckingDataset\",\n  \"isDatasetImage\",\n  \"isDatasetAudio\",\n  \"trainOnCompletions\",\n  \"maxPositionEmbeddings\",\n]);\n\nfunction partializePersistedState(\n  state: TrainingConfigStore,\n): Partial<TrainingConfigStore> {\n  return Object.fromEntries(\n    Object.entries(state).filter(([key]) => {\n      const stateKey = key as keyof TrainingConfigState;\n      return !NON_PERSISTED_STATE_KEYS.has(stateKey);\n    }),\n  ) as Partial<TrainingConfigStore>;\n}\n\nfunction clampStep(step: number): StepNumber {\n  return Math.min(MAX_STEP, Math.max(MIN_STEP, step)) as StepNumber;\n}\n\nfunction canProceedForStep(state: TrainingConfigState): boolean {\n  switch (state.currentStep) {\n    case 1:\n      return state.modelType !== null;\n    case 2:\n      return state.selectedModel !== null;\n    case 3:\n      return state.datasetSource === \"upload\"\n        ? state.uploadedFile !== null\n        : state.dataset !== null;\n    case 4:\n    case 5:\n      return true;\n    default:\n      return false;\n  }\n}\n\nexport const useTrainingConfigStore = create<TrainingConfigStore>()(\n  persist(\n    (set, get) => {\n      const loadAndApplyModelDefaults = (modelName: string) => {\n        _modelConfigController?.abort();\n        const controller = new AbortController();\n        _modelConfigController = controller;\n        set({\n          isLoadingModelDefaults: true,\n          isCheckingVision: true,\n          modelDefaultsError: null,\n        });\n\n        void getModelConfig(modelName, controller.signal, get().hfToken || undefined)\n          .then((modelDetails) => {\n            if (controller.signal.aborted) return;\n            if (get().selectedModel !== modelName) return;\n\n            _trainOnCompletionsManuallySet = false;\n            const patch = mapBackendModelConfigToTrainingPatch(modelDetails.config);\n\n            // If vision model + image dataset already known, override\n            // trainOnCompletions to false regardless of backend default.\n            if (modelDetails.is_vision && get().isDatasetImage === true) {\n              patch.trainOnCompletions = false;\n            }\n\n            const isAudio = !!modelDetails.is_audio;\n            // Pure audio model → always uncheck trainOnCompletions.\n            if (isAudio && !modelDetails.is_vision) {\n              patch.trainOnCompletions = false;\n            }\n            // Audio-capable vision model (e.g. gemma3n) + audio dataset → uncheck.\n            if (isAudio && modelDetails.is_vision && get().isDatasetAudio) {\n              patch.trainOnCompletions = false;\n            }\n\n            // Use backend-provided model_type when available, otherwise\n            // infer from capability flags.\n            const isEmbedding = !!modelDetails.is_embedding;\n            const inferredModelType: ModelType = modelDetails.model_type\n              ?? (isEmbedding ? \"embeddings\" : modelDetails.is_vision ? \"vision\" : modelDetails.is_audio ? \"audio\" : \"text\");\n\n            // Auto-select training method based on model size vs GPU memory.\n            // If model_size * 1.5 * context_scale fits in free VRAM, use LoRA 16-bit.\n            // Otherwise use QLoRA 4-bit.\n            const modelSizeBytes = modelDetails.model_size_bytes;\n            if (modelSizeBytes && modelSizeBytes > 0) {\n              void autoSelectTrainingMethod(modelSizeBytes, patch.contextLength ?? get().contextLength)\n                .then((method) => {\n                  if (get().selectedModel !== modelName) return;\n                  if (method) set({ trainingMethod: method });\n                });\n            }\n\n            set({\n              ...patch,\n              modelType: inferredModelType,\n              isVisionModel: modelDetails.is_vision,\n              isEmbeddingModel: isEmbedding,\n              isAudioModel: isAudio,\n              isLoadingModelDefaults: false,\n              isCheckingVision: false,\n              modelDefaultsError: null,\n              modelDefaultsAppliedFor: modelName,\n              maxPositionEmbeddings: modelDetails.max_position_embeddings ?? null,\n            });\n          })\n          .catch((error) => {\n            if (controller.signal.aborted) return;\n            if (get().selectedModel !== modelName) return;\n\n            set({\n              isLoadingModelDefaults: false,\n              isEmbeddingModel: false,\n              isAudioModel: false,\n              modelDefaultsError:\n                error instanceof Error\n                  ? error.message\n                  : \"Failed to load model defaults\",\n            });\n\n            // Fallback vision check if config endpoint fails.\n            void checkVisionModel(modelName)\n              .then((isVision) => {\n                if (get().selectedModel !== modelName) return;\n                set({\n                  modelType: isVision ? \"vision\" : \"text\",\n                  isVisionModel: isVision,\n                  isEmbeddingModel: false,\n                  isAudioModel: false,\n                  isCheckingVision: false,\n                });\n              })\n              .catch(() => {\n                if (get().selectedModel !== modelName) return;\n                set({ isCheckingVision: false, isEmbeddingModel: false, isAudioModel: false });\n              });\n          });\n      };\n\n      const runDatasetCheck = (datasetName: string, split: string) => {\n        _datasetCheckController?.abort();\n        const controller = new AbortController();\n        _datasetCheckController = controller;\n        set({ isCheckingDataset: true });\n\n        const state = get();\n        checkDatasetFormat({\n          datasetName,\n          hfToken: state.hfToken.trim() || null,\n          subset: state.datasetSubset,\n          split,\n          isVlm: state.isVisionModel,\n        })\n          .then((res) => {\n            if (controller.signal.aborted) return;\n            const isImage = !!res.is_image;\n            const isAudio = !!res.is_audio;\n            const updates: Record<string, unknown> = {\n              isDatasetImage: isImage,\n              isDatasetAudio: isAudio,\n              isCheckingDataset: false,\n            };\n            if (!_trainOnCompletionsManuallySet) {\n              const { isVisionModel, isAudioModel } = get();\n              if (isVisionModel && isImage) {\n                updates.trainOnCompletions = false;\n              }\n              // Pure audio model → always uncheck regardless of dataset.\n              if (isAudioModel && !isVisionModel) {\n                updates.trainOnCompletions = false;\n              }\n              // Audio-capable vision model (e.g. gemma3n) + audio dataset → uncheck.\n              if (isAudioModel && isVisionModel && isAudio) {\n                updates.trainOnCompletions = false;\n              }\n            }\n            set(updates);\n          })\n          .catch(() => {\n            if (controller.signal.aborted) return;\n            set({ isDatasetImage: null, isCheckingDataset: false });\n          });\n      };\n\n      const resetDatasetState = (): Partial<TrainingConfigStore> => ({\n        datasetSubset: null,\n        datasetSplit: null,\n        datasetEvalSplit: null,\n        datasetManualMapping: emptyManualMapping(),\n        datasetSystemPrompt: \"\",\n        datasetUserTemplate: \"\",\n        datasetAssistantTemplate: \"\",\n        datasetLabelMapping: {},\n        datasetAdvisorNotification: null,\n        datasetSliceStart: null,\n        datasetSliceEnd: null,\n        uploadedEvalFile: null,\n        isDatasetImage: null,\n        isDatasetAudio: false,\n        isCheckingDataset: false,\n      });\n\n      return {\n        ...initialState,\n        setStep: (step) => set({ currentStep: step }),\n        nextStep: () => set({ currentStep: clampStep(get().currentStep + 1) }),\n        prevStep: () => set({ currentStep: clampStep(get().currentStep - 1) }),\n        setModelType: (modelType) => {\n          _modelConfigController?.abort();\n          _modelConfigController = null;\n\n          set({\n            modelType,\n            selectedModel: null,\n            isCheckingVision: false,\n            isVisionModel: false,\n            isEmbeddingModel: false,\n            isAudioModel: false,\n            isDatasetAudio: false,\n            isLoadingModelDefaults: false,\n            modelDefaultsError: null,\n            modelDefaultsAppliedFor: null,\n          });\n        },\n        setSelectedModel: (selectedModel) => {\n          const previousModel = get().selectedModel;\n          set({ selectedModel, modelDefaultsError: null });\n\n          if (!selectedModel) {\n            _modelConfigController?.abort();\n            _modelConfigController = null;\n            set({\n              isCheckingVision: false,\n              isVisionModel: false,\n              isEmbeddingModel: false,\n              isAudioModel: false,\n              isDatasetAudio: false,\n              isLoadingModelDefaults: false,\n              modelDefaultsError: null,\n              modelDefaultsAppliedFor: null,\n            });\n            return;\n          }\n\n          const shouldLoadDefaults =\n            selectedModel !== previousModel ||\n            get().modelDefaultsAppliedFor !== selectedModel;\n          if (shouldLoadDefaults) {\n            void loadAndApplyModelDefaults(selectedModel);\n          }\n        },\n        ensureModelDefaultsLoaded: () => {\n          const state = get();\n          if (!state.selectedModel) return;\n          if (state.isLoadingModelDefaults) return;\n          if (state.modelDefaultsAppliedFor === state.selectedModel) return;\n          void loadAndApplyModelDefaults(state.selectedModel);\n        },\n        setTrainingMethod: (trainingMethod) => set({ trainingMethod }),\n        setHfToken: (hfToken) =>\n          set({ hfToken: hfToken.trim().replace(/^[\"']+|[\"']+$/g, \"\") }),\n        setDatasetSource: (datasetSource) => set({ datasetSource }),\n        selectHfDataset: (dataset) => {\n          _datasetCheckController?.abort();\n          _datasetCheckController = null;\n          _trainOnCompletionsManuallySet = false;\n          set({\n            datasetSource: \"huggingface\",\n            dataset,\n            uploadedFile: null,\n            ...resetDatasetState(),\n          });\n        },\n        selectLocalDataset: (uploadedFile) => {\n          _datasetCheckController?.abort();\n          _datasetCheckController = null;\n          _trainOnCompletionsManuallySet = false;\n          set({\n            datasetSource: \"upload\",\n            dataset: null,\n            uploadedFile,\n            ...resetDatasetState(),\n          });\n          if (uploadedFile) {\n            runDatasetCheck(uploadedFile, \"train\");\n          }\n        },\n        setDatasetFormat: (datasetFormat) => set({ datasetFormat }),\n        setDataset: (dataset) => {\n          _datasetCheckController?.abort();\n          _datasetCheckController = null;\n          _trainOnCompletionsManuallySet = false;\n          set({\n            dataset,\n            datasetSubset: null,\n            datasetSplit: null,\n            datasetEvalSplit: null,\n            datasetManualMapping: emptyManualMapping(),\n            datasetSliceStart: null,\n            datasetSliceEnd: null,\n            isDatasetImage: null,\n            isDatasetAudio: false,\n            isCheckingDataset: false,\n          });\n        },\n        setDatasetSubset: (datasetSubset) => {\n          _datasetCheckController?.abort();\n          _datasetCheckController = null;\n          _trainOnCompletionsManuallySet = false;\n          set({\n            datasetSubset,\n            datasetSplit: null,\n            datasetEvalSplit: null,\n            datasetManualMapping: emptyManualMapping(),\n            isDatasetImage: null,\n            isDatasetAudio: false,\n            isCheckingDataset: false,\n          });\n        },\n        setDatasetSplit: (datasetSplit) => {\n          set({\n            datasetSplit,\n            datasetManualMapping: emptyManualMapping(),\n            isDatasetImage: null,\n            isDatasetAudio: false,\n            isCheckingDataset: false,\n          });\n\n          const state = get();\n          const datasetName =\n            state.datasetSource === \"huggingface\"\n              ? state.dataset\n              : state.uploadedFile;\n          if (!datasetName) return;\n\n          runDatasetCheck(datasetName, datasetSplit || \"train\");\n        },\n        ensureDatasetChecked: () => {\n          const state = get();\n          if (state.isCheckingDataset) return;\n          if (state.isDatasetImage !== null) return;\n\n          const datasetName =\n            state.datasetSource === \"huggingface\"\n              ? state.dataset\n              : state.uploadedFile;\n          if (!datasetName) return;\n\n          const split = state.datasetSplit || \"train\";\n          runDatasetCheck(datasetName, split);\n        },\n        setDatasetEvalSplit: (datasetEvalSplit) => {\n          set({\n            datasetEvalSplit,\n            evalSteps: datasetEvalSplit ? 0.1 : 0,\n          });\n        },\n        setDatasetManualMapping: (datasetManualMapping) =>\n          set({ datasetManualMapping }),\n        setDatasetAdvisorFields: (fields) =>\n          set({\n            datasetSystemPrompt: fields.systemPrompt ?? get().datasetSystemPrompt,\n            datasetUserTemplate: \"\",  // templates no longer used\n            datasetAssistantTemplate: \"\",  // templates no longer used\n            datasetLabelMapping: fields.labelMapping ?? get().datasetLabelMapping,\n            datasetAdvisorNotification: fields.notification !== undefined ? fields.notification : get().datasetAdvisorNotification,\n          }),\n        clearDatasetAdvisorFields: () =>\n          set({\n            datasetSystemPrompt: \"\",\n            datasetUserTemplate: \"\",\n            datasetAssistantTemplate: \"\",\n            datasetLabelMapping: {},\n            datasetAdvisorNotification: null,\n          }),\n        setDatasetSliceStart: (datasetSliceStart) => set({ datasetSliceStart }),\n        setDatasetSliceEnd: (datasetSliceEnd) => set({ datasetSliceEnd }),\n        setUploadedFile: (uploadedFile) => {\n          _datasetCheckController?.abort();\n          _datasetCheckController = null;\n          _trainOnCompletionsManuallySet = false;\n          set({\n            uploadedFile,\n            datasetSubset: null,\n            datasetSplit: null,\n            datasetEvalSplit: null,\n            datasetManualMapping: emptyManualMapping(),\n            datasetSliceStart: null,\n            datasetSliceEnd: null,\n            uploadedEvalFile: null,\n            isDatasetImage: null,\n            isDatasetAudio: false,\n            isCheckingDataset: false,\n          });\n        },\n        setUploadedEvalFile: (uploadedEvalFile) => set({\n          uploadedEvalFile,\n          evalSteps: uploadedEvalFile ? 0.1 : 0,\n        }),\n        setEpochs: (epochs) => set({ epochs }),\n        setContextLength: (contextLength) => set({ contextLength }),\n        setLearningRate: (learningRate) => set({ learningRate }),\n        setOptimizerType: (optimizerType) => set({ optimizerType }),\n        setLrSchedulerType: (lrSchedulerType) => set({ lrSchedulerType }),\n        setLoraRank: (loraRank) => set({ loraRank }),\n        setLoraAlpha: (loraAlpha) => set({ loraAlpha }),\n        setLoraDropout: (loraDropout) => set({ loraDropout }),\n        setLoraVariant: (loraVariant) => set({ loraVariant }),\n        setBatchSize: (batchSize) => set({ batchSize }),\n        setGradientAccumulation: (gradientAccumulation) =>\n          set({ gradientAccumulation }),\n        setWeightDecay: (weightDecay) => set({ weightDecay }),\n        setWarmupSteps: (warmupSteps) => set({ warmupSteps }),\n        setMaxSteps: (maxSteps) => set({ maxSteps }),\n        setSaveSteps: (saveSteps) => set({ saveSteps }),\n        setEvalSteps: (evalSteps) => set({ evalSteps }),\n        setPacking: (packing) => set({ packing }),\n        setTrainOnCompletions: (trainOnCompletions) => {\n          _trainOnCompletionsManuallySet = true;\n          set({ trainOnCompletions });\n        },\n        setGradientCheckpointing: (gradientCheckpointing) =>\n          set({ gradientCheckpointing }),\n        setRandomSeed: (randomSeed) => set({ randomSeed }),\n        setEnableWandb: (enableWandb) => set({ enableWandb }),\n        setWandbToken: (wandbToken) => set({ wandbToken }),\n        setWandbProject: (wandbProject) => set({ wandbProject }),\n        setEnableTensorboard: (enableTensorboard) => set({ enableTensorboard }),\n        setTensorboardDir: (tensorboardDir) => set({ tensorboardDir }),\n        setLogFrequency: (logFrequency) => set({ logFrequency }),\n        setFinetuneVisionLayers: (finetuneVisionLayers) =>\n          set({ finetuneVisionLayers }),\n        setFinetuneLanguageLayers: (finetuneLanguageLayers) =>\n          set({ finetuneLanguageLayers }),\n        setFinetuneAttentionModules: (finetuneAttentionModules) =>\n          set({ finetuneAttentionModules }),\n        setFinetuneMLPModules: (finetuneMLPModules) =>\n          set({ finetuneMLPModules }),\n        setTargetModules: (targetModules) => set({ targetModules }),\n        canProceed: () => canProceedForStep(get()),\n        reset: () => set(initialState),\n        resetToModelDefaults: () => {\n          const { selectedModel } = get();\n          if (!selectedModel) return;\n          set({ modelDefaultsAppliedFor: null });\n          loadAndApplyModelDefaults(selectedModel);\n        },\n        applyConfigPatch: (config: BackendModelConfig) => {\n          const patch = mapBackendModelConfigToTrainingPatch(config);\n          set(patch);\n        },\n      };\n    },\n    {\n      name: \"unsloth_training_config_v1\",\n      version: 8,\n      migrate: (persisted, version) => {\n        const s = persisted as Record<string, unknown>;\n        if (version < 2 && s.datasetSubset == null && s.datasetConfig != null) {\n          s.datasetSubset = s.datasetConfig;\n        }\n        delete s.datasetConfig;\n        if (version < 3 && s.modelDefaultsAppliedFor == null) {\n          s.modelDefaultsAppliedFor = null;\n        }\n        if (version < 4 && s.optimizerType == null) {\n          s.optimizerType = DEFAULT_HYPERPARAMS.optimizerType;\n        }\n        if (version < 5 && s.lrSchedulerType == null) {\n          s.lrSchedulerType = DEFAULT_HYPERPARAMS.lrSchedulerType;\n        }\n        if (version < 6 && s.datasetEvalSplit == null) {\n          s.datasetEvalSplit = null;\n        }\n        if (version < 7) {\n          s.datasetSliceStart ??= null;\n          s.datasetSliceEnd ??= null;\n        }\n        if (version < 8) {\n          s.datasetSystemPrompt ??= \"\";\n          s.datasetUserTemplate ??= \"\";\n          s.datasetAssistantTemplate ??= \"\";\n          s.datasetLabelMapping ??= {};\n          s.datasetAdvisorNotification ??= null;\n        }\n        return s as unknown as TrainingConfigStore;\n      },\n      partialize: partializePersistedState,\n    },\n  ),\n);\n"
  },
  {
    "path": "studio/frontend/src/features/training/stores/training-runtime-store.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { create } from \"zustand\";\nimport type {\n  TrainingMetricsResponse,\n  TrainingProgressPayload,\n  TrainingRuntimeState,\n  TrainingRuntimeStore,\n  TrainingSeriesPoint,\n  TrainingStatusResponse,\n} from \"../types/runtime\";\n\nconst initialState: TrainingRuntimeState = {\n  jobId: null,\n  phase: \"idle\",\n  isTrainingRunning: false,\n  evalEnabled: false,\n  message: \"Ready to train\",\n  error: null,\n  isHydrating: false,\n  hasHydrated: false,\n  isStarting: false,\n  startError: null,\n  sseConnected: false,\n  firstStepReceived: false,\n  lastEventId: null,\n  currentStep: 0,\n  totalSteps: 0,\n  currentEpoch: 0,\n  currentLoss: 0,\n  currentLearningRate: 0,\n  progressPercent: 0,\n  elapsedSeconds: null,\n  etaSeconds: null,\n  currentGradNorm: null,\n  currentNumTokens: null,\n  lossHistory: [],\n  lrHistory: [],\n  gradNormHistory: [],\n  evalLossHistory: [],\n  resetGeneration: 0,\n  stopRequested: false,\n};\n\nfunction sortSeries(points: TrainingSeriesPoint[]): TrainingSeriesPoint[] {\n  return [...points].sort((a, b) => a.step - b.step);\n}\n\nfunction toSeries(steps: number[], values: number[]): TrainingSeriesPoint[] {\n  const points: TrainingSeriesPoint[] = [];\n  for (let i = 0; i < steps.length; i += 1) {\n    const step = steps[i];\n    const value = values[i];\n    if (!Number.isFinite(step) || !Number.isFinite(value)) {\n      continue;\n    }\n    points.push({ step, value });\n  }\n  return sortSeries(points);\n}\n\nfunction toFiniteNumber(value: unknown): number | null {\n  if (typeof value !== \"number\") return null;\n  return Number.isFinite(value) ? value : null;\n}\n\nfunction upsertPoint(\n  points: TrainingSeriesPoint[],\n  step: number,\n  value: number,\n): TrainingSeriesPoint[] {\n  const next = points.slice();\n  const index = next.findIndex((point) => point.step === step);\n  if (index >= 0) {\n    next[index] = { step, value };\n    return next;\n  }\n  next.push({ step, value });\n  return sortSeries(next);\n}\n\nfunction applyMetricHistoryFromStatus(payload: TrainingStatusResponse): {\n  lossHistory: TrainingSeriesPoint[] | null;\n  lrHistory: TrainingSeriesPoint[] | null;\n  gradNormHistory: TrainingSeriesPoint[] | null;\n  evalLossHistory: TrainingSeriesPoint[] | null;\n} {\n  const history = payload.metric_history;\n  if (!history || !history.steps?.length) {\n    return {\n      lossHistory: null,\n      lrHistory: null,\n      gradNormHistory: null,\n      evalLossHistory: null,\n    };\n  }\n\n  const steps = history.steps;\n  const lossHistory = history.loss ? toSeries(steps, history.loss) : null;\n  const lrHistory = history.lr ? toSeries(steps, history.lr) : null;\n  const gradNormHistory =\n    history.grad_norm && history.grad_norm_steps\n      ? toSeries(history.grad_norm_steps, history.grad_norm)\n      : null;\n  const evalLossHistory =\n    history.eval_loss && history.eval_steps\n      ? toSeries(history.eval_steps, history.eval_loss)\n      : null;\n\n  return { lossHistory, lrHistory, gradNormHistory, evalLossHistory };\n}\n\nexport const useTrainingRuntimeStore = create<TrainingRuntimeStore>()((set) => ({\n  ...initialState,\n\n  setStopRequested: (value) => set({ stopRequested: value }),\n  setHydrating: (value) => set({ isHydrating: value }),\n  setHasHydrated: (value) => set({ hasHydrated: value }),\n  setStarting: (value) => set({ isStarting: value }),\n  setStartError: (value) => set({ startError: value }),\n  setSseConnected: (value) => set({ sseConnected: value }),\n  setLastEventId: (value) => set({ lastEventId: value }),\n\n  resetRuntime: () =>\n    set((state) => ({\n      ...initialState,\n      lossHistory: [],\n      lrHistory: [],\n      gradNormHistory: [],\n      evalLossHistory: [],\n      resetGeneration: state.resetGeneration + 1,\n    })),\n\n  setStartQueued: (jobId, message) =>\n    set((state) => ({\n      ...state,\n      jobId,\n      message,\n      error: null,\n      startError: null,\n      phase: \"configuring\",\n      isStarting: false,\n      sseConnected: false,\n      firstStepReceived: false,\n      lastEventId: null,\n      currentStep: 0,\n      totalSteps: 0,\n      currentEpoch: 0,\n      currentLoss: 0,\n      currentLearningRate: 0,\n      progressPercent: 0,\n      elapsedSeconds: null,\n      etaSeconds: null,\n      currentGradNorm: null,\n      currentNumTokens: null,\n      lossHistory: [],\n      lrHistory: [],\n      gradNormHistory: [],\n      evalLossHistory: [],\n      resetGeneration: state.resetGeneration + 1,\n    })),\n\n  setRuntimeError: (message) =>\n    set({\n      error: message,\n      phase: \"error\",\n      isStarting: false,\n      startError: null,\n      sseConnected: false,\n    }),\n\n  applyStatus: (payload) =>\n    set((state) => {\n      const metricHistory = applyMetricHistoryFromStatus(payload);\n      const detailStep = payload.details?.step;\n      const detailTotal = payload.details?.total_steps;\n      const detailLoss = payload.details?.loss;\n      const detailLr = payload.details?.learning_rate;\n      const detailEpoch = payload.details?.epoch;\n      const stopRequested =\n        payload.is_training_running ? state.stopRequested : false;\n\n      return {\n        ...state,\n        jobId: payload.job_id || state.jobId,\n        phase: payload.phase,\n        isTrainingRunning: payload.is_training_running,\n        stopRequested,\n        evalEnabled: payload.eval_enabled ?? state.evalEnabled,\n        message: payload.message,\n        error: payload.error,\n        currentStep:\n          typeof detailStep === \"number\" ? Math.max(detailStep, 0) : state.currentStep,\n        totalSteps:\n          typeof detailTotal === \"number\"\n            ? Math.max(detailTotal, 0)\n            : state.totalSteps,\n        currentLoss:\n          typeof detailLoss === \"number\" ? detailLoss : state.currentLoss,\n        currentLearningRate:\n          typeof detailLr === \"number\" ? detailLr : state.currentLearningRate,\n        currentEpoch:\n          typeof detailEpoch === \"number\" ? detailEpoch : state.currentEpoch,\n        lossHistory: metricHistory.lossHistory ?? state.lossHistory,\n        lrHistory: metricHistory.lrHistory ?? state.lrHistory,\n        gradNormHistory: metricHistory.gradNormHistory ?? state.gradNormHistory,\n        evalLossHistory: metricHistory.evalLossHistory ?? state.evalLossHistory,\n      };\n    }),\n\n  applyMetrics: (payload: TrainingMetricsResponse) =>\n    set((state) => {\n      const lossHistory = toSeries(payload.step_history, payload.loss_history);\n      const lrHistory = toSeries(payload.step_history, payload.lr_history);\n      const gradNormHistory = toSeries(\n        payload.grad_norm_step_history,\n        payload.grad_norm_history,\n      );\n      const latestStep =\n        payload.current_step ??\n        (payload.step_history.length > 0\n          ? payload.step_history[payload.step_history.length - 1]\n          : null);\n\n      return {\n        ...state,\n        lossHistory: lossHistory.length > 0 ? lossHistory : state.lossHistory,\n        lrHistory: lrHistory.length > 0 ? lrHistory : state.lrHistory,\n        gradNormHistory:\n          gradNormHistory.length > 0 ? gradNormHistory : state.gradNormHistory,\n        currentStep:\n          typeof latestStep === \"number\"\n            ? Math.max(latestStep, state.currentStep)\n            : state.currentStep,\n        currentLoss:\n          typeof payload.current_loss === \"number\"\n            ? payload.current_loss\n            : state.currentLoss,\n        currentLearningRate:\n          typeof payload.current_lr === \"number\"\n            ? payload.current_lr\n            : state.currentLearningRate,\n      };\n    }),\n\n  applyProgress: (payload: TrainingProgressPayload, eventId?: number) =>\n    set((state) => {\n      const step = Math.max(payload.step, 0);\n      const currentLoss = toFiniteNumber(payload.loss);\n      const currentLearningRate = toFiniteNumber(payload.learning_rate);\n      const currentGradNorm = toFiniteNumber(payload.grad_norm);\n      const evalLoss = toFiniteNumber(payload.eval_loss);\n\n      return {\n        ...state,\n        jobId: payload.job_id || state.jobId,\n        currentStep: step,\n        totalSteps: Math.max(payload.total_steps, state.totalSteps),\n        currentLoss: currentLoss ?? state.currentLoss,\n        currentLearningRate: currentLearningRate ?? state.currentLearningRate,\n        progressPercent: payload.progress_percent,\n        currentEpoch: payload.epoch ?? state.currentEpoch,\n        elapsedSeconds: payload.elapsed_seconds,\n        etaSeconds: payload.eta_seconds,\n        currentGradNorm,\n        currentNumTokens: payload.num_tokens,\n        firstStepReceived: state.firstStepReceived || step > 0,\n        lastEventId: typeof eventId === \"number\" ? eventId : state.lastEventId,\n        lossHistory:\n          step > 0 && currentLoss !== null\n            ? upsertPoint(state.lossHistory, step, currentLoss)\n            : state.lossHistory,\n        lrHistory:\n          step > 0 && currentLearningRate !== null\n            ? upsertPoint(state.lrHistory, step, currentLearningRate)\n            : state.lrHistory,\n        gradNormHistory:\n          step > 0 && currentGradNorm !== null\n            ? upsertPoint(state.gradNormHistory, step, currentGradNorm)\n            : state.gradNormHistory,\n        evalLossHistory:\n          step > 0 && evalLoss !== null\n            ? upsertPoint(state.evalLossHistory, step, evalLoss)\n            : state.evalLossHistory,\n      };\n    }),\n}));\n\nexport function shouldShowTrainingView(state: TrainingRuntimeStore): boolean {\n  return (\n    state.phase !== \"idle\" ||\n    state.isTrainingRunning ||\n    state.isStarting ||\n    state.lossHistory.length > 0 ||\n    state.currentStep > 0\n  );\n}\n"
  },
  {
    "path": "studio/frontend/src/features/training/types/api.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nexport interface TrainingStartRequest {\n  model_name: string;\n  training_type: string;\n  hf_token: string | null;\n  load_in_4bit: boolean;\n  max_seq_length: number;\n  /** Allow loading models with custom code. Only enable for repos you trust. */\n  trust_remote_code?: boolean;\n  hf_dataset: string | null;\n  subset: string | null;\n  train_split: string | null;\n  eval_split: string | null;\n  dataset_slice_start: number | null;\n  dataset_slice_end: number | null;\n  local_datasets: string[];\n  local_eval_datasets: string[];\n  format_type: string;\n  custom_format_mapping?: Record<string, unknown> | null;\n  num_epochs: number;\n  learning_rate: string;\n  batch_size: number;\n  gradient_accumulation_steps: number;\n  warmup_steps: number | null;\n  warmup_ratio: number | null;\n  max_steps: number | null;\n  save_steps: number;\n  eval_steps: number;\n  weight_decay: number;\n  random_seed: number;\n  packing: boolean;\n  optim: string;\n  lr_scheduler_type: string;\n  use_lora: boolean;\n  lora_r: number;\n  lora_alpha: number;\n  lora_dropout: number;\n  target_modules: string[];\n  gradient_checkpointing: string;\n  use_rslora: boolean;\n  use_loftq: boolean;\n  train_on_completions: boolean;\n  finetune_vision_layers: boolean;\n  finetune_language_layers: boolean;\n  finetune_attention_modules: boolean;\n  finetune_mlp_modules: boolean;\n  is_dataset_image: boolean;\n  is_dataset_audio: boolean;\n  is_embedding: boolean;\n  enable_wandb: boolean;\n  wandb_token: string | null;\n  wandb_project: string | null;\n  enable_tensorboard: boolean;\n  tensorboard_dir: string | null;\n}\n\nexport interface TrainingStartResponse {\n  job_id: string;\n  status: \"queued\" | \"error\";\n  message: string;\n  error: string | null;\n}\n\nexport interface TrainingStopResponse {\n  status: \"stopped\" | \"idle\";\n  message: string;\n}\n"
  },
  {
    "path": "studio/frontend/src/features/training/types/config.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type {\n  DatasetFormat,\n  DatasetSource,\n  GradientCheckpointing,\n  ModelType,\n  StepNumber,\n  TrainingMethod,\n} from \"@/types/training\";\nimport type { BackendModelConfig } from \"../api/models-api\";\n\nexport type LoraVariant = \"lora\" | \"rslora\" | \"loftq\";\n\n/** Column-to-role mapping, e.g. { \"problem\": \"user\", \"solution\": \"assistant\", \"context\": \"system\" } */\nexport type DatasetManualMapping = Record<string, string>;\n\nexport interface TrainingConfigState {\n  currentStep: StepNumber;\n  modelType: ModelType | null;\n  selectedModel: string | null;\n  trainingMethod: TrainingMethod;\n  hfToken: string;\n  datasetSource: DatasetSource;\n  datasetFormat: DatasetFormat;\n  dataset: string | null;\n  datasetSubset: string | null;\n  datasetSplit: string | null;\n  datasetEvalSplit: string | null;\n  datasetManualMapping: DatasetManualMapping;\n  datasetSystemPrompt: string;\n  datasetUserTemplate: string;\n  datasetAssistantTemplate: string;\n  datasetLabelMapping: Record<string, Record<string, string>>;\n  datasetAdvisorNotification: string | null;\n  datasetSliceStart: string | null;\n  datasetSliceEnd: string | null;\n  uploadedFile: string | null;\n  uploadedEvalFile: string | null;\n  epochs: number;\n  contextLength: number;\n  learningRate: number;\n  optimizerType: string;\n  lrSchedulerType: string;\n  loraRank: number;\n  loraAlpha: number;\n  loraDropout: number;\n  loraVariant: LoraVariant;\n  batchSize: number;\n  gradientAccumulation: number;\n  weightDecay: number;\n  warmupSteps: number;\n  maxSteps: number;\n  saveSteps: number;\n  evalSteps: number;\n  packing: boolean;\n  trainOnCompletions: boolean;\n  gradientCheckpointing: GradientCheckpointing;\n  randomSeed: number;\n  enableWandb: boolean;\n  wandbToken: string;\n  wandbProject: string;\n  enableTensorboard: boolean;\n  tensorboardDir: string;\n  logFrequency: number;\n  isCheckingVision: boolean;\n  isVisionModel: boolean;\n  isEmbeddingModel: boolean;\n  isAudioModel: boolean;\n  isLoadingModelDefaults: boolean;\n  modelDefaultsError: string | null;\n  modelDefaultsAppliedFor: string | null;\n  isCheckingDataset: boolean;\n  isDatasetImage: boolean | null;\n  isDatasetAudio: boolean;\n  trustRemoteCode: boolean;\n  finetuneVisionLayers: boolean;\n  finetuneLanguageLayers: boolean;\n  finetuneAttentionModules: boolean;\n  finetuneMLPModules: boolean;\n  targetModules: string[];\n  maxPositionEmbeddings: number | null;\n}\n\nexport interface TrainingConfigActions {\n  setStep: (step: StepNumber) => void;\n  nextStep: () => void;\n  prevStep: () => void;\n  setModelType: (type: ModelType) => void;\n  setSelectedModel: (model: string | null) => void;\n  ensureModelDefaultsLoaded: () => void;\n  ensureDatasetChecked: () => void;\n  setTrainingMethod: (method: TrainingMethod) => void;\n  setHfToken: (token: string) => void;\n  setDatasetSource: (source: DatasetSource) => void;\n  selectHfDataset: (dataset: string | null) => void;\n  selectLocalDataset: (file: string | null) => void;\n  setDatasetFormat: (format: DatasetFormat) => void;\n  setDataset: (dataset: string | null) => void;\n  setDatasetSubset: (subset: string | null) => void;\n  setDatasetSplit: (split: string | null) => void;\n  setDatasetEvalSplit: (split: string | null) => void;\n  setDatasetManualMapping: (mapping: DatasetManualMapping) => void;\n  setDatasetAdvisorFields: (fields: {\n    systemPrompt?: string;\n    labelMapping?: Record<string, Record<string, string>>;\n    notification?: string | null;\n  }) => void;\n  clearDatasetAdvisorFields: () => void;\n  setDatasetSliceStart: (value: string | null) => void;\n  setDatasetSliceEnd: (value: string | null) => void;\n  setUploadedFile: (file: string | null) => void;\n  setUploadedEvalFile: (file: string | null) => void;\n  setEpochs: (epochs: number) => void;\n  setContextLength: (length: number) => void;\n  setLearningRate: (rate: number) => void;\n  setOptimizerType: (value: string) => void;\n  setLrSchedulerType: (value: string) => void;\n  setLoraRank: (rank: number) => void;\n  setLoraAlpha: (alpha: number) => void;\n  setLoraDropout: (dropout: number) => void;\n  setLoraVariant: (variant: LoraVariant) => void;\n  setBatchSize: (value: number) => void;\n  setGradientAccumulation: (value: number) => void;\n  setWeightDecay: (value: number) => void;\n  setWarmupSteps: (value: number) => void;\n  setMaxSteps: (value: number) => void;\n  setSaveSteps: (value: number) => void;\n  setEvalSteps: (value: number) => void;\n  setPacking: (value: boolean) => void;\n  setTrainOnCompletions: (value: boolean) => void;\n  setGradientCheckpointing: (value: GradientCheckpointing) => void;\n  setRandomSeed: (value: number) => void;\n  setEnableWandb: (value: boolean) => void;\n  setWandbToken: (value: string) => void;\n  setWandbProject: (value: string) => void;\n  setEnableTensorboard: (value: boolean) => void;\n  setTensorboardDir: (value: string) => void;\n  setLogFrequency: (value: number) => void;\n  setFinetuneVisionLayers: (value: boolean) => void;\n  setFinetuneLanguageLayers: (value: boolean) => void;\n  setFinetuneAttentionModules: (value: boolean) => void;\n  setFinetuneMLPModules: (value: boolean) => void;\n  setTargetModules: (value: string[]) => void;\n  canProceed: () => boolean;\n  reset: () => void;\n  resetToModelDefaults: () => void;\n  applyConfigPatch: (config: BackendModelConfig) => void;\n}\n\nexport type TrainingConfigStore = TrainingConfigState & TrainingConfigActions;\n"
  },
  {
    "path": "studio/frontend/src/features/training/types/datasets.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nexport type CheckFormatResponse = {\n  requires_manual_mapping: boolean;\n  detected_format: string;\n  columns: string[];\n  suggested_mapping?: Record<string, string> | null;\n  detected_image_column?: string | null;\n  detected_audio_column?: string | null;\n  detected_text_column?: string | null;\n  detected_speaker_column?: string | null;\n  preview_samples?: Record<string, unknown>[] | null;\n  total_rows?: number | null;\n  is_image?: boolean;\n  is_audio?: boolean;\n  multimodal_columns?: string[] | null;\n  warning?: string | null;\n};\n\nexport type UploadDatasetResponse = {\n  filename: string;\n  stored_path: string;\n};\n\nexport type LocalDatasetInfo = {\n  metadata?: {\n    actual_num_records?: number | null;\n    target_num_records?: number | null;\n    total_num_batches?: number | null;\n    num_completed_batches?: number | null;\n    columns?: string[] | null;\n  } | null;\n  id: string;\n  label: string;\n  path: string;\n  rows?: number | null;\n  updated_at?: number | null;\n};\n\nexport type LocalDatasetsResponse = {\n  datasets: LocalDatasetInfo[];\n};\n"
  },
  {
    "path": "studio/frontend/src/features/training/types/runtime.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nexport type TrainingPhase =\n  | \"idle\"\n  | \"downloading_model\"\n  | \"downloading_dataset\"\n  | \"loading_model\"\n  | \"loading_dataset\"\n  | \"configuring\"\n  | \"training\"\n  | \"completed\"\n  | \"error\"\n  | \"stopped\";\n\nexport interface TrainingStatusResponse {\n  job_id: string;\n  phase: TrainingPhase;\n  is_training_running: boolean;\n  eval_enabled: boolean;\n  message: string;\n  error: string | null;\n  details?: {\n    epoch?: number;\n    step?: number;\n    total_steps?: number;\n    loss?: number;\n    learning_rate?: number;\n  } | null;\n  metric_history?: {\n    steps?: number[];\n    loss?: number[];\n    lr?: number[];\n    grad_norm?: number[];\n    grad_norm_steps?: number[];\n    eval_loss?: number[];\n    eval_steps?: number[];\n  } | null;\n}\n\nexport interface TrainingMetricsResponse {\n  loss_history: number[];\n  lr_history: number[];\n  step_history: number[];\n  grad_norm_history: number[];\n  grad_norm_step_history: number[];\n  current_loss: number | null;\n  current_lr: number | null;\n  current_step: number | null;\n}\n\nexport interface TrainingProgressPayload {\n  job_id: string;\n  step: number;\n  total_steps: number;\n  loss: number;\n  learning_rate: number;\n  progress_percent: number;\n  epoch: number | null;\n  elapsed_seconds: number | null;\n  eta_seconds: number | null;\n  grad_norm: number | null;\n  num_tokens: number | null;\n  eval_loss: number | null;\n}\n\nexport interface TrainingSeriesPoint {\n  step: number;\n  value: number;\n}\n\nexport interface TrainingRuntimeState {\n  jobId: string | null;\n  phase: TrainingPhase;\n  isTrainingRunning: boolean;\n  evalEnabled: boolean;\n  message: string;\n  error: string | null;\n  isHydrating: boolean;\n  hasHydrated: boolean;\n  isStarting: boolean;\n  startError: string | null;\n  sseConnected: boolean;\n  firstStepReceived: boolean;\n  lastEventId: number | null;\n  currentStep: number;\n  totalSteps: number;\n  currentEpoch: number;\n  currentLoss: number;\n  currentLearningRate: number;\n  progressPercent: number;\n  elapsedSeconds: number | null;\n  etaSeconds: number | null;\n  currentGradNorm: number | null;\n  currentNumTokens: number | null;\n  lossHistory: TrainingSeriesPoint[];\n  lrHistory: TrainingSeriesPoint[];\n  gradNormHistory: TrainingSeriesPoint[];\n  evalLossHistory: TrainingSeriesPoint[];\n  resetGeneration: number;\n  stopRequested: boolean;\n}\n\nexport interface TrainingRuntimeActions {\n  setStopRequested: (value: boolean) => void;\n  setHydrating: (value: boolean) => void;\n  setHasHydrated: (value: boolean) => void;\n  setStarting: (value: boolean) => void;\n  setStartError: (value: string | null) => void;\n  setSseConnected: (value: boolean) => void;\n  setLastEventId: (value: number | null) => void;\n  resetRuntime: () => void;\n  applyStatus: (payload: TrainingStatusResponse) => void;\n  applyMetrics: (payload: TrainingMetricsResponse) => void;\n  applyProgress: (payload: TrainingProgressPayload, eventId?: number) => void;\n  setStartQueued: (jobId: string, message: string) => void;\n  setRuntimeError: (message: string) => void;\n}\n\nexport type TrainingRuntimeStore = TrainingRuntimeState & TrainingRuntimeActions;\n"
  },
  {
    "path": "studio/frontend/src/hooks/index.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nexport { useDebouncedValue } from \"./use-debounced-value\";\nexport { useGpuInfo } from \"./use-gpu-info\";\nexport { useGpuUtilization } from \"./use-gpu-utilization\";\nexport { useHardwareInfo } from \"./use-hardware-info\";\nexport { useHfModelSearch } from \"./use-hf-model-search\";\nexport { useRecommendedModelVram } from \"./use-recommended-model-vram\";\nexport { useHfDatasetSearch } from \"./use-hf-dataset-search\";\nexport { useHfDatasetSplits } from \"./use-hf-dataset-splits\";\nexport { useHfTokenValidation } from \"./use-hf-token-validation\";\nexport { useInfiniteScroll } from \"./use-infinite-scroll\";\n"
  },
  {
    "path": "studio/frontend/src/hooks/use-debounced-value.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { useEffect, useState } from \"react\";\n\nexport function useDebouncedValue<T>(value: T, delayMs = 300): T {\n  const [debounced, setDebounced] = useState(value);\n  useEffect(() => {\n    const id = setTimeout(() => setDebounced(value), delayMs);\n    return () => clearTimeout(id);\n  }, [value, delayMs]);\n  return debounced;\n}\n"
  },
  {
    "path": "studio/frontend/src/hooks/use-gpu-info.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { useEffect, useState } from \"react\";\n\nexport interface GpuInfo {\n  available: boolean;\n  name: string;\n  memoryTotalGb: number;\n  systemRamAvailableGb: number;\n}\n\nconst DEFAULT_GPU: GpuInfo = {\n  available: false,\n  name: \"Unknown\",\n  memoryTotalGb: 0,\n  systemRamAvailableGb: 0,\n};\n\n// Module-level cache so multiple components share one fetch.\nlet cachedGpu: GpuInfo | null = null;\nlet fetchPromise: Promise<GpuInfo> | null = null;\n\nasync function fetchGpuOnce(): Promise<GpuInfo> {\n  if (cachedGpu) return cachedGpu;\n  if (fetchPromise) return fetchPromise;\n\n  fetchPromise = (async () => {\n    try {\n      const res = await fetch(\"/api/system\");\n      if (!res.ok) throw new Error(`HTTP ${res.status}`);\n      const data = await res.json();\n      const gpuData = data?.gpu;\n      if (!gpuData?.available || !gpuData.devices?.length) return DEFAULT_GPU;\n      const devices = gpuData.devices as Array<{ name?: string; memory_total_gb?: number }>;\n      const totalGb = devices.reduce((sum, d) => sum + (d.memory_total_gb ?? 0), 0);\n      const info: GpuInfo = {\n        available: true,\n        name: devices[0]?.name ?? \"Unknown\",\n        memoryTotalGb: totalGb,\n        systemRamAvailableGb: data?.memory?.available_gb ?? 0,\n      };\n      cachedGpu = info;\n      return info;\n    } catch {\n      // Reset promise so subsequent calls retry (e.g. backend wasn't ready)\n      fetchPromise = null;\n      return DEFAULT_GPU;\n    }\n  })();\n\n  return fetchPromise;\n}\n\n/**\n * Fetch GPU info from the backend /api/system endpoint.\n *\n * The result is cached at module level -- only one network request is made\n * regardless of how many components call this hook.\n */\nexport function useGpuInfo(): GpuInfo {\n  const [gpu, setGpu] = useState<GpuInfo>(cachedGpu ?? DEFAULT_GPU);\n\n  useEffect(() => {\n    if (cachedGpu) return;\n\n    let cancelled = false;\n    fetchGpuOnce().then((info) => {\n      if (!cancelled) setGpu(info);\n    });\n    return () => { cancelled = true; };\n  }, []);\n\n  return gpu;\n}\n"
  },
  {
    "path": "studio/frontend/src/hooks/use-gpu-utilization.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { authFetch } from \"@/features/auth\";\nimport { useEffect, useRef, useState } from \"react\";\n\nexport interface GpuUtilization {\n    available: boolean;\n    backend: string | null;\n    gpu_utilization_pct: number | null;\n    temperature_c: number | null;\n    vram_used_gb: number | null;\n    vram_total_gb: number | null;\n    vram_utilization_pct: number | null;\n    power_draw_w: number | null;\n    power_limit_w: number | null;\n    power_utilization_pct: number | null;\n}\n\nconst DEFAULT: GpuUtilization = {\n    available: false,\n    backend: null,\n    gpu_utilization_pct: null,\n    temperature_c: null,\n    vram_used_gb: null,\n    vram_total_gb: null,\n    vram_utilization_pct: null,\n    power_draw_w: null,\n    power_limit_w: null,\n    power_utilization_pct: null,\n};\n\n/**\n * Poll `GET /api/train/hardware` for live GPU utilization stats.\n *\n * Only polls while `enabled` is true (i.e. training is running).\n * Polling interval defaults to 10 000 ms.\n */\nexport function useGpuUtilization(\n    enabled: boolean,\n    intervalMs = 10_000,\n): GpuUtilization {\n    const [data, setData] = useState<GpuUtilization>(DEFAULT);\n    const timerRef = useRef<ReturnType<typeof setInterval> | null>(null);\n\n    useEffect(() => {\n        if (!enabled) {\n            // Reset when training stops so the cards show \"--\" again\n            setData(DEFAULT);\n            return;\n        }\n\n        let cancelled = false;\n\n        async function poll() {\n            try {\n                const res = await authFetch(\"/api/train/hardware\");\n                if (!res.ok || cancelled) return;\n                const json = (await res.json()) as GpuUtilization;\n                if (!cancelled) setData(json);\n            } catch {\n                // Silently ignore — next poll will retry\n            }\n        }\n\n        // Fetch immediately, then set up interval\n        void poll();\n        timerRef.current = setInterval(() => void poll(), intervalMs);\n\n        return () => {\n            cancelled = true;\n            if (timerRef.current) clearInterval(timerRef.current);\n        };\n    }, [enabled, intervalMs]);\n\n    return data;\n}\n"
  },
  {
    "path": "studio/frontend/src/hooks/use-hardware-info.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { authFetch } from \"@/features/auth\";\nimport { useEffect, useState } from \"react\";\n\nexport interface HardwareInfo {\n    gpuName: string | null;\n    vramTotalGb: number | null;\n    vramFreeGb: number | null;\n    torch: string | null;\n    cuda: string | null;\n    transformers: string | null;\n    unsloth: string | null;\n}\n\nconst DEFAULT: HardwareInfo = {\n    gpuName: null,\n    vramTotalGb: null,\n    vramFreeGb: null,\n    torch: null,\n    cuda: null,\n    transformers: null,\n    unsloth: null,\n};\n\n// Module-level cache so multiple components share one fetch.\nlet cached: HardwareInfo | null = null;\nlet fetchPromise: Promise<HardwareInfo> | null = null;\n\nasync function fetchOnce(): Promise<HardwareInfo> {\n    if (cached) return cached;\n    if (fetchPromise) return fetchPromise;\n\n    fetchPromise = (async () => {\n        try {\n            const res = await authFetch(\"/api/system/hardware\");\n            if (!res.ok) throw new Error(`HTTP ${res.status}`);\n            const data = await res.json();\n            const info: HardwareInfo = {\n                gpuName: data?.gpu?.gpu_name ?? null,\n                vramTotalGb: data?.gpu?.vram_total_gb ?? null,\n                vramFreeGb: data?.gpu?.vram_free_gb ?? null,\n                torch: data?.versions?.torch ?? null,\n                cuda: data?.versions?.cuda ?? null,\n                transformers: data?.versions?.transformers ?? null,\n                unsloth: data?.versions?.unsloth ?? null,\n            };\n            cached = info;\n            return info;\n        } catch {\n            // Reset promise so subsequent calls retry (e.g. backend wasn't ready)\n            fetchPromise = null;\n            return DEFAULT;\n        }\n    })();\n\n    return fetchPromise;\n}\n\n/**\n * Fetch hardware info from `GET /api/system/hardware`.\n *\n * The result is cached at module level — only one network request is made\n * regardless of how many components call this hook.\n */\nexport function useHardwareInfo(): HardwareInfo {\n    const [info, setInfo] = useState<HardwareInfo>(cached ?? DEFAULT);\n\n    useEffect(() => {\n        if (cached) return;\n\n        let cancelled = false;\n        fetchOnce().then((hw) => {\n            if (!cancelled) setInfo(hw);\n        });\n        return () => { cancelled = true; };\n    }, []);\n\n    return info;\n}\n"
  },
  {
    "path": "studio/frontend/src/hooks/use-hf-dataset-search.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { listDatasets } from \"@huggingface/hub\";\nimport { useCallback, useMemo } from \"react\";\nimport type { ModelType } from \"@/types/training\";\nimport { useHfPaginatedSearch } from \"./use-hf-paginated-search\";\n\ninterface DatasetInfoSplit {\n  name: string;\n  // biome-ignore lint/style/useNamingConvention: external schema\n  num_bytes: number;\n  // biome-ignore lint/style/useNamingConvention: external schema\n  num_examples: number;\n}\n\ninterface CardDataWithInfo {\n  size_categories?: string[];\n  pretty_name?: string;\n  dataset_info?:\n    | {\n        splits?: DatasetInfoSplit[];\n        download_size?: number;\n        dataset_size?: number;\n      }\n    | Array<{ splits?: DatasetInfoSplit[] }>;\n}\n\nfunction extractTotalExamples(\n  cardData: CardDataWithInfo | undefined,\n): number | undefined {\n  if (!cardData?.dataset_info) {\n    return undefined;\n  }\n\n  const infos = Array.isArray(cardData.dataset_info)\n    ? cardData.dataset_info\n    : [cardData.dataset_info];\n\n  const examples = infos\n    .flatMap((info) => info.splits ?? [])\n    .filter((s) => typeof s.num_examples === \"number\")\n    .map((s) => s.num_examples);\n\n  return examples.length > 0 ? examples.reduce((a, b) => a + b, 0) : undefined;\n}\n\nexport interface HfDatasetResult {\n  id: string;\n  downloads: number;\n  likes: number;\n  totalExamples?: number;\n  sizeCategory?: string;\n  taskCategories: string[];\n  plainTags: string[];\n}\n\nfunction mapDataset(raw: unknown): HfDatasetResult {\n  const ds = raw as {\n    name: string;\n    downloads: number;\n    likes: number;\n    tags?: string[];\n    cardData?: unknown;\n  };\n  const card = ds.cardData as CardDataWithInfo | undefined;\n  const tags = ds.tags ?? [];\n  const taskCategories = tags\n    .filter((t) => t.startsWith(\"task_categories:\"))\n    .map((t) => t.slice(\"task_categories:\".length));\n  const plainTags = tags.filter((t) => !t.includes(\":\"));\n  return {\n    id: ds.name,\n    downloads: ds.downloads,\n    likes: ds.likes,\n    totalExamples: extractTotalExamples(card),\n    sizeCategory: card?.size_categories?.[0],\n    taskCategories,\n    plainTags,\n  };\n}\n\nfunction withTrendingSort(\n  input: Parameters<typeof fetch>[0],\n  init?: Parameters<typeof fetch>[1],\n): ReturnType<typeof fetch> {\n  const rawUrl =\n    typeof input === \"string\"\n      ? input\n      : input instanceof URL\n        ? input.toString()\n        : input.url;\n  const url = new URL(rawUrl);\n\n  if (!url.searchParams.has(\"sort\")) {\n    url.searchParams.set(\"sort\", \"trendingScore\");\n  }\n  if (!url.searchParams.has(\"direction\")) {\n    url.searchParams.set(\"direction\", \"-1\");\n  }\n\n  return fetch(url, init);\n}\n\ntype DatasetRelevance = \"incompatible\" | \"neutral\" | \"boosted\";\n\nconst BOOSTED_TASK_CATEGORIES: Record<ModelType, Set<string>> = {\n  text: new Set([\n    \"text-generation\",\n    \"text2text-generation\",\n    \"question-answering\",\n    \"summarization\",\n    \"conversational\",\n  ]),\n  vision: new Set([\n    \"image-text-to-text\",\n    \"visual-question-answering\",\n    \"image-to-text\",\n    \"image-captioning\",\n  ]),\n  audio: new Set([\n    \"text-to-speech\",\n    \"text-to-audio\",\n    \"automatic-speech-recognition\",\n  ]),\n  embeddings: new Set([\n    \"feature-extraction\",\n    \"sentence-similarity\",\n    \"text-retrieval\",\n  ]),\n};\n\nconst INCOMPATIBLE_TASKS_ALL_MODELS = new Set([\n  \"text-to-3d\",\n  \"image-to-3d\",\n  \"robotics\",\n  \"reinforcement-learning\",\n  \"tabular-classification\",\n  \"tabular-regression\",\n  \"time-series-forecasting\",\n]);\n\nconst PRETRAINING_PLAIN_TAGS = new Set([\"pretraining\", \"pre-training\"]);\nconst OCR_PLAIN_TAGS = new Set([\"ocr\", \"document-ocr\"]);\n\nconst PRETRAINING_SIZE_CATEGORIES = new Set([\n  \"5M<n<10M\",\n  \"10M<n<100M\",\n  \"100M<n<1B\",\n  \"1B<n<10B\",\n  \"10B<n<100B\",\n  \"100B<n<1T\",\n  \"n>1T\",\n]);\n\nconst OCR_OR_VISION_TEXT_TASKS = new Set([\n  \"image-to-text\",\n  \"image-captioning\",\n  \"visual-question-answering\",\n  \"document-question-answering\",\n]);\n\nconst CURATED_EMPTY_QUERY_DATASET_IDS: Partial<Record<ModelType, string[]>> = {\n  text: [\n    \"unsloth/alpaca-cleaned\",\n    \"unsloth/OpenMathReasoning-mini\",\n    \"mlabonne/FineTome-100k\",\n    \"openai/gsm8k\",\n    \"philschmid/guanaco-sharegpt-style\",\n    \"open-r1/DAPO-Math-17k-Processed\",\n    \"HuggingFaceH4/Multilingual-Thinking\",\n    \"HuggingFaceH4/ultrafeedback_binarized\",\n    \"reciperesearch/dolphin-sft-v0.1-preference\",\n    \"roneneldan/TinyStories\",\n    \"FreedomIntelligence/alpaca-gpt4-korean\",\n    \"Goedel-LM/SFT_dataset_v2\",\n    \"allenai/tulu-3-sft-mixture\",\n    \"HuggingFaceH4/no_robots\",\n    \"Magpie-Align/Magpie-Air-300K-Filtered\",\n    \"teknium/OpenHermes-2.5\",\n    \"databricks/databricks-dolly-15k\",\n    \"tatsu-lab/alpaca\",\n    \"garage-bAInd/Open-Platypus\",\n    \"microsoft/orca-math-word-problems-200k\",\n    \"Open-Orca/OpenOrca\",\n    \"openbmb/UltraInteract_sft\",\n  ],\n  vision: [\n    \"unsloth/LaTeX_OCR\",\n    \"unsloth/llava-instruct-mix-vsft-mini\",\n    \"unsloth/Radiology_mini\",\n    \"AI4Math/MathVista\",\n    \"AI4Math/MathVerse\",\n    \"ChongyanChen/VQAonline\",\n    \"lmms-lab/VQAv2\",\n    \"hezarai/parsynth-ocr-200k\",\n  ],\n  audio: [\n    \"MrDragonFox/Elise\",\n    \"keithito/lj_speech\",\n    \"parler-tts/mls_eng_10k\",\n    \"parler-tts/libritts-r-filtered-speaker-descriptions\",\n    \"openslr/librispeech_asr\",\n    \"MikhailT/hifi-tts\",\n    \"mozilla-foundation/common_voice_17_0\",\n    \"facebook/voxpopuli\",\n    \"speechcolab/gigaspeech\",\n    \"kth-tmh/vctk\",\n    \"Wenetspeech4TTS/WenetSpeech4TTS\",\n  ],\n  embeddings: [\n    \"electroglyph/technical\",\n  ],\n};\n\nconst INCOMPATIBLE_TASKS_BY_MODEL: Record<ModelType, Set<string>> = {\n  text: new Set([\n    \"text-to-image\",\n    \"image-to-image\",\n    \"image-to-video\",\n    \"text-to-video\",\n    \"image-classification\",\n    \"image-feature-extraction\",\n    \"image-text-to-image\",\n    \"zero-shot-image-classification\",\n    \"keypoint-detection\",\n    \"object-detection\",\n    \"image-segmentation\",\n    \"depth-estimation\",\n    \"text-to-speech\",\n    \"text-to-audio\",\n    \"audio-classification\",\n    \"audio-to-audio\",\n    \"automatic-speech-recognition\",\n    \"video-classification\",\n    \"visual-document-retrieval\",\n  ]),\n  vision: new Set([\n    \"text-to-speech\",\n    \"text-to-audio\",\n    \"audio-classification\",\n    \"audio-to-audio\",\n    \"automatic-speech-recognition\",\n  ]),\n  audio: new Set([\n    \"text-to-image\",\n    \"image-to-image\",\n    \"image-to-video\",\n    \"text-to-video\",\n    \"image-classification\",\n    \"image-feature-extraction\",\n    \"image-text-to-image\",\n    \"zero-shot-image-classification\",\n    \"keypoint-detection\",\n    \"object-detection\",\n    \"image-segmentation\",\n    \"depth-estimation\",\n    \"video-classification\",\n    \"visual-document-retrieval\",\n  ]),\n  embeddings: new Set([\n    \"text-to-image\",\n    \"image-to-image\",\n    \"image-to-video\",\n    \"text-to-video\",\n    \"image-classification\",\n    \"image-feature-extraction\",\n    \"image-text-to-image\",\n    \"zero-shot-image-classification\",\n    \"keypoint-detection\",\n    \"object-detection\",\n    \"image-segmentation\",\n    \"depth-estimation\",\n    \"text-to-speech\",\n    \"text-to-audio\",\n    \"audio-classification\",\n    \"audio-to-audio\",\n    \"automatic-speech-recognition\",\n    \"video-classification\",\n    \"visual-document-retrieval\",\n  ]),\n};\n\nfunction isPretrainingDataset(dataset: HfDatasetResult): boolean {\n  if (dataset.plainTags.some((t) => PRETRAINING_PLAIN_TAGS.has(t.toLowerCase())))\n    return true;\n  if (\n    dataset.sizeCategory &&\n    PRETRAINING_SIZE_CATEGORIES.has(dataset.sizeCategory)\n  )\n    return true;\n  return false;\n}\n\nfunction rankDatasetRelevance(\n  dataset: HfDatasetResult,\n  modelType: ModelType,\n): DatasetRelevance {\n  if (isPretrainingDataset(dataset)) return \"incompatible\";\n\n  // Keep OCR / vision-text corpora out of non-vision defaults.\n  if (modelType !== \"vision\") {\n    if (\n      dataset.plainTags.some((t) => OCR_PLAIN_TAGS.has(t.toLowerCase())) ||\n      dataset.taskCategories.some((t) => OCR_OR_VISION_TEXT_TASKS.has(t))\n    ) {\n      return \"incompatible\";\n    }\n  }\n\n  const { taskCategories } = dataset;\n  if (taskCategories.length === 0) return \"neutral\";\n\n  const boosted = BOOSTED_TASK_CATEGORIES[modelType];\n  const modelIncompat = INCOMPATIBLE_TASKS_BY_MODEL[modelType];\n\n  if (taskCategories.some((t) => boosted.has(t))) return \"boosted\";\n  if (\n    taskCategories.every(\n      (t) => INCOMPATIBLE_TASKS_ALL_MODELS.has(t) || modelIncompat.has(t),\n    )\n  )\n    return \"incompatible\";\n  return \"neutral\";\n}\n\nfunction isOcrOrVisionTextDataset(dataset: HfDatasetResult): boolean {\n  return (\n    dataset.plainTags.some((t) => OCR_PLAIN_TAGS.has(t.toLowerCase())) ||\n    dataset.taskCategories.some((t) => OCR_OR_VISION_TEXT_TASKS.has(t))\n  );\n}\n\nfunction toCuratedDatasetResult(id: string): HfDatasetResult {\n  // Curated defaults are id-only. This adapter satisfies the shared result shape\n  // used by downstream combobox/ranking code without making extra HF requests.\n  return {\n    id,\n    downloads: 0,\n    likes: 0,\n    taskCategories: [],\n    plainTags: [],\n  };\n}\n\nexport function useHfDatasetSearch(\n  query: string,\n  options?: { modelType?: ModelType | null; accessToken?: string; enabled?: boolean },\n) {\n  const { modelType, accessToken, enabled = true } = options ?? {};\n  const hasQuery = query.trim().length > 0;\n  const useCuratedOnly = !hasQuery && !!modelType;\n  const createIter = useCallback(\n    () => {\n      // Use curated defaults for typed model flows only.\n      if (useCuratedOnly) {\n        return (async function* empty() {})() as AsyncGenerator<unknown>;\n      }\n      return listDatasets({\n        search: hasQuery ? { query } : {},\n        additionalFields: [\"cardData\", \"tags\"],\n        fetch: withTrendingSort,\n        ...(accessToken ? { credentials: { accessToken } } : {}),\n      }) as AsyncGenerator<unknown>;\n    },\n    [useCuratedOnly, hasQuery, query, accessToken],\n  );\n\n  const search = useHfPaginatedSearch(createIter, mapDataset, { enabled });\n\n  const results = useMemo(() => {\n    if (!enabled) return [];\n    const hideOcr = modelType !== \"vision\";\n    const baseResults = hideOcr\n      ? search.results.filter((ds) => !isOcrOrVisionTextDataset(ds))\n      : search.results;\n\n    if (!hasQuery && modelType) {\n      const curatedIds = CURATED_EMPTY_QUERY_DATASET_IDS[modelType] ?? [];\n      return curatedIds.map(toCuratedDatasetResult);\n    }\n\n    if (!modelType) return baseResults;\n\n    const boosted: HfDatasetResult[] = [];\n    const neutral: HfDatasetResult[] = [];\n\n    for (const ds of baseResults) {\n      const relevance = rankDatasetRelevance(ds, modelType);\n      if (relevance === \"boosted\") boosted.push(ds);\n      else if (relevance !== \"incompatible\") neutral.push(ds);\n    }\n\n    return [...boosted, ...neutral];\n  }, [enabled, search.results, modelType, query]);\n\n  return { ...search, results };\n}\n"
  },
  {
    "path": "studio/frontend/src/hooks/use-hf-dataset-splits.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { useCallback, useEffect, useState } from \"react\";\n\n// ---------------------------------------------------------------------------\n// Types\n// ---------------------------------------------------------------------------\n\nexport interface HfSplitEntry {\n  dataset: string;\n  config: string;\n  split: string;\n}\n\nexport interface HfSplitsResponse {\n  splits: HfSplitEntry[];\n  pending: unknown[];\n  failed: unknown[];\n}\n\nexport interface HfDatasetSplitsResult {\n  /** All unique subset names found in the dataset */\n  subsets: string[];\n  /** All split names available for the currently selected subset */\n  splits: string[];\n  /** Raw split entries from the API */\n  entries: HfSplitEntry[];\n  /** Whether the dataset has more than one subset */\n  hasMultipleSubsets: boolean;\n  /** Whether the selected subset has more than one split */\n  hasMultipleSplits: boolean;\n  /** True while the request is in-flight */\n  isLoading: boolean;\n  /** Error message if the fetch failed */\n  error: string | null;\n}\n\nconst HF_SPLITS_API = \"https://datasets-server.huggingface.co/splits\";\n\nfunction normalizeDatasetSplitsError(message: string): string {\n  const normalized = message.toLowerCase();\n\n  // datasets-server returns technical script/runtime details for legacy datasets.\n  if (\n    normalized.includes(\"dataset scripts are no longer supported\") ||\n    normalized.includes(\"runs arbitrary python code\")\n  ) {\n    return \"We can’t load subset/split options for this Hub dataset because it relies on a legacy custom script.\";\n  }\n\n  if (\n    normalized.includes(\"unauthorized\") ||\n    normalized.includes(\"forbidden\") ||\n    normalized.includes(\"access token\") ||\n    normalized.includes(\"private\") ||\n    normalized.includes(\"gated\") ||\n    normalized.includes(\"401\") ||\n    normalized.includes(\"403\")\n  ) {\n    return \"Unable to load dataset splits. This dataset may be private or gated. Add a Hugging Face token with access and try again.\";\n  }\n\n  if (normalized.includes(\"not found\") || normalized.includes(\"404\")) {\n    return \"Dataset not found. Check the dataset name and try again.\";\n  }\n\n  return \"Unable to load dataset split options for this dataset.\";\n}\n\n// ---------------------------------------------------------------------------\n// Hook\n// ---------------------------------------------------------------------------\n\n/**\n * Fetches the available configs (subsets) and splits for a HuggingFace dataset\n * using the datasets-server API.\n *\n * @param datasetName - HF dataset id (e.g. \"ibm/duorc\"), or null to skip.\n * @param selectedSubset - Currently selected subset, used to filter splits.\n * @param options.accessToken - Optional HF access token for gated datasets.\n */\nexport function useHfDatasetSplits(\n  datasetName: string | null,\n  selectedSubset: string | null,\n  options?: { accessToken?: string },\n): HfDatasetSplitsResult {\n  const [entries, setEntries] = useState<HfSplitEntry[]>([]);\n  const [isLoading, setIsLoading] = useState(false);\n  const [error, setError] = useState<string | null>(null);\n\n  \n  const [prevDatasetName, setPrevDatasetName] = useState(datasetName);\n  if (datasetName !== prevDatasetName) {\n    setPrevDatasetName(datasetName);\n    setEntries([]);\n    setError(null);\n  }\n\n  const accessToken = options?.accessToken;\n\n  const fetchSplits = useCallback(\n    async (dataset: string, signal: AbortSignal) => {\n      const url = `${HF_SPLITS_API}?dataset=${encodeURIComponent(dataset)}`;\n      const headers: Record<string, string> = {};\n      if (accessToken) {\n        headers.Authorization = `Bearer ${accessToken}`;\n      }\n\n      const res = await fetch(url, { headers, signal });\n      if (!res.ok) {\n        const body = await res.json().catch(() => null);\n        throw new Error(\n          body?.error || `Failed to fetch splits (${res.status})`,\n        );\n      }\n\n      const data: HfSplitsResponse = await res.json();\n      return data.splits ?? [];\n    },\n    [accessToken],\n  );\n\n  useEffect(() => {\n    if (!datasetName) {\n      setEntries([]);\n      setError(null);\n      setIsLoading(false);\n      return;\n    }\n\n    const controller = new AbortController();\n    setIsLoading(true);\n    setError(null);\n\n    fetchSplits(datasetName, controller.signal)\n      .then((splits) => {\n        if (!controller.signal.aborted) {\n          setEntries(splits);\n          setError(null);\n        }\n      })\n      .catch((err) => {\n        if (!controller.signal.aborted) {\n          const rawErrorMessage =\n            err instanceof Error\n              ? err.message\n              : typeof err === \"string\"\n                ? err\n                : \"Failed to fetch dataset splits\";\n          console.warn(\"[useHfDatasetSplits] Failed to fetch dataset splits\", {\n            datasetName,\n            message: rawErrorMessage,\n            error: err,\n          });\n          setError(normalizeDatasetSplitsError(rawErrorMessage));\n          setEntries([]);\n        }\n      })\n      .finally(() => {\n        if (!controller.signal.aborted) {\n          setIsLoading(false);\n        }\n      });\n\n    return () => controller.abort();\n  }, [datasetName, fetchSplits]);\n\n  // Derive unique subsets\n  const subsets = Array.from(new Set(entries.map((e) => e.config)));\n\n  // Derive splits for the active subset.\n  // If dataset has >1 subset and none is selected yet, return no splits so UI\n  // doesn't auto-pick/show a split before subset is chosen.\n  const activeSubset =\n    selectedSubset ?? (subsets.length === 1 ? subsets[0] : null);\n  const filteredEntries = activeSubset\n    ? entries.filter((e) => e.config === activeSubset)\n    : [];\n  const splits = Array.from(new Set(filteredEntries.map((e) => e.split)));\n\n  return {\n    subsets,\n    splits,\n    entries,\n    hasMultipleSubsets: subsets.length > 1,\n    hasMultipleSplits: activeSubset ? splits.length > 1 : false,\n    isLoading,\n    error,\n  };\n}\n"
  },
  {
    "path": "studio/frontend/src/hooks/use-hf-model-search.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport type { PipelineType } from \"@huggingface/hub\";\nimport { listModels, modelInfo } from \"@huggingface/hub\";\nimport { useCallback, useMemo } from \"react\";\nimport { useHfPaginatedSearch } from \"./use-hf-paginated-search\";\n\nexport interface HfModelResult {\n  id: string;\n  downloads: number;\n  likes: number;\n  totalParams?: number;\n  estimatedSizeBytes?: number;\n}\n\nconst EXCLUDED_TAGS = new Set([\n  \"gptq\",\n  \"awq\",\n  \"exl2\",\n  \"mlx\",\n  \"onnx\",\n  \"openvino\",\n  \"coreml\",\n  \"tflite\",\n  \"ctranslate2\",\n]);\n\n// Embedding / sentence-transformer models ship with onnx/openvino as additional\n// export formats — they should not be excluded by the tag check above.\nconst EMBEDDING_TAGS = new Set([\n  \"sentence-transformers\",\n  \"feature-extraction\",\n]);\n\nfunction withPopularitySort(\n  input: Parameters<typeof fetch>[0],\n  init?: Parameters<typeof fetch>[1],\n): ReturnType<typeof fetch> {\n  const rawUrl =\n    typeof input === \"string\"\n      ? input\n      : input instanceof URL\n        ? input.toString()\n        : input.url;\n  const url = new URL(rawUrl);\n\n  if (!url.searchParams.has(\"sort\")) {\n    url.searchParams.set(\"sort\", \"downloads\");\n  }\n  if (!url.searchParams.has(\"direction\")) {\n    url.searchParams.set(\"direction\", \"-1\");\n  }\n\n  return fetch(url, init);\n}\n\n/** Bytes per parameter for each dtype. */\nconst DTYPE_BYTES: Record<string, number> = {\n  F64: 8, F32: 4, F16: 2, BF16: 2,\n  I64: 8, I32: 4, I16: 2, I8: 1, U8: 1,\n  // Quantized types (4-bit)\n  NF4: 0.5, FP4: 0.5, INT4: 0.5, GPTQ: 0.5,\n};\n\nfunction estimateSizeFromDtypes(\n  params: Record<string, number> | undefined,\n): number | undefined {\n  if (!params) return undefined;\n  let total = 0;\n  for (const [dtype, count] of Object.entries(params)) {\n    const bpp = DTYPE_BYTES[dtype.toUpperCase()] ?? 2; // default BF16\n    total += count * bpp;\n  }\n  return total > 0 ? total : undefined;\n}\n\nfunction makeMapModel(excludeGguf: boolean) {\n  return (raw: unknown): HfModelResult | null => {\n    const m = raw as {\n      name: string;\n      downloads: number;\n      likes: number;\n      safetensors?: { total: number; parameters?: Record<string, number> };\n      tags?: string[];\n    };\n    const isEmbedding = m.tags?.some((t) => EMBEDDING_TAGS.has(t));\n    if (!isEmbedding && m.tags?.some((t) => EXCLUDED_TAGS.has(t))) {\n      return null;\n    }\n    if (excludeGguf && m.tags?.includes(\"gguf\")) {\n      return null;\n    }\n    return {\n      id: m.name,\n      downloads: m.downloads,\n      likes: m.likes,\n      totalParams: m.safetensors?.total,\n      estimatedSizeBytes: estimateSizeFromDtypes(m.safetensors?.parameters),\n    };\n  };\n}\n\n/** Number of unsloth results to pull up-front before yielding general results. */\nconst UNSLOTH_PREFETCH = 20;\n\n/**\n * Creates a merged async generator that yields unsloth-owned models first,\n * then general results (with deduplication).\n */\nasync function* mergedModelIterator(\n  query: string,\n  task?: PipelineType,\n  accessToken?: string,\n): AsyncGenerator<unknown> {\n  const common = {\n    additionalFields: [\"safetensors\", \"tags\"] as (\"safetensors\" | \"tags\")[],\n    fetch: withPopularitySort,\n    ...(accessToken ? { credentials: { accessToken } } : {}),\n  };\n\n  // Fire both iterators immediately (parallel network requests on first pull)\n  const unslothIter = listModels({\n    search: { query, owner: \"unsloth\", ...(task ? { task } : {}) },\n    ...common,\n  });\n  const generalIter = listModels({\n    search: { query, ...(task ? { task } : {}) },\n    ...common,\n  });\n\n  // Phase 1: pull & yield unsloth models first\n  const seen = new Set<string>();\n  let count = 0;\n  for await (const model of unslothIter) {\n    const m = model as { name?: string };\n    if (m.name) seen.add(m.name);\n    yield model;\n    count++;\n    if (count >= UNSLOTH_PREFETCH) break;\n  }\n\n  // Phase 2: yield general results, skipping already-seen unsloth models\n  for await (const model of generalIter) {\n    const m = model as { name?: string };\n    if (m.name && seen.has(m.name)) continue;\n    yield model;\n  }\n}\n\n/**\n * Creates an async generator that yields priority models (fetched individually\n * via modelInfo for full metadata), then the general unsloth listing.\n */\nasync function* priorityThenListingIterator(\n  priorityIds: readonly string[],\n  task?: PipelineType,\n  accessToken?: string,\n): AsyncGenerator<unknown> {\n  const common = {\n    additionalFields: [\"safetensors\", \"tags\"] as (\"safetensors\" | \"tags\")[],\n    fetch: withPopularitySort,\n    ...(accessToken ? { credentials: { accessToken } } : {}),\n  };\n\n  // Phase 1: fetch priority models in parallel via modelInfo\n  const seen = new Set<string>();\n  const settled = await Promise.allSettled(\n    priorityIds.map((id) =>\n      modelInfo({\n        name: id,\n        additionalFields: [\"safetensors\", \"tags\"],\n        ...(accessToken ? { credentials: { accessToken } } : {}),\n      }),\n    ),\n  );\n  for (const result of settled) {\n    if (result.status === \"fulfilled\") {\n      const m = result.value as { name?: string; pipeline_tag?: string };\n      // Skip models that don't match the selected task filter\n      if (task && m.pipeline_tag && m.pipeline_tag !== task) continue;\n      if (m.name) seen.add(m.name);\n      yield result.value;\n    }\n  }\n\n  // Phase 2: yield general unsloth listing, skipping already-seen\n  const generalIter = listModels({\n    search: { owner: \"unsloth\", ...(task ? { task } : {}) },\n    ...common,\n  });\n  for await (const model of generalIter) {\n    const m = model as { name?: string };\n    if (m.name && seen.has(m.name)) continue;\n    yield model;\n  }\n}\n\nexport function useHfModelSearch(\n  query: string,\n  options?: {\n    task?: PipelineType;\n    accessToken?: string;\n    excludeGguf?: boolean;\n    priorityIds?: readonly string[];\n  },\n) {\n  const { task, accessToken, excludeGguf = false, priorityIds } = options ?? {};\n\n  const createIter = useCallback(\n    () => {\n      const trimmed = query.trim();\n      if (!trimmed) {\n        // No query → show priority models first (with full metadata), then general unsloth listing\n        if (priorityIds && priorityIds.length > 0) {\n          return priorityThenListingIterator(priorityIds, task, accessToken) as AsyncGenerator<unknown>;\n        }\n        return listModels({\n          search: { owner: \"unsloth\", ...(task ? { task } : {}) },\n          additionalFields: [\"safetensors\", \"tags\"],\n          fetch: withPopularitySort,\n          ...(accessToken ? { credentials: { accessToken } } : {}),\n        }) as AsyncGenerator<unknown>;\n      }\n      // Typed query: disable task filter so explicitly searched models still appear even if HF task metadata is wrong/missing.\n      return mergedModelIterator(trimmed, undefined, accessToken) as AsyncGenerator<unknown>;\n    },\n    [query, task, accessToken, priorityIds],\n  );\n\n  const mapModel = useMemo(() => makeMapModel(excludeGguf), [excludeGguf]);\n  const search = useHfPaginatedSearch(createIter, mapModel);\n\n  // Secondary sort guarantee: unsloth models always float to the top\n  const results = useMemo(\n    () =>\n      [...search.results].sort((a, b) => {\n        const aFirst = a.id.startsWith(\"unsloth/\") ? 0 : 1;\n        const bFirst = b.id.startsWith(\"unsloth/\") ? 0 : 1;\n        return aFirst - bFirst;\n      }),\n    [search.results],\n  );\n\n  return { ...search, results };\n}\n\n"
  },
  {
    "path": "studio/frontend/src/hooks/use-hf-paginated-search.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { useCallback, useEffect, useRef, useState } from \"react\";\n\ninterface HfPaginatedState<T> {\n  results: T[];\n  isLoading: boolean;\n  isLoadingMore: boolean;\n  hasMore: boolean;\n  error: string | null;\n}\n\nconst INITIAL: HfPaginatedState<never> = {\n  results: [],\n  isLoading: false,\n  isLoadingMore: false,\n  hasMore: false,\n  error: null,\n};\nconst BATCH = 20;\n\nasync function pullBatch<T>(\n  iter: AsyncGenerator<unknown>,\n  mapItem: (raw: unknown) => T | null,\n  size: number,\n) {\n  const items: T[] = [];\n  while (items.length < size) {\n    const result = await iter.next();\n    if (result.done) {\n      return { items, done: true };\n    }\n    const mapped = mapItem(result.value);\n    if (mapped !== null) {\n      items.push(mapped);\n    }\n  }\n  return { items, done: false };\n}\n\nexport function useHfPaginatedSearch<T>(\n  createIter: () => AsyncGenerator<unknown>,\n  mapItem: (raw: unknown) => T | null,\n  options?: { enabled?: boolean },\n): HfPaginatedState<T> & { fetchMore: () => void } {\n  const enabled = options?.enabled ?? true;\n  const [state, setState] = useState<HfPaginatedState<T>>(\n    INITIAL as HfPaginatedState<T>,\n  );\n  const stateRef = useRef(state);\n  useEffect(() => {\n    stateRef.current = state;\n  }, [state]);\n\n  const iterRef = useRef<AsyncGenerator<unknown> | null>(null);\n  const versionRef = useRef(0);\n\n  useEffect(() => {\n    const v = ++versionRef.current;\n    iterRef.current = null;\n\n    if (!enabled) {\n      setState(INITIAL as HfPaginatedState<T>);\n      return;\n    }\n\n    setState({\n      ...(INITIAL as HfPaginatedState<T>),\n      isLoading: true,\n    });\n\n    const iter = createIter();\n    iterRef.current = iter;\n\n    pullBatch(iter, mapItem, BATCH)\n      .then(({ items, done }) => {\n        if (versionRef.current !== v) {\n          return;\n        }\n        setState({\n          results: items,\n          isLoading: false,\n          isLoadingMore: false,\n          hasMore: !done,\n          error: null,\n        });\n      })\n      .catch((err) => {\n        if (versionRef.current !== v) {\n          return;\n        }\n        setState({\n          results: [],\n          isLoading: false,\n          isLoadingMore: false,\n          hasMore: false,\n          error: err instanceof Error ? err.message : \"Search failed\",\n        });\n      });\n  }, [createIter, mapItem, enabled]);\n\n  const fetchMore = useCallback(() => {\n    const iter = iterRef.current;\n    const { isLoading, isLoadingMore, hasMore } = stateRef.current;\n    if (!iter || isLoading || isLoadingMore || !hasMore) {\n      return;\n    }\n\n    const v = versionRef.current;\n    setState((prev) => ({ ...prev, isLoadingMore: true }));\n\n    pullBatch(iter, mapItem, BATCH)\n      .then(({ items, done }) => {\n        if (versionRef.current !== v) {\n          return;\n        }\n        setState((prev) => ({\n          ...prev,\n          results: [...prev.results, ...items],\n          isLoadingMore: false,\n          hasMore: !done,\n        }));\n      })\n      .catch(() => {\n        if (versionRef.current !== v) {\n          return;\n        }\n        setState((prev) => ({ ...prev, isLoadingMore: false, hasMore: false }));\n      });\n  }, [mapItem]);\n\n  return { ...state, fetchMore };\n}\n"
  },
  {
    "path": "studio/frontend/src/hooks/use-hf-token-validation.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { whoAmI } from \"@huggingface/hub\";\nimport { useCallback, useEffect, useRef, useState } from \"react\";\nimport { useDebouncedValue } from \"./use-debounced-value\";\n\nexport interface HfTokenValidationState {\n  isValid: boolean | null;\n  error: string | null;\n  isChecking: boolean;\n}\n\nconst INITIAL: HfTokenValidationState = {\n  isValid: null,\n  error: null,\n  isChecking: false,\n};\n\n/**\n * Validates the Hugging Face token by calling the whoami-v2 API.\n * Debounces the token to avoid excessive requests while typing.\n * Returns validation state: isValid (null = not checked), error message, and isChecking.\n */\nexport function useHfTokenValidation(token: string): HfTokenValidationState {\n  const debouncedToken = useDebouncedValue(\n    token.trim().replace(/^[\"']+|[\"']+$/g, \"\"),\n    500,\n  );\n  const [state, setState] = useState<HfTokenValidationState>(INITIAL);\n  const versionRef = useRef(0);\n\n  const runCheck = useCallback(async (t: string) => {\n    if (!t) {\n      setState({ isValid: null, error: null, isChecking: false });\n      return;\n    }\n\n    const v = ++versionRef.current;\n    setState((prev) => ({ ...prev, isChecking: true, error: null }));\n\n    try {\n      await whoAmI({ accessToken: t });\n      if (versionRef.current !== v) return;\n      setState({ isValid: true, error: null, isChecking: false });\n    } catch {\n      if (versionRef.current !== v) return;\n      setState({\n        isValid: false,\n        error: \"invalid or expired token\",\n        isChecking: false,\n      });\n    }\n  }, []);\n\n  useEffect(() => {\n    if (!debouncedToken) {\n      setState(INITIAL);\n      return;\n    }\n    runCheck(debouncedToken);\n  }, [debouncedToken, runCheck]);\n\n  return state;\n}\n"
  },
  {
    "path": "studio/frontend/src/hooks/use-infinite-scroll.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { useEffect, useRef } from \"react\";\n\nexport function useInfiniteScroll(fetchMore: () => void, _itemCount: number) {\n  const scrollRef = useRef<HTMLDivElement>(null);\n  const sentinelRef = useRef<HTMLDivElement>(null);\n\n  useEffect(() => {\n    const el = sentinelRef.current;\n    if (!el) {\n      return;\n    }\n    const obs = new IntersectionObserver(\n      ([e]) => {\n        if (e.isIntersecting) {\n          fetchMore();\n        }\n      },\n      { threshold: 0, root: scrollRef.current },\n    );\n    obs.observe(el);\n    return () => obs.disconnect();\n  }, [fetchMore]);\n\n  return { scrollRef, sentinelRef };\n}\n"
  },
  {
    "path": "studio/frontend/src/hooks/use-mobile.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { useEffect, useState } from \"react\";\n\nconst MOBILE_BREAKPOINT = 768;\n\nexport function useIsMobile() {\n  const [isMobile, setIsMobile] = useState<boolean | undefined>(undefined);\n\n  useEffect(() => {\n    const mql = window.matchMedia(`(max-width: ${MOBILE_BREAKPOINT - 1}px)`);\n    const onChange = () => {\n      setIsMobile(window.innerWidth < MOBILE_BREAKPOINT);\n    };\n    mql.addEventListener(\"change\", onChange);\n    setIsMobile(window.innerWidth < MOBILE_BREAKPOINT);\n    return () => mql.removeEventListener(\"change\", onChange);\n  }, []);\n\n  return !!isMobile;\n}\n"
  },
  {
    "path": "studio/frontend/src/hooks/use-recommended-model-vram.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { modelInfo } from \"@huggingface/hub\";\nimport { useEffect, useState } from \"react\";\n\n/**\n * Fetches Hugging Face model info (safetensors total param count) for a list of\n * model IDs. Used to show VRAM fit (FIT / TIGHT / OOM) for recommended/default\n * models in the chat model dropdown.\n */\nexport function useRecommendedModelVram(ids: string[]) {\n  const [paramCountById, setParamCountById] = useState<\n    Map<string, number>\n  >(new Map());\n  const [isLoading, setIsLoading] = useState(false);\n\n  const stableKey = [...ids].filter(Boolean).sort().join(\",\");\n\n  useEffect(() => {\n    const stableIds = stableKey ? stableKey.split(\",\") : [];\n    if (stableIds.length === 0) {\n      setParamCountById(new Map());\n      setIsLoading(false);\n      return;\n    }\n    let canceled = false;\n    void (async () => {\n      setIsLoading(true);\n      const next = new Map<string, number>();\n      await Promise.all(\n        stableIds.map(async (id) => {\n          if (canceled) return;\n          try {\n            const info = await modelInfo({\n              name: id,\n              additionalFields: [\"safetensors\"],\n            });\n            const raw = info as { safetensors?: { total?: number } };\n            const total = raw.safetensors?.total;\n            if (typeof total === \"number\" && total > 0) {\n              next.set(id, total);\n            }\n          } catch {\n            // Model not on HF or no safetensors; skip\n          }\n        }),\n      );\n      if (!canceled) {\n        setParamCountById(next);\n        setIsLoading(false);\n      }\n    })();\n    return () => {\n      canceled = true;\n    };\n  }, [stableKey]);\n\n  return { paramCountById, isLoading };\n}\n"
  },
  {
    "path": "studio/frontend/src/index.css",
    "content": "/* SPDX-License-Identifier: AGPL-3.0-only */\n/* Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 */\n\n@import \"tailwindcss\";\n@import \"tw-animate-css\";\n@import \"shadcn/tailwind.css\";\n@import \"streamdown/styles.css\";\n@import \"@fontsource-variable/figtree\";\n@import \"@fontsource-variable/space-grotesk\";\n@import \"@fontsource-variable/inter\";\n@import \"tw-shimmer\";\n@plugin \"@toolwind/corner-shape\";\n@source \"../node_modules/streamdown/dist/*.js\";\n\n@font-face {\n  font-family: \"Hellix\";\n  src: url(\"/fonts/Hellix-SemiBold.woff2\") format(\"woff2\"),\n    url(\"/fonts/Hellix-SemiBold.woff\") format(\"woff\");\n  font-weight: 600;\n  font-style: normal;\n  font-display: swap;\n}\n\n@custom-variant dark (&:is(.dark *));\n\n:root {\n  /* Animation timing */\n  --duration-micro: 100ms;\n  --duration-fast: 150ms;\n  --duration-normal: 200ms;\n\n  /* Easing curves (Emil Kowalski) */\n  --ease-out-quart: cubic-bezier(0.165, 0.84, 0.44, 1);\n  --ease-out-cubic: cubic-bezier(0.215, 0.61, 0.355, 1);\n\n  --background: oklch(1 0 0);\n  --foreground: oklch(0.2686 0 0);\n  --card: oklch(1 0 0);\n  --card-foreground: oklch(0.1281 0.0179 169.2764);\n  --popover: oklch(1 0 0);\n  --popover-foreground: oklch(0.1281 0.0179 169.2764);\n  --primary: oklch(0.6929 0.1396 166.5513);\n  --primary-foreground: oklch(1 0 0);\n  --secondary: oklch(0.9596 0.0275 167.8295);\n  --secondary-foreground: oklch(0.2868 0.0649 159.9823);\n  --muted: oklch(0.9702 0 0);\n  --muted-foreground: oklch(0.5486 0 0);\n  --accent: oklch(0.9596 0.0275 167.8295);\n  --accent-foreground: oklch(0.2868 0.0649 159.9823);\n  --destructive: oklch(0.6368 0.2078 25.3313);\n  --border: oklch(0.9208 0.0101 164.8536);\n  --input: oklch(0.9208 0.0101 164.8536);\n  --ring: oklch(0.6929 0.1396 166.5513);\n  --chart-1: oklch(0.6929 0.1396 166.5513);\n  --chart-2: oklch(0.694 0.1395 136.6059);\n  --chart-3: oklch(0.7014 0.1193 197.5897);\n  --chart-4: oklch(0.6926 0.1112 346.5775);\n  --chart-5: oklch(0.7497 0.1003 85.0057);\n  --radius: 1.2rem;\n  --sidebar: oklch(0.975 0 0);\n  --sidebar-foreground: oklch(0.1281 0.0179 169.2764);\n  --sidebar-primary: oklch(0.6929 0.1396 166.5513);\n  --sidebar-primary-foreground: oklch(1 0 0);\n  --sidebar-accent: oklch(0.96 0.0279 166.55);\n  --sidebar-accent-foreground: oklch(0.2868 0.0649 159.9823);\n  --sidebar-border: oklch(0.9208 0.0101 164.8536);\n  --sidebar-ring: oklch(0.6929 0.1396 166.5513);\n  --destructive-foreground: oklch(1 0 0);\n  --font-sans: \"Inter Variable\", ui-sans-serif, sans-serif, system-ui;\n  --font-heading: \"Hellix\", \"Space Grotesk Variable\", ui-sans-serif, sans-serif;\n  --font-serif: Source Serif 4, serif;\n  --font-mono: JetBrains Mono, monospace;\n  --shadow-color: hsl(0 0% 0%);\n  --shadow-opacity: 0;\n  --shadow-blur: 0px;\n  --shadow-spread: 0px;\n  --shadow-offset-x: 0px;\n  --shadow-offset-y: 0px;\n  --letter-spacing: 0em;\n  --spacing: 0.25rem;\n  /*--shadow-2xs: 0px 0px 0px 0px hsl(0 0% 0% / 0);*/\n  /*--shadow-xs: 0px 0px 0px 0px hsl(0 0% 0% / 0);*/\n  /*--shadow-sm:*/\n  /*    0px 0px 0px 0px hsl(0 0% 0% / 0),*/\n  /*    0px 1px 2px 0px hsl(0 0% 0% / 0);*/\n  /*--shadow:*/\n  /*    0px 0px 0px 0px hsl(0 0% 0% / 0),*/\n  /*    0px 1px 2px 0px hsl(0 0% 0% / 0);*/\n  /*--shadow-md:*/\n  /*    0px 0px 0px 0px hsl(0 0% 0% / 0),*/\n  /*    0px 2px 4px 0px hsl(0 0% 0% / 0);*/\n  /*--shadow-lg:*/\n  /*    0px 0px 0px 0px hsl(0 0% 0% / 0),*/\n  /*    0px 4px 6px 0px hsl(0 0% 0% / 0);*/\n  /*--shadow-xl:*/\n  /*    0px 0px 0px 0px hsl(0 0% 0% / 0),*/\n  /*    0px 8px 10px 0px hsl(0 0% 0% / 0);*/\n  /*--shadow-2xl: 0px 0px 0px 0px hsl(0 0% 0% / 0);*/\n  --tracking-normal: 0em;\n}\n.dark {\n  --background: oklch(0.24 0 0);\n  --foreground: oklch(0.98 0 0);\n  --card: oklch(0.28 0 0);\n  --card-foreground: oklch(0.98 0 0);\n  --popover: oklch(0.28 0 0);\n  --popover-foreground: oklch(0.98 0 0);\n  --primary: oklch(0.6929 0.1396 166.5513);\n  --primary-foreground: oklch(1 0 0);\n  --secondary: oklch(0.33 0 0);\n  --secondary-foreground: oklch(0.98 0 0);\n  --muted: oklch(0.33 0 0);\n  --muted-foreground: oklch(0.70 0 0);\n  --accent: oklch(0.33 0 0);\n  --accent-foreground: oklch(0.98 0 0);\n  --destructive: oklch(0.6368 0.2078 25.3313);\n  --border: oklch(0.38 0 0);\n  --input: oklch(0.38 0 0);\n  --ring: oklch(0.6929 0.1396 166.5513);\n  --chart-1: oklch(0.7511 0.1407 166.2284);\n  --chart-2: oklch(0.75 0.14 136.5572);\n  --chart-3: oklch(0.7554 0.1285 197.339);\n  --chart-4: oklch(0.7503 0.1199 346.7805);\n  --chart-5: oklch(0.799 0.1196 84.6633);\n  --sidebar: oklch(0.24 0 0);\n  --sidebar-foreground: oklch(0.98 0 0);\n  --sidebar-primary: oklch(0.6929 0.1396 166.5513);\n  --sidebar-primary-foreground: oklch(1 0 0);\n  --sidebar-accent: oklch(0.33 0 0);\n  --sidebar-accent-foreground: oklch(0.98 0 0);\n  --sidebar-border: oklch(0.38 0 0);\n  --sidebar-ring: oklch(0.6929 0.1396 166.5513);\n  --destructive-foreground: oklch(1 0 0);\n  --radius: 1.2rem;\n  --font-sans: Geist, ui-sans-serif, sans-serif, system-ui;\n  --font-serif: Source Serif 4, serif;\n  --font-mono: JetBrains Mono, monospace;\n  --shadow-color: hsl(0 0% 0%);\n  --shadow-opacity: 0;\n  --shadow-blur: 0px;\n  --shadow-spread: 0px;\n  --shadow-offset-x: 0px;\n  --shadow-offset-y: 0px;\n  --letter-spacing: 0em;\n  --spacing: 0.25rem;\n  --shadow-2xs: 0px 0px 0px 0px hsl(0 0% 0% / 0);\n  --shadow-xs: 0px 0px 0px 0px hsl(0 0% 0% / 0);\n  --shadow-sm: 0px 0px 0px 0px hsl(0 0% 0% / 0), 0px 1px 2px 0px\n  hsl(0 0% 0% / 0);\n  --shadow: 0px 0px 0px 0px hsl(0 0% 0% / 0), 0px 1px 2px 0px hsl(0 0% 0% / 0);\n  --shadow-md: 0px 0px 0px 0px hsl(0 0% 0% / 0), 0px 2px 4px 0px\n  hsl(0 0% 0% / 0);\n  --shadow-lg: 0px 0px 0px 0px hsl(0 0% 0% / 0), 0px 4px 6px 0px\n  hsl(0 0% 0% / 0);\n  --shadow-xl: 0px 0px 0px 0px hsl(0 0% 0% / 0), 0px 8px 10px 0px\n  hsl(0 0% 0% / 0);\n  --shadow-2xl: 0px 0px 0px 0px hsl(0 0% 0% / 0);\n}\n\n@theme inline {\n  --font-sans: \"Inter Variable\", ui-sans-serif, sans-serif, system-ui;\n  --font-heading: \"Hellix\", \"Space Grotesk Variable\", ui-sans-serif, sans-serif;\n  --color-sidebar-ring: var(--sidebar-ring);\n  --color-sidebar-border: var(--sidebar-border);\n  --color-sidebar-accent-foreground: var(--sidebar-accent-foreground);\n  --color-sidebar-accent: var(--sidebar-accent);\n  --color-sidebar-primary-foreground: var(--sidebar-primary-foreground);\n  --color-sidebar-primary: var(--sidebar-primary);\n  --color-sidebar-foreground: var(--sidebar-foreground);\n  --color-sidebar: var(--sidebar);\n  --color-chart-5: var(--chart-5);\n  --color-chart-4: var(--chart-4);\n  --color-chart-3: var(--chart-3);\n  --color-chart-2: var(--chart-2);\n  --color-chart-1: var(--chart-1);\n  --color-ring: var(--ring);\n  --color-input: var(--input);\n  --color-border: var(--border);\n  --color-destructive: var(--destructive);\n  --color-accent-foreground: var(--accent-foreground);\n  --color-accent: var(--accent);\n  --color-muted-foreground: var(--muted-foreground);\n  --color-muted: var(--muted);\n  --color-secondary-foreground: var(--secondary-foreground);\n  --color-secondary: var(--secondary);\n  --color-primary-foreground: var(--primary-foreground);\n  --color-primary: var(--primary);\n  --color-popover-foreground: var(--popover-foreground);\n  --color-popover: var(--popover);\n  --color-card-foreground: var(--card-foreground);\n  --color-card: var(--card);\n  --color-foreground: var(--foreground);\n  --color-background: var(--background);\n  --radius-sm: calc(var(--radius) - 4px);\n  --radius-md: calc(var(--radius) - 2px);\n  --radius-lg: var(--radius);\n  --radius-xl: calc(var(--radius) + 4px);\n  --radius-2xl: calc(var(--radius) + 8px);\n  --radius-3xl: calc(var(--radius) + 12px);\n  --radius-4xl: calc(var(--radius) + 16px);\n  --font-mono: JetBrains Mono, monospace;\n  --font-serif: Source Serif 4, serif;\n  --radius: 1.2rem;\n  --tracking-tighter: calc(var(--tracking-normal) - 0.05em);\n  --tracking-tight: calc(var(--tracking-normal) - 0.025em);\n  --tracking-wide: calc(var(--tracking-normal) + 0.025em);\n  --tracking-wider: calc(var(--tracking-normal) + 0.05em);\n  --tracking-widest: calc(var(--tracking-normal) + 0.1em);\n  --tracking-normal: var(--tracking-normal);\n  /*--shadow-2xl: var(--shadow-2xl);*/\n  /*--shadow-xl: var(--shadow-xl);*/\n  /*--shadow-lg: var(--shadow-lg);*/\n  /*--shadow-md: var(--shadow-md);*/\n  /*--shadow: var(--shadow);*/\n  /*--shadow-sm: var(--shadow-sm);*/\n  /*--shadow-xs: var(--shadow-xs);*/\n  /*--shadow-2xs: var(--shadow-2xs);*/\n  /*--spacing: var(--spacing);*/\n  /*--letter-spacing: var(--letter-spacing);*/\n  /*--shadow-offset-y: var(--shadow-offset-y);*/\n  /*--shadow-offset-x: var(--shadow-offset-x);*/\n  /*--shadow-spread: var(--shadow-spread);*/\n  /*--shadow-blur: var(--shadow-blur);*/\n  /*--shadow-opacity: var(--shadow-opacity);*/\n  /*--color-shadow-color: var(--shadow-color);*/\n  --color-destructive-foreground: var(--destructive-foreground);\n  --animate-pulse: pulse var(--duration) ease-out infinite;\n  @keyframes pulse {\n    0%,\n    100% {\n      box-shadow: 0 0 0 0 var(--pulse-color);\n    }\n    50% {\n      box-shadow: 0 0 0 8px var(--pulse-color);\n    }\n  }\n  --animate-shiny-text: shiny-text 8s infinite;\n  @keyframes shiny-text {\n    0%,\n    90%,\n    100% {\n      background-position: calc(-100% - var(--shiny-width)) 0;\n    }\n    30%,\n    60% {\n      background-position: calc(100% + var(--shiny-width)) 0;\n    }\n  }\n  --animate-shine: shine var(--duration) infinite linear\n;\n  @keyframes shine {\n  0% {\n    background-position: 0% 0%;\n    }\n  50% {\n    background-position: 100% 100%;\n    }\n  to {\n    background-position: 0% 0%;\n    }\n  }}\n\n@layer base {\n  * {\n    @apply border-border outline-ring/50;\n  }\n  body {\n    @apply font-sans bg-background text-foreground;\n    letter-spacing: var(--tracking-normal);\n  }\n  html {\n    @apply font-sans;\n    scrollbar-gutter: stable;\n  }\n  body[data-scroll-locked] {\n    margin-right: 0 !important;\n  }\n  h1,\n  h2,\n  h3,\n  h4,\n  h5,\n  h6 {\n    font-family: var(--font-heading);\n  }\n  .font-medium,\n  .font-semibold,\n  .font-bold {\n    font-family: var(--font-heading);\n  }\n}\n\n@layer utilities {\n  /* Heading font utility */\n  .font-heading {\n    font-family: var(--font-heading);\n  }\n\n  /* Elevated surface shadow (use ring-* for borders) */\n  .shadow-border {\n    --tw-shadow: 0 4px 16px rgba(0, 0, 0, 0.1);\n    --tw-shadow-colored: 0 4px 16px var(--tw-shadow-color);\n    box-shadow: var(--tw-ring-offset-shadow, 0 0 #0000),\n      var(--tw-ring-shadow, 0 0 #0000), var(--tw-shadow);\n  }\n  .dark .shadow-border {\n    --tw-shadow: 0 4px 16px rgba(0, 0, 0, 0.3);\n  }\n\n  [data-streamdown=\"unordered-list\"] {\n    list-style-type: disc;\n    list-style-position: outside;\n    padding-left: 1.25rem;\n    margin-block: 0.5rem;\n  }\n\n  [data-streamdown=\"ordered-list\"] {\n    list-style-type: decimal;\n    list-style-position: outside;\n    padding-left: 1.25rem;\n    margin-block: 0.5rem;\n  }\n\n  [data-streamdown=\"list-item\"] {\n    display: list-item;\n  }\n\n  /* Flatten code blocks: single border, language label, then code directly */\n  [data-streamdown=\"code-block-body\"] {\n    border: none !important;\n    border-radius: 0 !important;\n    background: transparent !important;\n    padding: 0 !important;\n  }\n  [data-streamdown=\"code-block\"] {\n    gap: 0;\n    padding: 0.5rem;\n  }\n  [data-streamdown=\"code-block-header\"] {\n    padding-left: 0.75rem;\n  }\n}\n\n/* Minimal scrollbar — thumb only, no track */\n* {\n  scrollbar-width: thin;\n  scrollbar-color: transparent transparent;\n}\n*:hover {\n  scrollbar-color: oklch(0.6 0 0 / 0.3) transparent;\n}\n.dark *:hover {\n  scrollbar-color: oklch(0.5 0 0 / 0.35) transparent;\n}\n\n/* Webkit (Chrome, Safari, Edge) */\n::-webkit-scrollbar {\n  width: 6px;\n  height: 6px;\n}\n::-webkit-scrollbar-track {\n  background: transparent;\n}\n::-webkit-scrollbar-thumb {\n  background: transparent;\n  border-radius: 9999px;\n}\n*:hover::-webkit-scrollbar-thumb {\n  background: oklch(0.6 0 0 / 0.3);\n}\n.dark *:hover::-webkit-scrollbar-thumb {\n  background: oklch(0.5 0 0 / 0.35);\n}\n\n/*---break---*/\n\n@layer base {\n  * {\n    @apply border-border outline-ring/50;\n  }\n  body {\n    @apply bg-background text-foreground;\n  }\n}\n\n::view-transition-old(root), ::view-transition-new(root) {\n    animation: none;\n    mix-blend-mode: normal;\n}"
  },
  {
    "path": "studio/frontend/src/main.tsx",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { StrictMode } from \"react\";\nimport { createRoot } from \"react-dom/client\";\n\nimport \"./index.css\";\nimport { fetchDeviceType } from \"./config/env\";\nimport { App } from \"./app/app\";\n\nconst globalCrypto = globalThis.crypto as Crypto | undefined;\n\nif (globalCrypto && typeof globalCrypto.randomUUID !== \"function\") {\n  // Some envs ship `crypto` but no `randomUUID()` (or a non-function stub).\n  // Provide a best-effort v4 UUID using `getRandomValues` when available.\n  const cryptoRef = globalCrypto;\n\n  function getRandomByte(): number {\n    if (typeof cryptoRef.getRandomValues === \"function\") {\n      return cryptoRef.getRandomValues(new Uint8Array(1))[0];\n    }\n    return Math.floor(Math.random() * 256);\n  }\n\n  cryptoRef.randomUUID = (() =>\n    \"10000000-1000-4000-8000-100000000000\".replace(/[018]/g, (c) =>\n      (+c ^ (getRandomByte() & (15 >> (+c / 4)))).toString(16),\n    )) as Crypto[\"randomUUID\"];\n}\n\nconst rootElement = document.getElementById(\"root\");\nif (!rootElement) {\n  throw new Error(\"Root element not found\");\n}\n\nfetchDeviceType().then(() => {\n  createRoot(rootElement).render(\n    <StrictMode>\n      <App />\n    </StrictMode>,\n  );\n});\n"
  },
  {
    "path": "studio/frontend/src/shared/toast.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { toast } from \"sonner\";\n\nexport function toastSuccess(message: string): void {\n  toast.success(message);\n}\n\nexport function toastError(message: string, description?: string): void {\n  toast.error(message, {\n    description,\n  });\n}\n"
  },
  {
    "path": "studio/frontend/src/speech-recognition.d.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n/**\n * Minimal Web Speech API (Speech Recognition) types for browsers that support it.\n * Full types: @types/dom-speech-recognition\n */\ninterface SpeechRecognitionResultList {\n  readonly length: number;\n  item(index: number): SpeechRecognitionResult;\n  [index: number]: SpeechRecognitionResult;\n}\n\ninterface SpeechRecognitionResult {\n  readonly length: number;\n  readonly isFinal: boolean;\n  item(index: number): SpeechRecognitionAlternative;\n  [index: number]: SpeechRecognitionAlternative;\n}\n\ninterface SpeechRecognitionAlternative {\n  readonly transcript: string;\n  readonly confidence: number;\n}\n\ninterface SpeechRecognitionEvent extends Event {\n  readonly resultIndex: number;\n  readonly results: SpeechRecognitionResultList;\n}\n\ninterface SpeechRecognition extends EventTarget {\n  continuous: boolean;\n  interimResults: boolean;\n  lang: string;\n  onresult: ((event: SpeechRecognitionEvent) => void) | null;\n  onerror: ((event: Event) => void) | null;\n  onend: (() => void) | null;\n  start(): void;\n  stop(): void;\n  abort(): void;\n}\n\ninterface SpeechRecognitionConstructor {\n  new (): SpeechRecognition;\n}\n\ninterface Window {\n  SpeechRecognition?: SpeechRecognitionConstructor;\n  webkitSpeechRecognition?: SpeechRecognitionConstructor;\n}\n\ndeclare var SpeechRecognition: SpeechRecognitionConstructor | undefined;\n"
  },
  {
    "path": "studio/frontend/src/stores/index.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n// Global stores\nexport {};\n"
  },
  {
    "path": "studio/frontend/src/stores/training.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport { useTrainingConfigStore } from \"@/features/training\";\n\nexport const useWizardStore = useTrainingConfigStore;\nexport { useTrainingConfigStore };\n"
  },
  {
    "path": "studio/frontend/src/types/index.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n// Shared types\nexport type {};\n"
  },
  {
    "path": "studio/frontend/src/types/training.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nexport type ModelType = \"vision\" | \"audio\" | \"embeddings\" | \"text\";\nexport type TrainingMethod = \"qlora\" | \"lora\" | \"full\";\n\nexport function isAdapterMethod(method: TrainingMethod): boolean {\n  return method === \"lora\" || method === \"qlora\";\n}\nexport type StepNumber = 1 | 2 | 3 | 4 | 5;\nexport type DatasetSource = \"huggingface\" | \"upload\";\nexport type DatasetFormat = \"auto\" | \"alpaca\" | \"chatml\" | \"sharegpt\";\nexport type GradientCheckpointing = \"none\" | \"true\" | \"unsloth\";\n\nexport interface WizardState {\n  currentStep: StepNumber;\n  modelType: ModelType | null;\n  selectedModel: string | null;\n  trainingMethod: TrainingMethod;\n  hfToken: string;\n  datasetSource: DatasetSource;\n  datasetFormat: DatasetFormat;\n  dataset: string | null;\n  datasetSubset: string | null;\n  datasetSplit: string | null;\n  uploadedFile: string | null;\n  epochs: number;\n  contextLength: number;\n  learningRate: number;\n  loraRank: number;\n  loraAlpha: number;\n  loraDropout: number;\n  loraVariant: \"lora\" | \"rslora\" | \"loftq\";\n  batchSize: number;\n  gradientAccumulation: number;\n  weightDecay: number;\n  warmupSteps: number;\n  maxSteps: number;\n  saveSteps: number;\n  packing: boolean;\n  trainOnCompletions: boolean;\n  gradientCheckpointing: GradientCheckpointing;\n  randomSeed: number;\n  enableWandb: boolean;\n  wandbToken: string;\n  wandbProject: string;\n  enableTensorboard: boolean;\n  tensorboardDir: string;\n  logFrequency: number;\n  finetuneVisionLayers: boolean;\n  finetuneLanguageLayers: boolean;\n  finetuneAttentionModules: boolean;\n  finetuneMLPModules: boolean;\n  targetModules: string[];\n}\n\nexport interface WizardActions {\n  setStep: (step: StepNumber) => void;\n  nextStep: () => void;\n  prevStep: () => void;\n  setModelType: (type: ModelType) => void;\n  setSelectedModel: (model: string | null) => void;\n  setTrainingMethod: (method: TrainingMethod) => void;\n  setHfToken: (token: string) => void;\n  setDatasetSource: (source: DatasetSource) => void;\n  setDatasetFormat: (format: DatasetFormat) => void;\n  setDataset: (dataset: string | null) => void;\n  setDatasetSubset: (subset: string | null) => void;\n  setDatasetSplit: (split: string | null) => void;\n  setUploadedFile: (file: string | null) => void;\n  setEpochs: (epochs: number) => void;\n  setContextLength: (length: number) => void;\n  setLearningRate: (rate: number) => void;\n  setLoraRank: (rank: number) => void;\n  setLoraAlpha: (alpha: number) => void;\n  setLoraDropout: (dropout: number) => void;\n  setLoraVariant: (v: \"lora\" | \"rslora\" | \"loftq\") => void;\n  setBatchSize: (v: number) => void;\n  setGradientAccumulation: (v: number) => void;\n  setWeightDecay: (v: number) => void;\n  setWarmupSteps: (v: number) => void;\n  setMaxSteps: (v: number) => void;\n  setSaveSteps: (v: number) => void;\n  setPacking: (v: boolean) => void;\n  setTrainOnCompletions: (v: boolean) => void;\n  setGradientCheckpointing: (v: GradientCheckpointing) => void;\n  setRandomSeed: (v: number) => void;\n  setEnableWandb: (v: boolean) => void;\n  setWandbToken: (v: string) => void;\n  setWandbProject: (v: string) => void;\n  setEnableTensorboard: (v: boolean) => void;\n  setTensorboardDir: (v: string) => void;\n  setLogFrequency: (v: number) => void;\n  setFinetuneVisionLayers: (v: boolean) => void;\n  setFinetuneLanguageLayers: (v: boolean) => void;\n  setFinetuneAttentionModules: (v: boolean) => void;\n  setFinetuneMLPModules: (v: boolean) => void;\n  setTargetModules: (v: string[]) => void;\n  canProceed: () => boolean;\n  reset: () => void;\n}\n\nexport interface StepConfig {\n  number: StepNumber;\n  title: string;\n  subtitle: string;\n  description: string;\n}\n"
  },
  {
    "path": "studio/frontend/src/utils/index.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n// Utility functions\nexport { normalizeNonEmptyName } from \"./strings\";\n"
  },
  {
    "path": "studio/frontend/src/utils/strings.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nexport function normalizeNonEmptyName(\n  value: string,\n  fallback = \"Unnamed\",\n): string {\n  const trimmed = value.trim();\n  return trimmed.length > 0 ? trimmed : fallback;\n}\n\n"
  },
  {
    "path": "studio/frontend/tsconfig.app.json",
    "content": "{\r\n  \"compilerOptions\": {\r\n    \"tsBuildInfoFile\": \"./node_modules/.tmp/tsconfig.app.tsbuildinfo\",\r\n    \"target\": \"ES2022\",\r\n    \"useDefineForClassFields\": true,\r\n    \"lib\": [\"ES2022\", \"DOM\", \"DOM.Iterable\"],\r\n    \"module\": \"ESNext\",\r\n    \"types\": [\"vite/client\", \"react\", \"react-dom\"],\r\n    \"skipLibCheck\": true,\r\n\r\n    /* Bundler mode */\r\n    \"moduleResolution\": \"bundler\",\r\n    \"allowImportingTsExtensions\": true,\r\n    \"verbatimModuleSyntax\": true,\r\n    \"moduleDetection\": \"force\",\r\n    \"noEmit\": true,\r\n    \"jsx\": \"react-jsx\",\r\n\r\n    /* Linting */\r\n    \"strict\": true,\r\n    \"noUnusedLocals\": false,\n    \"noUnusedParameters\": true,\r\n    \"erasableSyntaxOnly\": true,\r\n    \"noFallthroughCasesInSwitch\": true,\r\n    \"noUncheckedSideEffectImports\": true,\r\n    \"baseUrl\": \".\",\r\n    \"paths\": {\r\n      \"@/*\": [\"./src/*\"]\r\n    }\r\n  },\r\n  \"include\": [\"src\"]\r\n}\r\n"
  },
  {
    "path": "studio/frontend/tsconfig.json",
    "content": "{\r\n  \"files\": [],\r\n  \"references\": [\r\n    { \"path\": \"./tsconfig.app.json\" },\r\n    { \"path\": \"./tsconfig.node.json\" }\r\n  ],\r\n  \"compilerOptions\": {\r\n    \"baseUrl\": \".\",\r\n    \"paths\": {\r\n      \"@/*\": [\"./src/*\"]\r\n    }\r\n  }\r\n}\r\n"
  },
  {
    "path": "studio/frontend/tsconfig.node.json",
    "content": "{\r\n  \"compilerOptions\": {\r\n    \"tsBuildInfoFile\": \"./node_modules/.tmp/tsconfig.node.tsbuildinfo\",\r\n    \"target\": \"ES2023\",\r\n    \"lib\": [\"ES2023\"],\r\n    \"module\": \"ESNext\",\r\n    \"types\": [\"node\"],\r\n    \"skipLibCheck\": true,\r\n\r\n    /* Bundler mode */\r\n    \"moduleResolution\": \"bundler\",\r\n    \"allowImportingTsExtensions\": true,\r\n    \"verbatimModuleSyntax\": true,\r\n    \"moduleDetection\": \"force\",\r\n    \"noEmit\": true,\r\n\r\n    /* Linting */\r\n    \"strict\": true,\r\n    \"noUnusedLocals\": true,\r\n    \"noUnusedParameters\": true,\r\n    \"erasableSyntaxOnly\": true,\r\n    \"noFallthroughCasesInSwitch\": true,\r\n    \"noUncheckedSideEffectImports\": true\r\n  },\r\n  \"include\": [\"vite.config.ts\"]\r\n}\r\n"
  },
  {
    "path": "studio/frontend/vite.config.ts",
    "content": "// SPDX-License-Identifier: AGPL-3.0-only\n// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport path from \"node:path\";\nimport tailwindcss from \"@tailwindcss/vite\";\nimport react from \"@vitejs/plugin-react\";\nimport { defineConfig } from \"vite\";\n\n// https://vite.dev/config/\nexport default defineConfig({\n  plugins: [react(), tailwindcss()],\n  optimizeDeps: {\n    include: [\"@dagrejs/dagre\", \"@dagrejs/graphlib\"],\n  },\n  server: {\n    host: \"0.0.0.0\",\n    allowedHosts: true,\n    proxy: {\n      \"/api\": {\n        target: \"http://127.0.0.1:8888\",\n        changeOrigin: true,\n      },\n      \"/v1\": {\n        target: \"http://127.0.0.1:8888\",\n        changeOrigin: true,\n      },\n      \"/seed/inspect\": {\n        target: \"http://127.0.0.1:8004\",\n        changeOrigin: true,\n      },\n      \"/seed/preview\": {\n        target: \"http://127.0.0.1:8004\",\n        changeOrigin: true,\n      },\n      \"/preview\": {\n        target: \"http://127.0.0.1:8004\",\n        changeOrigin: true,\n      },\n      \"/validate\": {\n        target: \"http://127.0.0.1:8004\",\n        changeOrigin: true,\n      },\n      \"/tools\": {\n        target: \"http://127.0.0.1:8004\",\n        changeOrigin: true,\n      },\n    },\n  },\n  resolve: {\n    alias: {\n      \"@\": path.resolve(__dirname, \"./src\"),\n      \"@dagrejs/dagre\": path.resolve(\n        __dirname,\n        \"./node_modules/@dagrejs/dagre/dist/dagre.cjs.js\",\n      ),\n    },\n  },\n  build: {\n    commonjsOptions: {\n      include: [/node_modules/, /@dagrejs\\/dagre/, /@dagrejs\\/graphlib/],\n    },\n  },\n});\n"
  },
  {
    "path": "studio/install_python_stack.py",
    "content": "#!/usr/bin/env python3\n\n# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"Cross-platform Python dependency installer for Unsloth Studio.\n\nCalled by both setup.sh (Linux / WSL) and setup.ps1 (Windows) after the\nvirtual environment is already activated.  Expects `pip` and `python` on\nPATH to point at the venv.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport os\nimport shutil\nimport subprocess\nimport sys\nimport tempfile\nimport urllib.request\nfrom pathlib import Path\n\nIS_WINDOWS = sys.platform == \"win32\"\n\n# ── Verbosity control ──────────────────────────────────────────────────────────\n# By default the installer shows a minimal progress bar (one line, in-place).\n# Set UNSLOTH_VERBOSE=1 in the environment to restore full per-step output:\n#   Linux/Mac:  UNSLOTH_VERBOSE=1 ./studio/setup.sh\n#   Windows:    $env:UNSLOTH_VERBOSE=\"1\" ; .\\studio\\setup.ps1\nVERBOSE: bool = os.environ.get(\"UNSLOTH_VERBOSE\", \"0\") == \"1\"\n\n# Progress bar state — updated by _progress() as each install step runs.\n# _TOTAL counts: pip-upgrade + 7 shared steps + triton (non-Windows) + local-plugin + finalize\n# Update _TOTAL here if you add or remove install steps in install_python_stack().\n_STEP: int = 0\n_TOTAL: int = 0  # set at runtime in install_python_stack() based on platform\n\n# ── Paths ──────────────────────────────────────────────────────────────\nSCRIPT_DIR = Path(__file__).resolve().parent\nREQ_ROOT = SCRIPT_DIR / \"backend\" / \"requirements\"\nSINGLE_ENV = REQ_ROOT / \"single-env\"\nCONSTRAINTS = SINGLE_ENV / \"constraints.txt\"\nLOCAL_DD_UNSTRUCTURED_PLUGIN = (\n    SCRIPT_DIR / \"backend\" / \"plugins\" / \"data-designer-unstructured-seed\"\n)\n\n# ── Color support ──────────────────────────────────────────────────────\n\n\ndef _enable_colors() -> bool:\n    \"\"\"Try to enable ANSI color support. Returns True if available.\"\"\"\n    if not hasattr(sys.stdout, \"fileno\"):\n        return False\n    try:\n        if not os.isatty(sys.stdout.fileno()):\n            return False\n    except Exception:\n        return False\n    if IS_WINDOWS:\n        try:\n            import ctypes\n\n            kernel32 = ctypes.windll.kernel32\n            # Enable ENABLE_VIRTUAL_TERMINAL_PROCESSING (0x0004) on stdout\n            handle = kernel32.GetStdHandle(-11)  # STD_OUTPUT_HANDLE\n            mode = ctypes.c_ulong()\n            kernel32.GetConsoleMode(handle, ctypes.byref(mode))\n            kernel32.SetConsoleMode(handle, mode.value | 0x0004)\n            return True\n        except Exception:\n            return False\n    return True  # Unix terminals support ANSI by default\n\n\n# Colors disabled — Colab and most CI runners render ANSI fine, but plain output\n# is cleaner in the notebook cell. Re-enable by setting _HAS_COLOR = _enable_colors()\n_HAS_COLOR = False\n\n\ndef _green(msg: str) -> str:\n    return f\"\\033[92m{msg}\\033[0m\" if _HAS_COLOR else msg\n\n\ndef _cyan(msg: str) -> str:\n    return f\"\\033[96m{msg}\\033[0m\" if _HAS_COLOR else msg\n\n\ndef _red(msg: str) -> str:\n    return f\"\\033[91m{msg}\\033[0m\" if _HAS_COLOR else msg\n\n\ndef _progress(label: str) -> None:\n    \"\"\"Print an in-place progress bar for the current install step.\n\n    Uses only stdlib (sys.stdout) — no extra packages required.\n    In VERBOSE mode this is a no-op; per-step labels are printed by run() instead.\n    \"\"\"\n    global _STEP\n    _STEP += 1\n    if VERBOSE:\n        return  # verbose mode: run() already printed the label\n    width = 20\n    filled = int(width * _STEP / _TOTAL)\n    bar = \"=\" * filled + \"-\" * (width - filled)\n    end = \"\\n\" if _STEP >= _TOTAL else \"\"  # newline only on the final step\n    sys.stdout.write(f\"\\r[{bar}] {_STEP:2}/{_TOTAL}  {label:<40}{end}\")\n    sys.stdout.flush()\n\n\ndef run(\n    label: str, cmd: list[str], *, quiet: bool = True\n) -> subprocess.CompletedProcess[bytes]:\n    \"\"\"Run a command; on failure print output and exit.\"\"\"\n    if VERBOSE:\n        print(f\"   {label}...\")\n    result = subprocess.run(\n        cmd,\n        stdout = subprocess.PIPE if quiet else None,\n        stderr = subprocess.STDOUT if quiet else None,\n    )\n    if result.returncode != 0:\n        print(_red(f\"❌ {label} failed (exit code {result.returncode}):\"))\n        if result.stdout:\n            print(result.stdout.decode(errors = \"replace\"))\n        sys.exit(result.returncode)\n    return result\n\n\n# Packages to skip on Windows (require special build steps)\nWINDOWS_SKIP_PACKAGES = {\"open_spiel\", \"triton_kernels\"}\n\n# ── uv bootstrap ──────────────────────────────────────────────────────\n\nUSE_UV = False  # Set by _bootstrap_uv() at the start of install_python_stack()\nUV_NEEDS_SYSTEM = False  # Set by _bootstrap_uv() via probe\n\n\ndef _bootstrap_uv() -> bool:\n    \"\"\"Check if uv is available and probe whether --system is needed.\"\"\"\n    global UV_NEEDS_SYSTEM\n    if not shutil.which(\"uv\"):\n        return False\n    # Probe: try a dry-run install targeting the current Python explicitly.\n    # Without --python, uv can ignore the activated venv on some platforms.\n    probe = subprocess.run(\n        [\"uv\", \"pip\", \"install\", \"--dry-run\", \"--python\", sys.executable, \"pip\"],\n        stdout = subprocess.PIPE,\n        stderr = subprocess.STDOUT,\n    )\n    if probe.returncode != 0:\n        # Retry with --system (some envs need it when uv can't find a venv)\n        probe_sys = subprocess.run(\n            [\"uv\", \"pip\", \"install\", \"--dry-run\", \"--system\", \"pip\"],\n            stdout = subprocess.PIPE,\n            stderr = subprocess.STDOUT,\n        )\n        if probe_sys.returncode != 0:\n            return False  # uv is broken, fall back to pip\n        UV_NEEDS_SYSTEM = True\n    return True\n\n\ndef _filter_requirements(req: Path, skip: set[str]) -> Path:\n    \"\"\"Return a temp copy of a requirements file with certain packages removed.\"\"\"\n    lines = req.read_text(encoding = \"utf-8\").splitlines(keepends = True)\n    filtered = [\n        line\n        for line in lines\n        if not any(line.strip().lower().startswith(pkg) for pkg in skip)\n    ]\n    tmp = tempfile.NamedTemporaryFile(\n        mode = \"w\",\n        suffix = \".txt\",\n        delete = False,\n        encoding = \"utf-8\",\n    )\n    tmp.writelines(filtered)\n    tmp.close()\n    return Path(tmp.name)\n\n\ndef _translate_pip_args_for_uv(args: tuple[str, ...]) -> list[str]:\n    \"\"\"Translate pip flags to their uv equivalents.\"\"\"\n    translated: list[str] = []\n    for arg in args:\n        if arg == \"--no-cache-dir\":\n            continue  # uv cache is fast; drop this flag\n        elif arg == \"--force-reinstall\":\n            translated.append(\"--reinstall\")\n        else:\n            translated.append(arg)\n    return translated\n\n\ndef _build_pip_cmd(args: tuple[str, ...]) -> list[str]:\n    \"\"\"Build a standard pip install command.\"\"\"\n    cmd = [sys.executable, \"-m\", \"pip\", \"install\"]\n    cmd.extend(args)\n    return cmd\n\n\ndef _build_uv_cmd(args: tuple[str, ...]) -> list[str]:\n    \"\"\"Build a uv pip install command with translated flags.\"\"\"\n    cmd = [\"uv\", \"pip\", \"install\"]\n    if UV_NEEDS_SYSTEM:\n        cmd.append(\"--system\")\n    # Always pass --python so uv targets the correct environment.\n    # Without this, uv can ignore an activated venv and install into\n    # the system Python (observed on Colab and similar environments).\n    cmd.extend([\"--python\", sys.executable])\n    cmd.extend(_translate_pip_args_for_uv(args))\n    cmd.append(\"--torch-backend=auto\")\n    return cmd\n\n\ndef pip_install(\n    label: str,\n    *args: str,\n    req: Path | None = None,\n    constrain: bool = True,\n) -> None:\n    \"\"\"Build and run a pip install command (uses uv when available, falls back to pip).\"\"\"\n    constraint_args: list[str] = []\n    if constrain and CONSTRAINTS.is_file():\n        constraint_args = [\"-c\", str(CONSTRAINTS)]\n\n    actual_req = req\n    if req is not None and IS_WINDOWS and WINDOWS_SKIP_PACKAGES:\n        actual_req = _filter_requirements(req, WINDOWS_SKIP_PACKAGES)\n    req_args: list[str] = []\n    if actual_req is not None:\n        req_args = [\"-r\", str(actual_req)]\n\n    try:\n        if USE_UV:\n            uv_cmd = _build_uv_cmd(args) + constraint_args + req_args\n            if VERBOSE:\n                print(f\"   {label}...\")\n            result = subprocess.run(\n                uv_cmd,\n                stdout = subprocess.PIPE,\n                stderr = subprocess.STDOUT,\n            )\n            if result.returncode == 0:\n                return\n            print(_red(f\"   uv failed, falling back to pip...\"))\n            if result.stdout:\n                print(result.stdout.decode(errors = \"replace\"))\n\n        pip_cmd = _build_pip_cmd(args) + constraint_args + req_args\n        run(f\"{label} (pip)\" if USE_UV else label, pip_cmd)\n    finally:\n        if actual_req is not None and actual_req != req:\n            actual_req.unlink(missing_ok = True)\n\n\ndef download_file(url: str, dest: Path) -> None:\n    \"\"\"Download a file using urllib (no curl dependency).\"\"\"\n    urllib.request.urlretrieve(url, dest)\n\n\ndef patch_package_file(package_name: str, relative_path: str, url: str) -> None:\n    \"\"\"Download a file from url and overwrite a file inside an installed package.\"\"\"\n    result = subprocess.run(\n        [sys.executable, \"-m\", \"pip\", \"show\", package_name],\n        capture_output = True,\n        text = True,\n    )\n    if result.returncode != 0:\n        print(_red(f\"   ⚠️  Could not find package {package_name}, skipping patch\"))\n        return\n\n    location = None\n    for line in result.stdout.splitlines():\n        if line.lower().startswith(\"location:\"):\n            location = line.split(\":\", 1)[1].strip()\n            break\n\n    if not location:\n        print(_red(f\"   ⚠️  Could not determine location of {package_name}\"))\n        return\n\n    dest = Path(location) / relative_path\n    print(_cyan(f\"   Patching {dest.name} in {package_name}...\"))\n    download_file(url, dest)\n\n\n# ── Main install sequence ─────────────────────────────────────────────\n\n\ndef install_python_stack() -> int:\n    global USE_UV, _STEP, _TOTAL\n    _STEP = 0\n    _TOTAL = 10 if IS_WINDOWS else 11\n\n    # 1. Upgrade pip (needed even with uv as fallback and for bootstrapping)\n    _progress(\"pip upgrade\")\n    run(\"Upgrading pip\", [sys.executable, \"-m\", \"pip\", \"install\", \"--upgrade\", \"pip\"])\n\n    # Try to use uv for faster installs\n    USE_UV = _bootstrap_uv()\n\n    # 2. Core packages: unsloth-zoo + unsloth\n    _progress(\"base packages\")\n    pip_install(\n        \"Installing base packages\",\n        \"--no-cache-dir\",\n        req = REQ_ROOT / \"base.txt\",\n    )\n\n    # 3. Extra dependencies\n    _progress(\"unsloth extras\")\n    pip_install(\n        \"Installing additional unsloth dependencies\",\n        \"--no-cache-dir\",\n        req = REQ_ROOT / \"extras.txt\",\n    )\n\n    # 3b. Extra dependencies (no-deps) — audio model support etc.\n    _progress(\"extra codecs\")\n    pip_install(\n        \"Installing extras (no-deps)\",\n        \"--no-deps\",\n        \"--no-cache-dir\",\n        req = REQ_ROOT / \"extras-no-deps.txt\",\n    )\n\n    # 4. Overrides (torchao, transformers) — force-reinstall\n    _progress(\"dependency overrides\")\n    pip_install(\n        \"Installing dependency overrides\",\n        \"--force-reinstall\",\n        \"--no-cache-dir\",\n        req = REQ_ROOT / \"overrides.txt\",\n    )\n\n    # 5. Triton kernels (no-deps, from source)\n    if not IS_WINDOWS:\n        _progress(\"triton kernels\")\n        pip_install(\n            \"Installing triton kernels\",\n            \"--no-deps\",\n            \"--no-cache-dir\",\n            req = REQ_ROOT / \"triton-kernels.txt\",\n            constrain = False,\n        )\n\n    # # 6. Patch: override llama_cpp.py with fix from unsloth-zoo  feature/llama-cpp-windows-support branch\n    # patch_package_file(\n    #     \"unsloth-zoo\",\n    #     os.path.join(\"unsloth_zoo\", \"llama_cpp.py\"),\n    #     \"https://raw.githubusercontent.com/unslothai/unsloth-zoo/refs/heads/main/unsloth_zoo/llama_cpp.py\",\n    # )\n\n    # # 7a. Patch: override vision.py with fix from unsloth PR #4091\n    # patch_package_file(\n    #     \"unsloth\",\n    #     os.path.join(\"unsloth\", \"models\", \"vision.py\"),\n    #     \"https://raw.githubusercontent.com/unslothai/unsloth/80e0108a684c882965a02a8ed851e3473c1145ab/unsloth/models/vision.py\",\n    # )\n\n    # # 7b. Patch : override save.py with fix from feature/llama-cpp-windows-support\n    # patch_package_file(\n    #     \"unsloth\",\n    #     os.path.join(\"unsloth\", \"save.py\"),\n    #     \"https://raw.githubusercontent.com/unslothai/unsloth/refs/heads/main/unsloth/save.py\",\n    # )\n\n    # 8. Studio dependencies\n    _progress(\"studio deps\")\n    pip_install(\n        \"Installing studio dependencies\",\n        \"--no-cache-dir\",\n        req = REQ_ROOT / \"studio.txt\",\n    )\n\n    # 9. Data-designer dependencies\n    _progress(\"data designer deps\")\n    pip_install(\n        \"Installing data-designer base dependencies\",\n        \"--no-cache-dir\",\n        req = SINGLE_ENV / \"data-designer-deps.txt\",\n    )\n\n    # 10. Data-designer packages (no-deps to avoid conflicts)\n    _progress(\"data designer\")\n    pip_install(\n        \"Installing data-designer\",\n        \"--no-cache-dir\",\n        \"--no-deps\",\n        req = SINGLE_ENV / \"data-designer.txt\",\n    )\n\n    # 11. Local Data Designer seed plugin\n    if not LOCAL_DD_UNSTRUCTURED_PLUGIN.is_dir():\n        print(\n            _red(\n                f\"❌ Missing local plugin directory: {LOCAL_DD_UNSTRUCTURED_PLUGIN}\",\n            ),\n        )\n        return 1\n    _progress(\"local plugin\")\n    pip_install(\n        \"Installing local data-designer unstructured plugin\",\n        \"--no-cache-dir\",\n        \"--no-deps\",\n        str(LOCAL_DD_UNSTRUCTURED_PLUGIN),\n        constrain = False,\n    )\n\n    # 12. Patch metadata for single-env compatibility\n    _progress(\"finalizing\")\n    run(\n        \"Patching single-env metadata\",\n        [sys.executable, str(SINGLE_ENV / \"patch_metadata.py\")],\n    )\n\n    # 13. Final check (silent; third-party conflicts are expected)\n    subprocess.run(\n        [sys.executable, \"-m\", \"pip\", \"check\"],\n        stdout = subprocess.DEVNULL,\n        stderr = subprocess.DEVNULL,\n    )\n\n    print(_green(\"✅ Python dependencies installed\"))\n    return 0\n\n\nif __name__ == \"__main__\":\n    sys.exit(install_python_stack())\n"
  },
  {
    "path": "studio/setup.bat",
    "content": "@echo off\nREM SPDX-License-Identifier: AGPL-3.0-only\nREM Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\npowershell -ExecutionPolicy Bypass -File \"%~dp0setup.ps1\" %*\n"
  },
  {
    "path": "studio/setup.ps1",
    "content": "#Requires -Version 5.1\n# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n<#\n.SYNOPSIS\n    Full environment setup for Unsloth Studio on Windows (bundled version).\n.DESCRIPTION\n    Always installs Node.js if needed. When running from pip install:\n    skips frontend build (already bundled). When running from git repo:\n    full setup including frontend build.\n    Supports NVIDIA GPU (full training + inference) and CPU-only (GGUF chat mode).\n.NOTES\n    Usage: powershell -ExecutionPolicy Bypass -File setup.ps1\n#>\n\n$ErrorActionPreference = \"Stop\"\n$ScriptDir = Split-Path -Parent $MyInvocation.MyCommand.Path\n$PackageDir = Split-Path -Parent $ScriptDir\n\n# Detect if running from pip install (no frontend/ dir in studio)\n$FrontendDir = Join-Path $ScriptDir \"frontend\"\n$OxcValidatorDir = Join-Path $ScriptDir \"backend\\core\\data_recipe\\oxc-validator\"\n$IsPipInstall = -not (Test-Path $FrontendDir)\n\n# ─────────────────────────────────────────────\n# Helper functions\n# ─────────────────────────────────────────────\n\n# Reload ALL environment variables from registry.\n# Picks up changes made by installers (winget, msi, etc.) including\n# Path, CUDA_PATH, CUDA_PATH_V*, and any other vars they set.\nfunction Refresh-Environment {\n    foreach ($level in @('Machine', 'User')) {\n        $vars = [System.Environment]::GetEnvironmentVariables($level)\n        foreach ($key in $vars.Keys) {\n            if ($key -eq 'Path') { continue }\n            Set-Item -Path \"Env:$key\" -Value $vars[$key] -ErrorAction SilentlyContinue\n        }\n    }\n    $machinePath = [System.Environment]::GetEnvironmentVariable('Path', 'Machine')\n    $userPath = [System.Environment]::GetEnvironmentVariable('Path', 'User')\n    $env:Path = \"$machinePath;$userPath\"\n}\n\n# Find nvcc on PATH, CUDA_PATH, or standard toolkit dirs.\n# Returns the path to nvcc.exe, or $null if not found.\nfunction Find-Nvcc {\n    param([string]$MaxVersion = \"\")\n\n    # If MaxVersion is set, we need to find a toolkit <= that version.\n    # CUDA toolkits install side-by-side under C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\vX.Y\\\n\n    $toolkitBase = 'C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA'\n\n    if ($MaxVersion -and (Test-Path $toolkitBase)) {\n        $drMajor = [int]$MaxVersion.Split('.')[0]\n        $drMinor = [int]$MaxVersion.Split('.')[1]\n\n        # Get all installed CUDA dirs, sorted descending (highest first)\n        $cudaDirs = Get-ChildItem -Directory $toolkitBase | Where-Object {\n            $_.Name -match '^v(\\d+)\\.(\\d+)'\n        } | Sort-Object { [version]($_.Name -replace '^v','') } -Descending\n\n        foreach ($dir in $cudaDirs) {\n            if ($dir.Name -match '^v(\\d+)\\.(\\d+)') {\n                $tkMajor = [int]$Matches[1]; $tkMinor = [int]$Matches[2]\n                $compatible = ($tkMajor -lt $drMajor) -or ($tkMajor -eq $drMajor -and $tkMinor -le $drMinor)\n                if ($compatible) {\n                    $nvcc = Join-Path $dir.FullName 'bin\\nvcc.exe'\n                    if (Test-Path $nvcc) {\n                        return $nvcc\n                    }\n                }\n            }\n        }\n\n        # No compatible side-by-side version found\n        return $null\n    }\n\n    # Fallback: no version constraint — pick latest or whatever is available\n\n    # 1. Check nvcc on PATH\n    $cmd = Get-Command nvcc -ErrorAction SilentlyContinue\n    if ($cmd) { return $cmd.Source }\n\n    # 2. Check CUDA_PATH env var\n    $cudaRoot = [Environment]::GetEnvironmentVariable('CUDA_PATH', 'Process')\n    if (-not $cudaRoot) { $cudaRoot = [Environment]::GetEnvironmentVariable('CUDA_PATH', 'Machine') }\n    if (-not $cudaRoot) { $cudaRoot = [Environment]::GetEnvironmentVariable('CUDA_PATH', 'User') }\n    if ($cudaRoot -and (Test-Path (Join-Path $cudaRoot 'bin\\nvcc.exe'))) {\n        return (Join-Path $cudaRoot 'bin\\nvcc.exe')\n    }\n\n    # 3. Scan standard toolkit directory\n    if (Test-Path $toolkitBase) {\n        $latest = Get-ChildItem -Directory $toolkitBase | Sort-Object Name | Select-Object -Last 1\n        if ($latest -and (Test-Path (Join-Path $latest.FullName 'bin\\nvcc.exe'))) {\n            return (Join-Path $latest.FullName 'bin\\nvcc.exe')\n        }\n    }\n\n    return $null\n}\n\n# Detect CUDA Compute Capability via nvidia-smi.\n# Returns e.g. \"80\" for A100 (8.0), \"89\" for RTX 4090 (8.9), etc.\n# Returns $null if detection fails.\nfunction Get-CudaComputeCapability {\n    # Use the resolved absolute path ($NvidiaSmiExe) to survive Refresh-Environment\n    $smiExe = if ($script:NvidiaSmiExe) { $script:NvidiaSmiExe } else {\n        $cmd = Get-Command nvidia-smi -ErrorAction SilentlyContinue\n        if ($cmd) { $cmd.Source } else { $null }\n    }\n    if (-not $smiExe) { return $null }\n\n    try {\n        $raw = & $smiExe --query-gpu=compute_cap --format=csv,noheader 2>$null\n        if ($LASTEXITCODE -ne 0 -or -not $raw) { return $null }\n\n        # nvidia-smi may return multiple GPUs; take the first one\n        $cap = ($raw -split \"`n\")[0].Trim()\n        if ($cap -match '^(\\d+)\\.(\\d+)$') {\n            $major = $Matches[1]\n            $minor = $Matches[2]\n            return \"$major$minor\"\n        }\n    } catch { }\n\n    return $null\n}\n\n# Check if an nvcc binary supports a given sm_ architecture.\n# Uses `nvcc --list-gpu-code` which outputs sm_* tokens (--list-gpu-arch\n# outputs compute_* tokens instead).  Available since CUDA 11.6.\n# Returns $false if the flag isn't supported (old toolkit) — safer to reject\n# and fall back to scanning/PTX than to assume support and fail later.\nfunction Test-NvccArchSupport {\n    param([string]$NvccExe, [string]$Arch)\n    try {\n        $listCode = & $NvccExe --list-gpu-code 2>&1 | Out-String\n        if ($LASTEXITCODE -ne 0) { return $false }\n        return ($listCode -match \"sm_$Arch\")\n    } catch {\n        return $false\n    }\n}\n\n# Given an nvcc binary, return the highest sm_ architecture it supports.\n# Returns e.g. \"90\" for CUDA 12.4. Returns $null if detection fails.\nfunction Get-NvccMaxArch {\n    param([string]$NvccExe)\n    try {\n        $listCode = & $NvccExe --list-gpu-code 2>&1 | Out-String\n        if ($LASTEXITCODE -ne 0) { return $null }\n        $arches = @()\n        foreach ($line in $listCode -split \"`n\") {\n            if ($line.Trim() -match '^sm_(\\d+)') {\n                $arches += [int]$Matches[1]\n            }\n        }\n        if ($arches.Count -gt 0) {\n            return ($arches | Sort-Object | Select-Object -Last 1).ToString()\n        }\n    } catch { }\n    return $null\n}\n\n# Detect driver's max CUDA version from nvidia-smi and return the highest\n# compatible PyTorch CUDA index tag (e.g. \"cu128\").\n# PyTorch on Windows ships CPU-only by default from PyPI; CUDA wheels live at\n# https://download.pytorch.org/whl/<tag>. The tag must not exceed the driver's\n# capability: e.g. driver \"CUDA Version: 12.9\" → cu128 (not cu130).\nfunction Get-PytorchCudaTag {\n    $smiExe = if ($script:NvidiaSmiExe) { $script:NvidiaSmiExe } else {\n        $cmd = Get-Command nvidia-smi -ErrorAction SilentlyContinue\n        if ($cmd) { $cmd.Source } else { $null }\n    }\n    if (-not $smiExe) { return \"cu124\" }\n\n    try {\n        # 2>&1 | Out-String merges stderr into stdout then converts to a single\n        # string.  Plain 2>$null doesn't fully suppress stderr in PS 5.1 --\n        # ErrorRecord objects leak into $output and break the -match.\n        $output = & $smiExe 2>&1 | Out-String\n        if ($output -match 'CUDA Version:\\s+(\\d+)\\.(\\d+)') {\n            $major = [int]$Matches[1]\n            $minor = [int]$Matches[2]\n            # PyTorch 2.10 offers: cu124, cu126, cu128, cu130\n            if ($major -ge 13) { return \"cu130\" }\n            if ($major -eq 12 -and $minor -ge 8) { return \"cu128\" }\n            if ($major -eq 12 -and $minor -ge 6) { return \"cu126\" }\n            return \"cu124\"\n        }\n    } catch { }\n\n    return \"cu124\"\n}\n\n# Find Visual Studio Build Tools for cmake -G flag.\n# Strategy: (1) vswhere, (2) scan filesystem (handles broken vswhere registration).\n# Returns @{ Generator = \"Visual Studio 17 2022\"; InstallPath = \"C:\\...\"; Source = \"...\" } or $null.\nfunction Find-VsBuildTools {\n    $map = @{ '2022' = '17'; '2019' = '16'; '2017' = '15' }\n\n    # --- Try vswhere first (works when VS is properly registered) ---\n    $vsw = \"${env:ProgramFiles(x86)}\\Microsoft Visual Studio\\Installer\\vswhere.exe\"\n    if (Test-Path $vsw) {\n        $info = & $vsw -latest -requires Microsoft.VisualStudio.Component.VC.Tools.x86.x64 -property catalog_productLineVersion 2>$null\n        $path = & $vsw -latest -requires Microsoft.VisualStudio.Component.VC.Tools.x86.x64 -property installationPath 2>$null\n        if ($info -and $path) {\n            $y = $info.Trim()\n            $n = $map[$y]\n            if ($n) {\n                return @{ Generator = \"Visual Studio $n $y\"; InstallPath = $path.Trim(); Source = 'vswhere' }\n            }\n        }\n    }\n\n    # --- Scan filesystem (handles broken vswhere registration after winget cycles) ---\n    $roots = @($env:ProgramFiles, ${env:ProgramFiles(x86)})\n    $editions = @('BuildTools', 'Community', 'Professional', 'Enterprise')\n    $years = @('2022', '2019', '2017')\n\n    foreach ($y in $years) {\n        foreach ($r in $roots) {\n            foreach ($ed in $editions) {\n                $candidate = Join-Path $r \"Microsoft Visual Studio\\$y\\$ed\"\n                if (Test-Path $candidate) {\n                    $vcDir = Join-Path $candidate \"VC\\Tools\\MSVC\"\n                    if (Test-Path $vcDir) {\n                        $cl = Get-ChildItem -Path $vcDir -Filter \"cl.exe\" -Recurse -ErrorAction SilentlyContinue | Select-Object -First 1\n                        if ($cl) {\n                            $n = $map[$y]\n                            if ($n) {\n                                return @{ Generator = \"Visual Studio $n $y\"; InstallPath = $candidate; Source = \"filesystem ($ed)\"; ClExe = $cl.FullName }\n                            }\n                        }\n                    }\n                }\n            }\n        }\n    }\n\n    return $null\n}\n\n# ─────────────────────────────────────────────\n# Banner\n# ─────────────────────────────────────────────\nWrite-Host \"+==============================================+\" -ForegroundColor Green\nWrite-Host \"|       Unsloth Studio Setup (Windows)         |\" -ForegroundColor Green\nWrite-Host \"+==============================================+\" -ForegroundColor Green\n\n# ==========================================================================\n#  PHASE 1: System-level prerequisites (winget installs, env vars)\n#  All heavy system tool installs happen here BEFORE touching Python.\n# ==========================================================================\n\n# ============================================\n# 1a. GPU detection\n# ============================================\n$HasNvidiaSmi = $false\n$NvidiaSmiExe = $null  # Absolute path -- survives Refresh-Environment\ntry {\n    $nvSmiCmd = Get-Command nvidia-smi -ErrorAction SilentlyContinue\n    if ($nvSmiCmd) {\n        & $nvSmiCmd.Source 2>&1 | Out-Null\n        if ($LASTEXITCODE -eq 0) {\n            $HasNvidiaSmi = $true\n            $NvidiaSmiExe = $nvSmiCmd.Source\n        }\n    }\n} catch {}\n# Fallback: nvidia-smi may not be on PATH even though a GPU + driver exist.\n# Check the default install location and the Windows driver store.\nif (-not $HasNvidiaSmi) {\n    $nvSmiDefaults = @(\n        \"$env:ProgramFiles\\NVIDIA Corporation\\NVSMI\\nvidia-smi.exe\",\n        \"$env:SystemRoot\\System32\\nvidia-smi.exe\"\n    )\n    foreach ($p in $nvSmiDefaults) {\n        if (Test-Path $p) {\n            try {\n                & $p 2>&1 | Out-Null\n                if ($LASTEXITCODE -eq 0) {\n                    $HasNvidiaSmi = $true\n                    $NvidiaSmiExe = $p\n                    Write-Host \"   Found nvidia-smi at $(Split-Path $p -Parent)\" -ForegroundColor Gray\n                    break\n                }\n            } catch {}\n        }\n    }\n}\nif (-not $HasNvidiaSmi) {\n    Write-Host \"\"\n    Write-Host \"[WARN] No NVIDIA GPU detected. Studio will run in chat-only (GGUF) mode.\" -ForegroundColor Yellow\n    Write-Host \"       Training and GPU inference require an NVIDIA GPU with drivers installed.\" -ForegroundColor Yellow\n    Write-Host \"       https://www.nvidia.com/Download/index.aspx\" -ForegroundColor Yellow\n    Write-Host \"\"\n} else {\n    Write-Host \"[OK] NVIDIA GPU detected\" -ForegroundColor Green\n}\n\n# ============================================\n# 1a.5. Windows Long Paths (required for deep node_modules / Python paths)\n# ============================================\n$LongPathsEnabled = $false\ntry {\n    $regVal = Get-ItemProperty -Path \"HKLM:\\SYSTEM\\CurrentControlSet\\Control\\FileSystem\" -Name \"LongPathsEnabled\" -ErrorAction SilentlyContinue\n    if ($regVal -and $regVal.LongPathsEnabled -eq 1) {\n        $LongPathsEnabled = $true\n    }\n} catch {}\n\nif ($LongPathsEnabled) {\n    Write-Host \"[OK] Windows Long Paths enabled\" -ForegroundColor Green\n} else {\n    Write-Host \"Windows Long Paths not enabled (required for Triton compilation and deep dependency paths).\" -ForegroundColor Yellow\n    Write-Host \"   Requesting admin access to fix...\" -ForegroundColor Yellow\n    try {\n        # Spawn an elevated process to set the registry key (triggers UAC prompt)\n        $proc = Start-Process -FilePath \"reg.exe\" `\n            -ArgumentList 'add \"HKLM\\SYSTEM\\CurrentControlSet\\Control\\FileSystem\" /v LongPathsEnabled /t REG_DWORD /d 1 /f' `\n            -Verb RunAs -Wait -PassThru -ErrorAction Stop\n        if ($proc.ExitCode -eq 0) {\n            $LongPathsEnabled = $true\n            Write-Host \"[OK] Windows Long Paths enabled (via UAC)\" -ForegroundColor Green\n        } else {\n            Write-Host \"[WARN] Failed to enable Long Paths (exit code: $($proc.ExitCode))\" -ForegroundColor Yellow\n        }\n    } catch {\n        Write-Host \"[WARN] Could not enable Long Paths (UAC was declined or not available)\" -ForegroundColor Yellow\n        Write-Host \"       Run this manually in an Admin terminal:\" -ForegroundColor Yellow\n        Write-Host '       reg add \"HKLM\\SYSTEM\\CurrentControlSet\\Control\\FileSystem\" /v LongPathsEnabled /t REG_DWORD /d 1 /f' -ForegroundColor Cyan\n    }\n}\n\n# ============================================\n# 1b. Git (required by pip for git+https:// deps and by npm)\n# ============================================\n$HasGit = $null -ne (Get-Command git -ErrorAction SilentlyContinue)\nif (-not $HasGit) {\n    Write-Host \"Git not found -- installing via winget...\" -ForegroundColor Yellow\n    $HasWinget = $null -ne (Get-Command winget -ErrorAction SilentlyContinue)\n    if ($HasWinget) {\n        try {\n            winget install Git.Git --source winget --accept-package-agreements --accept-source-agreements 2>&1 | Out-Null\n            Refresh-Environment\n            $HasGit = $null -ne (Get-Command git -ErrorAction SilentlyContinue)\n        } catch { }\n    }\n    if (-not $HasGit) {\n        Write-Host \"[ERROR] Git is required but could not be installed automatically.\" -ForegroundColor Red\n        Write-Host \"        Install Git from https://git-scm.com/download/win and re-run.\" -ForegroundColor Red\n        exit 1\n    }\n    Write-Host \"[OK] Git installed: $(git --version)\" -ForegroundColor Green\n} else {\n    Write-Host \"[OK] Git found: $(git --version)\" -ForegroundColor Green\n}\n\n# ============================================\n# 1c. CMake (required for llama.cpp build)\n# ============================================\n$HasCmake = $null -ne (Get-Command cmake -ErrorAction SilentlyContinue)\nif (-not $HasCmake) {\n    Write-Host \"CMake not found -- installing via winget...\" -ForegroundColor Yellow\n    $HasWinget = $null -ne (Get-Command winget -ErrorAction SilentlyContinue)\n    if ($HasWinget) {\n        try {\n            winget install Kitware.CMake --source winget --accept-package-agreements --accept-source-agreements 2>&1 | Out-Null\n            Refresh-Environment\n            $HasCmake = $null -ne (Get-Command cmake -ErrorAction SilentlyContinue)\n        } catch { }\n    }\n    # winget may succeed but cmake isn't on PATH yet (MSI PATH changes need a\n    # new shell). Try the default install location as a fallback.\n    if (-not $HasCmake) {\n        $cmakeDefaults = @(\n            \"$env:ProgramFiles\\CMake\\bin\",\n            \"${env:ProgramFiles(x86)}\\CMake\\bin\",\n            \"$env:LOCALAPPDATA\\CMake\\bin\"\n        )\n        foreach ($d in $cmakeDefaults) {\n            if (Test-Path (Join-Path $d \"cmake.exe\")) {\n                $env:Path = \"$d;$env:Path\"\n                # Persist to user PATH so Refresh-Environment does not drop it later\n                $userPath = [Environment]::GetEnvironmentVariable('Path', 'User')\n                if (-not $userPath -or $userPath -notlike \"*$d*\") {\n                    [Environment]::SetEnvironmentVariable('Path', \"$d;$userPath\", 'User')\n                }\n                $HasCmake = $null -ne (Get-Command cmake -ErrorAction SilentlyContinue)\n                if ($HasCmake) {\n                    Write-Host \"   Found cmake at $d (added to PATH)\" -ForegroundColor Gray\n                    break\n                }\n            }\n        }\n    }\n    if ($HasCmake) {\n        Write-Host \"[OK] CMake installed\" -ForegroundColor Green\n    } else {\n        Write-Host \"[ERROR] CMake is required but could not be installed.\" -ForegroundColor Red\n        Write-Host \"        Install CMake from https://cmake.org/download/ and re-run.\" -ForegroundColor Red\n        exit 1\n    }\n} else {\n    Write-Host \"[OK] CMake found: $(cmake --version | Select-Object -First 1)\" -ForegroundColor Green\n}\n\n# ============================================\n# 1d. Visual Studio Build Tools (C++ compiler for llama.cpp)\n# ============================================\n$CmakeGenerator = $null\n$VsInstallPath = $null\n$vsResult = Find-VsBuildTools\n\nif (-not $vsResult) {\n    Write-Host \"Visual Studio Build Tools not found -- installing via winget...\" -ForegroundColor Yellow\n    Write-Host \"   (This is a one-time install, may take several minutes)\" -ForegroundColor Gray\n    $HasWinget = $null -ne (Get-Command winget -ErrorAction SilentlyContinue)\n    if ($HasWinget) {\n        $prevEAPTemp = $ErrorActionPreference\n        $ErrorActionPreference = \"Continue\"\n        winget install Microsoft.VisualStudio.2022.BuildTools --source winget --accept-package-agreements --accept-source-agreements --override \"--add Microsoft.VisualStudio.Workload.VCTools --includeRecommended --passive --wait\"\n        $ErrorActionPreference = $prevEAPTemp\n        # Re-scan after install (don't trust vswhere catalog)\n        $vsResult = Find-VsBuildTools\n    }\n}\n\nif ($vsResult) {\n    $CmakeGenerator = $vsResult.Generator\n    $VsInstallPath = $vsResult.InstallPath\n    Write-Host \"[OK] $CmakeGenerator detected via $($vsResult.Source)\" -ForegroundColor Green\n    if ($vsResult.ClExe) { Write-Host \"   cl.exe: $($vsResult.ClExe)\" -ForegroundColor Gray }\n} else {\n    Write-Host \"[ERROR] Visual Studio Build Tools could not be found or installed.\" -ForegroundColor Red\n    Write-Host \"        Manual install:\" -ForegroundColor Red\n    Write-Host '        1. winget install Microsoft.VisualStudio.2022.BuildTools --source winget' -ForegroundColor Yellow\n    Write-Host '        2. Open Visual Studio Installer -> Modify -> check \"Desktop development with C++\"' -ForegroundColor Yellow\n    exit 1\n}\n\n# ============================================\n# 1e. CUDA Toolkit (nvcc for llama.cpp build + env vars)\n# ============================================\nif ($HasNvidiaSmi) {\n# IMPORTANT: The CUDA Toolkit version must be <= the max CUDA version the\n# NVIDIA driver supports.  nvidia-smi reports this as \"CUDA Version: X.Y\".\n# If we install a toolkit newer than the driver supports, llama-server will\n# fail at runtime with \"ggml_cuda_init: failed to initialize CUDA: (null)\".\n\n# -- Detect max CUDA version the driver supports --\n$DriverMaxCuda = $null\ntry {\n    $smiOut = & $NvidiaSmiExe 2>&1 | Out-String\n    if ($smiOut -match \"CUDA Version:\\s+([\\d]+)\\.([\\d]+)\") {\n        $DriverMaxCuda = \"$($Matches[1]).$($Matches[2])\"\n        Write-Host \"   Driver supports up to CUDA $DriverMaxCuda\" -ForegroundColor Gray\n    }\n} catch {}\n\n# Detect compute capability early so we can validate toolkit support\n$CudaArch = Get-CudaComputeCapability\nif ($CudaArch) {\n    Write-Host \"   GPU Compute Capability = $($CudaArch.Insert($CudaArch.Length-1, '.')) (sm_$CudaArch)\" -ForegroundColor Gray\n}\n\n# -- Find a toolkit that's compatible with the driver AND the GPU --\n# Strategy: prefer the toolkit at CUDA_PATH (user's existing setup) if it's\n# compatible with the driver AND supports the GPU architecture.  Only fall back\n# to scanning side-by-side installs if CUDA_PATH is missing, points to an\n# incompatible version, or can't compile for the GPU.  This avoids\n# header/binary mismatches when multiple toolkits are installed.\n$IncompatibleToolkit = $null\n$NvccPath = $null\n\nif ($DriverMaxCuda) {\n    $drMajorCuda = [int]$DriverMaxCuda.Split('.')[0]\n    $drMinorCuda = [int]$DriverMaxCuda.Split('.')[1]\n\n    # --- Step 1: Check existing CUDA_PATH first ---\n    $existingCudaPath = [Environment]::GetEnvironmentVariable('CUDA_PATH', 'Machine')\n    if (-not $existingCudaPath) {\n        $existingCudaPath = [Environment]::GetEnvironmentVariable('CUDA_PATH', 'User')\n    }\n    if ($existingCudaPath -and (Test-Path (Join-Path $existingCudaPath 'bin\\nvcc.exe'))) {\n        $candidateNvcc = Join-Path $existingCudaPath 'bin\\nvcc.exe'\n        $verOut = & $candidateNvcc --version 2>&1 | Out-String\n        if ($verOut -match 'release\\s+(\\d+)\\.(\\d+)') {\n            $tkMaj = [int]$Matches[1]; $tkMin = [int]$Matches[2]\n            $isCompat = ($tkMaj -lt $drMajorCuda) -or ($tkMaj -eq $drMajorCuda -and $tkMin -le $drMinorCuda)\n            if ($isCompat) {\n                # Also verify the toolkit supports our GPU architecture\n                Write-Host \"   [DEBUG] Checking CUDA compatibility: toolkit=$tkMaj.$tkMin arch=sm_$CudaArch\" -ForegroundColor Magenta\n                $archOk = $true\n                if ($CudaArch) {\n                    $archOk = Test-NvccArchSupport -NvccExe $candidateNvcc -Arch $CudaArch\n                    if (-not $archOk) {\n                        Write-Host \"   [INFO] CUDA_PATH toolkit (CUDA $tkMaj.$tkMin) does not support GPU arch sm_$CudaArch\" -ForegroundColor Yellow\n                        Write-Host \"          Looking for a newer toolkit...\" -ForegroundColor Yellow\n                    }\n                }\n                if ($archOk) {\n                    $NvccPath = $candidateNvcc\n                    Write-Host \"   [OK] Using existing CUDA Toolkit at CUDA_PATH (nvcc: $NvccPath)\" -ForegroundColor Green\n                }\n            } else {\n                Write-Host \"   [INFO] CUDA_PATH ($existingCudaPath) has CUDA $tkMaj.$tkMin which exceeds driver max $DriverMaxCuda\" -ForegroundColor Yellow\n            }\n        }\n    }\n\n    # --- Step 2: Fall back to scanning side-by-side installs ---\n    if (-not $NvccPath) {\n        $NvccPath = Find-Nvcc -MaxVersion $DriverMaxCuda\n        if ($NvccPath) {\n            Write-Host \"   [OK] Found compatible CUDA Toolkit (nvcc: $NvccPath)\" -ForegroundColor Green\n            if ($existingCudaPath) {\n                $selectedRoot = Split-Path (Split-Path $NvccPath -Parent) -Parent\n                if ($existingCudaPath.TrimEnd('\\') -ne $selectedRoot.TrimEnd('\\')) {\n                    Write-Host \"   [INFO] Overriding CUDA_PATH from $existingCudaPath to $selectedRoot\" -ForegroundColor Yellow\n                }\n            }\n        } else {\n            # Check if there's an incompatible (too new) toolkit installed\n            $AnyNvcc = Find-Nvcc\n            if ($AnyNvcc) {\n                $NvccOut = & $AnyNvcc --version 2>&1 | Out-String\n                if ($NvccOut -match \"release\\s+([\\d]+\\.[\\d]+)\") {\n                    $IncompatibleToolkit = $Matches[1]\n                }\n            }\n        }\n    }\n} else {\n    $NvccPath = Find-Nvcc\n}\n\n# -- If incompatible toolkit is blocking, tell user to uninstall it --\nif (-not $NvccPath -and $IncompatibleToolkit) {\n    Write-Host \"\" -ForegroundColor Red\n    Write-Host \"========================================================================\" -ForegroundColor Red\n    Write-Host \"[ERROR] CUDA Toolkit $IncompatibleToolkit is installed but INCOMPATIBLE\" -ForegroundColor Red\n    Write-Host \"        with your NVIDIA driver (which supports up to CUDA $DriverMaxCuda).\" -ForegroundColor Red\n    Write-Host \"\" -ForegroundColor Red\n    Write-Host \"  This will cause 'failed to initialize CUDA' errors at runtime.\" -ForegroundColor Red\n    Write-Host \"\" -ForegroundColor Red\n    Write-Host \"  To fix:\" -ForegroundColor Yellow\n    Write-Host \"    1. Open Control Panel -> Programs -> Uninstall a program\" -ForegroundColor Yellow\n    Write-Host \"    2. Uninstall 'NVIDIA CUDA Toolkit $IncompatibleToolkit'\" -ForegroundColor Yellow\n    Write-Host \"    3. Re-run setup.bat (it will install CUDA $DriverMaxCuda automatically)\" -ForegroundColor Yellow\n    Write-Host \"\" -ForegroundColor Yellow\n    Write-Host \"  Alternatively, update your NVIDIA driver to one that supports CUDA $IncompatibleToolkit.\" -ForegroundColor Gray\n    Write-Host \"========================================================================\" -ForegroundColor Red\n    exit 1\n}\n\n# -- No toolkit at all: install via winget --\nif (-not $NvccPath) {\n    Write-Host \"CUDA toolkit (nvcc) not found -- installing via winget...\" -ForegroundColor Yellow\n    $HasWinget = $null -ne (Get-Command winget -ErrorAction SilentlyContinue)\n    if ($HasWinget) {\n        if ($DriverMaxCuda) {\n            # Query winget for available CUDA Toolkit versions\n            $drMajor = [int]$DriverMaxCuda.Split('.')[0]\n            $drMinor = [int]$DriverMaxCuda.Split('.')[1]\n            $AvailableVersions = @()\n            try {\n                $rawOutput = winget show Nvidia.CUDA --versions --accept-source-agreements 2>&1 | Out-String\n                # Parse version lines (e.g. \"12.6\", \"12.5\", \"11.8\")\n                foreach ($line in $rawOutput -split \"`n\") {\n                    $line = $line.Trim()\n                    if ($line -match '^\\d+\\.\\d+') {\n                        $AvailableVersions += $line\n                    }\n                }\n            } catch {}\n\n            # Filter to compatible versions (<= driver max) and pick the highest\n            $BestVersion = $null\n            foreach ($ver in $AvailableVersions) {\n                $parts = $ver.Split('.')\n                $vMajor = [int]$parts[0]\n                $vMinor = [int]$parts[1]\n                if ($vMajor -lt $drMajor -or ($vMajor -eq $drMajor -and $vMinor -le $drMinor)) {\n                    $BestVersion = $ver\n                    break  # list is descending, first match is highest compatible\n                }\n            }\n\n            if ($BestVersion) {\n                Write-Host \"   Installing CUDA Toolkit $BestVersion via winget...  \" -ForegroundColor Cyan\n                $prevEAPCuda = $ErrorActionPreference\n                $ErrorActionPreference = \"Continue\"\n                winget install --id=Nvidia.CUDA --version=$BestVersion -e --source winget --accept-package-agreements --accept-source-agreements 2>&1 | Out-Null\n                $ErrorActionPreference = $prevEAPCuda\n                Refresh-Environment\n                $NvccPath = Find-Nvcc -MaxVersion $DriverMaxCuda\n                if ($NvccPath) {\n                    Write-Host \"   [OK] CUDA Toolkit $BestVersion installed (nvcc: $NvccPath)\" -ForegroundColor Green\n                }\n            } else {\n                Write-Host \"   [WARN] No compatible CUDA Toolkit version found in winget (need <= $DriverMaxCuda)\" -ForegroundColor Yellow\n            }\n        } else {\n            Write-Host \"   Installing CUDA Toolkit (latest) via winget...\" -ForegroundColor Cyan\n            winget install --id=Nvidia.CUDA -e --source winget --accept-package-agreements --accept-source-agreements\n            Refresh-Environment\n            $NvccPath = Find-Nvcc\n            if ($NvccPath) {\n                Write-Host \"   [OK] CUDA Toolkit installed (nvcc: $NvccPath)\" -ForegroundColor Green\n            }\n        }\n    }\n}\n\nif (-not $NvccPath) {\n    Write-Host \"[ERROR] CUDA Toolkit (nvcc) is required but could not be found or installed.\" -ForegroundColor Red\n    if ($DriverMaxCuda) {\n        Write-Host \"        Install CUDA Toolkit $DriverMaxCuda from https://developer.nvidia.com/cuda-toolkit-archive\" -ForegroundColor Yellow\n    } else {\n        Write-Host \"        Install CUDA Toolkit from https://developer.nvidia.com/cuda-downloads\" -ForegroundColor Yellow\n    }\n    exit 1\n}\n\n# -- Set CUDA env vars so cmake AND MSBuild can find the toolkit --\n$CudaToolkitRoot = Split-Path (Split-Path $NvccPath -Parent) -Parent\n# CUDA_PATH: used by cmake's find_package(CUDAToolkit)\n[Environment]::SetEnvironmentVariable('CUDA_PATH', $CudaToolkitRoot, 'Process')\n# CudaToolkitDir: the MSBuild property that CUDA .targets checks directly\n# Trailing backslash required -- the .targets file appends subpaths to it\n[Environment]::SetEnvironmentVariable('CudaToolkitDir', \"$CudaToolkitRoot\\\", 'Process')\n# Always persist CUDA_PATH to User registry so the compatible toolkit is used\n# in future sessions (overwrites any existing value pointing to a newer, incompatible version)\n[Environment]::SetEnvironmentVariable('CUDA_PATH', $CudaToolkitRoot, 'User')\nWrite-Host \"   Persisted CUDA_PATH=$CudaToolkitRoot to user environment\" -ForegroundColor Gray\n# Clear all versioned CUDA_PATH_V* env vars in this process to prevent\n# cmake/MSBuild from discovering a conflicting CUDA installation.\n$cudaPathVars = @([Environment]::GetEnvironmentVariables('Process').Keys | Where-Object { $_ -match '^CUDA_PATH_V' })\nforeach ($v in $cudaPathVars) {\n    [Environment]::SetEnvironmentVariable($v, $null, 'Process')\n}\n# Set only the versioned var matching the selected toolkit (e.g. CUDA_PATH_V13_0)\n$tkDirName = Split-Path $CudaToolkitRoot -Leaf\nif ($tkDirName -match '^v(\\d+)\\.(\\d+)') {\n    $cudaPathVerVar = \"CUDA_PATH_V$($Matches[1])_$($Matches[2])\"\n    [Environment]::SetEnvironmentVariable($cudaPathVerVar, $CudaToolkitRoot, 'Process')\n    Write-Host \"   Set $cudaPathVerVar (cleared other CUDA_PATH_V* vars)\" -ForegroundColor Gray\n}\n# Ensure nvcc's bin dir is on PATH for this process\n$nvccBinDir = Split-Path $NvccPath -Parent\nif ($env:PATH -notlike \"*$nvccBinDir*\") {\n    [Environment]::SetEnvironmentVariable('PATH', \"$nvccBinDir;$env:PATH\", 'Process')\n}\n# Persist nvcc bin dir to User PATH so it works in new terminals\n$userPath = [Environment]::GetEnvironmentVariable('Path', 'User')\nif (-not $userPath -or $userPath -notlike \"*$nvccBinDir*\") {\n    if ($userPath) {\n        [Environment]::SetEnvironmentVariable('Path', \"$nvccBinDir;$userPath\", 'User')\n    } else {\n        [Environment]::SetEnvironmentVariable('Path', \"$nvccBinDir\", 'User')\n    }\n    Write-Host \"   Persisted CUDA bin dir to user PATH\" -ForegroundColor Gray\n}\n\n# -- Ensure CUDA ↔ Visual Studio integration files exist --\n# When CUDA is installed before VS Build Tools (or VS is reinstalled after CUDA),\n# the MSBuild .targets/.props files that let VS compile .cu files are missing.\n# cmake fails with \"No CUDA toolset found\". Fix: copy from CUDA extras dir.\nif ($VsInstallPath -and $CudaToolkitRoot) {\n    $vsCustomizations = Join-Path $VsInstallPath \"MSBuild\\Microsoft\\VC\\v170\\BuildCustomizations\"\n    $cudaExtras = Join-Path $CudaToolkitRoot \"extras\\visual_studio_integration\\MSBuildExtensions\"\n    if ((Test-Path $cudaExtras) -and (Test-Path $vsCustomizations)) {\n        $hasTargets = Get-ChildItem $vsCustomizations -Filter \"CUDA *.targets\" -ErrorAction SilentlyContinue\n        if (-not $hasTargets) {\n            Write-Host \"   [INFO] CUDA VS integration missing -- copying .targets files...\" -ForegroundColor Yellow\n            try {\n                Copy-Item \"$cudaExtras\\*\" $vsCustomizations -Force -ErrorAction Stop\n                Write-Host \"   [OK] CUDA VS integration files installed\" -ForegroundColor Green\n            } catch {\n                # Direct copy failed (needs admin). Try elevated copy via Start-Process.\n                try {\n                    $copyCmd = \"Copy-Item '$cudaExtras\\*' '$vsCustomizations' -Force\"\n                    Start-Process powershell -ArgumentList \"-NoProfile -Command $copyCmd\" -Verb RunAs -Wait -ErrorAction Stop\n                    $hasTargetsRetry = Get-ChildItem $vsCustomizations -Filter \"CUDA *.targets\" -ErrorAction SilentlyContinue\n                    if ($hasTargetsRetry) {\n                        Write-Host \"   [OK] CUDA VS integration files installed (elevated)\" -ForegroundColor Green\n                    } else {\n                        throw \"Copy did not produce .targets files\"\n                    }\n                } catch {\n                    Write-Host \"   [WARN] Could not copy CUDA VS integration files\" -ForegroundColor Yellow\n                    Write-Host \"          The llama.cpp build may fail with 'No CUDA toolset found'.\" -ForegroundColor Yellow\n                    Write-Host \"          Manual fix: copy contents of\" -ForegroundColor Yellow\n                    Write-Host \"            $cudaExtras\" -ForegroundColor Cyan\n                    Write-Host \"          into:\" -ForegroundColor Yellow\n                    Write-Host \"            $vsCustomizations\" -ForegroundColor Cyan\n                }\n            }\n        }\n    }\n}\n\nWrite-Host \"[OK] CUDA Toolkit: $NvccPath\" -ForegroundColor Green\nWrite-Host \"   CUDA_PATH      = $CudaToolkitRoot\" -ForegroundColor Gray\nWrite-Host \"   CudaToolkitDir = $CudaToolkitRoot\\\" -ForegroundColor Gray\n\n# $CudaArch was detected earlier (before toolkit selection) so it could\n# influence which toolkit we picked.  Just log the final state here.\nif (-not $CudaArch) {\n    Write-Host \"   [WARN] Could not detect compute capability -- cmake will use defaults\" -ForegroundColor Yellow\n}\n} else {\n    Write-Host \"[SKIP] CUDA Toolkit -- no NVIDIA GPU detected\" -ForegroundColor Yellow\n}\n\n# ============================================\n# 1f. Node.js / npm (skip if pip-installed -- only needed for frontend build)\n# ============================================\nif ($IsPipInstall) {\n    Write-Host \"[OK] Running from pip install - frontend already bundled, skipping Node/npm check\" -ForegroundColor Green\n} else {\n    # setup.sh installs Node LTS (v22) via nvm. We enforce the same range here:\n    # Node >= 20, npm >= 11.\n    $NeedNode = $true\n    try {\n        $NodeVersion = (node -v 2>$null)\n        $NpmVersion = (npm -v 2>$null)\n        if ($NodeVersion -and $NpmVersion) {\n            $NodeMajor = [int]($NodeVersion -replace 'v','').Split('.')[0]\n            $NpmMajor = [int]$NpmVersion.Split('.')[0]\n\n            if ($NodeMajor -ge 20 -and $NpmMajor -ge 11) {\n                Write-Host \"[OK] Node $NodeVersion and npm $NpmVersion already meet requirements.\" -ForegroundColor Green\n                $NeedNode = $false\n            } else {\n                Write-Host \"[WARN] Node $NodeVersion / npm $NpmVersion too old.\" -ForegroundColor Yellow\n            }\n        }\n    } catch {\n        Write-Host \"[WARN] Node/npm not found.\" -ForegroundColor Yellow\n    }\n\n    if ($NeedNode) {\n        Write-Host \"Installing Node.js LTS via winget...\" -ForegroundColor Cyan\n        try {\n            winget install OpenJS.NodeJS.LTS --source winget --accept-package-agreements --accept-source-agreements\n            Refresh-Environment\n        } catch {\n            Write-Host \"[ERROR] Could not install Node.js automatically.\" -ForegroundColor Red\n            Write-Host \"Please install Node.js >= 20 from https://nodejs.org/\" -ForegroundColor Red\n            exit 1\n        }\n    }\n\n    Write-Host \"[OK] Node $(node -v) | npm $(npm -v)\" -ForegroundColor Green\n}\n\n# ============================================\n# 1g. Python (>= 3.11 and < 3.14, matching setup.sh)\n# ============================================\n$HasPython = $null -ne (Get-Command python -ErrorAction SilentlyContinue)\n$PythonOk = $false\n\nif ($HasPython) {\n    $PyVer = python --version 2>&1\n    if ($PyVer -match \"(\\d+)\\.(\\d+)\") {\n        $PyMajor = [int]$Matches[1]; $PyMinor = [int]$Matches[2]\n        if ($PyMajor -eq 3 -and $PyMinor -ge 11 -and $PyMinor -lt 14) {\n            Write-Host \"[OK] Python $PyVer\" -ForegroundColor Green\n            $PythonOk = $true\n        } else {\n            Write-Host \"[ERROR] Python $PyVer is outside supported range (need >= 3.11 and < 3.14).\" -ForegroundColor Red\n            Write-Host \"        Install Python 3.12 from https://python.org/downloads/\" -ForegroundColor Yellow\n            exit 1\n        }\n    }\n} else {\n    # No Python at all -- install 3.12\n    Write-Host \"Python not found -- installing Python 3.12 via winget...\" -ForegroundColor Yellow\n    $HasWinget = $null -ne (Get-Command winget -ErrorAction SilentlyContinue)\n    if ($HasWinget) {\n        winget install -e --id Python.Python.3.12 --source winget --accept-package-agreements --accept-source-agreements\n        Refresh-Environment\n    }\n    $HasPython = $null -ne (Get-Command python -ErrorAction SilentlyContinue)\n    if (-not $HasPython) {\n        Write-Host \"[ERROR] Python could not be installed automatically.\" -ForegroundColor Red\n        Write-Host \"        Install Python 3.12 from https://python.org/downloads/\" -ForegroundColor Yellow\n        exit 1\n    }\n    Write-Host \"[OK] Python $(python --version)\" -ForegroundColor Green\n    $PythonOk = $true\n}\n\n# Ensure Python Scripts dir is on PATH (so 'unsloth' command works in new terminals)\n$ScriptsDir = python -c \"import sysconfig; print(sysconfig.get_path('scripts', 'nt_user') if __import__('os').path.exists(sysconfig.get_path('scripts', 'nt_user')) else sysconfig.get_path('scripts'))\"\nif ($LASTEXITCODE -eq 0 -and $ScriptsDir -and (Test-Path $ScriptsDir)) {\n    $UserPath = [Environment]::GetEnvironmentVariable('Path', 'User')\n    $UserPathEntries = if ($UserPath) { $UserPath.Split(';') } else { @() }\n    if (-not ($UserPathEntries | Where-Object { $_.TrimEnd('\\') -eq $ScriptsDir })) {\n        $newUserPath = if ($UserPath) { \"$ScriptsDir;$UserPath\" } else { $ScriptsDir }\n        [Environment]::SetEnvironmentVariable('Path', $newUserPath, 'User')\n\n        # Also add to current process so it's available immediately\n        $ProcessPathEntries = $env:PATH.Split(';')\n        if (-not ($ProcessPathEntries | Where-Object { $_.TrimEnd('\\') -eq $ScriptsDir })) {\n            $env:PATH = \"$ScriptsDir;$env:PATH\"\n        }\n        Write-Host \"   Persisted Python Scripts dir to user PATH: $ScriptsDir\" -ForegroundColor Gray\n    }\n}\n\nWrite-Host \"\"\nWrite-Host \"--- System prerequisites ready ---\" -ForegroundColor Green\nWrite-Host \"\"\n\n# ==========================================================================\n#  PHASE 2: Frontend build (skip if pip-installed -- already bundled)\n# ==========================================================================\n$DistDir = Join-Path $FrontendDir \"dist\"\n# Skip build if dist/ exists and no tracked input is newer than dist/.\n# Checks src/, public/, package.json, config files -- not just src/.\n$NeedFrontendBuild = $true\nif ($IsPipInstall) {\n    $NeedFrontendBuild = $false\n    Write-Host \"[OK] Running from pip install - frontend already bundled, skipping build\" -ForegroundColor Green\n} elseif (Test-Path $DistDir) {\n    $DistTime = (Get-Item $DistDir).LastWriteTime\n    $NewerFile = $null\n    # Check src/ and public/ recursively (probe paths directly, not via -Include)\n    foreach ($subDir in @(\"src\", \"public\")) {\n        $subPath = Join-Path $FrontendDir $subDir\n        if (Test-Path $subPath) {\n            $NewerFile = Get-ChildItem -Path $subPath -Recurse -File -ErrorAction SilentlyContinue |\n                Where-Object { $_.LastWriteTime -gt $DistTime } | Select-Object -First 1\n            if ($NewerFile) { break }\n        }\n    }\n    # Also check all top-level files (package.json, bun.lock, vite.config.ts, index.html, etc.)\n    if (-not $NewerFile) {\n        $NewerFile = Get-ChildItem -Path $FrontendDir -File -ErrorAction SilentlyContinue |\n            Where-Object { $_.LastWriteTime -gt $DistTime } |\n            Select-Object -First 1\n    }\n    if (-not $NewerFile) {\n        $NeedFrontendBuild = $false\n        Write-Host \"[OK] Frontend already built and up to date -- skipping build\" -ForegroundColor Green\n    } else {\n        Write-Host \"[INFO] Frontend source changed since last build -- rebuilding...\" -ForegroundColor Yellow\n    }\n}\nif ($NeedFrontendBuild -and -not $IsPipInstall) {\n    Write-Host \"\"\n    Write-Host \"Building frontend...\" -ForegroundColor Cyan\n\n    # ── Tailwind v4 .gitignore workaround ──\n    # Tailwind v4's oxide scanner respects .gitignore in parent directories.\n    # Python venvs create a .gitignore with \"*\" (ignore everything), which\n    # prevents Tailwind from scanning .tsx source files for class names.\n    # Temporarily hide any such .gitignore during the build, then restore it.\n    $HiddenGitignores = @()\n    $WalkDir = (Get-Item $FrontendDir).Parent.FullName\n    while ($WalkDir -and $WalkDir -ne [System.IO.Path]::GetPathRoot($WalkDir)) {\n        $gi = Join-Path $WalkDir \".gitignore\"\n        if (Test-Path $gi) {\n            $content = Get-Content $gi -Raw -ErrorAction SilentlyContinue\n            if ($content -and ($content.Trim() -match '^\\*$')) {\n                $hidden = \"$gi._twbuild\"\n                Rename-Item -Path $gi -NewName (Split-Path $hidden -Leaf) -Force\n                $HiddenGitignores += $gi\n                Write-Host \"   [INFO] Temporarily hiding $gi (venv .gitignore blocks Tailwind scanner)\" -ForegroundColor DarkGray\n            }\n        }\n        $WalkDir = Split-Path $WalkDir -Parent\n    }\n\n    # npm writes warnings to stderr; lower ErrorActionPreference so PS doesn't\n    # treat them as terminating errors (same pattern as the pip section below).\n    $prevEAP_npm = $ErrorActionPreference\n    $ErrorActionPreference = \"Continue\"\n    Push-Location $FrontendDir\n    npm install 2>&1 | Out-Null\n    if ($LASTEXITCODE -ne 0) {\n        Pop-Location\n        $ErrorActionPreference = $prevEAP_npm\n        foreach ($gi in $HiddenGitignores) { Rename-Item -Path \"$gi._twbuild\" -NewName (Split-Path $gi -Leaf) -Force -ErrorAction SilentlyContinue }\n        Write-Host \"[ERROR] npm install failed (exit code $LASTEXITCODE)\" -ForegroundColor Red\n        Write-Host \"   Try running 'npm install' manually in frontend/ to see errors\" -ForegroundColor Yellow\n        exit 1\n    }\n    npm run build 2>&1 | Out-Null\n    if ($LASTEXITCODE -ne 0) {\n        Pop-Location\n        $ErrorActionPreference = $prevEAP_npm\n        foreach ($gi in $HiddenGitignores) { Rename-Item -Path \"$gi._twbuild\" -NewName (Split-Path $gi -Leaf) -Force -ErrorAction SilentlyContinue }\n        Write-Host \"[ERROR] npm run build failed (exit code $LASTEXITCODE)\" -ForegroundColor Red\n        exit 1\n    }\n    Pop-Location\n    $ErrorActionPreference = $prevEAP_npm\n\n    # ── Restore hidden .gitignore files ──\n    foreach ($gi in $HiddenGitignores) {\n        Rename-Item -Path \"$gi._twbuild\" -NewName (Split-Path $gi -Leaf) -Force -ErrorAction SilentlyContinue\n    }\n\n    # ── Validate CSS output ──\n    $CssFiles = Get-ChildItem (Join-Path $DistDir \"assets\") -Filter \"*.css\" -ErrorAction SilentlyContinue\n    $MaxCssSize = ($CssFiles | Measure-Object -Property Length -Maximum).Maximum\n    if ($MaxCssSize -lt 100000) {\n        Write-Host \"[WARN] Largest CSS file is only $([math]::Round($MaxCssSize / 1024))KB -- Tailwind may not have scanned all source files.\" -ForegroundColor Yellow\n        Write-Host \"       Expected >100KB. Check for .gitignore files blocking the Tailwind oxide scanner.\" -ForegroundColor Yellow\n    } else {\n        Write-Host \"[OK] Frontend built to frontend/dist (CSS: $([math]::Round($MaxCssSize / 1024))KB)\" -ForegroundColor Green\n    }\n}\n\nif (Test-Path $OxcValidatorDir) {\n    Write-Host \"Installing OXC validator runtime...\" -ForegroundColor Cyan\n    $prevEAP_oxc = $ErrorActionPreference\n    $ErrorActionPreference = \"Continue\"\n    Push-Location $OxcValidatorDir\n    npm install 2>&1 | Out-Null\n    if ($LASTEXITCODE -ne 0) {\n        Pop-Location\n        $ErrorActionPreference = $prevEAP_oxc\n        Write-Host \"[ERROR] OXC validator npm install failed (exit code $LASTEXITCODE)\" -ForegroundColor Red\n        exit 1\n    }\n    Pop-Location\n    $ErrorActionPreference = $prevEAP_oxc\n    Write-Host \"[OK] OXC validator runtime installed\" -ForegroundColor Green\n}\n\n# ==========================================================================\n#  PHASE 3: Python environment + dependencies\n# ==========================================================================\nWrite-Host \"\"\nWrite-Host \"Setting up Python environment...\" -ForegroundColor Cyan\n\n# Find Python\n$PythonCmd = $null\nforeach ($candidate in @(\"python3.13\", \"python3.12\", \"python3.11\", \"python3\", \"python\")) {\n    try {\n        $ver = & $candidate --version 2>&1\n        if ($ver -match 'Python 3\\.(\\d+)') {\n            $minor = [int]$Matches[1]\n            if ($minor -ge 11 -and $minor -le 13) {\n                $PythonCmd = $candidate\n                break\n            }\n        }\n    } catch { }\n}\n\nif (-not $PythonCmd) {\n    Write-Host \"[ERROR] No Python 3.11-3.13 found.\" -ForegroundColor Red\n    exit 1\n}\n\nWrite-Host \"[OK] Using $PythonCmd ($(& $PythonCmd --version 2>&1))\" -ForegroundColor Green\n\n# Always create a .venv for isolation -- even for pip installs.\n# Created in the repo root (parent of studio/).\n$VenvDir = Join-Path $env:USERPROFILE \".unsloth\\studio\\.venv\"\nif (-not (Test-Path $VenvDir)) {\n    Write-Host \"   Creating virtual environment at $VenvDir...\" -ForegroundColor Cyan\n    & $PythonCmd -m venv $VenvDir\n} else {\n    Write-Host \"   Reusing existing virtual environment at $VenvDir\" -ForegroundColor Green\n}\n\n# pip and python write to stderr even on success (progress bars, warnings).\n# With $ErrorActionPreference = \"Stop\" (set at top of script), PS 5.1\n# converts stderr lines into terminating ErrorRecords, breaking output.\n# Lower to \"Continue\" for the pip/python section.\n$prevEAP = $ErrorActionPreference\n$ErrorActionPreference = \"Continue\"\n\n$ActivateScript = Join-Path $VenvDir \"Scripts\\Activate.ps1\"\n. $ActivateScript\n\n# Try to use uv (much faster than pip), fall back to pip if unavailable\n$UseUv = $false\nif (Get-Command uv -ErrorAction SilentlyContinue) {\n    $UseUv = $true\n} else {\n    Write-Host \"   Installing uv package manager...\" -ForegroundColor Cyan\n    try {\n        powershell -ExecutionPolicy ByPass -c \"irm https://astral.sh/uv/install.ps1 | iex\" 2>&1 | Out-Null\n        Refresh-Environment\n        # Re-activate venv since Refresh-Environment rebuilds PATH from\n        # registry and drops the venv's Scripts directory\n        . $ActivateScript\n        if (Get-Command uv -ErrorAction SilentlyContinue) { $UseUv = $true }\n    } catch { }\n}\n\n# Helper: install a package, preferring uv with pip fallback\nfunction Fast-Install {\n    param([Parameter(ValueFromRemainingArguments=$true)]$Args_)\n    if ($UseUv) {\n        $VenvPy = (Get-Command python).Source\n        $result = & uv pip install --python $VenvPy @Args_ 2>&1\n        if ($LASTEXITCODE -eq 0) { return }\n    }\n    & python -m pip install @Args_ 2>&1\n}\n\nFast-Install --upgrade pip | Out-Null\n\n# if (-not $IsPipInstall) {\n#     # Running from repo: copy requirements and do editable install\n#     $RepoRoot = (Resolve-Path (Join-Path $ScriptDir \"..\\..\")).Path\n#     $ReqsSrc = Join-Path $RepoRoot \"backend\\requirements\"\n#     $ReqsDst = Join-Path $PackageDir \"requirements\"\n#     if (-not (Test-Path $ReqsDst)) { New-Item -ItemType Directory -Path $ReqsDst | Out-Null }\n#     Copy-Item (Join-Path $ReqsSrc \"*.txt\") $ReqsDst -Force\n\n#     Write-Host \"   Installing CLI entry point...\" -ForegroundColor Cyan\n#     pip install -e $RepoRoot 2>&1 | Out-Null\n# } else {\n#     # Running from pip install: the package is in system Python but not in\n#     # the fresh .venv. Install it so run_install() can find its modules\n#     # and bundled requirements files.\n#     Write-Host \"   Installing package into venv...\" -ForegroundColor Cyan\n#     pip install unsloth-roland-test 2>&1 | Out-Null\n# }\n\n# Pre-install PyTorch with CUDA support.\n# On Windows, the default PyPI torch wheel is CPU-only.\n# We need PyTorch's CUDA index to get GPU-enabled wheels.\n# PyTorch bundles its own CUDA runtime, so this works regardless\n# of whether the CUDA Toolkit is installed yet.\n# The CUDA tag is chosen based on the driver's max supported CUDA version.\n\n# Windows MAX_PATH (260 chars) causes Triton kernel compilation to fail because\n# the auto-generated filenames are extremely long. Use a short cache directory.\n$TorchCacheDir = \"C:\\tc\"\nif (-not (Test-Path $TorchCacheDir)) { New-Item -ItemType Directory -Path $TorchCacheDir -Force | Out-Null }\n$env:TORCHINDUCTOR_CACHE_DIR = $TorchCacheDir\n[Environment]::SetEnvironmentVariable('TORCHINDUCTOR_CACHE_DIR', $TorchCacheDir, 'User')\nWrite-Host \"[OK] TORCHINDUCTOR_CACHE_DIR set to $TorchCacheDir (avoids MAX_PATH issues)\" -ForegroundColor Green\n\nif ($HasNvidiaSmi) {\n    $CuTag = Get-PytorchCudaTag\n    Write-Host \"   Installing PyTorch with CUDA support ($CuTag)...\" -ForegroundColor Cyan\n    Write-Host \"   (This download is ~2.8 GB -- may take a few minutes)\" -ForegroundColor Gray\n    $output = Fast-Install torch torchvision torchaudio --index-url \"https://download.pytorch.org/whl/$CuTag\" | Out-String\n    if ($LASTEXITCODE -ne 0) {\n        Write-Host \"[FAILED] PyTorch CUDA install failed (exit code $LASTEXITCODE)\" -ForegroundColor Red\n        Write-Host $output -ForegroundColor Red\n        exit 1\n    }\n\n    # Install Triton for Windows (enables torch.compile -- without it training can hang)\n    Write-Host \"   Installing Triton for Windows...\" -ForegroundColor Cyan\n    $output = Fast-Install \"triton-windows<3.7\" | Out-String\n    if ($LASTEXITCODE -ne 0) {\n        Write-Host \"[WARN] Triton install failed -- torch.compile may not work\" -ForegroundColor Yellow\n        Write-Host $output -ForegroundColor Yellow\n    } else {\n        Write-Host \"[OK] Triton for Windows installed (enables torch.compile)\" -ForegroundColor Green\n    }\n} else {\n    Write-Host \"   Installing PyTorch (CPU-only)...\" -ForegroundColor Cyan\n    $output = Fast-Install torch torchvision torchaudio --index-url \"https://download.pytorch.org/whl/cpu\" | Out-String\n    if ($LASTEXITCODE -ne 0) {\n        Write-Host \"[FAILED] PyTorch install failed (exit code $LASTEXITCODE)\" -ForegroundColor Red\n        Write-Host $output -ForegroundColor Red\n        exit 1\n    }\n}\n\n# Ordered heavy dependency installation -- shared cross-platform script\nWrite-Host \"   Running ordered dependency installation...\" -ForegroundColor Cyan\npython \"$PSScriptRoot\\install_python_stack.py\"\n# Restore ErrorActionPreference after pip/python work\n$ErrorActionPreference = $prevEAP\n\n# ── Pre-install transformers 5.x into .venv_t5/ ──\n# Models like GLM-4.7-Flash need transformers>=5.3.0. Instead of pip-installing\n# at runtime (slow, ~10-15s), we pre-install into a separate directory.\n# The training subprocess just prepends .venv_t5/ to sys.path -- instant switch.\nWrite-Host \"\"\nWrite-Host \"   Pre-installing transformers 5.x for newer model support...\" -ForegroundColor Cyan\n$VenvT5Dir = Join-Path $env:USERPROFILE \".unsloth\\studio\\.venv_t5\"\nif (Test-Path $VenvT5Dir) { Remove-Item -Recurse -Force $VenvT5Dir }\nNew-Item -ItemType Directory -Path $VenvT5Dir -Force | Out-Null\n$prevEAP_t5 = $ErrorActionPreference\n$ErrorActionPreference = \"Continue\"\nforeach ($pkg in @(\"transformers==5.3.0\", \"huggingface_hub==1.7.1\", \"hf_xet==1.4.2\")) {\n    $output = Fast-Install --target $VenvT5Dir --no-deps $pkg | Out-String\n    if ($LASTEXITCODE -ne 0) {\n        Write-Host \"[FAIL] Could not install $pkg into .venv_t5/\" -ForegroundColor Red\n        Write-Host $output -ForegroundColor Red\n        $ErrorActionPreference = $prevEAP_t5\n        exit 1\n    }\n}\n# tiktoken is needed by Qwen-family tokenizers -- install with deps since\n# regex/requests may be missing on Windows\n$output = Fast-Install --target $VenvT5Dir tiktoken | Out-String\nif ($LASTEXITCODE -ne 0) {\n    Write-Host \"[WARN] Could not install tiktoken into .venv_t5/ -- Qwen tokenizers may fail\" -ForegroundColor Yellow\n}\n$ErrorActionPreference = $prevEAP_t5\nWrite-Host \"[OK] Transformers 5.x pre-installed to .venv_t5/\" -ForegroundColor Green\n\n# ==========================================================================\n#  PHASE 3.5: Install OpenSSL dev (for HTTPS support in llama-server)\n# ==========================================================================\n# llama-server needs OpenSSL to download models from HuggingFace via -hf.\n# ShiningLight.OpenSSL.Dev includes headers + libs that cmake can find.\n$OpenSslAvailable = $false\n\n# Check if OpenSSL dev is already installed (look for include dir)\n$OpenSslRoots = @(\n    'C:\\Program Files\\OpenSSL-Win64',\n    'C:\\Program Files\\OpenSSL',\n    'C:\\OpenSSL-Win64'\n)\n$OpenSslRoot = $null\nforeach ($root in $OpenSslRoots) {\n    if (Test-Path (Join-Path $root 'include\\openssl\\ssl.h')) {\n        $OpenSslRoot = $root\n        break\n    }\n}\n\nif ($OpenSslRoot) {\n    $OpenSslAvailable = $true\n    Write-Host \"[OK] OpenSSL dev found at $OpenSslRoot\" -ForegroundColor Green\n} else {\n    Write-Host \"\" \n    Write-Host \"Installing OpenSSL dev (for HTTPS in llama-server)...\" -ForegroundColor Cyan\n    $HasWinget = $null -ne (Get-Command winget -ErrorAction SilentlyContinue)\n    if ($HasWinget) {\n        winget install -e --id ShiningLight.OpenSSL.Dev --accept-package-agreements --accept-source-agreements\n        # Re-check after install\n        foreach ($root in $OpenSslRoots) {\n            if (Test-Path (Join-Path $root 'include\\openssl\\ssl.h')) {\n                $OpenSslRoot = $root\n                $OpenSslAvailable = $true\n                Write-Host \"[OK] OpenSSL dev installed at $OpenSslRoot\" -ForegroundColor Green\n                break\n            }\n        }\n    }\n    if (-not $OpenSslAvailable) {\n        Write-Host \"[WARN] OpenSSL dev not available -- llama-server will be built without HTTPS\" -ForegroundColor Yellow\n    }\n}\n\n# ==========================================================================\n#  PHASE 4: Build llama.cpp with CUDA for GGUF inference + export\n# ==========================================================================\n# Builds at ~/.unsloth/llama.cpp — a single shared location under the user's\n# home directory. This is used by both the inference server and the GGUF\n# export pipeline (unsloth-zoo).\n# We build:\n#   - llama-server:   for GGUF model inference (with HTTPS if OpenSSL available)\n#   - llama-quantize: for GGUF export quantization\n# Prerequisites (git, cmake, VS Build Tools, CUDA Toolkit) already installed in Phase 1.\n$UnslothHome = Join-Path $env:USERPROFILE \".unsloth\"\nif (-not (Test-Path $UnslothHome)) { New-Item -ItemType Directory -Force $UnslothHome | Out-Null }\n$LlamaCppDir = Join-Path $UnslothHome \"llama.cpp\"\n$BuildDir = Join-Path $LlamaCppDir \"build\"\n$LlamaServerBin = Join-Path $BuildDir \"bin\\Release\\llama-server.exe\"\n\n$HasCmakeForBuild = $null -ne (Get-Command cmake -ErrorAction SilentlyContinue)\n\n# Check if existing llama-server matches current GPU mode. A CUDA-built binary\n# on a now-CPU-only machine (or vice versa) needs to be rebuilt.\n$NeedRebuild = $false\nif (Test-Path $LlamaServerBin) {\n    $CmakeCacheFile = Join-Path $BuildDir \"CMakeCache.txt\"\n    if (Test-Path $CmakeCacheFile) {\n        $cachedCuda = Select-String -Path $CmakeCacheFile -Pattern 'GGML_CUDA:BOOL=ON' -Quiet\n        if ($HasNvidiaSmi -and -not $cachedCuda) {\n            Write-Host \"   Existing llama-server is CPU-only but GPU is available -- rebuilding\" -ForegroundColor Yellow\n            $NeedRebuild = $true\n        } elseif (-not $HasNvidiaSmi -and $cachedCuda) {\n            Write-Host \"   Existing llama-server was built with CUDA but no GPU detected -- rebuilding\" -ForegroundColor Yellow\n            $NeedRebuild = $true\n        }\n    }\n}\n\nif ((Test-Path $LlamaServerBin) -and -not $NeedRebuild) {\n    Write-Host \"\"\n    Write-Host \"[OK] llama-server already exists at $LlamaServerBin\" -ForegroundColor Green\n} elseif (-not $HasCmakeForBuild) {\n    Write-Host \"\"\n    if (-not $HasNvidiaSmi) {\n        # CPU-only machines depend entirely on llama-server for GGUF chat -- cmake is required\n        Write-Host \"[ERROR] CMake is required to build llama-server for GGUF chat mode.\" -ForegroundColor Red\n        Write-Host \"        Install CMake from https://cmake.org/download/ and re-run setup.\" -ForegroundColor Yellow\n        exit 1\n    }\n    Write-Host \"[SKIP] llama-server build -- cmake not available\" -ForegroundColor Yellow\n    Write-Host \"       GGUF inference and export will not be available.\" -ForegroundColor Yellow\n    Write-Host \"       Install CMake from https://cmake.org/download/ and re-run setup.\" -ForegroundColor Yellow\n} else {\n    Write-Host \"\"\n    if ($HasNvidiaSmi) {\n        Write-Host \"Building llama.cpp with CUDA support...\" -ForegroundColor Cyan\n    } else {\n        Write-Host \"Building llama.cpp (CPU-only, no NVIDIA GPU detected)...\" -ForegroundColor Cyan\n    }\n    Write-Host \"   This typically takes 5-10 minutes on first build.\" -ForegroundColor Gray\n    Write-Host \"\"\n\n    # Start total build timer\n    $totalSw = [System.Diagnostics.Stopwatch]::StartNew()\n\n    # Native commands (git, cmake) write to stderr even on success.\n    # With $ErrorActionPreference = \"Stop\" (set at top of script), PS 5.1\n    # converts stderr lines into terminating ErrorRecords, breaking output.\n    # Lower to \"Continue\" for the build section.\n    $prevEAP = $ErrorActionPreference\n    $ErrorActionPreference = \"Continue\"\n\n    $BuildOk = $true\n    $FailedStep = \"\"\n\n    # Re-sanitize CUDA_PATH_V* vars — Refresh-Environment (called during\n    # Node/Python installs above) may have repopulated conflicting versioned\n    # vars from the Machine registry.\n    if ($HasNvidiaSmi -and $CudaToolkitRoot) {\n        $cudaPathVars2 = @([Environment]::GetEnvironmentVariables('Process').Keys | Where-Object { $_ -match '^CUDA_PATH_V' })\n        foreach ($v2 in $cudaPathVars2) {\n            [Environment]::SetEnvironmentVariable($v2, $null, 'Process')\n        }\n        $tkDirName2 = Split-Path $CudaToolkitRoot -Leaf\n        if ($tkDirName2 -match '^v(\\d+)\\.(\\d+)') {\n            [Environment]::SetEnvironmentVariable(\"CUDA_PATH_V$($Matches[1])_$($Matches[2])\", $CudaToolkitRoot, 'Process')\n        }\n        # Also re-assert CUDA_PATH and CudaToolkitDir in case they were overwritten\n        [Environment]::SetEnvironmentVariable('CUDA_PATH', $CudaToolkitRoot, 'Process')\n        [Environment]::SetEnvironmentVariable('CudaToolkitDir', \"$CudaToolkitRoot\\\", 'Process')\n    }\n\n    # -- Step A: Clone or pull llama.cpp --\n\n    if (Test-Path (Join-Path $LlamaCppDir \".git\")) {\n        Write-Host \"   llama.cpp repo already cloned, pulling latest...\" -ForegroundColor Gray\n        git -C $LlamaCppDir pull 2>&1 | Out-Null\n        if ($LASTEXITCODE -ne 0) {\n            Write-Host \"   [WARN] git pull failed -- using existing source\" -ForegroundColor Yellow\n        }\n    } else {\n        Write-Host \"   Cloning llama.cpp...\" -ForegroundColor Gray\n        if (Test-Path $LlamaCppDir) { Remove-Item -Recurse -Force $LlamaCppDir }\n        git clone --depth 1 https://github.com/ggml-org/llama.cpp.git $LlamaCppDir 2>&1 | Out-Null\n        if ($LASTEXITCODE -ne 0) {\n            $BuildOk = $false\n            $FailedStep = \"git clone\"\n        }\n    }\n\n    # -- Step B: cmake configure --\n    # Clean stale CMake cache to prevent previous CUDA settings from leaking\n    # into a CPU-only rebuild (or vice versa).\n    $CmakeCacheFile = Join-Path $BuildDir \"CMakeCache.txt\"\n    if (Test-Path $CmakeCacheFile) {\n        Remove-Item -Recurse -Force $BuildDir\n    }\n\n    if ($BuildOk) {\n        Write-Host \"\"\n        Write-Host \"--- cmake configure ---\" -ForegroundColor Cyan\n\n        $CmakeArgs = @(\n            '-S', $LlamaCppDir,\n            '-B', $BuildDir,\n            '-G', $CmakeGenerator,\n            '-Wno-dev'\n        )\n        # Tell cmake exactly where VS is (bypasses registry lookup)\n        if ($VsInstallPath) {\n            $CmakeArgs += \"-DCMAKE_GENERATOR_INSTANCE=$VsInstallPath\"\n        }\n        # Common flags\n        $CmakeArgs += '-DBUILD_SHARED_LIBS=OFF'\n        $CmakeArgs += '-DLLAMA_BUILD_TESTS=OFF'\n        $CmakeArgs += '-DLLAMA_BUILD_EXAMPLES=OFF'\n        $CmakeArgs += '-DLLAMA_BUILD_SERVER=ON'\n        $CmakeArgs += '-DGGML_NATIVE=ON'\n        # HTTPS support via OpenSSL\n        if ($OpenSslAvailable -and $OpenSslRoot) {\n            $CmakeArgs += \"-DOPENSSL_ROOT_DIR=$OpenSslRoot\"\n            $CmakeArgs += '-DLLAMA_OPENSSL=ON'\n        } else {\n            $CmakeArgs += '-DLLAMA_CURL=OFF'\n        }\n        $CmakeArgs += '-DCMAKE_EXE_LINKER_FLAGS=/NODEFAULTLIB:LIBCMT'\n        # CUDA flags -- only if GPU available, otherwise explicitly disable\n        if ($HasNvidiaSmi -and $NvccPath) {\n            $CmakeArgs += '-DGGML_CUDA=ON'\n            $CmakeArgs += \"-DCUDAToolkit_ROOT=$CudaToolkitRoot\"\n            $CmakeArgs += \"-DCUDA_TOOLKIT_ROOT_DIR=$CudaToolkitRoot\"\n            $CmakeArgs += \"-DCMAKE_CUDA_COMPILER=$NvccPath\"\n            if ($CudaArch) {\n                # Validate nvcc actually supports this architecture\n                if (Test-NvccArchSupport -NvccExe $NvccPath -Arch $CudaArch) {\n                    $CmakeArgs += \"-DCMAKE_CUDA_ARCHITECTURES=$CudaArch\"\n                } else {\n                    # GPU arch too new for this toolkit -- fall back to highest supported.\n                    # PTX forward-compatibility will JIT-compile for the actual GPU at runtime.\n                    $maxArch = Get-NvccMaxArch -NvccExe $NvccPath\n                    if ($maxArch) {\n                        $CmakeArgs += \"-DCMAKE_CUDA_ARCHITECTURES=$maxArch\"\n                        Write-Host \"   [WARN] GPU is sm_$CudaArch but nvcc only supports up to sm_$maxArch\" -ForegroundColor Yellow\n                        Write-Host \"          Building with sm_$maxArch (PTX will JIT for your GPU at runtime)\" -ForegroundColor Yellow\n                    }\n                    # else: omit flag entirely, let cmake pick defaults\n                }\n            }\n        } else {\n            $CmakeArgs += '-DGGML_CUDA=OFF'\n        }\n\n        $cmakeOutput = cmake @CmakeArgs 2>&1 | Out-String\n        if ($LASTEXITCODE -ne 0) {\n            $BuildOk = $false\n            $FailedStep = \"cmake configure\"\n            Write-Host $cmakeOutput -ForegroundColor Red\n            if ($cmakeOutput -match 'No CUDA toolset found|CUDA_TOOLKIT_ROOT_DIR|nvcc') {\n                Write-Host \"\"\n                Write-Host \"   Hint: CUDA VS integration may be missing. Try running as admin:\" -ForegroundColor Yellow\n                Write-Host \"   Copy contents of:\" -ForegroundColor Yellow\n                Write-Host \"     <CUDA_PATH>\\extras\\visual_studio_integration\\MSBuildExtensions\" -ForegroundColor Yellow\n                Write-Host \"   into:\" -ForegroundColor Yellow\n                Write-Host \"     <VS_PATH>\\MSBuild\\Microsoft\\VC\\v170\\BuildCustomizations\" -ForegroundColor Yellow\n            }\n        }\n    }\n\n    # -- Step C: Build llama-server --\n    $NumCpu = [Environment]::ProcessorCount\n    if ($NumCpu -lt 1) { $NumCpu = 4 }\n\n    if ($BuildOk) {\n        Write-Host \"\"\n        Write-Host \"--- cmake build (llama-server) ---\" -ForegroundColor Cyan\n        Write-Host \"   Parallel jobs: $NumCpu\" -ForegroundColor Gray\n        Write-Host \"\"\n\n        $output = cmake --build $BuildDir --config Release --target llama-server -j $NumCpu 2>&1 | Out-String\n        if ($LASTEXITCODE -ne 0) {\n            $BuildOk = $false\n            $FailedStep = \"cmake build (llama-server)\"\n            Write-Host $output -ForegroundColor Red\n        }\n    }\n\n    # -- Step D: Build llama-quantize (optional, best-effort) --\n    if ($BuildOk) {\n        Write-Host \"\"\n        Write-Host \"--- cmake build (llama-quantize) ---\" -ForegroundColor Cyan\n        $output = cmake --build $BuildDir --config Release --target llama-quantize -j $NumCpu 2>&1 | Out-String\n        if ($LASTEXITCODE -ne 0) {\n            Write-Host \"   [WARN] llama-quantize build failed (GGUF export may be unavailable)\" -ForegroundColor Yellow\n            Write-Host $output -ForegroundColor Yellow\n        }\n    }\n\n    # Restore ErrorActionPreference\n    $ErrorActionPreference = $prevEAP\n\n    # Stop timer\n    $totalSw.Stop()\n    $totalMin = [math]::Floor($totalSw.Elapsed.TotalMinutes)\n    $totalSec = [math]::Round($totalSw.Elapsed.TotalSeconds % 60, 1)\n\n    # -- Summary --\n    Write-Host \"\"\n    if ($BuildOk -and (Test-Path $LlamaServerBin)) {\n        Write-Host \"[OK] llama-server built at $LlamaServerBin\" -ForegroundColor Green\n        $QuantizeBin = Join-Path $BuildDir \"bin\\Release\\llama-quantize.exe\"\n        if (Test-Path $QuantizeBin) {\n            Write-Host \"[OK] llama-quantize available for GGUF export\" -ForegroundColor Green\n        }\n        Write-Host \"   Build time: ${totalMin}m ${totalSec}s\" -ForegroundColor Cyan\n    } else {\n        # Check alternate paths (some cmake generators don't use Release subdir)\n        $altBin = Join-Path $BuildDir \"bin\\llama-server.exe\"\n        if ($BuildOk -and (Test-Path $altBin)) {\n            Write-Host \"[OK] llama-server built at $altBin\" -ForegroundColor Green\n            Write-Host \"   Build time: ${totalMin}m ${totalSec}s\" -ForegroundColor Cyan\n        } else {\n            Write-Host \"[FAILED] llama.cpp build failed at step: $FailedStep (${totalMin}m ${totalSec}s)\" -ForegroundColor Red\n            Write-Host \"         To retry: delete $LlamaCppDir and re-run setup.\" -ForegroundColor Yellow\n            exit 1\n        }\n    }\n}\n\n# ============================================\n# Done\n# ============================================\nWrite-Host \"\"\nWrite-Host \"+===============================================+\" -ForegroundColor Green\nWrite-Host \"|           Setup Complete!                     |\" -ForegroundColor Green\nWrite-Host \"|                                               |\" -ForegroundColor Green\nWrite-Host \"|  Launch with:                                 |\" -ForegroundColor Green\nWrite-Host \"|    unsloth studio -H 0.0.0.0 -p 8888          |\" -ForegroundColor Green\nWrite-Host \"|                                               |\" -ForegroundColor Green\nWrite-Host \"+===============================================+\" -ForegroundColor Green\n"
  },
  {
    "path": "studio/setup.sh",
    "content": "#!/usr/bin/env bash\n# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nset -euo pipefail\n\nSCRIPT_DIR=\"$(cd \"$(dirname \"${BASH_SOURCE[0]}\")\" && pwd)\"\nREPO_ROOT=\"$(cd \"$SCRIPT_DIR/..\" && pwd)\"\n\n# ── Helper: run command quietly, show output only on failure ──\nrun_quiet() {\n    local label=\"$1\"\n    shift\n    local tmplog\n    tmplog=$(mktemp)\n    if \"$@\" > \"$tmplog\" 2>&1; then\n        rm -f \"$tmplog\"\n    else\n        local exit_code=$?\n        echo \"❌ $label failed (exit code $exit_code):\"\n        cat \"$tmplog\"\n        rm -f \"$tmplog\"\n        exit $exit_code\n    fi\n}\n\necho \"╔══════════════════════════════════════╗\"\necho \"║     Unsloth Studio Setup Script      ║\"\necho \"╚══════════════════════════════════════╝\"\n\n# ── Clean up stale Unsloth compiled caches ──\nrm -rf \"$REPO_ROOT/unsloth_compiled_cache\"\nrm -rf \"$SCRIPT_DIR/backend/unsloth_compiled_cache\"\nrm -rf \"$SCRIPT_DIR/tmp/unsloth_compiled_cache\"\n\n# ── Detect Colab (like unsloth does) ──\nIS_COLAB=false\nkeynames=$'\\n'$(printenv | cut -d= -f1)\nif [[ \"$keynames\" == *$'\\nCOLAB_'* ]]; then\n    IS_COLAB=true\nfi\n\n# ── Detect whether frontend needs building ──\n# Skip if dist/ exists AND no tracked input is newer than dist/.\n# Checks top-level config/entry files and src/, public/ recursively.\n# This handles: PyPI installs (dist/ bundled), repeat runs (no changes),\n# and upgrades/pulls (source newer than dist/ triggers rebuild).\n_NEED_FRONTEND_BUILD=true\nif [ -d \"$SCRIPT_DIR/frontend/dist\" ]; then\n    # Check all top-level files (package.json, bun.lock, vite.config.ts, index.html, etc.)\n    _changed=$(find \"$SCRIPT_DIR/frontend\" -maxdepth 1 -type f \\\n        -newer \"$SCRIPT_DIR/frontend/dist\" -print -quit 2>/dev/null)\n    # Check src/ and public/ recursively (|| true guards against set -e when dirs are missing)\n    if [ -z \"$_changed\" ]; then\n        _changed=$(find \"$SCRIPT_DIR/frontend/src\" \"$SCRIPT_DIR/frontend/public\" \\\n            -type f -newer \"$SCRIPT_DIR/frontend/dist\" -print -quit 2>/dev/null) || true\n    fi\n    if [ -z \"$_changed\" ]; then\n        _NEED_FRONTEND_BUILD=false\n    fi\nfi\nif [ \"$_NEED_FRONTEND_BUILD\" = false ]; then\n    echo \"✅ Frontend already built and up to date -- skipping Node/npm check.\"\nelse\nNEED_NODE=true\nif command -v node &>/dev/null && command -v npm &>/dev/null; then\n    NODE_MAJOR=$(node -v | sed 's/v//' | cut -d. -f1)\n    NPM_MAJOR=$(npm -v | cut -d. -f1)\n    if [ \"$NODE_MAJOR\" -ge 20 ] && [ \"$NPM_MAJOR\" -ge 11 ]; then\n        echo \"✅ Node $(node -v) and npm $(npm -v) already meet requirements. Skipping nvm install.\"\n        NEED_NODE=false\n    else\n        if [ \"$IS_COLAB\" = true ]; then\n            echo \"✅ Node $(node -v) and npm $(npm -v) detected in Colab.\"\n            # In Colab, just upgrade npm directly - nvm doesn't work well\n            if [ \"$NPM_MAJOR\" -lt 11 ]; then\n                echo \"   Upgrading npm to latest...\"\n                npm install -g npm@latest > /dev/null 2>&1\n            fi\n            NEED_NODE=false\n        else\n            echo \"⚠️  Node $(node -v) / npm $(npm -v) too old. Installing via nvm...\"\n        fi\n    fi\nelse\n    echo \"⚠️  Node/npm not found. Installing via nvm...\"\nfi\n\nif [ \"$NEED_NODE\" = true ]; then\n    # ── 2. Install nvm ──\n    export NODE_OPTIONS=--dns-result-order=ipv4first # or else fails on colab.\n    echo \"Installing nvm...\"\n    curl -so- https://raw.githubusercontent.com/nvm-sh/nvm/v0.40.1/install.sh | bash > /dev/null 2>&1\n\n    # Load nvm (source ~/.bashrc won't work inside a script)\n    export NVM_DIR=\"$HOME/.nvm\"\n    set +u\n    [ -s \"$NVM_DIR/nvm.sh\" ] && \\. \"$NVM_DIR/nvm.sh\"\n\n    # ── Fix npmrc conflict with nvm ──\n    # System npm (apt, conda, etc.) may have written `prefix` or `globalconfig`\n    # to ~/.npmrc, which is incompatible with nvm and causes \"nvm use\" to fail\n    # with: \"has a `globalconfig` and/or a `prefix` setting, which are\n    # incompatible with nvm.\"\n    if [ -f \"$HOME/.npmrc\" ]; then\n        if grep -qE '^\\s*(prefix|globalconfig)\\s*=' \"$HOME/.npmrc\"; then\n            echo \"   Removing incompatible prefix/globalconfig from ~/.npmrc for nvm...\"\n            sed -i.bak '/^\\s*\\(prefix\\|globalconfig\\)\\s*=/d' \"$HOME/.npmrc\"\n        fi\n    fi\n\n    # ── 3. Install Node LTS ──\n    echo \"Installing Node LTS...\"\n    run_quiet \"nvm install\" nvm install --lts\n    nvm use --lts > /dev/null 2>&1\n    set -u\n    # ── 4. Verify versions ──\n    NODE_MAJOR=$(node -v | sed 's/v//' | cut -d. -f1)\n    NPM_MAJOR=$(npm -v | cut -d. -f1)\n\n    if [ \"$NODE_MAJOR\" -lt 20 ]; then\n        echo \"❌ ERROR: Node version must be >= 20 (got $(node -v))\"\n        exit 1\n    fi\n    if [ \"$NPM_MAJOR\" -lt 11 ]; then\n        echo \"⚠️  npm version is $(npm -v), expected >= 11. Updating...\"\n        run_quiet \"npm update\" npm install -g npm@latest\n    fi\nfi\n\necho \"✅ Node $(node -v) | npm $(npm -v)\"\n\n# ── 5. Build frontend ──\ncd \"$SCRIPT_DIR/frontend\"\n\n# Tailwind v4's oxide scanner respects .gitignore in parent directories.\n# Python venvs create a .gitignore with \"*\" (ignore everything), which\n# prevents Tailwind from scanning .tsx source files for class names.\n# Temporarily hide any such .gitignore during the build, then restore it.\n_HIDDEN_GITIGNORES=()\n_dir=\"$(pwd)\"\nwhile [ \"$_dir\" != \"/\" ]; do\n    _dir=\"$(dirname \"$_dir\")\"\n    if [ -f \"$_dir/.gitignore\" ] && grep -qx '\\*' \"$_dir/.gitignore\" 2>/dev/null; then\n        mv \"$_dir/.gitignore\" \"$_dir/.gitignore._twbuild\"\n        _HIDDEN_GITIGNORES+=(\"$_dir/.gitignore\")\n    fi\ndone\n\n_restore_gitignores() {\n    for _gi in \"${_HIDDEN_GITIGNORES[@]+\"${_HIDDEN_GITIGNORES[@]}\"}\"; do\n        mv \"${_gi}._twbuild\" \"$_gi\" 2>/dev/null || true\n    done\n}\ntrap _restore_gitignores EXIT\n\nrun_quiet \"npm install\" npm install\nrun_quiet \"npm run build\" npm run build\n\n_restore_gitignores\ntrap - EXIT\n\n# Validate CSS output -- catch truncated Tailwind builds\n_MAX_CSS=$(find \"$SCRIPT_DIR/frontend/dist/assets\" -name '*.css' -exec wc -c {} + 2>/dev/null | sort -n | tail -1 | awk '{print $1}')\nif [ -z \"$_MAX_CSS\" ]; then\n    echo \"⚠️  WARNING: No CSS files were emitted. The frontend build may have failed.\"\nelif [ \"$_MAX_CSS\" -lt 100000 ]; then\n    echo \"⚠️  WARNING: Largest CSS file is only $((_MAX_CSS / 1024))KB (expected >100KB).\"\n    echo \"   Tailwind may not have scanned all source files. Check for .gitignore interference.\"\nfi\n\ncd \"$SCRIPT_DIR\"\necho \"✅ Frontend built to frontend/dist\"\n\nfi  # end frontend build check\n\n# ── oxc-validator runtime (needs npm -- skip if not available) ──\nif [ -d \"$SCRIPT_DIR/backend/core/data_recipe/oxc-validator\" ] && command -v npm &>/dev/null; then\n    cd \"$SCRIPT_DIR/backend/core/data_recipe/oxc-validator\"\n    run_quiet \"npm install (oxc validator runtime)\" npm install\n    cd \"$SCRIPT_DIR\"\nfi\n\n# ── 6. Python venv + deps ──\n\n# ── 6a. Discover best Python >= 3.11 and < 3.14 (i.e. 3.11.x, 3.12.x, or 3.13.x) ──\nMIN_PY_MINOR=11   # minimum minor version (>= 3.11)\nMAX_PY_MINOR=13   # maximum minor version (< 3.14)\nBEST_PY=\"\"\nBEST_MAJOR=0\nBEST_MINOR=0\n\n# Collect candidate python3 binaries (python3, python3.9, python3.10, …)\nfor candidate in $(compgen -c python3 2>/dev/null | grep -E '^python3(\\.[0-9]+)?$' | sort -u); do\n    if ! command -v \"$candidate\" &>/dev/null; then\n        continue\n    fi\n    # Get version string, e.g. \"Python 3.12.5\"\n    ver_str=$(\"$candidate\" --version 2>&1) || continue\n    ver_str=$(echo \"$ver_str\" | awk '{print $2}')\n    py_major=$(echo \"$ver_str\" | cut -d. -f1)\n    py_minor=$(echo \"$ver_str\" | cut -d. -f2)\n\n    # Skip anything that isn't Python 3\n    if [ \"$py_major\" -ne 3 ] 2>/dev/null; then\n        continue\n    fi\n\n    # Skip versions below 3.12 (require > 3.11)\n    if [ \"$py_minor\" -lt \"$MIN_PY_MINOR\" ] 2>/dev/null; then\n        continue\n    fi\n\n    # Skip versions above 3.13 (require < 3.14)\n    if [ \"$py_minor\" -gt \"$MAX_PY_MINOR\" ] 2>/dev/null; then\n        continue\n    fi\n\n    # Keep the highest qualifying version\n    if [ \"$py_minor\" -gt \"$BEST_MINOR\" ]; then\n        BEST_PY=\"$candidate\"\n        BEST_MAJOR=\"$py_major\"\n        BEST_MINOR=\"$py_minor\"\n    fi\ndone\necho \"finished finding best python\"\nif [ -z \"$BEST_PY\" ]; then\n    echo \"❌ ERROR: No Python version between 3.${MIN_PY_MINOR} and 3.${MAX_PY_MINOR} found on this system.\"\n    echo \"   Detected Python 3 installations:\"\n    for candidate in $(compgen -c python3 2>/dev/null | grep -E '^python3(\\.[0-9]+)?$' | sort -u); do\n        if command -v \"$candidate\" &>/dev/null; then\n            echo \"     - $candidate ($($candidate --version 2>&1))\"\n        fi\n    done\n    echo \"\"\n    echo \"   Please install Python 3.${MIN_PY_MINOR} or 3.${MAX_PY_MINOR}.\"\n    echo \"   For example:  sudo apt install python3.12 python3.12-venv\"\n    exit 1\nfi\n\nBEST_VER=$(\"$BEST_PY\" --version 2>&1 | awk '{print $2}')\necho \"✅ Using $BEST_PY ($BEST_VER) — compatible (3.${MIN_PY_MINOR}.x – 3.${MAX_PY_MINOR}.x)\"\n\nREQ_ROOT=\"$SCRIPT_DIR/backend/requirements\"\nSINGLE_ENV_CONSTRAINTS=\"$REQ_ROOT/single-env/constraints.txt\"\nSINGLE_ENV_DATA_DESIGNER=\"$REQ_ROOT/single-env/data-designer.txt\"\nSINGLE_ENV_DATA_DESIGNER_DEPS=\"$REQ_ROOT/single-env/data-designer-deps.txt\"\nSINGLE_ENV_PATCH=\"$REQ_ROOT/single-env/patch_metadata.py\"\n\ninstall_python_stack() {\n    python \"$SCRIPT_DIR/install_python_stack.py\"\n}\n\n# Create venv under ~/.unsloth/studio/ (shared location, not in repo).\n# All platforms (including Colab) use the same isolated venv so that\n# studio dependencies are never installed into the system Python.\nSTUDIO_HOME=\"$HOME/.unsloth/studio\"\nVENV_DIR=\"$STUDIO_HOME/.venv\"\nVENV_T5_DIR=\"$STUDIO_HOME/.venv_t5\"\nmkdir -p \"$STUDIO_HOME\"\n\n# Clean up legacy in-repo venvs if they exist\n[ -d \"$REPO_ROOT/.venv\" ] && rm -rf \"$REPO_ROOT/.venv\"\n[ -d \"$REPO_ROOT/.venv_overlay\" ] && rm -rf \"$REPO_ROOT/.venv_overlay\"\n[ -d \"$REPO_ROOT/.venv_t5\" ] && rm -rf \"$REPO_ROOT/.venv_t5\"\n\nrm -rf \"$VENV_DIR\"\nrm -rf \"$VENV_T5_DIR\"\n# Try creating venv with pip; fall back to --without-pip + bootstrap\n# (some environments like Colab have broken ensurepip)\nif ! \"$BEST_PY\" -m venv \"$VENV_DIR\" 2>/dev/null; then\n    \"$BEST_PY\" -m venv --without-pip \"$VENV_DIR\"\n    source \"$VENV_DIR/bin/activate\"\n    curl -sS https://bootstrap.pypa.io/get-pip.py | python > /dev/null\nelse\n    source \"$VENV_DIR/bin/activate\"\nfi\n\n# ── Ensure uv is available (much faster than pip) ──\nUSE_UV=false\nif command -v uv &>/dev/null; then\n    USE_UV=true\nelif curl -LsSf https://astral.sh/uv/install.sh | sh > /dev/null 2>&1; then\n    export PATH=\"$HOME/.local/bin:$PATH\"\n    command -v uv &>/dev/null && USE_UV=true\nfi\n\n# Helper: install a package, preferring uv with pip fallback\nfast_install() {\n    if [ \"$USE_UV\" = true ]; then\n        uv pip install --python \"$(command -v python)\" \"$@\" && return 0\n    fi\n    python -m pip install \"$@\"\n}\n\ncd \"$SCRIPT_DIR\"\ninstall_python_stack\n\n# ── 6b. Pre-install transformers 5.x into .venv_t5/ ──\n# Models like GLM-4.7-Flash need transformers>=5.3.0. Instead of pip-installing\n# at runtime (slow, ~10-15s), we pre-install into a separate directory.\n# The training subprocess just prepends .venv_t5/ to sys.path -- instant switch.\necho \"\"\necho \"   Pre-installing transformers 5.x for newer model support...\"\nmkdir -p \"$VENV_T5_DIR\"\nrun_quiet \"install transformers 5.x\" fast_install --target \"$VENV_T5_DIR\" --no-deps \"transformers==5.3.0\"\nrun_quiet \"install huggingface_hub for t5\" fast_install --target \"$VENV_T5_DIR\" --no-deps \"huggingface_hub==1.7.1\"\nrun_quiet \"install hf_xet for t5\" fast_install --target \"$VENV_T5_DIR\" --no-deps \"hf_xet==1.4.2\"\n# tiktoken is needed by Qwen-family tokenizers. Install with deps since\n# regex/requests may be missing on Windows.\nrun_quiet \"install tiktoken for t5\" fast_install --target \"$VENV_T5_DIR\" \"tiktoken\"\necho \"✅ Transformers 5.x pre-installed to $VENV_T5_DIR/\"\n\n# ── 7. WSL: pre-install GGUF build dependencies ──\n# On WSL, sudo requires a password and can't be entered during GGUF export\n# (runs in a non-interactive subprocess). Install build deps here instead.\nif grep -qi microsoft /proc/version 2>/dev/null; then\n    echo \"\"\n    echo \"⚠️  WSL detected -- installing build dependencies for GGUF export...\"\n    _GGUF_DEPS=\"pciutils build-essential cmake curl git libcurl4-openssl-dev\"\n\n    # Try without sudo first (works when already root)\n    apt-get update -y >/dev/null 2>&1 || true\n    apt-get install -y $_GGUF_DEPS >/dev/null 2>&1 || true\n\n    # Check which packages are still missing\n    _STILL_MISSING=\"\"\n    for _pkg in $_GGUF_DEPS; do\n        case \"$_pkg\" in\n            build-essential) command -v gcc >/dev/null 2>&1 || _STILL_MISSING=\"$_STILL_MISSING $_pkg\" ;;\n            pciutils) command -v lspci >/dev/null 2>&1 || _STILL_MISSING=\"$_STILL_MISSING $_pkg\" ;;\n            libcurl4-openssl-dev) dpkg -s \"$_pkg\" >/dev/null 2>&1 || _STILL_MISSING=\"$_STILL_MISSING $_pkg\" ;;\n            *) command -v \"$_pkg\" >/dev/null 2>&1 || _STILL_MISSING=\"$_STILL_MISSING $_pkg\" ;;\n        esac\n    done\n    _STILL_MISSING=$(echo \"$_STILL_MISSING\" | sed 's/^ *//')\n\n    if [ -z \"$_STILL_MISSING\" ]; then\n        echo \"✅ GGUF build dependencies installed\"\n    elif command -v sudo >/dev/null 2>&1; then\n        echo \"\"\n        echo \"   !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\"\n        echo \"   WARNING: We require sudo elevated permissions to install:\"\n        echo \"   $_STILL_MISSING\"\n        echo \"   If you accept, we'll run sudo now, and it'll prompt your password.\"\n        echo \"   !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\"\n        echo \"\"\n        printf \"   Accept? [Y/n] \"\n        if [ -r /dev/tty ]; then\n            read -r REPLY </dev/tty || REPLY=\"y\"\n        else\n            REPLY=\"y\"\n        fi\n        case \"$REPLY\" in\n            [nN]*)\n                echo \"\"\n                echo \"   Please install these packages first, then re-run Unsloth Studio setup:\"\n                echo \"   sudo apt-get update -y && sudo apt-get install -y $_STILL_MISSING\"\n                _SKIP_GGUF_BUILD=true\n                ;;\n            *)\n                sudo apt-get update -y\n                sudo apt-get install -y $_STILL_MISSING\n                echo \"✅ GGUF build dependencies installed\"\n                ;;\n        esac\n    else\n        echo \"   sudo is not available on this system.\"\n        echo \"   Please install as root, then re-run setup:\"\n        echo \"   apt-get install -y $_STILL_MISSING\"\n        _SKIP_GGUF_BUILD=true\n    fi\nfi\n\n# ── 8. Build llama.cpp binaries for GGUF inference + export ──\n# Builds at ~/.unsloth/llama.cpp — a single shared location under the user's\n# home directory. This is used by both the inference server and the GGUF\n# export pipeline (unsloth-zoo).\n#   - llama-server: for GGUF model inference\n#   - llama-quantize: for GGUF export quantization (symlinked to root for check_llama_cpp())\nUNSLOTH_HOME=\"$HOME/.unsloth\"\nmkdir -p \"$UNSLOTH_HOME\"\nLLAMA_CPP_DIR=\"$UNSLOTH_HOME/llama.cpp\"\nLLAMA_SERVER_BIN=\"$LLAMA_CPP_DIR/build/bin/llama-server\"\nif [ \"${_SKIP_GGUF_BUILD:-}\" = true ]; then\n    echo \"\"\n    echo \"Skipping llama-server build (missing dependencies)\"\n    echo \"   Install the missing packages and re-run setup to enable GGUF inference.\"\nelse\nrm -rf \"$LLAMA_CPP_DIR\"\n{\n    # Check prerequisites\n    if ! command -v cmake &>/dev/null; then\n        echo \"\"\n        echo \"⚠️  cmake not found — skipping llama-server build (GGUF inference won't be available)\"\n        echo \"   Install cmake and re-run setup.sh to enable GGUF inference.\"\n    elif ! command -v git &>/dev/null; then\n        echo \"\"\n        echo \"⚠️  git not found — skipping llama-server build (GGUF inference won't be available)\"\n    else\n        echo \"\"\n        echo \"Building llama-server for GGUF inference...\"\n\n        BUILD_OK=true\n        run_quiet \"clone llama.cpp\" git clone --depth 1 https://github.com/ggml-org/llama.cpp.git \"$LLAMA_CPP_DIR\" || BUILD_OK=false\n\n        if [ \"$BUILD_OK\" = true ]; then\n            # Skip tests/examples we don't need (faster build)\n            CMAKE_ARGS=\"-DLLAMA_BUILD_TESTS=OFF -DLLAMA_BUILD_EXAMPLES=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_NATIVE=ON\"\n\n            # Use ccache if available (dramatically faster rebuilds)\n            if command -v ccache &>/dev/null; then\n                CMAKE_ARGS=\"$CMAKE_ARGS -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache -DCMAKE_CUDA_COMPILER_LAUNCHER=ccache\"\n                echo \"   Using ccache for faster compilation\"\n            fi\n\n            # Detect CUDA: check nvcc on PATH, then common install locations\n            NVCC_PATH=\"\"\n            if command -v nvcc &>/dev/null; then\n                NVCC_PATH=\"$(command -v nvcc)\"\n            elif [ -x /usr/local/cuda/bin/nvcc ]; then\n                NVCC_PATH=\"/usr/local/cuda/bin/nvcc\"\n                export PATH=\"/usr/local/cuda/bin:$PATH\"\n            elif ls /usr/local/cuda-*/bin/nvcc &>/dev/null 2>&1; then\n                # Pick the newest cuda-XX.X directory\n                NVCC_PATH=\"$(ls -d /usr/local/cuda-*/bin/nvcc 2>/dev/null | sort -V | tail -1)\"\n                export PATH=\"$(dirname \"$NVCC_PATH\"):$PATH\"\n            fi\n\n            if [ -n \"$NVCC_PATH\" ]; then\n                echo \"   Building with CUDA support (nvcc: $NVCC_PATH)...\"\n                CMAKE_ARGS=\"$CMAKE_ARGS -DGGML_CUDA=ON\"\n\n                # Detect GPU compute capability and limit CUDA architectures\n                # Without this, cmake builds for ALL default archs (very slow)\n                CUDA_ARCHS=\"\"\n                if command -v nvidia-smi &>/dev/null; then\n                    # Read all GPUs, deduplicate (handles mixed-GPU hosts)\n                    _raw_caps=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader 2>/dev/null || true)\n                    while IFS= read -r _cap; do\n                        _cap=$(echo \"$_cap\" | tr -d '[:space:]')\n                        if [[ \"$_cap\" =~ ^([0-9]+)\\.([0-9]+)$ ]]; then\n                            _arch=\"${BASH_REMATCH[1]}${BASH_REMATCH[2]}\"\n                            # Append if not already present\n                            case \";$CUDA_ARCHS;\" in\n                                *\";$_arch;\"*) ;;\n                                *) CUDA_ARCHS=\"${CUDA_ARCHS:+$CUDA_ARCHS;}$_arch\" ;;\n                            esac\n                        fi\n                    done <<< \"$_raw_caps\"\n                fi\n\n                if [ -n \"$CUDA_ARCHS\" ]; then\n                    echo \"   GPU compute capabilities: ${CUDA_ARCHS//;/, } -- limiting build to detected archs\"\n                    CMAKE_ARGS=\"$CMAKE_ARGS -DCMAKE_CUDA_ARCHITECTURES=${CUDA_ARCHS}\"\n                else\n                    echo \"   Could not detect GPU arch -- building for all default CUDA architectures (slower)\"\n                fi\n\n                # Multi-threaded nvcc compilation (uses all CPU cores per .cu file)\n                CMAKE_ARGS=\"$CMAKE_ARGS -DCMAKE_CUDA_FLAGS=--threads=0\"\n            elif [ -d /usr/local/cuda ] || nvidia-smi &>/dev/null; then\n                echo \"   CUDA driver detected but nvcc not found — building CPU-only\"\n                echo \"   To enable GPU: install cuda-toolkit or add nvcc to PATH\"\n            else\n                echo \"   Building CPU-only (no CUDA detected)...\"\n            fi\n\n            NCPU=$(nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 4)\n\n            # Use Ninja if available (faster parallel builds than Make)\n            CMAKE_GENERATOR_ARGS=\"\"\n            if command -v ninja &>/dev/null; then\n                CMAKE_GENERATOR_ARGS=\"-G Ninja\"\n            fi\n\n            run_quiet \"cmake llama.cpp\" cmake $CMAKE_GENERATOR_ARGS -S \"$LLAMA_CPP_DIR\" -B \"$LLAMA_CPP_DIR/build\" $CMAKE_ARGS || BUILD_OK=false\n        fi\n\n        if [ \"$BUILD_OK\" = true ]; then\n            run_quiet \"build llama-server\" cmake --build \"$LLAMA_CPP_DIR/build\" --config Release --target llama-server -j\"$NCPU\" || BUILD_OK=false\n        fi\n\n        # Also build llama-quantize (needed by unsloth-zoo's GGUF export pipeline)\n        if [ \"$BUILD_OK\" = true ]; then\n            run_quiet \"build llama-quantize\" cmake --build \"$LLAMA_CPP_DIR/build\" --config Release --target llama-quantize -j\"$NCPU\" || true\n            # Symlink to llama.cpp root — check_llama_cpp() looks for the binary there\n            QUANTIZE_BIN=\"$LLAMA_CPP_DIR/build/bin/llama-quantize\"\n            if [ -f \"$QUANTIZE_BIN\" ]; then\n                ln -sf build/bin/llama-quantize \"$LLAMA_CPP_DIR/llama-quantize\"\n            fi\n        fi\n\n        if [ \"$BUILD_OK\" = true ]; then\n            if [ -f \"$LLAMA_SERVER_BIN\" ]; then\n                echo \"✅ llama-server built at $LLAMA_SERVER_BIN\"\n            else\n                echo \"⚠️  llama-server binary not found after build — GGUF inference won't be available\"\n            fi\n            if [ -f \"$LLAMA_CPP_DIR/llama-quantize\" ]; then\n                echo \"✅ llama-quantize available for GGUF export\"\n            fi\n        else\n            echo \"⚠️  llama-server build failed — GGUF inference won't be available, but everything else works\"\n        fi\n    fi\n}\nfi  # end _SKIP_GGUF_BUILD check\n\necho \"\"\nif [ \"$IS_COLAB\" = true ]; then\n    echo \"╔══════════════════════════════════════╗\"\n    echo \"║           Setup Complete!            ║\"\n    echo \"╠══════════════════════════════════════╣\"\n    echo \"║ Unsloth Studio is ready to start     ║\"\n    echo \"║ in your Colab notebook!              ║\"\n    echo \"║                                      ║\"\n    echo \"║ from colab import start              ║\"\n    echo \"║ start()                              ║\"\n    echo \"╚══════════════════════════════════════╝\"\nelse\n    echo \"╔══════════════════════════════════════╗\"\n    echo \"║           Setup Complete!            ║\"\n    echo \"╠══════════════════════════════════════╣\"\n    echo \"║ Launch with:                         ║\"\n    echo \"║                                      ║\"\n    echo \"║ unsloth studio -H 0.0.0.0 -p 8888    ║\"\n    echo \"╚══════════════════════════════════════╝\"\nfi\n"
  },
  {
    "path": "tests/__init__.py",
    "content": ""
  },
  {
    "path": "tests/qlora/README.md",
    "content": "## QLoRA Train and Merge Tests\n\n### Overview\nTests that performing QLoRA training and merging weights to 16-bits post-training maintains same behavior as trained model.\n\n- `test_unsloth_qlora_train_and_merge.py`: Test Unsloth QLoRA train and merge using `FastLanguageModel.from_pretrained`, `FastLanguageModel.get_peft_model`, and `FastLanguageModel.save_pretrained_merged` apis\n- `test_hf_qlora_train_and_merge.py`: Test Hugging Face QLoRA train and merge using `from_pretrained`, `get_peft_model`, and `merge_and_unload` apis.\n   - Demonstrates that `peft`'s `merge_and_unload` results in loss of accuracy as it requantizes the base layer after merging adapter weights so that the model still contains `Linear4Bit` layers post merging.\n   - I (@jeromeku) implemented a custom merge function that replaces all `LoraLayers` with `Linear` layers whose weights are the dequantized base layer weights with adapter weights merged (compute done in fp32, cast to original dtype after merging), roughly equivalent to `FastLanguageModel.save_pretrained_merged`.\n\n### Usage\nRun unsloth test:\n```bash\npython tests/qlora/test_unsloth_qlora_train_and_merge.py\n```\nRun huggingface test:\n```bash\npython tests/qlora/test_hf_qlora_train_and_merge.py\n```\n\n### Details\nThe tests train a QLoRA model on a single prompt dataset\n```\nQUESTION = \"What day was I born?\"\nANSWER = \"January 1, 2058\"\nUSER_MESSAGE = {\"role\": \"user\", \"content\": QUESTION}\nASSISTANT_MESSAGE = {\"role\": \"assistant\", \"content\": ANSWER}\n```\n\nGiven that the answer is impossible to answer accurately without finetuning, we can only expect the model to answer the question correctly if the model has been trained on the question.\n\nTo check this behavior, we check the model's response to the question before and after training and after merging, checking that the model's response contains the answer after training and merging but not before training.\n\n### Results\n\nFor the unsloth test, the model's behavior is as expected: \n- before training, the model's response does not contain the answer\n- after training, the model's response contains the answer\n- after merging, the model's response contains the answer\n\nFor the huggingface test, the model's behavior is as expected:\n- before training, the model's response does not contain the answer\n- after training, the model's response contains the answer\n- after using peft's `merge_and_unload`, the model's response does not contain the answer\n- after using my custom merge function, the model's response contains the answer\n\nThe scripts should output training params, training logs, as well as model responses before and after training and after merging (only prints model responses if answer is not contained in response)."
  },
  {
    "path": "tests/qlora/test_hf_qlora_train_and_merge.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# ruff: noqa\nimport sys\nfrom pathlib import Path\n\nREPO_ROOT = Path(__file__).parents[2]\nsys.path.append(str(REPO_ROOT))\n\nimport itertools\nfrom copy import deepcopy\n\nimport torch\nfrom datasets import Dataset\nfrom trl import SFTConfig\nfrom tests.utils import header_footer_context\nfrom tests.utils.data_utils import (\n    ANSWER,\n    DEFAULT_MESSAGES,\n    USER_MESSAGE,\n    check_responses,\n    create_dataset,\n    describe_peft_weights,\n)\nfrom tests.utils.hf_utils import (\n    convert_lora_to_linear,\n    fix_llama3_tokenizer,\n    get_peft_config,\n    sample_responses,\n    setup_model,\n    setup_tokenizer,\n    setup_trainer,\n)\n\nif __name__ == \"__main__\":\n    model_name = \"meta-llama/Llama-3.2-1B-Instruct\"\n    dtype = torch.bfloat16\n    max_steps = 100\n    num_examples = 1000\n    lora_rank = 64\n    output_dir = \"sft_test\"\n    seed = 42\n    batch_size = 5\n    num_generations = 5\n    tokenizer = setup_tokenizer(model_name, fixup_funcs = [fix_llama3_tokenizer])\n    temperature = 0.8\n    max_new_tokens = 20\n\n    peft_config = get_peft_config(lora_rank = lora_rank, target_modules = \"all-linear\")\n    model = setup_model(model_name, quantize = True, dtype = dtype, peft_config = peft_config)\n\n    prompt = tokenizer.apply_chat_template(\n        [USER_MESSAGE], tokenize = False, add_generation_prompt = True\n    )\n    with header_footer_context(\"Test Prompt and Answer\"):\n        print(f\"Test Prompt:\\n{prompt}\\nExpected Answer:\\n{ANSWER}\")\n\n    dataset: Dataset = create_dataset(\n        tokenizer, num_examples = num_examples, messages = DEFAULT_MESSAGES\n    )\n    with header_footer_context(\"Dataset\"):\n        print(f\"Dataset: {next(iter(dataset))}\")\n\n    training_args = SFTConfig(\n        output_dir = output_dir,\n        max_steps = max_steps,\n        per_device_train_batch_size = batch_size,\n        log_level = \"info\",\n        report_to = \"none\",\n        num_train_epochs = 1,\n        logging_steps = 1,\n        seed = seed,\n        bf16 = dtype == torch.bfloat16,\n        fp16 = dtype == torch.float16,\n        save_strategy = \"no\",\n    )\n\n    with header_footer_context(\"Train Args\"):\n        print(training_args)\n        print(peft_config)\n\n    trainer = setup_trainer(\n        model, tokenizer, dataset, training_args, peft_config = peft_config\n    )\n\n    with header_footer_context(\"Model\"):\n        print(type(model.model))\n\n    generation_args = {\n        \"num_generations\": num_generations,\n        \"max_new_tokens\": max_new_tokens,\n        \"temperature\": temperature,\n        \"skip_special_tokens\": False,\n        \"dtype\": dtype,\n    }\n    responses = sample_responses(\n        model,\n        tokenizer,\n        prompt = prompt,\n        **generation_args,\n    )\n    with header_footer_context(\"Responses before training\"):\n        check_responses(responses, answer = ANSWER, prompt = prompt)\n\n    with header_footer_context(\"Peft Weights before training\"):\n        for name, stats in itertools.islice(describe_peft_weights(model), 2):\n            print(f\"{name}:\\n{stats}\")\n\n    output = trainer.train()\n    with header_footer_context(\"Peft Weights after training\"):\n        for name, stats in itertools.islice(describe_peft_weights(model), 2):\n            print(f\"{name}:\\n{stats}\")\n\n    with header_footer_context(\"Trainer Output\"):\n        print(output)\n\n    responses = sample_responses(\n        model,\n        tokenizer,\n        prompt = prompt,\n        **generation_args,\n    )\n    with header_footer_context(\"Responses after training\"):\n        check_responses(responses, answer = ANSWER, prompt = prompt)\n\n    model_copy = deepcopy(model)\n\n    merged_model = convert_lora_to_linear(model)\n\n    responses = sample_responses(\n        merged_model,\n        tokenizer,\n        prompt = prompt,\n        **generation_args,\n    )\n    with header_footer_context(\"Responses after custom merging to 16bit\"):\n        check_responses(responses, answer = ANSWER, prompt = prompt)\n\n    merged_model_peft = model_copy.merge_and_unload()\n    responses = sample_responses(\n        merged_model_peft,\n        tokenizer,\n        prompt = prompt,\n        **generation_args,\n    )\n    with header_footer_context(\"Responses after peft merge_and_unload\"):\n        check_responses(responses, answer = ANSWER, prompt = prompt)\n"
  },
  {
    "path": "tests/qlora/test_unsloth_qlora_train_and_merge.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# ruff: noqa\nimport sys\nfrom pathlib import Path\n\nREPO_ROOT = Path(__file__).parents[2]\nsys.path.append(str(REPO_ROOT))\n\nimport itertools\nfrom unsloth import FastLanguageModel\n\nimport torch\nfrom datasets import Dataset\nfrom trl import SFTConfig\nfrom tests.utils import header_footer_context\nfrom tests.utils.data_utils import (\n    DEFAULT_MESSAGES,\n    USER_MESSAGE,\n    ANSWER,\n    create_dataset,\n    describe_peft_weights,\n    check_responses,\n)\nfrom tests.utils.hf_utils import (\n    sample_responses,\n    setup_trainer,\n)\n\n\ndef get_unsloth_model_and_tokenizer(\n    model_name: str,\n    max_seq_length: int,\n    load_in_4bit: bool,\n    fast_inference: bool,\n    max_lora_rank: int = None,\n    gpu_memory_utilization: float = 0.5,\n    dtype: torch.dtype = torch.bfloat16,\n):\n    return FastLanguageModel.from_pretrained(\n        model_name = model_name,\n        max_seq_length = max_seq_length,\n        load_in_4bit = load_in_4bit,\n        fast_inference = fast_inference,\n        max_lora_rank = max_lora_rank,\n        gpu_memory_utilization = gpu_memory_utilization,\n        dtype = dtype,\n    )\n\n\ndef get_unsloth_peft_model(\n    model,\n    lora_rank: int,\n    target_modules: list[str] = \"all-linear\",\n    use_gradient_checkpointing: str = False,\n    random_state: int = 42,\n):\n    return FastLanguageModel.get_peft_model(\n        model,\n        r = lora_rank,\n        target_modules = target_modules,\n        lora_alpha = lora_rank,\n        use_gradient_checkpointing = use_gradient_checkpointing,\n        random_state = random_state,\n    )\n\n\nif __name__ == \"__main__\":\n    model_name = \"meta-llama/Llama-3.2-1B-Instruct\"\n    dtype = torch.bfloat16\n    max_steps = 100\n    num_examples = 1000\n    lora_rank = 64\n    output_dir = \"sft_test\"\n    seed = 42\n    batch_size = 5\n    num_generations = 5\n    target_modules = [\n        \"q_proj\",\n        \"k_proj\",\n        \"v_proj\",\n        \"o_proj\",\n        \"gate_proj\",\n        \"up_proj\",\n        \"down_proj\",\n    ]\n    gradient_checkpointing = False\n    unsloth_merged_path = \"unsloth_merged_16bit\"\n\n    model, tokenizer = get_unsloth_model_and_tokenizer(\n        model_name,\n        max_seq_length = 512,\n        load_in_4bit = True,\n        fast_inference = False,\n        max_lora_rank = lora_rank,\n        dtype = dtype,\n    )\n    temperature = 0.8\n    max_new_tokens = 20\n\n    model = get_unsloth_peft_model(\n        model,\n        lora_rank = lora_rank,\n        target_modules = target_modules,\n        use_gradient_checkpointing = gradient_checkpointing,\n        random_state = seed,\n    )\n\n    prompt = tokenizer.apply_chat_template(\n        [USER_MESSAGE], tokenize = False, add_generation_prompt = True\n    )\n\n    with header_footer_context(\"Test Prompt and Answer\"):\n        print(f\"Test Prompt:\\n{prompt}\\nExpected Answer:\\n{ANSWER}\")\n\n    dataset: Dataset = create_dataset(\n        tokenizer, num_examples = num_examples, messages = DEFAULT_MESSAGES\n    )\n    with header_footer_context(\"Dataset\"):\n        print(f\"Dataset: {next(iter(dataset))}\")\n\n    training_args = SFTConfig(\n        output_dir = output_dir,\n        max_steps = max_steps,\n        per_device_train_batch_size = batch_size,\n        log_level = \"info\",\n        report_to = \"none\",\n        num_train_epochs = 1,\n        logging_steps = 1,\n        seed = seed,\n        bf16 = dtype == torch.bfloat16,\n        fp16 = dtype == torch.float16,\n        save_strategy = \"no\",\n    )\n\n    with header_footer_context(\"Train Args\"):\n        print(training_args)\n\n    trainer = setup_trainer(model, tokenizer, dataset, training_args)\n\n    with header_footer_context(\"Model\"):\n        print(type(model.model))\n\n    generation_args = {\n        \"num_generations\": num_generations,\n        \"max_new_tokens\": max_new_tokens,\n        \"temperature\": temperature,\n        \"skip_special_tokens\": False,\n        \"dtype\": dtype,\n    }\n    responses = sample_responses(\n        model,\n        tokenizer,\n        prompt = prompt,\n        **generation_args,\n    )\n    with header_footer_context(\"Responses before training\"):\n        check_responses(responses, answer = ANSWER, prompt = prompt)\n    with header_footer_context(\"Peft Weights before training\"):\n        for name, stats in itertools.islice(describe_peft_weights(model), 2):\n            print(f\"{name}:\\n{stats}\")\n\n    output = trainer.train()\n    with header_footer_context(\"Peft Weights after training\"):\n        for name, stats in itertools.islice(describe_peft_weights(model), 2):\n            print(f\"{name}:\\n{stats}\")\n\n    with header_footer_context(\"Trainer Output\"):\n        print(output)\n\n    responses = sample_responses(\n        model,\n        tokenizer,\n        prompt = prompt,\n        **generation_args,\n    )\n    with header_footer_context(\"Responses after training\"):\n        check_responses(responses, answer = ANSWER, prompt = prompt)\n\n    model.save_pretrained_merged(\n        unsloth_merged_path,\n        tokenizer,\n        save_method = \"merged_16bit\",\n    )\n    merged_model_unsloth, tokenizer = get_unsloth_model_and_tokenizer(\n        unsloth_merged_path,\n        max_seq_length = 512,\n        load_in_4bit = False,\n        fast_inference = False,\n        dtype = dtype,\n    )\n    responses = sample_responses(\n        merged_model_unsloth,\n        tokenizer,\n        prompt = prompt,\n        **generation_args,\n    )\n    with header_footer_context(\"Responses after unsloth merge to 16bit\"):\n        check_responses(responses, answer = ANSWER, prompt = prompt)\n"
  },
  {
    "path": "tests/saving/gpt-oss-merge/run_test.sh",
    "content": "#!/bin/bash\nset -e\n\necho \"================================================================\"\necho \"🚀 STEP 1: Running the training and merging script...\"\necho \"================================================================\"\npython train_and_merge.py\n\necho \"\"\necho \"================================================================\"\necho \"✅ STEP 2: Training complete. Running the inference script...\"\necho \"================================================================\"\npython test_merged_model.py\n\necho \"\"\necho \"================================================================\"\necho \"🎉 All steps completed successfully!\"\necho \"================================================================\"\n"
  },
  {
    "path": "tests/saving/gpt-oss-merge/test_merged_model.py",
    "content": "# inference_on_merged.py\nfrom unsloth import FastLanguageModel\nfrom transformers import TextStreamer\nimport torch\nimport gc\nimport os\nimport shutil\n\n\ndef safe_remove_directory(path):\n    try:\n        if os.path.exists(path) and os.path.isdir(path):\n            shutil.rmtree(path)\n            return True\n        else:\n            print(f\"Path {path} is not a valid directory\")\n            return False\n    except Exception as e:\n        print(f\"Failed to remove directory {path}: {e}\")\n        return False\n\n\nprint(\"🔥 Loading the 16-bit merged model from disk...\")\nmerged_model, merged_tokenizer = FastLanguageModel.from_pretrained(\n    model_name = \"./gpt-oss-finetuned-merged\",\n    max_seq_length = 1024,\n    load_in_4bit = True,\n    load_in_8bit = False,\n)\nprint(\"✅ Merged model loaded successfully.\")\n\n# --- Run Inference ---\nprint(\"\\n🚀 Running inference...\")\nmessages = [\n    {\"role\": \"user\", \"content\": \"Solve x^5 + 3x^4 - 10 = 3.\"},\n]\ninputs = merged_tokenizer.apply_chat_template(\n    messages,\n    add_generation_prompt = True,\n    return_tensors = \"pt\",\n    return_dict = True,\n    reasoning_effort = \"low\",  # **NEW!** Set reasoning effort to low, medium or high\n).to(merged_model.device)\n\n_ = merged_model.generate(\n    **inputs, max_new_tokens = 512, streamer = TextStreamer(merged_tokenizer)\n)\nprint(\"\\n✅ Inference complete.\")\n\n# --- Final Cleanup ---\nprint(\"\\n🧹 Cleaning up merged model directory and cache...\")\ndel merged_model, merged_tokenizer\ntorch.cuda.empty_cache()\ngc.collect()\n\nsafe_remove_directory(\"./gpt-oss-finetuned-merged\")\nsafe_remove_directory(\n    \"./unsloth_compiled_cache\"\n)  # Clean up cache created by this process\nprint(\"✅ Final cleanup complete. Exiting inference script.\")\n"
  },
  {
    "path": "tests/saving/gpt-oss-merge/train_and_merge.py",
    "content": "# train_and_merge.py\nfrom unsloth import FastLanguageModel\nfrom trl import SFTTrainer, SFTConfig\nfrom datasets import load_dataset\nimport torch\nimport gc\nimport os\nimport shutil\n\n\ndef safe_remove_directory(path):\n    try:\n        if os.path.exists(path) and os.path.isdir(path):\n            shutil.rmtree(path)\n            return True\n        else:\n            print(f\"Path {path} is not a valid directory\")\n            return False\n    except Exception as e:\n        print(f\"Failed to remove directory {path}: {e}\")\n        return False\n\n\n# This tokenizer will be used by the mapping function\ntokenizer = None\n\n\ndef formatting_prompts_func(examples):\n    convos = examples[\"messages\"]\n    texts = [\n        tokenizer.apply_chat_template(\n            convo, tokenize = False, add_generation_prompt = False\n        )\n        for convo in convos\n    ]\n    return {\"text\": texts}\n\n\n# --- Load 4-bit Model and Train ---\nprint(\"Loading 4-bit Mxfp4 gpt-oss model for training...\")\nmax_seq_length = 1024\nmodel, tokenizer = FastLanguageModel.from_pretrained(\n    \"unsloth/gpt-oss-20b\", max_seq_length = max_seq_length, load_in_4bit = True\n)\n\ndataset = load_dataset(\"HuggingFaceH4/Multilingual-Thinking\", split = \"train[:50]\").map(\n    formatting_prompts_func, batched = True\n)\n\nmodel = FastLanguageModel.get_peft_model(\n    model,\n    r = 8,\n    target_modules = [\n        \"q_proj\",\n        \"k_proj\",\n        \"v_proj\",\n        \"o_proj\",\n        \"gate_proj\",\n        \"up_proj\",\n        \"down_proj\",\n    ],\n    lora_alpha = 16,\n    use_gradient_checkpointing = \"unsloth\",\n    random_state = 3407,\n)\n\ntrainer = SFTTrainer(\n    model = model,\n    tokenizer = tokenizer,\n    train_dataset = dataset,\n    args = SFTConfig(\n        per_device_train_batch_size = 1,\n        gradient_accumulation_steps = 4,\n        max_steps = 10,\n        learning_rate = 2e-4,\n        output_dir = \"outputs\",\n        report_to = \"none\",\n    ),\n)\n\nprint(\"Starting fine-tuning...\")\ntrainer.train()\nprint(\"Fine-tuning complete.\")\n\n# --- Merge and Save ---\nprint(\"\\n💾 Merging and saving the 16-bit model to './gpt-oss-finetuned-merged'...\")\nmodel.save_pretrained_merged(\n    save_directory = \"./gpt-oss-finetuned-merged\", tokenizer = tokenizer\n)\nprint(\"✅ Model merged and saved.\")\n\n# --- Cleanup ---\nprint(\"\\n🧹 Cleaning up training artifacts...\")\ndel model, trainer, tokenizer, dataset\ntorch.cuda.empty_cache()\ngc.collect()\n\nsafe_remove_directory(\"./outputs\")\nsafe_remove_directory(\n    \"./unsloth_compiled_cache\"\n)  # Clean up the cache created by this process\nprint(\"✅ Cleanup complete. Exiting training script.\")\n"
  },
  {
    "path": "tests/saving/language_models/test_merge_4bit_validation.py",
    "content": "from unsloth import FastLanguageModel\nfrom unsloth.chat_templates import get_chat_template\nfrom trl import SFTTrainer, SFTConfig\nfrom transformers import DataCollatorForSeq2Seq, TrainingArguments\nfrom datasets import load_dataset\nimport torch\nimport sys\nfrom pathlib import Path\n\nREPO_ROOT = Path(__file__).parents[3]\nsys.path.insert(0, str(REPO_ROOT))\n\nfrom tests.utils.cleanup_utils import safe_remove_directory\n\n\ndef formatting_prompts_func(examples):\n    convos = examples[\"messages\"]\n    texts = [\n        tokenizer.apply_chat_template(\n            convo, tokenize = False, add_generation_prompt = False\n        )\n        for convo in convos\n    ]\n    return {\"text\": texts}\n\n\nprint(f\"\\n{'='*80}\")\nprint(\"🔍 PHASE 1: Loading Base Model and Initial Training\")\nprint(f\"{'='*80}\")\n\nif torch.cuda.is_bf16_supported():\n    compute_dtype = torch.bfloat16\n    attn_implementation = \"flash_attention_2\"\nelse:\n    compute_dtype = torch.float16\n    attn_implementation = \"sdpa\"\n\nmodel, tokenizer = FastLanguageModel.from_pretrained(\n    model_name = \"unsloth/Llama-3.1-8B-Instruct\",\n    max_seq_length = 2048,\n    dtype = compute_dtype,\n    load_in_4bit = True,\n    load_in_8bit = False,\n    full_finetuning = False,\n    attn_implementation = attn_implementation,\n)\n\ntokenizer = get_chat_template(\n    tokenizer,\n    chat_template = \"llama-3.1\",\n)\n\n# Load small dataset for quick training\ndataset_train = load_dataset(\n    \"allenai/openassistant-guanaco-reformatted\", split = \"train[:100]\"\n)\ndataset_train = dataset_train.map(formatting_prompts_func, batched = True)\n\nprint(\"✅ Base model loaded successfully!\")\n\nprint(f\"\\n{'='*80}\")\nprint(\"🔍 PHASE 2: First Fine-tuning\")\nprint(f\"{'='*80}\")\n\nmodel = FastLanguageModel.get_peft_model(\n    model,\n    r = 16,\n    target_modules = [\n        \"k_proj\",\n        \"q_proj\",\n        \"v_proj\",\n        \"o_proj\",\n        \"gate_proj\",\n        \"down_proj\",\n        \"up_proj\",\n    ],\n    lora_alpha = 16,\n    lora_dropout = 0,\n    bias = \"none\",\n    use_gradient_checkpointing = \"unsloth\",\n    random_state = 3407,\n    use_rslora = False,\n    loftq_config = None,\n)\n\nfrom unsloth import is_bfloat16_supported\n\ntrainer = SFTTrainer(\n    model = model,\n    tokenizer = tokenizer,\n    train_dataset = dataset_train,\n    dataset_text_field = \"text\",\n    max_seq_length = 2048,\n    data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),\n    dataset_num_proc = 2,\n    packing = False,\n    args = TrainingArguments(\n        per_device_train_batch_size = 2,\n        gradient_accumulation_steps = 4,\n        warmup_ratio = 0.1,\n        max_steps = 10,  # Very short training for test\n        learning_rate = 2e-4,\n        fp16 = not is_bfloat16_supported(),\n        bf16 = is_bfloat16_supported(),\n        logging_steps = 5,\n        optim = \"adamw_8bit\",\n        lr_scheduler_type = \"linear\",\n        seed = 3407,\n        output_dir = \"outputs\",\n        report_to = \"none\",\n    ),\n)\n\ntrainer_stats = trainer.train()\nprint(\"✅ First fine-tuning completed!\")\n\nprint(f\"\\n{'='*80}\")\nprint(\"🔍 PHASE 3: Save with Forced 4bit Merge\")\nprint(f\"{'='*80}\")\n\nmodel.save_pretrained_merged(\n    save_directory = \"./test_4bit_model\",\n    tokenizer = tokenizer,\n    save_method = \"forced_merged_4bit\",\n)\n\nprint(\"✅ Model saved with forced 4bit merge!\")\n\nprint(f\"\\n{'='*80}\")\nprint(\"🔍 PHASE 4: Loading 4bit Model and Second Fine-tuning\")\nprint(f\"{'='*80}\")\n\n# Clean up first model\ndel model\ndel tokenizer\ntorch.cuda.empty_cache()\n\n# Load the 4bit merged model\nmodel_4bit, tokenizer_4bit = FastLanguageModel.from_pretrained(\n    model_name = \"./test_4bit_model\",\n    max_seq_length = 2048,\n    load_in_4bit = True,\n    load_in_8bit = False,\n)\n\ntokenizer_4bit = get_chat_template(\n    tokenizer_4bit,\n    chat_template = \"llama-3.1\",\n)\n\nprint(\"✅ 4bit model loaded successfully!\")\n\n# Add LoRA adapters to the 4bit model\nmodel_4bit = FastLanguageModel.get_peft_model(\n    model_4bit,\n    r = 16,\n    target_modules = [\n        \"k_proj\",\n        \"q_proj\",\n        \"v_proj\",\n        \"o_proj\",\n        \"gate_proj\",\n        \"down_proj\",\n        \"up_proj\",\n    ],\n    lora_alpha = 16,\n    lora_dropout = 0,\n    bias = \"none\",\n    use_gradient_checkpointing = \"unsloth\",\n    random_state = 3407,\n    use_rslora = False,\n    loftq_config = None,\n)\n\n# Second fine-tuning\ntrainer_4bit = SFTTrainer(\n    model = model_4bit,\n    tokenizer = tokenizer_4bit,\n    train_dataset = dataset_train,\n    dataset_text_field = \"text\",\n    max_seq_length = 2048,\n    data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer_4bit),\n    dataset_num_proc = 2,\n    packing = False,\n    args = TrainingArguments(\n        per_device_train_batch_size = 2,\n        gradient_accumulation_steps = 4,\n        warmup_ratio = 0.1,\n        max_steps = 10,  # Very short training for test\n        learning_rate = 2e-4,\n        fp16 = not is_bfloat16_supported(),\n        bf16 = is_bfloat16_supported(),\n        logging_steps = 5,\n        optim = \"adamw_8bit\",\n        lr_scheduler_type = \"linear\",\n        seed = 3407,\n        output_dir = \"outputs_4bit\",\n        report_to = \"none\",\n    ),\n)\n\ntrainer_4bit.train()\nprint(\"✅ Second fine-tuning on 4bit model completed!\")\n\nprint(f\"\\n{'='*80}\")\nprint(\"🔍 PHASE 5: Testing TypeError on Regular Merge (Should Fail)\")\nprint(f\"{'='*80}\")\n\ntry:\n    model_4bit.save_pretrained_merged(\n        save_directory = \"./test_should_fail\",\n        tokenizer = tokenizer_4bit,\n        # No save_method specified, should default to regular merge\n    )\n    assert False, \"Expected TypeError but merge succeeded!\"\nexcept TypeError as e:\n    expected_error = \"Base model should be a 16bits or mxfp4 base model for a 16bit model merge. Use `save_method=forced_merged_4bit` instead\"\n    assert expected_error in str(e), f\"Unexpected error message: {str(e)}\"\n    print(\"✅ Correct TypeError raised for 4bit base model regular merge attempt!\")\n    print(f\"Error message: {str(e)}\")\n\nprint(f\"\\n{'='*80}\")\nprint(\"🔍 PHASE 6: Successful Save with Forced 4bit Method\")\nprint(f\"{'='*80}\")\n\ntry:\n    model_4bit.save_pretrained_merged(\n        save_directory = \"./test_4bit_second\",\n        tokenizer = tokenizer_4bit,\n        save_method = \"forced_merged_4bit\",\n    )\n    print(\"✅ Successfully saved 4bit model with forced 4bit method!\")\nexcept Exception as e:\n    assert False, f\"Phase 6 failed unexpectedly: {e}\"\n\nprint(f\"\\n{'='*80}\")\nprint(\"🔍 CLEANUP\")\nprint(f\"{'='*80}\")\n\n# Cleanup\nsafe_remove_directory(\"./outputs\")\nsafe_remove_directory(\"./outputs_4bit\")\nsafe_remove_directory(\"./unsloth_compiled_cache\")\nsafe_remove_directory(\"./test_4bit_model\")\nsafe_remove_directory(\"./test_4bit_second\")\nsafe_remove_directory(\"./test_should_fail\")\n\nprint(\"✅ All tests passed successfully!\")\n"
  },
  {
    "path": "tests/saving/language_models/test_merge_model_perplexity_llama-3.2.py",
    "content": "from unsloth import FastLanguageModel, FastVisionModel, UnslothVisionDataCollator\nfrom unsloth.chat_templates import get_chat_template\nfrom trl import SFTTrainer, SFTConfig\nfrom transformers import (\n    DataCollatorForLanguageModeling,\n    DataCollatorForSeq2Seq,\n    TrainingArguments,\n)\nfrom datasets import load_dataset, Dataset\nimport torch\nfrom tqdm import tqdm\nimport pandas as pd\nimport multiprocessing as mp\nfrom multiprocessing import Process, Queue\nimport gc\n\n# ruff: noqa\nimport sys\nfrom pathlib import Path\n\n\nREPO_ROOT = Path(__file__).parents[3]\nsys.path.insert(0, str(REPO_ROOT))\n\n\nfrom tests.utils.cleanup_utils import safe_remove_directory\nfrom tests.utils.perplexity_eval import (\n    ppl_model,\n    add_to_comparison,\n    print_model_comparison,\n)\n\n\n# Define helper functions outside of main\ndef formatting_prompts_func(examples):\n    convos = examples[\"messages\"]\n    texts = [\n        tokenizer.apply_chat_template(\n            convo, tokenize = False, add_generation_prompt = False\n        )\n        for convo in convos\n    ]\n    return {\"text\": texts}\n\n\ndef load_and_compute_8bit_ppl(result_queue, load_in_4bit = False, load_in_8bit = False):\n    \"\"\"Load model and compute perplexity in subprocess\"\"\"\n    from unsloth import FastLanguageModel\n    from unsloth.chat_templates import get_chat_template\n    from tests.utils.perplexity_eval import ppl_model\n\n    # Load model\n    merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(\n        model_name = \"./unsloth_out/merged_llama_text_model\",\n        max_seq_length = 2048,\n        load_in_4bit = load_in_4bit,\n        load_in_8bit = load_in_8bit,\n    )\n    # Set up tokenizer\n    merged_tokenizer = get_chat_template(\n        merged_tokenizer,\n        chat_template = \"llama-3.1\",\n    )\n\n    # Load dataset fresh in subprocess\n    dataset_ppl = load_dataset(\n        \"allenai/openassistant-guanaco-reformatted\", split = \"eval\"\n    )\n\n    # Format the dataset\n    def formatting_prompts_func(examples):\n        convos = examples[\"messages\"]\n        texts = [\n            merged_tokenizer.apply_chat_template(\n                convo, tokenize = False, add_generation_prompt = False\n            )\n            for convo in convos\n        ]\n        return {\"text\": texts}\n\n    dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)\n\n    # Compute perplexity using the passed dataset\n    ppl_value = ppl_model(merged_model, merged_tokenizer, dataset_ppl)\n\n    # IMPORTANT: Convert to Python float if it's a tensor\n    if torch.is_tensor(ppl_value):\n        ppl_value = ppl_value.cpu().item()  # Move to CPU and convert to Python scalar\n    elif hasattr(ppl_value, \"item\"):\n        ppl_value = ppl_value.item()  # Convert numpy or other array types\n    else:\n        ppl_value = float(ppl_value)  # Ensure it's a float\n\n    # Return only the perplexity value\n    result_queue.put(ppl_value)\n\n    # Clean up\n    del merged_model\n    del merged_tokenizer\n    del dataset_ppl\n    torch.cuda.empty_cache()\n    gc.collect()\n\n\n# Main execution code should be wrapped in this guard\nif __name__ == \"__main__\":\n    mp.set_start_method(\"spawn\", force = True)\n\n    if torch.cuda.is_bf16_supported():\n        compute_dtype = torch.bfloat16\n        attn_implementation = \"flash_attention_2\"\n    else:\n        compute_dtype = torch.float16\n        attn_implementation = \"sdpa\"\n\n    model, tokenizer = FastLanguageModel.from_pretrained(\n        model_name = \"unsloth/Llama-3.2-3B-Instruct\",\n        max_seq_length = 2048,\n        dtype = compute_dtype,\n        load_in_4bit = True,\n        load_in_8bit = False,\n        full_finetuning = False,\n        attn_implementation = attn_implementation,\n    )\n\n    tokenizer = get_chat_template(\n        tokenizer,\n        chat_template = \"llama-3.1\",\n    )\n\n    from unsloth.chat_templates import standardize_sharegpt\n\n    dataset_train = load_dataset(\n        \"allenai/openassistant-guanaco-reformatted\", split = \"train\"\n    )\n    dataset_ppl = load_dataset(\n        \"allenai/openassistant-guanaco-reformatted\", split = \"eval\"\n    )\n\n    dataset_train = dataset_train.map(formatting_prompts_func, batched = True)\n    dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)\n\n    add_to_comparison(\"Base model 4 bits\", ppl_model(model, tokenizer, dataset_ppl))\n\n    model = FastLanguageModel.get_peft_model(\n        model,\n        r = 16,\n        target_modules = [\n            \"k_proj\",\n            \"q_proj\",\n            \"v_proj\",\n            \"o_proj\",\n            \"gate_proj\",\n            \"down_proj\",\n            \"up_proj\",\n        ],\n        lora_alpha = 16,\n        lora_dropout = 0,\n        bias = \"none\",\n        use_gradient_checkpointing = \"unsloth\",\n        random_state = 3407,\n        use_rslora = False,\n        loftq_config = None,\n    )\n\n    from unsloth import is_bfloat16_supported\n\n    trainer = SFTTrainer(\n        model = model,\n        tokenizer = tokenizer,\n        train_dataset = dataset_train,\n        dataset_text_field = \"text\",\n        max_seq_length = 2048,\n        data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),\n        dataset_num_proc = 2,\n        packing = False,\n        args = TrainingArguments(\n            per_device_train_batch_size = 2,\n            gradient_accumulation_steps = 4,\n            warmup_ratio = 0.1,\n            max_steps = 10,\n            learning_rate = 2e-4,\n            fp16 = not is_bfloat16_supported(),\n            bf16 = is_bfloat16_supported(),\n            logging_steps = 50,\n            optim = \"adamw_8bit\",\n            lr_scheduler_type = \"linear\",\n            seed = 3407,\n            output_dir = \"outputs\",\n            report_to = \"none\",\n        ),\n    )\n\n    from unsloth.chat_templates import train_on_responses_only\n\n    trainer = train_on_responses_only(\n        trainer,\n        instruction_part = \"<|start_header_id|>user<|end_header_id|>\\n\\n\",\n        response_part = \"<|start_header_id|>assistant<|end_header_id|>\\n\\n\",\n    )\n\n    # run training\n    trainer_stats = trainer.train()\n\n    add_to_comparison(\"Qlora model\", ppl_model(model, tokenizer, dataset_ppl))\n\n    # saving and merging the model to local disk\n    print(\"merge and save to local disk\")\n    model.save_pretrained_merged(\n        save_directory = \"./unsloth_out/merged_llama_text_model\", tokenizer = tokenizer\n    )\n\n    # print(\"cleaning\")\n    # del model\n    # del tokenizer\n    # torch.cuda.empty_cache()\n    # gc.collect()\n\n    # load model from local disk and test\n    print(\"Loading merged model in 4 bit for perplexity test\")\n    merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(\n        model_name = \"./unsloth_out/merged_llama_text_model\",\n        max_seq_length = 2048,\n        load_in_4bit = True,\n        load_in_8bit = False,\n    )\n\n    add_to_comparison(\n        \"merged model load 4bit\", ppl_model(merged_model, merged_tokenizer, dataset_ppl)\n    )\n\n    print(\"Computing 8-bit model perplexity in subprocess...\")\n    result_queue = mp.Queue()\n    p = mp.Process(target = load_and_compute_8bit_ppl, args = (result_queue, False, True))\n    p.start()\n    p.join()\n\n    ppl_8bit = result_queue.get()\n    add_to_comparison(\"merged model loaded 8bits\", ppl_8bit)\n\n    print(\"Loading merged model in 16 bit for perplexity test\")\n    merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(\n        model_name = \"./unsloth_out/merged_llama_text_model\",\n        max_seq_length = 2048,\n        load_in_4bit = False,\n        load_in_8bit = False,\n    )\n\n    add_to_comparison(\n        \"merged model loaded 16bits\",\n        ppl_model(merged_model, merged_tokenizer, dataset_ppl),\n    )\n\n    print_model_comparison()\n\n    # final cleanup\n    safe_remove_directory(\"./outputs\")\n    safe_remove_directory(\"./unsloth_compiled_cache\")\n    safe_remove_directory(\"./unsloth_out\")\n"
  },
  {
    "path": "tests/saving/language_models/test_merge_model_perplexity_mistral.py",
    "content": "from unsloth import FastLanguageModel, FastVisionModel, UnslothVisionDataCollator\nfrom unsloth.chat_templates import get_chat_template\nfrom trl import SFTTrainer, SFTConfig\nfrom transformers import (\n    DataCollatorForLanguageModeling,\n    DataCollatorForSeq2Seq,\n    TrainingArguments,\n)\nfrom datasets import load_dataset, Dataset\nimport torch\nfrom tqdm import tqdm\nimport pandas as pd\nimport multiprocessing as mp\nfrom multiprocessing import Process, Queue\nimport gc\n\n# ruff: noqa\nimport sys\nfrom pathlib import Path\n\n\nREPO_ROOT = Path(__file__).parents[3]\nsys.path.insert(0, str(REPO_ROOT))\n\nfrom tests.utils.cleanup_utils import safe_remove_directory\nfrom tests.utils.perplexity_eval import (\n    ppl_model,\n    add_to_comparison,\n    print_model_comparison,\n)\n\n\ndef load_and_compute_8bit_ppl(result_queue, load_in_4bit = False, load_in_8bit = False):\n    \"\"\"Load model and compute perplexity in subprocess\"\"\"\n    from unsloth import FastLanguageModel\n    from tests.utils.perplexity_eval import ppl_model\n\n    # Load model\n    merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(\n        model_name = \"./unsloth_out/merged_mistral_text_model\",\n        max_seq_length = 2048,\n        load_in_4bit = load_in_4bit,\n        load_in_8bit = load_in_8bit,\n    )\n    # Set up tokenizer\n    # merged_tokenizer = get_chat_template(\n    #     merged_tokenizer,\n    #     chat_template=\"llama-3.1\",\n    # )\n\n    # Load dataset fresh in subprocess\n    dataset_ppl = load_dataset(\n        \"allenai/openassistant-guanaco-reformatted\", split = \"eval\"\n    )\n\n    alpaca_prompt = \"\"\"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n    ### Instruction:\n    {}\n\n    ### Input:\n    {}\n\n    ### Response:\n    {}\"\"\"\n\n    EOS_TOKEN = merged_tokenizer.eos_token\n\n    def formatting_prompts_func(examples):\n        instructions = []\n        inputs = []\n        outputs = []\n        texts = []\n\n        for conversation in examples[\"messages\"]:\n            # Extract user message and assistant response\n            user_message = \"\"\n            assistant_message = \"\"\n\n            for turn in conversation:\n                if turn[\"role\"] == \"user\":\n                    user_message = turn[\"content\"]\n                elif turn[\"role\"] == \"assistant\":\n                    assistant_message = turn[\"content\"]\n\n            # Store intermediate format\n            instruction = \"Complete the statement\"\n            instructions.append(instruction)\n            inputs.append(user_message)\n            outputs.append(assistant_message)\n\n            # Create formatted text\n            text = (\n                alpaca_prompt.format(instruction, user_message, assistant_message)\n                + EOS_TOKEN\n            )\n            texts.append(text)\n\n        return {\n            \"instruction\": instructions,\n            \"input\": inputs,\n            \"output\": outputs,\n            \"text\": texts,\n        }\n\n    dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)\n\n    # Compute perplexity using the passed dataset\n    ppl_value = ppl_model(merged_model, merged_tokenizer, dataset_ppl)\n\n    # IMPORTANT: Convert to Python float if it's a tensor\n    if torch.is_tensor(ppl_value):\n        ppl_value = ppl_value.cpu().item()  # Move to CPU and convert to Python scalar\n    elif hasattr(ppl_value, \"item\"):\n        ppl_value = ppl_value.item()  # Convert numpy or other array types\n    else:\n        ppl_value = float(ppl_value)  # Ensure it's a float\n\n    # Return only the perplexity value\n    result_queue.put(ppl_value)\n\n    # Clean up\n    del merged_model\n    del merged_tokenizer\n    del dataset_ppl\n    torch.cuda.empty_cache()\n    gc.collect()\n\n\n# Main execution code should be wrapped in this guard\nif __name__ == \"__main__\":\n    mp.set_start_method(\"spawn\", force = True)\n\n    if torch.cuda.is_bf16_supported():\n        compute_dtype = torch.bfloat16\n        attn_implementation = \"flash_attention_2\"\n    else:\n        compute_dtype = torch.float16\n        attn_implementation = \"sdpa\"\n\n    model, tokenizer = FastLanguageModel.from_pretrained(\n        model_name = \"unsloth/mistral-7b-v0.3\",\n        max_seq_length = 2048,\n        dtype = compute_dtype,\n        load_in_4bit = True,\n        load_in_8bit = False,\n        full_finetuning = False,\n        attn_implementation = attn_implementation,\n    )\n\n    EOS_TOKEN = tokenizer.eos_token\n\n    alpaca_prompt = \"\"\"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n    ### Instruction:\n    {}\n\n    ### Input:\n    {}\n\n    ### Response:\n    {}\"\"\"\n\n    # Define helper functions outside of main\n    def formatting_prompts_func(examples):\n        instructions = []\n        inputs = []\n        outputs = []\n        texts = []\n\n        for conversation in examples[\"messages\"]:\n            # Extract user message and assistant response\n            user_message = \"\"\n            assistant_message = \"\"\n\n            for turn in conversation:\n                if turn[\"role\"] == \"user\":\n                    user_message = turn[\"content\"]\n                elif turn[\"role\"] == \"assistant\":\n                    assistant_message = turn[\"content\"]\n\n            # Store intermediate format\n            instruction = \"Complete the statement\"\n            instructions.append(instruction)\n            inputs.append(user_message)\n            outputs.append(assistant_message)\n\n            # Create formatted text\n            text = (\n                alpaca_prompt.format(instruction, user_message, assistant_message)\n                + EOS_TOKEN\n            )\n            texts.append(text)\n\n        return {\n            \"instruction\": instructions,\n            \"input\": inputs,\n            \"output\": outputs,\n            \"text\": texts,\n        }\n\n    dataset_train = load_dataset(\n        \"allenai/openassistant-guanaco-reformatted\", split = \"train\"\n    )\n    dataset_ppl = load_dataset(\n        \"allenai/openassistant-guanaco-reformatted\", split = \"eval\"\n    )\n\n    dataset_train = dataset_train.map(formatting_prompts_func, batched = True)\n    dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)\n\n    add_to_comparison(\"Base model 4 bits\", ppl_model(model, tokenizer, dataset_ppl))\n\n    model = FastLanguageModel.get_peft_model(\n        model,\n        r = 16,\n        target_modules = [\n            \"k_proj\",\n            \"q_proj\",\n            \"v_proj\",\n            \"o_proj\",\n            \"gate_proj\",\n            \"down_proj\",\n            \"up_proj\",\n        ],\n        lora_alpha = 16,\n        lora_dropout = 0,\n        bias = \"none\",\n        use_gradient_checkpointing = \"unsloth\",\n        random_state = 3407,\n        use_rslora = False,\n        loftq_config = None,\n    )\n\n    from unsloth import is_bfloat16_supported\n\n    trainer = SFTTrainer(\n        model = model,\n        tokenizer = tokenizer,\n        train_dataset = dataset_train,\n        dataset_text_field = \"text\",\n        max_seq_length = 2048,\n        dataset_num_proc = 2,\n        packing = False,\n        args = TrainingArguments(\n            per_device_train_batch_size = 2,\n            gradient_accumulation_steps = 4,\n            warmup_ratio = 0.1,\n            max_steps = 200,\n            learning_rate = 2e-4,\n            fp16 = not is_bfloat16_supported(),\n            bf16 = is_bfloat16_supported(),\n            logging_steps = 50,\n            optim = \"adamw_8bit\",\n            lr_scheduler_type = \"linear\",\n            seed = 3407,\n            output_dir = \"outputs\",\n            report_to = \"none\",\n        ),\n    )\n\n    # run training\n    trainer_stats = trainer.train()\n\n    add_to_comparison(\"Qlora model\", ppl_model(model, tokenizer, dataset_ppl))\n\n    # saving and merging the model to local disk\n    print(\"merge and save to local disk\")\n    model.save_pretrained_merged(\n        save_directory = \"./unsloth_out/merged_mistral_text_model\", tokenizer = tokenizer\n    )\n\n    # print(\"cleaning\")\n    # del model\n    # del tokenizer\n    # torch.cuda.empty_cache()\n    # gc.collect()\n\n    # load model from local disk and test\n    print(\"Loading merged model in 4 bit for perplexity test\")\n    merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(\n        model_name = \"./unsloth_out/merged_mistral_text_model\",\n        max_seq_length = 2048,\n        load_in_4bit = True,\n        load_in_8bit = False,\n    )\n\n    add_to_comparison(\n        \"merged model load 4bit\", ppl_model(merged_model, merged_tokenizer, dataset_ppl)\n    )\n\n    print(\"Computing 8-bit model perplexity in subprocess...\")\n    result_queue = mp.Queue()\n    p = mp.Process(target = load_and_compute_8bit_ppl, args = (result_queue, False, True))\n    p.start()\n    p.join()\n\n    ppl_8bit = result_queue.get()\n    add_to_comparison(\"merged model loaded 8bits\", ppl_8bit)\n\n    print(\"Loading merged model in 16 bit for perplexity test\")\n    merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(\n        model_name = \"./unsloth_out/merged_mistral_text_model\",\n        max_seq_length = 2048,\n        load_in_4bit = False,\n        load_in_8bit = False,\n    )\n\n    add_to_comparison(\n        \"merged model loaded 16bits\",\n        ppl_model(merged_model, merged_tokenizer, dataset_ppl),\n    )\n\n    print_model_comparison()\n\n    safe_remove_directory(\"./outputs\")\n    safe_remove_directory(\"./unsloth_compiled_cache\")\n    safe_remove_directory(\"./unsloth_out\")\n"
  },
  {
    "path": "tests/saving/language_models/test_merge_model_perplexity_phi_4.py",
    "content": "from unsloth import FastLanguageModel, FastVisionModel, UnslothVisionDataCollator\nfrom unsloth.chat_templates import get_chat_template\nfrom trl import SFTTrainer, SFTConfig\nfrom transformers import (\n    DataCollatorForLanguageModeling,\n    DataCollatorForSeq2Seq,\n    TrainingArguments,\n)\nfrom datasets import load_dataset, Dataset\nimport torch\nfrom tqdm import tqdm\nimport pandas as pd\nimport multiprocessing as mp\nfrom multiprocessing import Process, Queue\nimport gc\n\n# ruff: noqa\nimport sys\nfrom pathlib import Path\n\n\nREPO_ROOT = Path(__file__).parents[3]\nsys.path.insert(0, str(REPO_ROOT))\n\n\nfrom tests.utils.cleanup_utils import safe_remove_directory\nfrom tests.utils.perplexity_eval import (\n    ppl_model,\n    add_to_comparison,\n    print_model_comparison,\n)\n\n\n# Define helper functions outside of main\ndef formatting_prompts_func(examples):\n    convos = examples[\"messages\"]\n    texts = [\n        tokenizer.apply_chat_template(\n            convo, tokenize = False, add_generation_prompt = False\n        )\n        for convo in convos\n    ]\n    return {\n        \"text\": texts,\n    }\n\n\ndef load_and_compute_8bit_ppl(result_queue, load_in_4bit = False, load_in_8bit = False):\n    \"\"\"Load model and compute perplexity in subprocess\"\"\"\n    from unsloth import FastLanguageModel\n    from unsloth.chat_templates import get_chat_template\n    from tests.utils.perplexity_eval import ppl_model\n\n    # Load model\n    merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(\n        model_name = \"./unsloth_out/merged_phi4_text_model\",\n        max_seq_length = 2048,\n        load_in_4bit = load_in_4bit,\n        load_in_8bit = load_in_8bit,\n    )\n    # Set up tokenizer\n    merged_tokenizer = get_chat_template(\n        merged_tokenizer,\n        chat_template = \"phi-4\",\n    )\n\n    # Load dataset fresh in subprocess\n    dataset_ppl = load_dataset(\n        \"allenai/openassistant-guanaco-reformatted\", split = \"eval\"\n    )\n\n    # Format the dataset\n    def formatting_prompts_func(examples):\n        convos = examples[\"messages\"]\n        texts = [\n            merged_tokenizer.apply_chat_template(\n                convo, tokenize = False, add_generation_prompt = False\n            )\n            for convo in convos\n        ]\n        return {\"text\": texts}\n\n    dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)\n\n    # Compute perplexity using the passed dataset\n    ppl_value = ppl_model(merged_model, merged_tokenizer, dataset_ppl)\n\n    # IMPORTANT: Convert to Python float if it's a tensor\n    if torch.is_tensor(ppl_value):\n        ppl_value = ppl_value.cpu().item()  # Move to CPU and convert to Python scalar\n    elif hasattr(ppl_value, \"item\"):\n        ppl_value = ppl_value.item()  # Convert numpy or other array types\n    else:\n        ppl_value = float(ppl_value)  # Ensure it's a float\n\n    # Return only the perplexity value\n    result_queue.put(ppl_value)\n\n    # Clean up\n    del merged_model\n    del merged_tokenizer\n    del dataset_ppl\n    torch.cuda.empty_cache()\n    gc.collect()\n\n\n# Main execution code should be wrapped in this guard\nif __name__ == \"__main__\":\n    mp.set_start_method(\"spawn\", force = True)\n\n    if torch.cuda.is_bf16_supported():\n        compute_dtype = torch.bfloat16\n        attn_implementation = \"flash_attention_2\"\n    else:\n        compute_dtype = torch.float16\n        attn_implementation = \"sdpa\"\n\n    model, tokenizer = FastLanguageModel.from_pretrained(\n        model_name = \"unsloth/Phi-4\",\n        max_seq_length = 2048,\n        dtype = compute_dtype,\n        load_in_4bit = True,\n        load_in_8bit = False,\n        full_finetuning = False,\n        attn_implementation = attn_implementation,\n    )\n\n    tokenizer = get_chat_template(\n        tokenizer,\n        chat_template = \"phi-4\",\n    )\n\n    dataset_train = load_dataset(\n        \"allenai/openassistant-guanaco-reformatted\", split = \"train\"\n    )\n    dataset_ppl = load_dataset(\n        \"allenai/openassistant-guanaco-reformatted\", split = \"eval\"\n    )\n\n    dataset_train = dataset_train.map(formatting_prompts_func, batched = True)\n    dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)\n\n    add_to_comparison(\"Base model 4 bits\", ppl_model(model, tokenizer, dataset_ppl))\n\n    model = FastLanguageModel.get_peft_model(\n        model,\n        r = 16,\n        target_modules = [\n            \"k_proj\",\n            \"q_proj\",\n            \"v_proj\",\n            \"o_proj\",\n            \"gate_proj\",\n            \"down_proj\",\n            \"up_proj\",\n        ],\n        lora_alpha = 16,\n        lora_dropout = 0,\n        bias = \"none\",\n        use_gradient_checkpointing = \"unsloth\",\n        random_state = 3407,\n        use_rslora = False,\n        loftq_config = None,\n    )\n\n    from unsloth import is_bfloat16_supported\n\n    trainer = SFTTrainer(\n        model = model,\n        tokenizer = tokenizer,\n        train_dataset = dataset_train,\n        dataset_text_field = \"text\",\n        max_seq_length = 2048,\n        data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),\n        dataset_num_proc = 2,\n        packing = False,\n        args = TrainingArguments(\n            per_device_train_batch_size = 2,\n            gradient_accumulation_steps = 4,\n            warmup_ratio = 0.1,\n            max_steps = 200,\n            learning_rate = 2e-4,\n            fp16 = not is_bfloat16_supported(),\n            bf16 = is_bfloat16_supported(),\n            logging_steps = 50,\n            optim = \"adamw_8bit\",\n            lr_scheduler_type = \"linear\",\n            seed = 3407,\n            output_dir = \"outputs\",\n            report_to = \"none\",\n        ),\n    )\n\n    from unsloth.chat_templates import train_on_responses_only\n\n    trainer = train_on_responses_only(\n        trainer,\n        instruction_part = \"<|im_start|>user<|im_sep|>\\n\\n\",\n        response_part = \"<|im_start|>assistant<|im_sep|>\\n\\n\",\n    )\n\n    # run training\n    trainer_stats = trainer.train()\n\n    add_to_comparison(\"Qlora model\", ppl_model(model, tokenizer, dataset_ppl))\n\n    # saving and merging the model to local disk\n    print(\"merge and save to local disk\")\n    model.save_pretrained_merged(\n        save_directory = \"./unsloth_out/merged_phi4_text_model\", tokenizer = tokenizer\n    )\n\n    # print(\"cleaning\")\n    # del model\n    # del tokenizer\n    # torch.cuda.empty_cache()\n    # gc.collect()\n\n    # load model from local disk and test\n    print(\"Loading merged model in 4 bit for perplexity test\")\n    merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(\n        model_name = \"./unsloth_out/merged_phi4_text_model\",\n        max_seq_length = 2048,\n        load_in_4bit = True,\n        load_in_8bit = False,\n    )\n\n    add_to_comparison(\n        \"merged model load 4bit\", ppl_model(merged_model, merged_tokenizer, dataset_ppl)\n    )\n\n    print(\"Computing 8-bit model perplexity in subprocess...\")\n    result_queue = mp.Queue()\n    p = mp.Process(target = load_and_compute_8bit_ppl, args = (result_queue, False, True))\n    p.start()\n    p.join()\n\n    ppl_8bit = result_queue.get()\n    add_to_comparison(\"merged model loaded 8bits\", ppl_8bit)\n\n    print(\"Loading merged model in 16 bit for perplexity test\")\n    merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(\n        model_name = \"./unsloth_out/merged_phi4_text_model\",\n        max_seq_length = 2048,\n        load_in_4bit = False,\n        load_in_8bit = False,\n    )\n\n    add_to_comparison(\n        \"merged model loaded 16bits\",\n        ppl_model(merged_model, merged_tokenizer, dataset_ppl),\n    )\n\n    print_model_comparison()\n\n    # final cleanup\n    safe_remove_directory(\"./outputs\")\n    safe_remove_directory(\"./unsloth_compiled_cache\")\n    safe_remove_directory(\"./unsloth_out\")\n"
  },
  {
    "path": "tests/saving/language_models/test_merged_model_perplexity_llama-3.1-8b.py",
    "content": "from unsloth import FastLanguageModel, FastVisionModel, UnslothVisionDataCollator\nfrom unsloth.chat_templates import get_chat_template\nfrom trl import SFTTrainer, SFTConfig\nfrom transformers import (\n    DataCollatorForLanguageModeling,\n    DataCollatorForSeq2Seq,\n    TrainingArguments,\n)\nfrom datasets import load_dataset, Dataset\nimport torch\nfrom tqdm import tqdm\nimport pandas as pd\nimport multiprocessing as mp\nfrom multiprocessing import Process, Queue\nimport gc\n\n# ruff: noqa\nimport sys\nfrom pathlib import Path\n\n\nREPO_ROOT = Path(__file__).parents[3]\nsys.path.insert(0, str(REPO_ROOT))\n\nfrom tests.utils.cleanup_utils import safe_remove_directory\nfrom tests.utils.perplexity_eval import (\n    ppl_model,\n    add_to_comparison,\n    print_model_comparison,\n)\n\n\n# Define helper functions outside of main\ndef formatting_prompts_func(examples):\n    convos = examples[\"messages\"]\n    texts = [\n        tokenizer.apply_chat_template(\n            convo, tokenize = False, add_generation_prompt = False\n        )\n        for convo in convos\n    ]\n    return {\"text\": texts}\n\n\ndef load_and_compute_8bit_ppl(result_queue, load_in_4bit = False, load_in_8bit = False):\n    \"\"\"Load model and compute perplexity in subprocess\"\"\"\n    from unsloth import FastLanguageModel\n    from unsloth.chat_templates import get_chat_template\n    from tests.utils.perplexity_eval import ppl_model\n\n    # Load model\n    merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(\n        model_name = \"./unsloth_out/merged_llama_text_model\",\n        max_seq_length = 2048,\n        load_in_4bit = load_in_4bit,\n        load_in_8bit = load_in_8bit,\n    )\n    # Set up tokenizer\n    merged_tokenizer = get_chat_template(\n        merged_tokenizer,\n        chat_template = \"llama-3.1\",\n    )\n\n    # Load dataset fresh in subprocess\n    dataset_ppl = load_dataset(\n        \"allenai/openassistant-guanaco-reformatted\", split = \"eval\"\n    )\n\n    # Format the dataset\n    def formatting_prompts_func(examples):\n        convos = examples[\"messages\"]\n        texts = [\n            merged_tokenizer.apply_chat_template(\n                convo, tokenize = False, add_generation_prompt = False\n            )\n            for convo in convos\n        ]\n        return {\"text\": texts}\n\n    dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)\n\n    # Compute perplexity using the passed dataset\n    ppl_value = ppl_model(merged_model, merged_tokenizer, dataset_ppl)\n\n    # IMPORTANT: Convert to Python float if it's a tensor\n    if torch.is_tensor(ppl_value):\n        ppl_value = ppl_value.cpu().item()  # Move to CPU and convert to Python scalar\n    elif hasattr(ppl_value, \"item\"):\n        ppl_value = ppl_value.item()  # Convert numpy or other array types\n    else:\n        ppl_value = float(ppl_value)  # Ensure it's a float\n\n    # Return only the perplexity value\n    result_queue.put(ppl_value)\n\n    # Clean up\n    del merged_model\n    del merged_tokenizer\n    del dataset_ppl\n    torch.cuda.empty_cache()\n    gc.collect()\n\n\n# Main execution code should be wrapped in this guard\nif __name__ == \"__main__\":\n    mp.set_start_method(\"spawn\", force = True)\n\n    if torch.cuda.is_bf16_supported():\n        compute_dtype = torch.bfloat16\n        attn_implementation = \"flash_attention_2\"\n    else:\n        compute_dtype = torch.float16\n        attn_implementation = \"sdpa\"\n\n    model, tokenizer = FastLanguageModel.from_pretrained(\n        model_name = \"unsloth/Llama-3.1-8B-Instruct\",\n        max_seq_length = 2048,\n        dtype = compute_dtype,\n        load_in_4bit = True,\n        load_in_8bit = False,\n        full_finetuning = False,\n        attn_implementation = attn_implementation,\n    )\n\n    tokenizer = get_chat_template(\n        tokenizer,\n        chat_template = \"llama-3.1\",\n    )\n\n    from unsloth.chat_templates import standardize_sharegpt\n\n    dataset_train = load_dataset(\n        \"allenai/openassistant-guanaco-reformatted\", split = \"train\"\n    )\n    dataset_ppl = load_dataset(\n        \"allenai/openassistant-guanaco-reformatted\", split = \"eval\"\n    )\n\n    dataset_train = dataset_train.map(formatting_prompts_func, batched = True)\n    dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)\n\n    print(\"\\n dataset sample [0]\")\n    print(dataset_train[0])\n\n    add_to_comparison(\"Base model 4 bits\", ppl_model(model, tokenizer, dataset_ppl))\n\n    model = FastLanguageModel.get_peft_model(\n        model,\n        r = 16,\n        target_modules = [\n            \"k_proj\",\n            \"q_proj\",\n            \"v_proj\",\n            \"o_proj\",\n            \"gate_proj\",\n            \"down_proj\",\n            \"up_proj\",\n        ],\n        lora_alpha = 16,\n        lora_dropout = 0,\n        bias = \"none\",\n        use_gradient_checkpointing = \"unsloth\",\n        random_state = 3407,\n        use_rslora = False,\n        loftq_config = None,\n    )\n\n    from unsloth import is_bfloat16_supported\n\n    trainer = SFTTrainer(\n        model = model,\n        tokenizer = tokenizer,\n        train_dataset = dataset_train,\n        dataset_text_field = \"text\",\n        max_seq_length = 2048,\n        data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),\n        dataset_num_proc = 2,\n        packing = False,\n        args = TrainingArguments(\n            per_device_train_batch_size = 2,\n            gradient_accumulation_steps = 4,\n            warmup_ratio = 0.1,\n            max_steps = 200,\n            learning_rate = 2e-4,\n            fp16 = not is_bfloat16_supported(),\n            bf16 = is_bfloat16_supported(),\n            logging_steps = 50,\n            optim = \"adamw_8bit\",\n            lr_scheduler_type = \"linear\",\n            seed = 3407,\n            output_dir = \"outputs\",\n            report_to = \"none\",\n        ),\n    )\n\n    from unsloth.chat_templates import train_on_responses_only\n\n    trainer = train_on_responses_only(\n        trainer,\n        instruction_part = \"<|start_header_id|>user<|end_header_id|>\\n\\n\",\n        response_part = \"<|start_header_id|>assistant<|end_header_id|>\\n\\n\",\n    )\n\n    tokenizer.decode(trainer.train_dataset[0][\"input_ids\"])\n\n    # run training\n    trainer_stats = trainer.train()\n\n    add_to_comparison(\"Qlora model\", ppl_model(model, tokenizer, dataset_ppl))\n\n    # saving and merging the model to local disk\n    print(\"merge and save to local disk\")\n    model.save_pretrained_merged(\n        save_directory = \"./unsloth_out/merged_llama_text_model\", tokenizer = tokenizer\n    )\n\n    # print(\"cleaning\")\n    # del model\n    # del tokenizer\n    # torch.cuda.empty_cache()\n    # gc.collect()\n\n    # load model from local disk and test\n    print(\"Loading merged model in 4 bit for perplexity test\")\n    merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(\n        model_name = \"./unsloth_out/merged_llama_text_model\",\n        max_seq_length = 2048,\n        load_in_4bit = True,\n        load_in_8bit = False,\n    )\n\n    add_to_comparison(\n        \"merged model load 4bit\", ppl_model(merged_model, merged_tokenizer, dataset_ppl)\n    )\n\n    print(\"Computing 8-bit model perplexity in subprocess...\")\n    result_queue = mp.Queue()\n    p = mp.Process(target = load_and_compute_8bit_ppl, args = (result_queue, False, True))\n    p.start()\n    p.join()\n\n    ppl_8bit = result_queue.get()\n    add_to_comparison(\"merged model loaded 8bits\", ppl_8bit)\n\n    print(\"Loading merged model in 16 bit for perplexity test\")\n    merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(\n        model_name = \"./unsloth_out/merged_llama_text_model\",\n        max_seq_length = 2048,\n        load_in_4bit = False,\n        load_in_8bit = False,\n    )\n\n    add_to_comparison(\n        \"merged model loaded 16bits\",\n        ppl_model(merged_model, merged_tokenizer, dataset_ppl),\n    )\n\n    print_model_comparison()\n\n    # final cleanup\n    safe_remove_directory(\"./outputs\")\n    safe_remove_directory(\"./unsloth_compiled_cache\")\n    safe_remove_directory(\"./unsloth_out\")\n"
  },
  {
    "path": "tests/saving/language_models/test_merged_model_perplexity_qwen_2.5.py",
    "content": "from unsloth import FastLanguageModel, FastVisionModel, UnslothVisionDataCollator\nfrom unsloth.chat_templates import get_chat_template\nfrom trl import SFTTrainer, SFTConfig\nfrom transformers import (\n    DataCollatorForLanguageModeling,\n    DataCollatorForSeq2Seq,\n    TrainingArguments,\n)\nfrom datasets import load_dataset, Dataset\nimport torch\nfrom tqdm import tqdm\nimport pandas as pd\nimport multiprocessing as mp\nfrom multiprocessing import Process, Queue\nimport gc\n\n# ruff: noqa\nimport sys\nfrom pathlib import Path\n\n\nREPO_ROOT = Path(__file__).parents[3]\nsys.path.insert(0, str(REPO_ROOT))\n\nfrom tests.utils.cleanup_utils import safe_remove_directory\nfrom tests.utils.perplexity_eval import (\n    ppl_model,\n    add_to_comparison,\n    print_model_comparison,\n)\n\n\nalpaca_prompt = \"\"\"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{}\n\n### Input:\n{}\n\n### Response:\n{}\"\"\"\n\n\n# Define helper functions outside of main\ndef formatting_prompts_func(examples):\n    instructions = []\n    inputs = []\n    outputs = []\n    texts = []\n\n    for conversation in examples[\"messages\"]:\n        # Extract user message and assistant response\n        user_message = \"\"\n        assistant_message = \"\"\n\n        for turn in conversation:\n            if turn[\"role\"] == \"user\":\n                user_message = turn[\"content\"]\n            elif turn[\"role\"] == \"assistant\":\n                assistant_message = turn[\"content\"]\n\n        # Store intermediate format\n        instruction = \"Complete the statement\"\n        instructions.append(instruction)\n        inputs.append(user_message)\n        outputs.append(assistant_message)\n\n        # Create formatted text\n        text = alpaca_prompt.format(instruction, user_message, assistant_message)\n        texts.append(text)\n\n    return {\n        \"instruction\": instructions,\n        \"input\": inputs,\n        \"output\": outputs,\n        \"text\": texts,\n    }\n\n\ndef load_and_compute_8bit_ppl(result_queue, load_in_4bit = False, load_in_8bit = False):\n    \"\"\"Load model and compute perplexity in subprocess\"\"\"\n    from unsloth import FastLanguageModel\n    from tests.utils.perplexity_eval import ppl_model\n\n    # Load model\n    merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(\n        model_name = \"./unsloth_out/merged_qwen_text_model\",\n        max_seq_length = 2048,\n        load_in_4bit = load_in_4bit,\n        load_in_8bit = load_in_8bit,\n    )\n    # Set up tokenizer\n    # merged_tokenizer = get_chat_template(\n    #     merged_tokenizer,\n    #     chat_template=\"llama-3.1\",\n    # )\n\n    # Load dataset fresh in subprocess\n    dataset_ppl = load_dataset(\n        \"allenai/openassistant-guanaco-reformatted\", split = \"eval\"\n    )\n\n    alpaca_prompt = \"\"\"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n    ### Instruction:\n    {}\n\n    ### Input:\n    {}\n\n    ### Response:\n    {}\"\"\"\n\n    def formatting_prompts_func(examples):\n        instructions = []\n        inputs = []\n        outputs = []\n        texts = []\n\n        for conversation in examples[\"messages\"]:\n            # Extract user message and assistant response\n            user_message = \"\"\n            assistant_message = \"\"\n\n            for turn in conversation:\n                if turn[\"role\"] == \"user\":\n                    user_message = turn[\"content\"]\n                elif turn[\"role\"] == \"assistant\":\n                    assistant_message = turn[\"content\"]\n\n            # Store intermediate format\n            instruction = \"Complete the statement\"\n            instructions.append(instruction)\n            inputs.append(user_message)\n            outputs.append(assistant_message)\n\n            # Create formatted text\n            text = alpaca_prompt.format(instruction, user_message, assistant_message)\n            texts.append(text)\n\n        return {\n            \"instruction\": instructions,\n            \"input\": inputs,\n            \"output\": outputs,\n            \"text\": texts,\n        }\n\n    dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)\n\n    # Compute perplexity using the passed dataset\n    ppl_value = ppl_model(merged_model, merged_tokenizer, dataset_ppl)\n\n    # IMPORTANT: Convert to Python float if it's a tensor\n    if torch.is_tensor(ppl_value):\n        ppl_value = ppl_value.cpu().item()  # Move to CPU and convert to Python scalar\n    elif hasattr(ppl_value, \"item\"):\n        ppl_value = ppl_value.item()  # Convert numpy or other array types\n    else:\n        ppl_value = float(ppl_value)  # Ensure it's a float\n\n    # Return only the perplexity value\n    result_queue.put(ppl_value)\n\n    # Clean up\n    # del merged_model\n    # del merged_tokenizer\n    # del dataset_ppl\n    # torch.cuda.empty_cache()\n    # gc.collect()\n\n\n# Main execution code should be wrapped in this guard\nif __name__ == \"__main__\":\n    mp.set_start_method(\"spawn\", force = True)\n\n    if torch.cuda.is_bf16_supported():\n        compute_dtype = torch.bfloat16\n        attn_implementation = \"flash_attention_2\"\n    else:\n        compute_dtype = torch.float16\n        attn_implementation = \"sdpa\"\n\n    model, tokenizer = FastLanguageModel.from_pretrained(\n        model_name = \"unsloth/Qwen2.5-7B-Instruct\",\n        max_seq_length = 2048,\n        dtype = compute_dtype,\n        load_in_4bit = True,\n        load_in_8bit = False,\n        full_finetuning = False,\n        attn_implementation = attn_implementation,\n    )\n\n    dataset_train = load_dataset(\n        \"allenai/openassistant-guanaco-reformatted\", split = \"train\"\n    )\n    dataset_ppl = load_dataset(\n        \"allenai/openassistant-guanaco-reformatted\", split = \"eval\"\n    )\n\n    dataset_train = dataset_train.map(formatting_prompts_func, batched = True)\n    dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)\n\n    add_to_comparison(\"Base model 4 bits\", ppl_model(model, tokenizer, dataset_ppl))\n\n    model = FastLanguageModel.get_peft_model(\n        model,\n        r = 16,\n        target_modules = [\n            \"k_proj\",\n            \"q_proj\",\n            \"v_proj\",\n            \"o_proj\",\n            \"gate_proj\",\n            \"down_proj\",\n            \"up_proj\",\n        ],\n        lora_alpha = 16,\n        lora_dropout = 0,\n        bias = \"none\",\n        use_gradient_checkpointing = \"unsloth\",\n        random_state = 3407,\n        use_rslora = False,\n        loftq_config = None,\n    )\n\n    from unsloth import is_bfloat16_supported\n\n    trainer = SFTTrainer(\n        model = model,\n        tokenizer = tokenizer,\n        train_dataset = dataset_train,\n        dataset_text_field = \"text\",\n        max_seq_length = 2048,\n        data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),\n        dataset_num_proc = 2,\n        packing = False,\n        args = TrainingArguments(\n            per_device_train_batch_size = 2,\n            gradient_accumulation_steps = 4,\n            warmup_ratio = 0.1,\n            max_steps = 200,\n            learning_rate = 2e-4,\n            fp16 = not is_bfloat16_supported(),\n            bf16 = is_bfloat16_supported(),\n            logging_steps = 50,\n            optim = \"adamw_8bit\",\n            lr_scheduler_type = \"linear\",\n            seed = 3407,\n            output_dir = \"outputs\",\n            report_to = \"none\",\n        ),\n    )\n\n    # run training\n    trainer_stats = trainer.train()\n\n    add_to_comparison(\"Qlora model\", ppl_model(model, tokenizer, dataset_ppl))\n\n    # saving and merging the model to local disk\n    print(\"merge and save to local disk\")\n    model.save_pretrained_merged(\n        save_directory = \"./unsloth_out/merged_qwen_text_model\", tokenizer = tokenizer\n    )\n\n    # print(\"cleaning\")\n    # del model\n    # del tokenizer\n    # torch.cuda.empty_cache()\n    # gc.collect()\n\n    # load model from local disk and test\n    print(\"Loading merged model in 4 bit for perplexity test\")\n    merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(\n        model_name = \"./unsloth_out/merged_qwen_text_model\",\n        max_seq_length = 2048,\n        load_in_4bit = True,\n        load_in_8bit = False,\n    )\n\n    add_to_comparison(\n        \"merged model load 4bit\", ppl_model(merged_model, merged_tokenizer, dataset_ppl)\n    )\n\n    print(\"Computing 8-bit model perplexity in subprocess...\")\n    result_queue = mp.Queue()\n    p = mp.Process(target = load_and_compute_8bit_ppl, args = (result_queue, False, True))\n    p.start()\n    p.join()\n\n    ppl_8bit = result_queue.get()\n    add_to_comparison(\"merged model loaded 8bits\", ppl_8bit)\n\n    print(\"Loading merged model in 16 bit for perplexity test\")\n    merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(\n        model_name = \"./unsloth_out/merged_qwen_text_model\",\n        max_seq_length = 2048,\n        load_in_4bit = False,\n        load_in_8bit = False,\n    )\n\n    add_to_comparison(\n        \"merged model loaded 16bits\",\n        ppl_model(merged_model, merged_tokenizer, dataset_ppl),\n    )\n\n    print_model_comparison()\n\n    safe_remove_directory(\"./outputs\")\n    safe_remove_directory(\"./unsloth_compiled_cache\")\n    safe_remove_directory(\"./unsloth_out\")\n"
  },
  {
    "path": "tests/saving/language_models/test_push_to_hub_merged.py",
    "content": "from unsloth import FastLanguageModel, FastVisionModel, UnslothVisionDataCollator\nfrom unsloth.chat_templates import get_chat_template\nfrom trl import SFTTrainer, SFTConfig\nfrom transformers import (\n    DataCollatorForLanguageModeling,\n    DataCollatorForSeq2Seq,\n    TrainingArguments,\n)\nfrom datasets import load_dataset, Dataset\nimport torch\nfrom tqdm import tqdm\nimport pandas as pd\nimport multiprocessing as mp\nfrom multiprocessing import Process, Queue\nimport gc\nimport os\nfrom huggingface_hub import HfFileSystem, hf_hub_download\n\n# ruff: noqa\nimport sys\nfrom pathlib import Path\n\n\nREPO_ROOT = Path(__file__).parents[3]\nsys.path.insert(0, str(REPO_ROOT))\n\nfrom tests.utils.cleanup_utils import safe_remove_directory\nfrom tests.utils.perplexity_eval import (\n    ppl_model,\n    add_to_comparison,\n    print_model_comparison,\n)\n\n\n# Define helper functions outside of main\ndef formatting_prompts_func(examples):\n    convos = examples[\"messages\"]\n    texts = [\n        tokenizer.apply_chat_template(\n            convo, tokenize = False, add_generation_prompt = False\n        )\n        for convo in convos\n    ]\n    return {\"text\": texts}\n\n\nif torch.cuda.is_bf16_supported():\n    compute_dtype = torch.bfloat16\n    attn_implementation = \"flash_attention_2\"\nelse:\n    compute_dtype = torch.float16\n    attn_implementation = \"sdpa\"\n\nmodel, tokenizer = FastLanguageModel.from_pretrained(\n    model_name = \"unsloth/Llama-3.2-1B-Instruct\",\n    max_seq_length = 2048,\n    dtype = compute_dtype,\n    load_in_4bit = True,\n    load_in_8bit = False,\n    full_finetuning = False,\n    attn_implementation = attn_implementation,\n)\n\ntokenizer = get_chat_template(\n    tokenizer,\n    chat_template = \"llama-3.1\",\n)\n\nfrom unsloth.chat_templates import standardize_sharegpt\n\ndataset_train = load_dataset(\"allenai/openassistant-guanaco-reformatted\", split = \"train\")\ndataset_ppl = load_dataset(\"allenai/openassistant-guanaco-reformatted\", split = \"eval\")\n\ndataset_train = dataset_train.map(formatting_prompts_func, batched = True)\ndataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)\n\nadd_to_comparison(\"Base model 4 bits\", ppl_model(model, tokenizer, dataset_ppl))\n\nmodel = FastLanguageModel.get_peft_model(\n    model,\n    r = 16,\n    target_modules = [\n        \"k_proj\",\n        \"q_proj\",\n        \"v_proj\",\n        \"o_proj\",\n        \"gate_proj\",\n        \"down_proj\",\n        \"up_proj\",\n    ],\n    lora_alpha = 16,\n    lora_dropout = 0,\n    bias = \"none\",\n    use_gradient_checkpointing = \"unsloth\",\n    random_state = 3407,\n    use_rslora = False,\n    loftq_config = None,\n)\n\nfrom unsloth import is_bfloat16_supported\n\ntrainer = SFTTrainer(\n    model = model,\n    tokenizer = tokenizer,\n    train_dataset = dataset_train,\n    dataset_text_field = \"text\",\n    max_seq_length = 2048,\n    data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),\n    dataset_num_proc = 2,\n    packing = False,\n    args = TrainingArguments(\n        per_device_train_batch_size = 2,\n        gradient_accumulation_steps = 4,\n        warmup_ratio = 0.1,\n        max_steps = 30,\n        learning_rate = 2e-4,\n        fp16 = not is_bfloat16_supported(),\n        bf16 = is_bfloat16_supported(),\n        logging_steps = 50,\n        optim = \"adamw_8bit\",\n        lr_scheduler_type = \"linear\",\n        seed = 3407,\n        output_dir = \"outputs\",\n        report_to = \"none\",\n    ),\n)\n\nfrom unsloth.chat_templates import train_on_responses_only\n\ntrainer = train_on_responses_only(\n    trainer,\n    instruction_part = \"<|start_header_id|>user<|end_header_id|>\\n\\n\",\n    response_part = \"<|start_header_id|>assistant<|end_header_id|>\\n\\n\",\n)\n\n# run training\ntrainer_stats = trainer.train()\n\n\n# saving and merging the model to local disk\nhf_username = os.environ.get(\"HF_USER\", \"\")\nif not hf_username:\n    hf_username = input(\"Please enter your Hugging Face username: \").strip()\n    os.environ[\"HF_USER\"] = hf_username\n\nhf_token = os.environ.get(\"HF_TOKEN\", \"\")\nif not hf_token:\n    hf_token = input(\"Please enter your Hugging Face token: \").strip()\n    os.environ[\"HF_TOKEN\"] = hf_token\n\n\nrepo_name = f\"{hf_username}/merged_llama_text_model\"\nsuccess = {\n    \"upload\": False,\n    \"download\": False,\n}\n\n# Stage 1: Upload model to Hub\ntry:\n    print(\"\\n\" + \"=\" * 80)\n    print(\"=== UPLOADING MODEL TO HUB ===\".center(80))\n    print(\"=\" * 80 + \"\\n\")\n    model.push_to_hub_merged(repo_name, tokenizer = tokenizer, token = hf_token)\n    success[\"upload\"] = True\n    print(\"✅ Model uploaded successfully!\")\nexcept Exception as e:\n    print(f\"❌ Failed to upload model: {e}\")\n    raise Exception(\"Model upload failed.\")\n\nt\n# Stage 2: Test downloading the model (even if cached)\nsafe_remove_directory(f\"./{hf_username}\")\n\ntry:\n    print(\"\\n\" + \"=\" * 80)\n    print(\"=== TESTING MODEL DOWNLOAD ===\".center(80))\n    print(\"=\" * 80 + \"\\n\")\n    # Force download even if cached\n    model, tokenizer = FastLanguageModel.from_pretrained(\n        f\"{hf_username}/merged_llama_text_model\"\n    )\n    success[\"download\"] = True\n    print(\"✅ Model downloaded successfully!\")\nexcept Exception as e:\n    print(f\"❌ Download failed: {e}\")\n    raise Exception(\"Model download failed.\")\n\n# Final report\nprint(\"\\n\" + \"=\" * 80)\nprint(\"=== VALIDATION REPORT ===\".center(80))\nprint(\"=\" * 80 + \"\\n\")\nfor stage, passed in success.items():\n    status = \"✓\" if passed else \"✗\"\n    print(f\"{status} {stage.replace('_', ' ').title()}\")\nprint(\"\\n\" + \"=\" * 80)\n\nif all(success.values()):\n    print(\"\\n🎉 All stages completed successfully!\")\nelse:\n    raise Exception(\"Validation failed for one or more stages.\")\n\n# final cleanup\nsafe_remove_directory(\"./outputs\")\nsafe_remove_directory(\"./unsloth_compiled_cache\")\n"
  },
  {
    "path": "tests/saving/language_models/test_push_to_hub_merged_sharded_index_file.py",
    "content": "from unsloth import FastLanguageModel, FastVisionModel, UnslothVisionDataCollator\nfrom unsloth.chat_templates import get_chat_template\nfrom trl import SFTTrainer, SFTConfig\nfrom transformers import (\n    DataCollatorForLanguageModeling,\n    DataCollatorForSeq2Seq,\n    TrainingArguments,\n)\nfrom datasets import load_dataset, Dataset\nimport torch\nfrom tqdm import tqdm\nimport pandas as pd\nimport multiprocessing as mp\nfrom multiprocessing import Process, Queue\nimport gc\nimport os\nfrom huggingface_hub import HfFileSystem, hf_hub_download\n\n# ruff: noqa\nimport sys\nfrom pathlib import Path\n\n\nREPO_ROOT = Path(__file__).parents[3]\nsys.path.insert(0, str(REPO_ROOT))\n\nfrom tests.utils.cleanup_utils import safe_remove_directory\nfrom tests.utils.perplexity_eval import (\n    ppl_model,\n    add_to_comparison,\n    print_model_comparison,\n)\n\n\n# Define helper functions outside of main\ndef formatting_prompts_func(examples):\n    convos = examples[\"messages\"]\n    texts = [\n        tokenizer.apply_chat_template(\n            convo, tokenize = False, add_generation_prompt = False\n        )\n        for convo in convos\n    ]\n    return {\"text\": texts}\n\n\nif torch.cuda.is_bf16_supported():\n    compute_dtype = torch.bfloat16\n    attn_implementation = \"flash_attention_2\"\nelse:\n    compute_dtype = torch.float16\n    attn_implementation = \"sdpa\"\n\nmodel, tokenizer = FastLanguageModel.from_pretrained(\n    model_name = \"unsloth/Llama-3.1-8B-Instruct\",\n    max_seq_length = 2048,\n    dtype = compute_dtype,\n    load_in_4bit = True,\n    load_in_8bit = False,\n    full_finetuning = False,\n    attn_implementation = attn_implementation,\n)\n\ntokenizer = get_chat_template(\n    tokenizer,\n    chat_template = \"llama-3.1\",\n)\n\nfrom unsloth.chat_templates import standardize_sharegpt\n\ndataset_train = load_dataset(\"allenai/openassistant-guanaco-reformatted\", split = \"train\")\ndataset_ppl = load_dataset(\"allenai/openassistant-guanaco-reformatted\", split = \"eval\")\n\ndataset_train = dataset_train.map(formatting_prompts_func, batched = True)\ndataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)\n\nadd_to_comparison(\"Base model 4 bits\", ppl_model(model, tokenizer, dataset_ppl))\n\nmodel = FastLanguageModel.get_peft_model(\n    model,\n    r = 16,\n    target_modules = [\n        \"k_proj\",\n        \"q_proj\",\n        \"v_proj\",\n        \"o_proj\",\n        \"gate_proj\",\n        \"down_proj\",\n        \"up_proj\",\n    ],\n    lora_alpha = 16,\n    lora_dropout = 0,\n    bias = \"none\",\n    use_gradient_checkpointing = \"unsloth\",\n    random_state = 3407,\n    use_rslora = False,\n    loftq_config = None,\n)\n\nfrom unsloth import is_bfloat16_supported\n\ntrainer = SFTTrainer(\n    model = model,\n    tokenizer = tokenizer,\n    train_dataset = dataset_train,\n    dataset_text_field = \"text\",\n    max_seq_length = 2048,\n    data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),\n    dataset_num_proc = 2,\n    packing = False,\n    args = TrainingArguments(\n        per_device_train_batch_size = 2,\n        gradient_accumulation_steps = 4,\n        warmup_ratio = 0.1,\n        max_steps = 30,\n        learning_rate = 2e-4,\n        fp16 = not is_bfloat16_supported(),\n        bf16 = is_bfloat16_supported(),\n        logging_steps = 50,\n        optim = \"adamw_8bit\",\n        lr_scheduler_type = \"linear\",\n        seed = 3407,\n        output_dir = \"outputs\",\n        report_to = \"none\",\n    ),\n)\n\nfrom unsloth.chat_templates import train_on_responses_only\n\ntrainer = train_on_responses_only(\n    trainer,\n    instruction_part = \"<|start_header_id|>user<|end_header_id|>\\n\\n\",\n    response_part = \"<|start_header_id|>assistant<|end_header_id|>\\n\\n\",\n)\n\n# run training\ntrainer_stats = trainer.train()\n\n\n# saving and merging the model to local disk\nhf_username = os.environ.get(\"HF_USER\", \"\")\nif not hf_username:\n    hf_username = input(\"Please enter your Hugging Face username: \").strip()\n    os.environ[\"HF_USER\"] = hf_username\n\nhf_token = os.environ.get(\"HF_TOKEN\", \"\")\nif not hf_token:\n    hf_token = input(\"Please enter your Hugging Face token: \").strip()\n    os.environ[\"HF_TOKEN\"] = hf_token\n\n\nrepo_name = f\"{hf_username}/merged_llama_text_model\"\nsuccess = {\n    \"upload\": False,\n    \"safetensors_check\": False,\n    \"download\": False,\n}\n\n# Stage 1: Upload model to Hub\ntry:\n    print(\"\\n\" + \"=\" * 80)\n    print(\"=== UPLOADING MODEL TO HUB ===\".center(80))\n    print(\"=\" * 80 + \"\\n\")\n    model.push_to_hub_merged(repo_name, tokenizer = tokenizer, token = hf_token)\n    success[\"upload\"] = True\n    print(\"✅ Model uploaded successfully!\")\nexcept Exception as e:\n    print(f\"❌ Failed to upload model: {e}\")\n    raise Exception(\"Model upload failed.\")\n\n# Stage 2: Verify safetensors.index.json exists\ntry:\n    print(\"\\n\" + \"=\" * 80)\n    print(\"=== VERIFYING REPO CONTENTS ===\".center(80))\n    print(\"=\" * 80 + \"\\n\")\n    fs = HfFileSystem(token = hf_token)\n    file_list = fs.ls(repo_name, detail = True)\n    safetensors_found = any(\n        file[\"name\"].endswith(\"model.safetensors.index.json\") for file in file_list\n    )\n    if safetensors_found:\n        success[\"safetensors_check\"] = True\n        print(\"✅ model.safetensors.index.json found in repo!\")\n    else:\n        raise Exception(\"model.safetensors.index.json not found in repo.\")\nexcept Exception as e:\n    print(f\"❌ Verification failed: {e}\")\n    raise Exception(\"Repo verification failed.\")\n\n# Stage 3: Test downloading the model (even if cached)\nsafe_remove_directory(\"./RTannous\")\n\ntry:\n    print(\"\\n\" + \"=\" * 80)\n    print(\"=== TESTING MODEL DOWNLOAD ===\".center(80))\n    print(\"=\" * 80 + \"\\n\")\n    # Force download even if cached\n    model, tokenizer = FastLanguageModel.from_pretrained(\n        f\"{hf_username}/merged_llama_text_model\"\n    )\n    success[\"download\"] = True\n    print(\"✅ Model downloaded successfully!\")\nexcept Exception as e:\n    print(f\"❌ Download failed: {e}\")\n    raise Exception(\"Model download failed.\")\n\n# Final report\nprint(\"\\n\" + \"=\" * 80)\nprint(\"=== VALIDATION REPORT ===\".center(80))\nprint(\"=\" * 80 + \"\\n\")\nfor stage, passed in success.items():\n    status = \"✓\" if passed else \"✗\"\n    print(f\"{status} {stage.replace('_', ' ').title()}\")\nprint(\"\\n\" + \"=\" * 80)\n\nif all(success.values()):\n    print(\"\\n🎉 All stages completed successfully!\")\nelse:\n    raise Exception(\"Validation failed for one or more stages.\")\n\n# final cleanup\nsafe_remove_directory(\"./outputs\")\nsafe_remove_directory(\"./unsloth_compiled_cache\")\n"
  },
  {
    "path": "tests/saving/language_models/test_save_merged_grpo_model.py",
    "content": "# -*- coding: utf-8 -*-\n\"\"\"test_Llama3_1_(3B)_GRPO_LoRA (1).ipynb\n\n### Unsloth\n\n\"\"\"\n\nfrom unsloth import FastLanguageModel\nimport torch\nimport sys\nfrom pathlib import Path\nimport multiprocessing as mp\nimport gc\nfrom multiprocessing import Queue\n\nREPO_ROOT = Path(__file__).parents[3]\nsys.path.insert(0, str(REPO_ROOT))\n\nfrom tests.utils.cleanup_utils import safe_remove_directory\nfrom tests.utils.aime_eval import evaluate_model_aime, compare_aime_results\n\n\nmax_seq_length = 2048  # Can increase for longer reasoning traces\nlora_rank = 64  # Larger rank = smarter, but slower\n\n\ndef evaluate_merged_model(result_queue, load_in_4bit = False, load_in_8bit = False):\n    from unsloth import FastLanguageModel\n    from tests.utils.aime_eval import evaluate_model_aime\n\n    max_seq_length = 2048  # Can increase for longer reasoning traces\n    lora_rank = 64  # Larger rank = smarter, but slower\n\n    model, tokenizer = FastLanguageModel.from_pretrained(\n        model_name = \"./final_merged_model\",\n        max_seq_length = max_seq_length,\n        load_in_4bit = True,  # False for LoRA 16bit\n        fast_inference = True,  # Enable vLLM fast inference\n        max_lora_rank = lora_rank,\n        gpu_memory_utilization = 0.8,  # Reduce if out of memory\n    )\n\n    print(f\"\\n{'='*60}\")\n    if load_in_4bit:\n        print(\"🔍 EVALUATION Merged model: 4 bits load\")\n        model_type = \"merged_model_4bits\"\n    elif load_in_8bit:\n        print(\"🔍 EVALUATION Merged model: 8 bits load\")\n        model_type = \"merged_model_8bits\"\n    else:\n        print(\"🔍 EVALUATION Merged model: 16 bits load\")\n        model_type = \"merged_model_16bits\"\n    print(f\"{'='*60}\")\n\n    evaluate_model_aime(\n        model = model,\n        tokenizer = tokenizer,\n        model_type = model_type,\n        temperature = 0.3,\n        n_sampling = 8,\n        max_tokens = 32768,\n        top_p = 0.95,\n        seed = 0,\n    )\n\n    result_queue.put(results)\n\n    del model\n    del tokenizer\n    torch.cuda.empty_cache()\n    gc.collect()\n\n\n# Main execution code should be wrapped in this guard\ndef training_run(result_queue):\n    model, tokenizer = FastLanguageModel.from_pretrained(\n        model_name = \"meta-llama/Llama-3.2-3B-Instruct\",\n        max_seq_length = max_seq_length,\n        load_in_4bit = False,  # False for LoRA 16bit\n        fast_inference = True,  # Enable vLLM fast inference\n        max_lora_rank = lora_rank,\n        gpu_memory_utilization = 0.8,  # Reduce if out of memory\n    )\n\n    \"\"\"### Helper Functions\n    <a name=\"Data\"></a>\n\n#### Helper functions - Data Prep\n    \"\"\"\n\n    import re\n    import json\n\n    reasoning_start = \"<reasoning>\"\n    reasoning_end = \"</reasoning>\"\n    solution_start = \"<answer>\"\n    solution_end = \"</answer>\"\n\n    def extract_hash_answer(text):\n        \"\"\"Extract answer from GSM8K format\"\"\"\n        if \"####\" not in text:\n            return None\n        return text.split(\"####\")[1].strip()\n\n    def prepare_gsm8k_dataset(dataset):\n        \"\"\"Format GSM8K dataset for training\"\"\"\n        reasoning_start = \"<reasoning>\"\n        reasoning_end = \"</reasoning>\"\n        solution_start = \"<answer>\"\n        solution_end = \"</answer>\"\n\n        system_prompt = (\n            f\"You are given a problem. Think about the problem and reason step by step. \"\n            f\"Place your thinking process between {reasoning_start} and {reasoning_end}. \"\n            f\"Then, provide your final numerical solution between {solution_start}{solution_end}\"\n        )\n\n        def format_gsm8k(example):\n            return {\n                \"prompt\": [\n                    {\"role\": \"system\", \"content\": system_prompt},\n                    {\"role\": \"user\", \"content\": example[\"question\"]},\n                ],\n                \"answer\": extract_hash_answer(example[\"answer\"]),\n            }\n\n        return dataset.map(format_gsm8k)\n\n    def prepare_limo_dataset(dataset):\n        \"\"\"Format LIMO dataset for SFT training\"\"\"\n        if dataset is None:\n            return None\n\n        system_prompt = \"\"\"You are a helpful reasoning assistant. When given a problem, think through it step by step and provide your answer in the following format:\n\n    <reasoning>\n    [Your detailed step-by-step reasoning and solution process]\n    </reasoning>\n    <answer>\n    [Your final numerical answer]\n    </answer>\"\"\"\n\n        def format_limo(example):\n            # Create the assistant response\n            assistant_response = f\"<reasoning>\\n{example['solution']}\\n</reasoning>\\n<answer>\\n{example['answer']}\\n</answer>\"\n\n            # Return a DICTIONARY with the conversation in a field\n            return {\n                \"prompt\": [  # ← This is the key change - wrap in a dict\n                    {\"role\": \"system\", \"content\": system_prompt},\n                    {\"role\": \"user\", \"content\": example[\"question\"]},\n                    {\"role\": \"assistant\", \"content\": assistant_response},\n                ]\n            }\n\n        return dataset.map(format_limo)\n\n    print(\"\\n✅ Dataset preparation functions defined!\")\n\n    \"\"\"#### Helper functions - Evaluation\"\"\"\n\n    def get_max_prompt_length(dataset, tokenizer):\n        \"\"\"Calculate maximum and average prompt length in dataset\"\"\"\n        print(\"Analyzing prompt lengths...\")\n\n        lengths = dataset.map(\n            lambda x: {\n                \"tokens\": tokenizer.apply_chat_template(\n                    x[\"prompt\"], add_generation_prompt = True, tokenize = True\n                )\n            },\n            batched = True,\n        ).map(lambda x: {\"length\": len(x[\"tokens\"])})[\"length\"]\n\n        max_length = max(lengths)\n        avg_length = sum(lengths) / len(lengths)\n        min_length = min(lengths)\n\n        print(\n            f\"Prompt lengths - Min: {min_length}, Max: {max_length}, Avg: {avg_length:.1f}\"\n        )\n        return max_length, avg_length\n\n    def extract_unsloth_answer(text, start_tag = \"<SOLUTION>\", end_tag = \"</SOLUTION>\"):\n        \"\"\"Extract answer from Unsloth SOLUTION tags\"\"\"\n        pattern = re.escape(start_tag) + r\"(.*?)\" + re.escape(end_tag)\n        matches = re.findall(pattern, text, re.DOTALL)\n\n        if matches:\n            answer = matches[-1]  # Get the last match\n            answer = re.sub(r\"[%$,]\", \"\", answer).strip()\n            return answer\n        return \"\"\n\n    def find_number(search_string):\n        \"\"\"Find the last number in a string\"\"\"\n        numbers = re.compile(\n            r\"-?[\\d,]*\\.?\\d+\",\n            re.MULTILINE | re.DOTALL | re.IGNORECASE,\n        ).findall(search_string)\n\n        if numbers:\n            return numbers[-1].replace(\",\", \"\").strip()\n        return \"\"\n\n    def remove_symbols(x: str) -> str:\n        \"\"\"Remove commas, percent and dollar symbols\"\"\"\n        if not x:\n            return \"\"\n        return x.replace(\",\", \"\").replace(\"%\", \"\").replace(\"$\", \"\").strip()\n\n    def get_num_tokens(text, tokenizer_instance):\n        \"\"\"Count tokens in text\"\"\"\n        if not text:\n            return 0\n        encoding = tokenizer_instance(text, return_tensors = \"pt\")\n        return len(encoding[\"input_ids\"][0])\n\n    def check_format_compliance(text, format_type = \"unsloth\"):\n        \"\"\"Check if response follows expected format\"\"\"\n        if format_type == \"unsloth\":\n            reasoning_start = \"<start_reasoning>\"\n            reasoning_end = \"<end_reasoning>\"\n            solution_start = \"<SOLUTION>\"\n            solution_end = \"</SOLUTION>\"\n\n            pattern = (\n                rf\"^[\\s]*{re.escape(reasoning_start)}.+?{re.escape(reasoning_end)}.*?\"\n                rf\"{re.escape(solution_start)}.+?{re.escape(solution_end)}[\\s]*$\"\n            )\n        else:\n            return False\n\n        return bool(re.match(pattern, text.strip(), re.DOTALL))\n\n    def normalize_answer(answer):\n        \"\"\"Normalize answer for comparison\"\"\"\n        if not answer:\n            return \"\"\n\n        normalized = remove_symbols(str(answer))\n\n        try:\n            float_val = float(normalized)\n            if float_val.is_integer():\n                return str(int(float_val))\n            else:\n                return str(float_val)\n        except (ValueError, TypeError):\n            return normalized\n\n    def evaluate_answer_correctness(extracted_answer, ground_truth):\n        \"\"\"Evaluate answer correctness with multiple criteria\"\"\"\n        if not extracted_answer or not ground_truth:\n            return False, False, 0.0\n\n        norm_extracted = normalize_answer(extracted_answer)\n        norm_ground_truth = normalize_answer(ground_truth)\n\n        if norm_extracted == norm_ground_truth:\n            return True, True, 1.0\n\n        try:\n            extracted_num = float(norm_extracted)\n            ground_truth_num = float(norm_ground_truth)\n\n            if ground_truth_num != 0:\n                relative_error = abs(extracted_num - ground_truth_num) / abs(\n                    ground_truth_num\n                )\n\n                if relative_error < 0.01:\n                    return True, True, 0.9\n                elif relative_error < 0.05:\n                    return False, True, 0.7\n                elif relative_error < 0.10:\n                    return False, True, 0.5\n            else:\n                if extracted_num == 0:\n                    return True, True, 1.0\n                elif abs(extracted_num) < 0.01:\n                    return False, True, 0.7\n\n        except (ValueError, TypeError):\n            if norm_extracted.lower() == norm_ground_truth.lower():\n                return True, True, 1.0\n\n        return False, False, 0.0\n\n    \"\"\"#### Reward Functions for GRPO\"\"\"\n\n    def match_format_exactly(completions, **kwargs):\n        \"\"\"Reward function for exact format matching\"\"\"\n        reasoning_start = \"<reasoning>\"\n        reasoning_end = \"</reasoning>\"\n        solution_start = \"<answer>\"\n        solution_end = \"</answer>\"\n\n        pattern = (\n            rf\"^[\\s]*{re.escape(reasoning_start)}.+?{re.escape(reasoning_end)}.*?\"\n            rf\"{re.escape(solution_start)}.+?{re.escape(solution_end)}[\\s]*$\"\n        )\n\n        responses = [completion[0][\"content\"] for completion in completions]\n        rewards = [\n            3.0 if re.match(pattern, response, re.DOTALL) else 0.0\n            for response in responses\n        ]\n        return rewards\n\n    def match_format_approximately(completions, **kwargs):\n        \"\"\"Reward function for approximate format matching\"\"\"\n        reasoning_start = \"<reasoning>\"\n        reasoning_end = \"</reasoning>\"\n        solution_start = \"<answerr>\"\n        solution_end = \"</answer>\"\n\n        scores = []\n        for completion in completions:\n            score = 0\n            response = completion[0][\"content\"]\n            score += 0.5 if response.count(reasoning_start) == 1 else -1.0\n            score += 0.5 if response.count(reasoning_end) == 1 else -1.0\n            score += 0.5 if response.count(solution_start) == 1 else -1.0\n            score += 0.5 if response.count(solution_end) == 1 else -1.0\n            scores.append(score)\n        return scores\n\n    def check_answer_correctness(prompts, completions, answer, **kwargs):\n        \"\"\"Reward function for answer correctness\"\"\"\n\n        def extract_solution_answer(text):\n            pattern = r\"<answer>(.*?)</answer>\"\n            match = re.search(pattern, text, re.DOTALL)\n            if match:\n                return re.sub(r\"[%$,]\", \"\", match.group(1)).strip()\n            return \"\"\n\n        responses = [completion[0][\"content\"] for completion in completions]\n        extracted_responses = [extract_solution_answer(r) for r in responses]\n\n        scores = []\n        for guess, true_answer in zip(extracted_responses, answer):\n            score = 0\n            if not guess:\n                scores.append(0)\n                continue\n\n            if guess == true_answer:\n                score += 3.0\n            elif guess.strip() == true_answer.strip():\n                score += 1.5\n            else:\n                try:\n                    ratio = float(guess) / float(true_answer)\n                    if 0.9 <= ratio <= 1.1:\n                        score += 1.0\n                    elif 0.8 <= ratio <= 1.2:\n                        score += 0.5\n                    else:\n                        score -= 1.5\n                except:\n                    score -= 1.5\n            scores.append(score)\n        return scores\n\n    print(\"✅ Reward functions defined!\")\n\n    \"\"\"#### Main Evaluation Function\"\"\"\n\n    import gc\n\n    \"\"\"#### Comparison and Memory Management\"\"\"\n\n    def compare_model_results(all_results):\n        \"\"\"Generate comprehensive comparison of multiple model results\"\"\"\n        print(f\"\\n{'='*80}\")\n        print(\"COMPREHENSIVE MODEL COMPARISON\")\n        print(f\"{'='*80}\")\n\n        # Main table\n        print(\n            f\"{'Model':<15} {'Format %':<10} {'Exact %':<10} {'Plausible %':<12} {'Confidence':<12}\"\n        )\n        print(\"-\" * 80)\n\n        for result in all_results:\n            print(\n                f\"{result['model_type']:<15} \"\n                f\"{result['correct_format_pct']:<10.1f} \"\n                f\"{result['exact_match_pct']:<10.1f} \"\n                f\"{result['plausible_match_pct']:<12.1f} \"\n                f\"{result['avg_confidence']:<12.3f}\"\n            )\n\n        # Improvement analysis\n        if len(all_results) > 1:\n            print(f\"\\n{'='*50}\")\n            print(\"IMPROVEMENT ANALYSIS\")\n            print(f\"{'='*50}\")\n\n            base_result = all_results[0]\n            for result in all_results[1:]:\n                print(f\"\\n{result['model_type']} vs {base_result['model_type']}:\")\n                format_improvement = (\n                    result[\"correct_format_pct\"] - base_result[\"correct_format_pct\"]\n                )\n                exact_improvement = (\n                    result[\"exact_match_pct\"] - base_result[\"exact_match_pct\"]\n                )\n                plausible_improvement = (\n                    result[\"plausible_match_pct\"] - base_result[\"plausible_match_pct\"]\n                )\n\n                print(f\"  Format compliance: {format_improvement:+.1f}%\")\n                print(f\"  Exact matches:     {exact_improvement:+.1f}%\")\n                print(f\"  Plausible matches: {plausible_improvement:+.1f}%\")\n\n        # Save comparison\n        comparison_data = {\n            \"summary\": all_results,\n            \"best_model\": max(all_results, key = lambda x: x[\"exact_match_pct\"]),\n        }\n\n        with open(\"model_comparison_comprehensive.json\", \"w\") as f:\n            json.dump(comparison_data, f, indent = 4)\n\n        print(\n            f\"\\nBest performing model: {comparison_data['best_model']['model_type']} \"\n            f\"({comparison_data['best_model']['exact_match_pct']:.1f}% exact matches)\"\n        )\n\n    def cleanup_memory():\n        \"\"\"Comprehensive memory cleanup\"\"\"\n        print(\"🧹 Cleaning up GPU memory...\")\n        for _ in range(10):\n            torch.cuda.empty_cache()\n            gc.collect()\n\n        if torch.cuda.is_available():\n            allocated = torch.cuda.memory_allocated() / 1024**3\n            reserved = torch.cuda.memory_reserved() / 1024**3\n            print(\n                f\"GPU memory - Allocated: {allocated:.2f} GB, Reserved: {reserved:.2f} GB\"\n            )\n\n    \"\"\"#### Data Loading and Preparation\"\"\"\n\n    from datasets import load_dataset\n\n    # Load GSM8K\n    gsm8k_dataset = load_dataset(\"openai/gsm8k\", \"main\", split = \"train\")\n\n    # Load LIMO (adjust this based on your access method)\n    limo_train = load_dataset(\"GAIR/LIMO\", split = \"train\")\n\n    # Prepare datasets\n    gsm8k_train = prepare_gsm8k_dataset(gsm8k_dataset)\n    limo_train = prepare_limo_dataset(limo_train)\n\n    print(f\"  GSM8K train: {len(gsm8k_train)}\")\n    print(f\"  LIMO train:  {len(limo_train) if limo_train else 0}\")\n\n    # Store results\n    all_results = []\n\n    # Single temperature evaluation on combined dataset\n    results = evaluate_model_aime(\n        model = model,\n        tokenizer = tokenizer,\n        model_type = \"base\",\n        temperature = 0.3,\n        n_sampling = 8,\n        max_tokens = 32768,\n        top_p = 0.95,\n        seed = 0,\n    )\n\n    from unsloth.chat_templates import get_chat_template\n\n    tokenizer = get_chat_template(\n        tokenizer,\n        chat_template = \"llama-3.1\",\n    )\n\n    def formatting_prompts_func(examples):\n        convos = examples[\"prompt\"]\n        texts = [\n            tokenizer.apply_chat_template(\n                convo, tokenize = False, add_generation_prompt = False\n            )\n            for convo in convos\n        ]\n        return {\n            \"text\": texts,\n        }\n\n    limo_train = limo_train.map(\n        formatting_prompts_func,\n        batched = True,\n    )\n\n    from trl import SFTTrainer\n    from transformers import DataCollatorForSeq2Seq, TrainingArguments\n    from unsloth import is_bfloat16_supported\n\n    print(f\"\\n{'*'*60}\")\n    print(\"🎯 STAGE 1: Qlora Fine-Tuning on LIMO\")\n    print(f\"{'*'*60}\")\n\n    model = FastLanguageModel.get_peft_model(\n        model,\n        r = lora_rank,  # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n        target_modules = [\n            \"q_proj\",\n            \"k_proj\",\n            \"v_proj\",\n            \"o_proj\",\n            \"gate_proj\",\n            \"up_proj\",\n            \"down_proj\",\n        ],  # Remove QKVO if out of memory\n        lora_alpha = lora_rank,\n        use_gradient_checkpointing = \"unsloth\",  # Enable long context finetuning\n        random_state = 3407,\n    )\n\n    if limo_train is not None:\n        trainer = SFTTrainer(\n            model = model,\n            tokenizer = tokenizer,\n            train_dataset = limo_train,\n            dataset_text_field = \"text\",\n            max_seq_length = max_seq_length,\n            data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),\n            dataset_num_proc = 2,\n            packing = False,  # Can make training 5x faster for short sequences.\n            args = TrainingArguments(\n                per_device_train_batch_size = 2,\n                gradient_accumulation_steps = 4,\n                warmup_steps = 5,\n                num_train_epochs = 1,  # Set this for 1 full training run.\n                # max_steps = 60,\n                learning_rate = 2e-4,\n                fp16 = not is_bfloat16_supported(),\n                bf16 = is_bfloat16_supported(),\n                logging_steps = 1,\n                optim = \"adamw_8bit\",\n                weight_decay = 0.01,\n                lr_scheduler_type = \"linear\",\n                seed = 3407,\n                output_dir = \"outputs\",\n                report_to = \"none\",  # Use this for WandB etc\n            ),\n        )\n\n        from unsloth.chat_templates import train_on_responses_only\n\n        trainer = train_on_responses_only(\n            trainer,\n            instruction_part = \"<|start_header_id|>user<|end_header_id|>\\n\\n\",\n            response_part = \"<|start_header_id|>assistant<|end_header_id|>\\n\\n\",\n        )\n\n        # Train\n        print(f\"🚂 Starting SFT training on {len(limo_train)} examples...\")\n        trainer.train()\n\n        # Save checkpoint\n        model.save_pretrained(\"qlora_checkpoint\")\n        tokenizer.save_pretrained(\"qlora_checkpoint\")\n        print(\"💾 Qlora checkpoint saved!\")\n\n        # Cleanup\n        del trainer\n        cleanup_memory()\n\n        print(\"✅ Qlora training completed!\")\n    else:\n        print(\"⚠️ Skipping Qlora training - no LIMO dataset available\")\n\n    # Cleanup\n    cleanup_memory()\n\n    global PRINTED_TIMES\n    PRINTED_TIMES = 0\n    global PRINT_EVERY_STEPS\n    PRINT_EVERY_STEPS = 5\n\n    match_numbers = re.compile(\n        solution_start + r\".*?([\\d\\.\\,]{1,})\", flags = re.MULTILINE | re.DOTALL\n    )\n\n    def check_numbers(prompts, completions, answer, **kwargs):\n        question = prompts[0][-1][\"content\"]\n        responses = [completion[0][\"content\"] for completion in completions]\n\n        extracted_responses = [\n            guess.group(1) if (guess := match_numbers.search(r)) is not None else None\n            for r in responses\n        ]\n\n        scores = []\n        # Print only every few steps\n        global PRINTED_TIMES\n        global PRINT_EVERY_STEPS\n        if PRINTED_TIMES % PRINT_EVERY_STEPS == 0:\n            print(\n                \"*\" * 20,\n                f\"Question:\\n{question}\",\n                f\"\\nAnswer:\\n{answer[0]}\",\n                f\"\\nResponse:\\n{responses[0]}\",\n                f\"\\nExtracted:\\n{extracted_responses[0]}\",\n            )\n        PRINTED_TIMES += 1\n\n        for guess, true_answer in zip(extracted_responses, answer):\n            if guess is None:\n                scores.append(0)\n                continue\n            # Convert to numbers\n            try:\n                true_answer = float(true_answer.strip())\n                # Remove commas like in 123,456\n                guess = float(guess.strip().replace(\",\", \"\"))\n                scores.append(1.5 if guess == true_answer else -0.5)\n            except:\n                scores.append(0)\n                continue\n        return scores\n\n    print(f\"\\n{'*'*60}\")\n    print(\"🎯 STAGE 2: GRPO Fine-Tuning on GSM8K\")\n    print(f\"{'*'*60}\")\n\n    # Get max prompt length\n    max_prompt_length, _ = get_max_prompt_length(gsm8k_train, tokenizer)\n    max_prompt_length = min(max_prompt_length + 10, 512)  # Add buffer, cap at 512\n\n    print(f\"Using max_prompt_length: {max_prompt_length}\")\n\n    from trl import GRPOConfig, GRPOTrainer\n\n    training_args = GRPOConfig(\n        learning_rate = 5e-6,\n        weight_decay = 0.1,\n        warmup_ratio = 0.1,\n        lr_scheduler_type = \"cosine\",\n        optim = \"adamw_torch_fused\",\n        logging_steps = 1,\n        per_device_train_batch_size = 1,\n        gradient_accumulation_steps = 4,  # Increase to 4 for smoother training\n        num_generations = 8,  # Decrease if out of memory\n        max_prompt_length = max_prompt_length,\n        max_completion_length = max_seq_length - max_prompt_length,\n        # num_train_epochs = 1, # Set to 1 for a full training run\n        # max_steps = 250,\n        max_steps = 1000,\n        save_steps = 250,\n        max_grad_norm = 0.1,\n        report_to = \"none\",  # Can use Weights & Biases\n        output_dir = \"outputs\",\n    )\n\n    trainer = GRPOTrainer(\n        model = model,\n        processing_class = tokenizer,\n        reward_funcs = [\n            match_format_exactly,\n            match_format_approximately,\n            check_answer_correctness,\n            check_numbers,\n        ],\n        args = training_args,\n        train_dataset = gsm8k_train,\n    )\n\n    # Train\n    print(f\"🚂 Starting GRPO training on {len(gsm8k_train)} examples...\")\n    trainer.train()\n\n    # Save checkpoint\n    model.save_pretrained(\"grpo_checkpoint\")\n    tokenizer.save_pretrained(\"grpo_checkpoint\")\n    print(\"💾 GRPO checkpoint saved!\")\n\n    # Cleanup\n    del trainer\n    del training_args\n    cleanup_memory()\n\n    print(\"✅ GRPO training completed!\")\n\n    print(f\"\\n{'='*60}\")\n    print(\"🔍 EVALUATION 3: Final GRPO Model\")\n    print(f\"{'='*60}\")\n\n    grpo_results = evaluate_model_aime(\n        model = model,\n        tokenizer = tokenizer,\n        model_type = \"grpo\",\n        temperature = 0.3,\n        n_sampling = 8,\n        max_tokens = 32768,\n        top_p = 0.95,\n        seed = 0,\n    )\n\n    all_results.append(grpo_results)\n    print(\"✅ Final model evaluation complete!\")\n\n    print(f\"\\n{'='*60}\")\n    print(\"💾 SAVING FINAL MODEL\")\n    print(f\"{'='*60}\")\n\n    # Save as merged model\n    try:\n        model.save_pretrained_merged(\n            \"final_merged_model\", tokenizer, save_method = \"merged_16bit\"\n        )\n        print(\"✅ Merged model saved to: final_merged_model/\")\n    except Exception as e:\n        print(f\"⚠️ Could not save merged model: {e}\")\n        print(\"Final model saved as LoRA adapter only\")\n\n    print(\"💾 Model saving complete!\")\n\n    safe_remove_directory(\"./unsloth_compiled_cache\")\n\n    result_queue.put(results)\n\n    # Clean up\n    del model\n    del tokenizer\n    torch.cuda.empty_cache()\n    gc.collect()\n\n    # # Merged model load 16 bits model AIME eval\n    # result_queue = mp.Queue()\n    # p = mp.Process(target=evaluate_merged_model, args=(result_queue, False, False))\n    # p.start()\n    # p.join()\n    #\n    # merged_16bits = result_queue.get()\n    # all_results.append(merged_16bits)\n    #\n    # # Clean up\n    # del merged_model\n    # del merged_tokenizer\n    # del dataset_ppl\n    # torch.cuda.empty_cache()\n    # gc.collect()\n    #\n    # safe_remove_directory(\"./unsloth_compiled_cache\")\n    #\n    # # Merged model load 8 bits model AIME eval\n    #\n    # result_queue = mp.Queue()\n    # p = mp.Process(target=evaluate_merged_model, args=(result_queue, False, True))\n    # p.start()\n    # p.join()\n    #\n    # merged_16bits = result_queue.get()\n    # all_results.append(merged_16bits)\n\n    # Merged model load 4 bits AIME eval\n    # result_queue = mp.Queue()\n    # p = mp.Process(target=evaluate_merged_model, args=(result_queue, True, False))\n    # p.start()\n    # p.join()\n    #\n    # merged_16bits = result_queue.get()\n    # all_results.append(merged_16bits)\n\n\nif __name__ == \"__main__\":\n    mp.set_start_method(\"spawn\", force = True)\n    result_queue = mp.Queue()\n    all_results = []\n\n    # run main finetuning and grpo loop\n    p = mp.Process(target = training_run, args = (result_queue,))\n    p.start()\n    p.join()\n\n    results = result_queue.get()\n    all_results = results\n\n    # evaluate merged model loaded 16bits\n    p = mp.Process(target = evaluate_merged_model, args = (result_queue, False, False))\n    p.start()\n    p.join()\n\n    merged_load_16bits = result_queue.get()\n    all_results.append(merged_load_16bits)\n    safe_remove_directory(\"./unsloth_compiled_cache\")\n\n    # Merged model load 8 bits model AIME eval\n    p = mp.Process(target = evaluate_merged_model, args = (result_queue, False, True))\n    p.start()\n    p.join()\n\n    merged_load_8bits = result_queue.get()\n    all_results.append(merged_load_8bits)\n\n    safe_remove_directory(\"./unsloth_compiled_cache\")\n\n    # Merged model load 4 bits model AIME eval\n    p = mp.Process(target = evaluate_merged_model, args = (result_queue, True, False))\n    p.start()\n    p.join()\n\n    merged_load_4bits = result_queue.get()\n    all_results.append(merged_load_4bits)\n\n    safe_remove_directory(\"./unsloth_compiled_cache\")\n\n    # AIME-specific comparison function\n\n    print(f\"\\n{'='*80}\")\n    print(\"🏆 FINAL TRAINING PIPELINE RESULTS\")\n    print(f\"{'='*80}\")\n\n    # Use the AIME-specific comparison\n    compare_aime_results(all_results)\n"
  },
  {
    "path": "tests/saving/non_peft/test_mistral_non_peft.py",
    "content": "from unsloth import FastLanguageModel\nfrom transformers import AutoModelForCausalLM\nfrom peft import PeftModel\nfrom pathlib import Path\nimport sys\nimport warnings\n\nREPO_ROOT = Path(__file__).parents[3]\nsys.path.insert(0, str(REPO_ROOT))\n\nfrom tests.utils.cleanup_utils import safe_remove_directory\n\n\nprint(f\"\\n{'='*80}\")\nprint(\"🔍 PHASE 1: Loading Base Model\")\nprint(f\"{'='*80}\")\n\nmodel, tokenizer = FastLanguageModel.from_pretrained(\n    model_name = \"unsloth/mistral-7b-v0.3\",\n    max_seq_length = 2048,\n    dtype = None,\n    load_in_4bit = True,\n    load_in_8bit = False,\n    full_finetuning = False,\n)\n\n\nprint(\"✅ Base model loaded successfully!\")\n\n### Attemtping save merge\n\n\nprint(f\"\\n{'='*80}\")\nprint(\"🔍 PHASE 2: Attempting save_pretrained_merged (Should Warn)\")\nprint(f\"{'='*80}\")\n\nwith warnings.catch_warnings(record = True) as w:\n    warnings.simplefilter(\"always\")\n    model.save_pretrained_merged(\"test_output\", tokenizer)\n\n    # Verify warning\n    assert len(w) >= 1, \"Expected warning but none raised\"\n    warning_msg = str(w[0].message)\n    expected_msg = \"Model is not a PeftModel (no Lora adapters detected). Skipping Merge. Please use save_pretrained() or push_to_hub() instead!\"\n    assert expected_msg in warning_msg, f\"Unexpected warning: {warning_msg}\"\n    assert expected_msg in warning_msg, f\"Unexpected warning: {warning_msg}\"\n\nprint(\"✅ Correct warning detected for non-PeftModel merge attempt!\")\n\n\nprint(f\"\\n{'='*80}\")\nprint(\"🔍 PHASE 3: Using save_pretrained (Should Succeed)\")\nprint(f\"{'='*80}\")\n\n\ntry:\n    with warnings.catch_warnings():\n        warnings.simplefilter(\"error\")  # Treat warnings as errors here\n        model.save_pretrained(\"test_output\")\n        print(\"✅ Standard save_pretrained completed successfully!\")\nexcept Exception as e:\n    assert False, f\"Phase 3 failed: {e}\"\n\nsafe_remove_directory(\"./test_output\")\nsafe_remove_directory(\"./unsloth_compiled_cache\")\n"
  },
  {
    "path": "tests/saving/non_peft/test_whisper_non_peft.py",
    "content": "from unsloth import FastLanguageModel, FastModel\nfrom transformers import AutoModelForCausalLM, WhisperForConditionalGeneration\nfrom peft import PeftModel\nfrom pathlib import Path\nimport sys\nimport warnings\n\nREPO_ROOT = Path(__file__).parents[3]\nsys.path.insert(0, str(REPO_ROOT))\n\nfrom tests.utils.cleanup_utils import safe_remove_directory\n\n\nprint(f\"\\n{'='*80}\")\nprint(\"🔍 PHASE 1: Loading Base Model\")\nprint(f\"{'='*80}\")\n\nmodel, tokenizer = FastModel.from_pretrained(\n    model_name = \"unsloth/whisper-large-v3\",\n    dtype = None,  # Leave as None for auto detection\n    load_in_4bit = False,  # Set to True to do 4bit quantization which reduces memory\n    auto_model = WhisperForConditionalGeneration,\n    whisper_language = \"English\",\n    whisper_task = \"transcribe\",\n    # token = \"hf_...\", # use one if using gated models like meta-llama/Llama-2-7b-hf\n)\n\nprint(\"✅ Base model loaded successfully!\")\n\n### Attemtping save merge\n\n\nprint(f\"\\n{'='*80}\")\nprint(\"🔍 PHASE 2: Attempting save_pretrained_merged (Should Warn)\")\nprint(f\"{'='*80}\")\n\nwith warnings.catch_warnings(record = True) as w:\n    warnings.simplefilter(\"always\")\n    model.save_pretrained_merged(\"test_output\", tokenizer)\n\n    # Verify warning\n    assert len(w) >= 1, \"Expected warning but none raised\"\n    warning_msg = str(w[0].message)\n    expected_msg = \"Model is not a PeftModel (no Lora adapters detected). Skipping Merge. Please use save_pretrained() or push_to_hub() instead!\"\n    assert expected_msg in warning_msg, f\"Unexpected warning: {warning_msg}\"\n    assert expected_msg in warning_msg, f\"Unexpected warning: {warning_msg}\"\n\nprint(\"✅ Correct warning detected for non-PeftModel merge attempt!\")\n\n\nprint(f\"\\n{'='*80}\")\nprint(\"🔍 PHASE 3: Using save_pretrained (Should Succeed)\")\nprint(f\"{'='*80}\")\n\n\ntry:\n    with warnings.catch_warnings():\n        warnings.simplefilter(\"error\")  # Treat warnings as errors here\n        model.save_pretrained(\"test_output\")\n        print(\"✅ Standard save_pretrained completed successfully!\")\nexcept Exception as e:\n    assert False, f\"Phase 3 failed: {e}\"\n\nsafe_remove_directory(\"./test_output\")\nsafe_remove_directory(\"./unsloth_compiled_cache\")\n"
  },
  {
    "path": "tests/saving/test_unsloth_save.py",
    "content": "import json\nimport os\nimport shutil\nimport tempfile\nimport pytest\nimport importlib\n\nfrom unsloth import FastLanguageModel, FastModel\n\nmodel_to_test = [\n    # Text Models\n    \"unsloth/tinyllama\",\n    \"unsloth/tinyllama-bnb-4bit\",\n    \"unsloth/Qwen2.5-0.5B-Instruct\",\n    \"unsloth/Qwen2.5-0.5B-Instruct-bnb-4bit\",\n    \"unsloth/Phi-4-mini-instruct\",\n    \"unsloth/Phi-4-mini-instruct-bnb-4bit\",\n    \"unsloth/Qwen2.5-0.5B\",\n    # Vision Models\n    \"unsloth/gemma-3-4b-it\",\n    \"unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit\",\n    \"unsloth/Qwen2.5-VL-3B-Instruct-bnb-4bit\",\n]\n\ntorchao_models = [\n    \"unsloth/tinyllama\",\n    \"unsloth/Qwen2.5-0.5B-Instruct\",\n    # \"unsloth/Phi-4-mini-instruct\",\n    # \"unsloth/Qwen2.5-0.5B\",\n    # Skip the -bnb-4bit variants since they're already quantized\n]\n\n\n# Variables\nsave_file_sizes = {}\nsave_file_sizes[\"merged_16bit\"] = {}\nsave_file_sizes[\"merged_4bit\"] = {}\nsave_file_sizes[\"torchao\"] = {}\n\ntokenizer_files = [\n    \"tokenizer_config.json\",\n    \"special_tokens_map.json\",\n]\n\n\n@pytest.fixture(scope = \"session\", params = model_to_test)\ndef loaded_model_tokenizer(request):\n    model_name = request.param\n    print(\"Loading model and tokenizer...\")\n\n    model, tokenizer = FastModel.from_pretrained(\n        model_name,  # use small model\n        max_seq_length = 128,\n        dtype = None,\n        load_in_4bit = True,\n    )\n\n    # Apply LoRA\n    model = FastModel.get_peft_model(\n        model,\n        r = 16,\n        target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\"],\n        lora_alpha = 16,\n        use_gradient_checkpointing = \"unsloth\",\n    )\n\n    return model, tokenizer\n\n\n@pytest.fixture(scope = \"session\", params = torchao_models)\ndef fp16_model_tokenizer(request):\n    \"\"\"Load model in FP16 for TorchAO quantization\"\"\"\n    model_name = request.param\n    print(f\"Loading model in FP16 for TorchAO: {model_name}\")\n\n    model, tokenizer = FastModel.from_pretrained(\n        model_name,\n        max_seq_length = 128,\n        dtype = None,\n        load_in_4bit = False,  # No BnB quantization\n    )\n\n    # Apply LoRA\n    model = FastModel.get_peft_model(\n        model,\n        r = 16,\n        target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\"],\n        lora_alpha = 16,\n        use_gradient_checkpointing = \"unsloth\",\n    )\n\n    return model, tokenizer\n\n\n@pytest.fixture(scope = \"session\")\ndef model(loaded_model_tokenizer):\n    return loaded_model_tokenizer[0]\n\n\n@pytest.fixture(scope = \"session\")\ndef tokenizer(loaded_model_tokenizer):\n    return loaded_model_tokenizer[1]\n\n\n@pytest.fixture\ndef temp_save_dir():\n    dir = tempfile.mkdtemp()\n    print(f\"Temporary directory created at: {dir}\")\n    yield dir\n    print(f\"Temporary directory deleted: {dir}\")\n    shutil.rmtree(dir)\n\n\ndef delete_quantization_config(model):\n    # Since merged, edit quantization_config\n    old_config = model.config\n    new_config = model.config.to_dict()\n    if \"quantization_config\" in new_config:\n        del new_config[\"quantization_config\"]\n    original_model = model\n    new_config = type(model.config).from_dict(new_config)\n    while hasattr(original_model, \"model\"):\n        original_model = original_model.model\n        original_model.config = new_config\n    model.config = new_config\n\n\ndef test_save_merged_16bit(model, tokenizer, temp_save_dir: str):\n    save_path = os.path.join(\n        temp_save_dir,\n        \"unsloth_merged_16bit\",\n        model.config._name_or_path.replace(\"/\", \"_\"),\n    )\n\n    model.save_pretrained_merged(\n        save_path, tokenizer = tokenizer, save_method = \"merged_16bit\"\n    )\n\n    # Check model files\n    assert os.path.isdir(save_path), f\"Directory {save_path} does not exist.\"\n    assert os.path.isfile(\n        os.path.join(save_path, \"config.json\")\n    ), \"config.json not found.\"\n\n    weight_files = [\n        f\n        for f in os.listdir(save_path)\n        if f.endswith(\".bin\") or f.endswith(\".safetensors\")\n    ]\n    assert len(weight_files) > 0, \"No weight files found in the save directory.\"\n\n    # Check tokenizer files\n    for file in tokenizer_files:\n        assert os.path.isfile(\n            os.path.join(save_path, file)\n        ), f\"{file} not found in the save directory.\"\n\n    # Check config to see if it is 16bit by checking for quantization config\n    config_path = os.path.join(save_path, \"config.json\")\n    with open(config_path, \"r\") as f:\n        config = json.load(f)\n\n    assert (\n        \"quantization_config\" not in config\n    ), \"Quantization config not found in the model config.\"\n\n    # Store the size of the model files\n    total_size = sum(os.path.getsize(os.path.join(save_path, f)) for f in weight_files)\n    save_file_sizes[\"merged_16bit\"][model.config._name_or_path] = total_size\n    print(f\"Total size of merged_16bit files: {total_size} bytes\")\n\n    # Test loading the model from the saved path\n    loaded_model, loaded_tokenizer = FastLanguageModel.from_pretrained(\n        save_path,\n        max_seq_length = 128,\n        dtype = None,\n        load_in_4bit = True,\n    )\n\n\ndef test_save_merged_4bit(model, tokenizer, temp_save_dir: str):\n    save_path = os.path.join(\n        temp_save_dir,\n        \"unsloth_merged_4bit\",\n        model.config._name_or_path.replace(\"/\", \"_\"),\n    )\n\n    model.save_pretrained_merged(\n        save_path, tokenizer = tokenizer, save_method = \"merged_4bit_forced\"\n    )\n\n    # Check model files\n    assert os.path.isdir(save_path), f\"Directory {save_path} does not exist.\"\n    assert os.path.isfile(\n        os.path.join(save_path, \"config.json\")\n    ), \"config.json not found.\"\n\n    weight_files = [\n        f\n        for f in os.listdir(save_path)\n        if f.endswith(\".bin\") or f.endswith(\".safetensors\")\n    ]\n    assert len(weight_files) > 0, \"No weight files found in the save directory.\"\n\n    # Check tokenizer files\n    for file in tokenizer_files:\n        assert os.path.isfile(\n            os.path.join(save_path, file)\n        ), f\"{file} not found in the save directory.\"\n\n    # Store the size of the model files\n    total_size = sum(os.path.getsize(os.path.join(save_path, f)) for f in weight_files)\n    save_file_sizes[\"merged_4bit\"][model.config._name_or_path] = total_size\n\n    print(f\"Total size of merged_4bit files: {total_size} bytes\")\n\n    assert (\n        total_size < save_file_sizes[\"merged_16bit\"][model.config._name_or_path]\n    ), \"Merged 4bit files are larger than merged 16bit files.\"\n\n    # Check config to see if it is 4bit\n    config_path = os.path.join(save_path, \"config.json\")\n    with open(config_path, \"r\") as f:\n        config = json.load(f)\n\n    assert (\n        \"quantization_config\" in config\n    ), \"Quantization config not found in the model config.\"\n\n    # Test loading the model from the saved path\n    loaded_model, loaded_tokenizer = FastModel.from_pretrained(\n        save_path,\n        max_seq_length = 128,\n        dtype = None,\n        load_in_4bit = True,\n    )\n\n\n@pytest.mark.skipif(\n    importlib.util.find_spec(\"torchao\") is None,\n    reason = \"require torchao to be installed\",\n)\ndef test_save_torchao(fp16_model_tokenizer, temp_save_dir: str):\n    model, tokenizer = fp16_model_tokenizer\n    save_path = os.path.join(\n        temp_save_dir, \"unsloth_torchao\", model.config._name_or_path.replace(\"/\", \"_\")\n    )\n\n    from torchao.quantization import Int8DynamicActivationInt8WeightConfig\n\n    torchao_config = Int8DynamicActivationInt8WeightConfig()\n    model.save_pretrained_torchao(\n        save_path,\n        tokenizer = tokenizer,\n        torchao_config = torchao_config,\n        push_to_hub = False,\n    )\n\n    weight_files_16bit = [\n        f\n        for f in os.listdir(save_path)\n        if f.endswith(\".bin\") or f.endswith(\".safetensors\")\n    ]\n    total_16bit_size = sum(\n        os.path.getsize(os.path.join(save_path, f)) for f in weight_files_16bit\n    )\n    save_file_sizes[\"merged_16bit\"][model.config._name_or_path] = total_16bit_size\n\n    torchao_save_path = save_path + \"-torchao\"\n\n    # Check model files\n    assert os.path.isdir(\n        torchao_save_path\n    ), f\"Directory {torchao_save_path} does not exist.\"\n    assert os.path.isfile(\n        os.path.join(torchao_save_path, \"config.json\")\n    ), \"config.json not found.\"\n\n    weight_files = [\n        f\n        for f in os.listdir(torchao_save_path)\n        if f.endswith(\".bin\") or f.endswith(\".safetensors\")\n    ]\n    assert len(weight_files) > 0, \"No weight files found in the save directory.\"\n\n    # Check tokenizer files\n    for file in tokenizer_files:\n        assert os.path.isfile(\n            os.path.join(torchao_save_path, file)\n        ), f\"{file} not found in the save directory.\"\n\n    # Store the size of the model files\n    total_size = sum(\n        os.path.getsize(os.path.join(torchao_save_path, f)) for f in weight_files\n    )\n    save_file_sizes[\"torchao\"][model.config._name_or_path] = total_size\n\n    assert (\n        total_size < save_file_sizes[\"merged_16bit\"][model.config._name_or_path]\n    ), \"torchao files are larger than merged 16bit files.\"\n\n    # Check config to see if it is quantized with torchao\n    config_path = os.path.join(torchao_save_path, \"config.json\")\n    with open(config_path, \"r\") as f:\n        config = json.load(f)\n\n    assert (\n        \"quantization_config\" in config\n    ), \"Quantization config not found in the model config.\"\n\n    # Test loading the model from the saved path\n    # can't set `load_in_4bit` to True because the model is torchao quantized\n    # can't quantize again with bitsandbytes\n    import torch.serialization\n\n    with torch.serialization.safe_globals([getattr]):\n        loaded_model, loaded_tokenizer = FastModel.from_pretrained(\n            torchao_save_path,\n            max_seq_length = 128,\n            dtype = None,\n            load_in_4bit = False,\n        )\n\n\n@pytest.mark.skipif(\n    importlib.util.find_spec(\"torchao\") is None,\n    reason = \"require torchao to be installed\",\n)\ndef test_save_and_inference_torchao(fp16_model_tokenizer, temp_save_dir: str):\n    model, tokenizer = fp16_model_tokenizer\n    model_name = model.config._name_or_path\n\n    print(f\"Testing TorchAO save and inference for: {model_name}\")\n\n    save_path = os.path.join(\n        temp_save_dir, \"torchao_models\", model_name.replace(\"/\", \"_\")\n    )\n\n    from torchao.quantization import Int8DynamicActivationInt8WeightConfig\n\n    torchao_config = Int8DynamicActivationInt8WeightConfig()\n\n    # Save with TorchAO\n    model.save_pretrained_torchao(\n        save_path,\n        tokenizer = tokenizer,\n        torchao_config = torchao_config,\n        push_to_hub = False,\n    )\n\n    torchao_save_path = save_path + \"-torchao\"\n\n    # Verify files exist\n    assert os.path.isdir(\n        torchao_save_path\n    ), f\"TorchAO directory {torchao_save_path} does not exist.\"\n\n    # Load with safe globals\n    import torch.serialization\n\n    with torch.serialization.safe_globals([getattr]):\n        loaded_model, loaded_tokenizer = FastModel.from_pretrained(\n            torchao_save_path,\n            max_seq_length = 128,\n            dtype = None,\n            load_in_4bit = False,\n        )\n\n    FastModel.for_inference(loaded_model)  # Enable native 2x faster inference\n\n    messages = [\n        {\n            \"role\": \"user\",\n            \"content\": \"Continue the fibonnaci sequence: 1, 1, 2, 3, 5, 8,\",\n        },\n    ]\n    inputs = loaded_tokenizer.apply_chat_template(\n        messages,\n        tokenize = True,\n        add_generation_prompt = True,  # Must add for generation\n        return_tensors = \"pt\",\n    ).to(\"cuda\")\n\n    outputs = loaded_model.generate(  # ← Use loaded_model, not model\n        input_ids = inputs,\n        max_new_tokens = 64,\n        use_cache = False,  # Avoid cache issues\n        temperature = 1.5,\n        min_p = 0.1,\n        do_sample = True,\n        pad_token_id = loaded_tokenizer.pad_token_id or loaded_tokenizer.eos_token_id,\n    )\n\n    # Decode with the LOADED tokenizer\n    generated_text = loaded_tokenizer.decode(outputs[0], skip_special_tokens = True)\n    input_text = loaded_tokenizer.decode(inputs[0], skip_special_tokens = True)\n    response_part = generated_text[len(input_text) :].strip()\n\n    print(f\"Input: {input_text}\")\n    print(f\"Full output: {generated_text}\")\n    print(f\"Response only: {response_part}\")\n"
  },
  {
    "path": "tests/saving/text_to_speech_models/test_csm.py",
    "content": "from unsloth import FastLanguageModel, FastModel\nfrom transformers import CsmForConditionalGeneration\nimport torch\n\n# ruff: noqa\nimport sys\nfrom pathlib import Path\nfrom peft import PeftModel\nimport warnings\nimport requests\n\nREPO_ROOT = Path(__file__).parents[3]\nsys.path.insert(0, str(REPO_ROOT))\n\nfrom tests.utils.cleanup_utils import safe_remove_directory\nfrom tests.utils.os_utils import require_package, require_python_package\n\nrequire_package(\"ffmpeg\", \"ffmpeg\")\nrequire_python_package(\"soundfile\")\n\nimport soundfile as sf\n\nprint(f\"\\n{'='*80}\")\nprint(\"🔍 SECTION 1: Loading Model and LoRA Adapters\")\nprint(f\"{'='*80}\")\n\n\nmodel, tokenizer = FastModel.from_pretrained(\n    model_name = \"unsloth/csm-1b\",\n    max_seq_length = 2048,  # Choose any for long context!\n    dtype = None,  # Leave as None for auto-detection\n    auto_model = CsmForConditionalGeneration,\n    load_in_4bit = False,  # Select True for 4bit - reduces memory usage\n)\n\n\nbase_model_class = model.__class__.__name__\n\n\nmodel = FastModel.get_peft_model(\n    model,\n    r = 32,  # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n    target_modules = [\n        \"q_proj\",\n        \"k_proj\",\n        \"v_proj\",\n        \"o_proj\",\n        \"gate_proj\",\n        \"up_proj\",\n        \"down_proj\",\n    ],\n    lora_alpha = 32,\n    lora_dropout = 0,  # Supports any, but = 0 is optimized\n    bias = \"none\",  # Supports any, but = \"none\" is optimized\n    # [NEW] \"unsloth\" uses 30% less VRAM, fits 2x larger batch sizes!\n    use_gradient_checkpointing = \"unsloth\",  # True or \"unsloth\" for very long context\n    random_state = 3407,\n    use_rslora = False,  # We support rank stabilized LoRA\n    loftq_config = None,  # And LoftQ\n)\n\nprint(\"✅ Model and LoRA adapters loaded successfully!\")\n\n\nprint(f\"\\n{'='*80}\")\nprint(\"🔍 SECTION 2: Checking Model Class Type\")\nprint(f\"{'='*80}\")\n\nassert isinstance(model, PeftModel), \"Model should be an instance of PeftModel\"\nprint(\"✅ Model is an instance of PeftModel!\")\n\n\nprint(f\"\\n{'='*80}\")\nprint(\"🔍 SECTION 3: Checking Config Model Class Type\")\nprint(f\"{'='*80}\")\n\n\ndef find_lora_base_model(model_to_inspect):\n    current = model_to_inspect\n    if hasattr(current, \"base_model\"):\n        current = current.base_model\n    if hasattr(current, \"model\"):\n        current = current.model\n    return current\n\n\nconfig_model = find_lora_base_model(model) if isinstance(model, PeftModel) else model\n\nassert (\n    config_model.__class__.__name__ == base_model_class\n), f\"Expected config_model class to be {base_model_class}\"\nprint(\"✅ config_model returns correct Base Model class:\", str(base_model_class))\n\n\nprint(f\"\\n{'='*80}\")\nprint(\"🔍 SECTION 4: Saving and Merging Model\")\nprint(f\"{'='*80}\")\n\nwith warnings.catch_warnings():\n    warnings.simplefilter(\"error\")  # Treat warnings as errors\n    try:\n        model.save_pretrained_merged(\"csm\", tokenizer)\n        print(\"✅ Model saved and merged successfully without warnings!\")\n    except Exception as e:\n        assert False, f\"Model saving/merging failed with exception: {e}\"\n\nprint(f\"\\n{'='*80}\")\nprint(\"🔍 SECTION 5: Loading Model for Inference\")\nprint(f\"{'='*80}\")\n\n\nmodel, processor = FastModel.from_pretrained(\n    model_name = \"./csm\",\n    max_seq_length = 2048,  # Choose any for long context!\n    dtype = None,  # Leave as None for auto-detection\n    auto_model = CsmForConditionalGeneration,\n    load_in_4bit = False,  # Select True for 4bit - reduces memory usage\n)\n\nfrom transformers import AutoProcessor\n\nprocessor = AutoProcessor.from_pretrained(\"unsloth/csm-1b\")\n\nprint(\"✅ Model loaded for inference successfully!\")\n\n\nprint(f\"\\n{'='*80}\")\nprint(\"🔍 SECTION 6: Running Inference\")\nprint(f\"{'='*80}\")\n\n\nfrom transformers import pipeline\nimport torch\n\noutput_audio_path = \"csm_audio.wav\"\ntry:\n    text = (\n        \"We just finished fine tuning a text to speech model... and it's pretty good!\"\n    )\n    speaker_id = 0\n    inputs = processor(f\"[{speaker_id}]{text}\", add_special_tokens = True).to(\"cuda\")\n    audio_values = model.generate(\n        **inputs,\n        max_new_tokens = 125,  # 125 tokens is 10 seconds of audio, for longer speech increase this\n        # play with these parameters to get the best results\n        depth_decoder_temperature = 0.6,\n        depth_decoder_top_k = 0,\n        depth_decoder_top_p = 0.9,\n        temperature = 0.8,\n        top_k = 50,\n        top_p = 1.0,\n        #########################################################\n        output_audio = True,\n    )\n    audio = audio_values[0].to(torch.float32).cpu().numpy()\n    sf.write(\"example_without_context.wav\", audio, 24000)\n    print(f\"✅ Audio generated and saved to {output_audio_path}!\")\nexcept Exception as e:\n    assert False, f\"Inference failed with exception: {e}\"\n\n\n## assert that transcribed_text contains The birch canoe slid on the smooth planks. Glued the sheet to the dark blue background. It's easy to tell the depth of a well. Four hours of steady work faced us.\n\nprint(\"✅ All sections passed successfully!\")\n\n\nsafe_remove_directory(\"./unsloth_compiled_cache\")\nsafe_remove_directory(\"./csm\")\n"
  },
  {
    "path": "tests/saving/text_to_speech_models/test_lasa.py",
    "content": "from unsloth import FastLanguageModel, FastModel\nfrom transformers import CsmForConditionalGeneration\nimport torch\n\n# ruff: noqa\nimport sys\nfrom pathlib import Path\nfrom peft import PeftModel\nimport warnings\nimport requests\n\n\nREPO_ROOT = Path(__file__).parents[3]\nsys.path.insert(0, str(REPO_ROOT))\n\n\nfrom tests.utils.cleanup_utils import safe_remove_directory\nfrom tests.utils.os_utils import require_package, require_python_package\n\nrequire_package(\"ffmpeg\", \"ffmpeg\")\nrequire_python_package(\"soundfile\")\nrequire_python_package(\"xcodec2\")\n\nimport soundfile as sf\nfrom xcodec2.modeling_xcodec2 import XCodec2Model\n\nXCODEC2_MODEL_NAME = \"HKUST-Audio/xcodec2\"\nSAMPLE_RATE = 16000\nDEVICE = \"cuda\"\n\ntry:\n    codec_model = XCodec2Model.from_pretrained(XCODEC2_MODEL_NAME)\n\nexcept Exception as e:\n    raise f\"ERROR loading XCodec2 model: {e}.\"\n\ncodec_model.to(\"cpu\")\n\nprint(f\"\\n{'='*80}\")\nprint(\"🔍 SECTION 1: Loading Model and LoRA Adapters\")\nprint(f\"{'='*80}\")\n\nmax_seq_length = 2048\nmodel, tokenizer = FastLanguageModel.from_pretrained(\n    model_name = \"unsloth/Llasa-1B\",\n    max_seq_length = max_seq_length,\n    dtype = None,  # Select None for auto detection\n    load_in_4bit = False,  # Choose True for 4bit which reduces memory\n    # token = \"hf_...\", # use one if using gated models like meta-llama/Llama-2-7b-hf\n)\n\nbase_model_class = model.__class__.__name__\n\n\nmodel = FastLanguageModel.get_peft_model(\n    model,\n    r = 128,  # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n    target_modules = [\"q_proj\", \"v_proj\"],\n    lora_alpha = 128,\n    lora_dropout = 0,  # Supports any, but = 0 is optimized\n    bias = \"none\",  # Supports any, but = \"none\" is optimized\n    # [NEW] \"unsloth\" uses 30% less VRAM, fits 2x larger batch sizes!\n    use_gradient_checkpointing = \"unsloth\",  # True or \"unsloth\" for very long context\n    random_state = 3407,\n    use_rslora = False,  # We support rank stabilized LoRA\n    loftq_config = None,  # And LoftQ\n)\n\nprint(\"✅ Model and LoRA adapters loaded successfully!\")\n\n\nprint(f\"\\n{'='*80}\")\nprint(\"🔍 SECTION 2: Checking Model Class Type\")\nprint(f\"{'='*80}\")\n\nassert isinstance(model, PeftModel), \"Model should be an instance of PeftModel\"\nprint(\"✅ Model is an instance of PeftModel!\")\n\n\nprint(f\"\\n{'='*80}\")\nprint(\"🔍 SECTION 3: Checking Config Model Class Type\")\nprint(f\"{'='*80}\")\n\n\ndef find_lora_base_model(model_to_inspect):\n    current = model_to_inspect\n    if hasattr(current, \"base_model\"):\n        current = current.base_model\n    if hasattr(current, \"model\"):\n        current = current.model\n    return current\n\n\nconfig_model = find_lora_base_model(model) if isinstance(model, PeftModel) else model\n\nassert (\n    config_model.__class__.__name__ == base_model_class\n), f\"Expected config_model class to be {base_model_class}\"\nprint(\"✅ config_model returns correct Base Model class:\", str(base_model_class))\n\n\nprint(f\"\\n{'='*80}\")\nprint(\"🔍 SECTION 4: Saving and Merging Model\")\nprint(f\"{'='*80}\")\n\nwith warnings.catch_warnings():\n    warnings.simplefilter(\"error\")  # Treat warnings as errors\n    try:\n        model.save_pretrained_merged(\"lasa\", tokenizer)\n        print(\"✅ Model saved and merged successfully without warnings!\")\n    except Exception as e:\n        assert False, f\"Model saving/merging failed with exception: {e}\"\n\nprint(f\"\\n{'='*80}\")\nprint(\"🔍 SECTION 5: Loading Model for Inference\")\nprint(f\"{'='*80}\")\n\n\nmodel, tokenizer = FastLanguageModel.from_pretrained(\n    model_name = \"./lasa\",\n    max_seq_length = max_seq_length,\n    dtype = None,  # Select None for auto detection\n    load_in_4bit = False,  # Choose True for 4bit which reduces memory\n    # token = \"hf_...\", # use one if using gated models like meta-llama/Llama-2-7b-hf\n)\n\n# from transformers import AutoProcessor\n# processor = AutoProcessor.from_pretrained(\"unsloth/csm-1b\")\n\nprint(\"✅ Model loaded for inference successfully!\")\n\n\nprint(f\"\\n{'='*80}\")\nprint(\"🔍 SECTION 6: Running Inference\")\nprint(f\"{'='*80}\")\n\n\nfrom transformers import pipeline\nimport torch\n\noutput_audio_path = \"lasa_audio.wav\"\ninput_text = \"Hey there my name is Elise, <giggles> and I'm a speech generation model that can sound like a person.\"\n\nFastLanguageModel.for_inference(model)\n\n\ndef ids_to_speech_tokens(speech_ids):\n    speech_tokens_str = []\n    for speech_id in speech_ids:\n        speech_tokens_str.append(f\"<|s_{speech_id}|>\")\n    return speech_tokens_str\n\n\ndef extract_speech_ids(speech_tokens_str):\n    speech_ids = []\n    for token_str in speech_tokens_str:\n        if token_str.startswith(\"<|s_\") and token_str.endswith(\"|>\"):\n            num_str = token_str[4:-2]\n\n            num = int(num_str)\n            speech_ids.append(num)\n        else:\n            print(f\"Unexpected token: {token_str}\")\n    return speech_ids\n\n\n# TTS start!\nwith torch.inference_mode():\n    with torch.amp.autocast(\"cuda\", dtype = model.dtype):\n        formatted_text = (\n            f\"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>\"\n        )\n\n        # Tokenize the text\n        chat = [\n            {\"role\": \"user\", \"content\": \"Convert the text to speech:\" + formatted_text},\n            {\"role\": \"assistant\", \"content\": \"<|SPEECH_GENERATION_START|>\"},\n        ]\n\n        input_ids = tokenizer.apply_chat_template(\n            chat, tokenize = True, return_tensors = \"pt\", continue_final_message = True\n        )\n        input_ids = input_ids.to(\"cuda\")\n\n        speech_end_id = tokenizer.convert_tokens_to_ids(\"<|SPEECH_GENERATION_END|>\")\n\n        # Generate the speech autoregressively\n        outputs = model.generate(\n            input_ids,\n            max_length = 2048,  # We trained our model with a max length of 2048\n            eos_token_id = speech_end_id,\n            do_sample = True,\n            top_p = 1.2,  #  Adjusts the diversity of generated content\n            temperature = 1.2,  #  Controls randomness in output\n        )\n    # Extract the speech tokens\n    generated_ids = outputs[0][input_ids.shape[1] : -1]\n\n    speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens = True)\n\n    # Convert  token <|s_23456|> to int 23456\n    speech_tokens = extract_speech_ids(speech_tokens)\n\n    speech_tokens = torch.tensor(speech_tokens).cpu().unsqueeze(0).unsqueeze(0)\n\n    # Decode the speech tokens to speech waveform\n    gen_wav = codec_model.decode_code(speech_tokens)\ntry:\n    sf.write(output_audio_path, gen_wav[0, 0, :].cpu().numpy(), 16000)\nexcept Exception as e:\n    assert False, f\"Inference failed with exception: {e}\"\n\n\n## assert that transcribed_text contains The birch canoe slid on the smooth planks. Glued the sheet to the dark blue background. It's easy to tell the depth of a well. Four hours of steady work faced us.\n\nprint(\"✅ All sections passed successfully!\")\n\n\nsafe_remove_directory(\"./unsloth_compiled_cache\")\nsafe_remove_directory(\"./lasa\")\n"
  },
  {
    "path": "tests/saving/text_to_speech_models/test_orpheus.py",
    "content": "from unsloth import FastLanguageModel, FastModel\nfrom transformers import CsmForConditionalGeneration\nimport torch\n\n# ruff: noqa\nimport sys\nfrom pathlib import Path\nfrom peft import PeftModel\nimport warnings\nimport requests\n\nREPO_ROOT = Path(__file__).parents[3]\nsys.path.insert(0, str(REPO_ROOT))\n\nfrom tests.utils.cleanup_utils import safe_remove_directory\nfrom tests.utils.os_utils import require_package, require_python_package\n\nrequire_package(\"ffmpeg\", \"ffmpeg\")\nrequire_python_package(\"soundfile\")\nrequire_python_package(\"snac\")\n\nimport soundfile as sf\nfrom snac import SNAC\n\nsnac_model = SNAC.from_pretrained(\"hubertsiuzdak/snac_24khz\")\nsnac_model = snac_model.to(\"cuda\")\nprint(f\"\\n{'='*80}\")\nprint(\"🔍 SECTION 1: Loading Model and LoRA Adapters\")\nprint(f\"{'='*80}\")\n\n\nmodel, tokenizer = FastLanguageModel.from_pretrained(\n    model_name = \"unsloth/orpheus-3b-0.1-ft\",\n    max_seq_length = 2048,  # Choose any for long context!\n    dtype = None,  # Select None for auto detection\n    load_in_4bit = False,  # Select True for 4bit which reduces memory usage\n    # token = \"hf_...\", # use one if using gated models like meta-llama/Llama-2-7b-hf\n)\n\nbase_model_class = model.__class__.__name__\n\n\nmodel = FastLanguageModel.get_peft_model(\n    model,\n    r = 64,  # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n    target_modules = [\n        \"q_proj\",\n        \"k_proj\",\n        \"v_proj\",\n        \"o_proj\",\n        \"gate_proj\",\n        \"up_proj\",\n        \"down_proj\",\n    ],\n    lora_alpha = 64,\n    lora_dropout = 0,  # Supports any, but = 0 is optimized\n    bias = \"none\",  # Supports any, but = \"none\" is optimized\n    # [NEW] \"unsloth\" uses 30% less VRAM, fits 2x larger batch sizes!\n    use_gradient_checkpointing = \"unsloth\",  # True or \"unsloth\" for very long context\n    random_state = 3407,\n    use_rslora = False,  # We support rank stabilized LoRA\n    loftq_config = None,  # And LoftQ\n)\nprint(\"✅ Model and LoRA adapters loaded successfully!\")\n\n\nprint(f\"\\n{'='*80}\")\nprint(\"🔍 SECTION 2: Checking Model Class Type\")\nprint(f\"{'='*80}\")\n\nassert isinstance(model, PeftModel), \"Model should be an instance of PeftModel\"\nprint(\"✅ Model is an instance of PeftModel!\")\n\n\nprint(f\"\\n{'='*80}\")\nprint(\"🔍 SECTION 3: Checking Config Model Class Type\")\nprint(f\"{'='*80}\")\n\n\ndef find_lora_base_model(model_to_inspect):\n    current = model_to_inspect\n    if hasattr(current, \"base_model\"):\n        current = current.base_model\n    if hasattr(current, \"model\"):\n        current = current.model\n    return current\n\n\nconfig_model = find_lora_base_model(model) if isinstance(model, PeftModel) else model\n\nassert (\n    config_model.__class__.__name__ == base_model_class\n), f\"Expected config_model class to be {base_model_class}\"\nprint(\"✅ config_model returns correct Base Model class:\", str(base_model_class))\n\n\nprint(f\"\\n{'='*80}\")\nprint(\"🔍 SECTION 4: Saving and Merging Model\")\nprint(f\"{'='*80}\")\n\nwith warnings.catch_warnings():\n    warnings.simplefilter(\"error\")  # Treat warnings as errors\n    try:\n        model.save_pretrained_merged(\"orpheus\", tokenizer)\n        print(\"✅ Model saved and merged successfully without warnings!\")\n    except Exception as e:\n        assert False, f\"Model saving/merging failed with exception: {e}\"\n\nprint(f\"\\n{'='*80}\")\nprint(\"🔍 SECTION 5: Loading Model for Inference\")\nprint(f\"{'='*80}\")\n\n\nmodel, tokenizer = FastLanguageModel.from_pretrained(\n    model_name = \"unsloth/orpheus-3b-0.1-ft\",\n    max_seq_length = 2048,  # Choose any for long context!\n    dtype = None,  # Select None for auto detection\n    load_in_4bit = False,  # Select True for 4bit which reduces memory usage\n    # token = \"hf_...\", # use one if using gated models like meta-llama/Llama-2-7b-hf\n)\n\n# from transformers import AutoProcessor\n# processor = AutoProcessor.from_pretrained(\"unsloth/csm-1b\")\n\nprint(\"✅ Model loaded for inference successfully!\")\n\n\nprint(f\"\\n{'='*80}\")\nprint(\"🔍 SECTION 6: Running Inference\")\nprint(f\"{'='*80}\")\n\n\n# @title Run Inference\n\n\nFastLanguageModel.for_inference(model)  # Enable native 2x faster inference\n\n# Moving snac_model cuda to cpu\nsnac_model.to(\"cpu\")\nprompts = [\n    \"Hey there my name is Elise, <giggles> and I'm a speech generation model that can sound like a person.\",\n]\n\nchosen_voice = None  # None for single-speaker\n\nprompts_ = [(f\"{chosen_voice}: \" + p) if chosen_voice else p for p in prompts]\n\nall_input_ids = []\n\nfor prompt in prompts_:\n    input_ids = tokenizer(prompt, return_tensors = \"pt\").input_ids\n    all_input_ids.append(input_ids)\n\nstart_token = torch.tensor([[128259]], dtype = torch.int64)  # Start of human\nend_tokens = torch.tensor(\n    [[128009, 128260]], dtype = torch.int64\n)  # End of text, End of human\n\nall_modified_input_ids = []\nfor input_ids in all_input_ids:\n    modified_input_ids = torch.cat(\n        [start_token, input_ids, end_tokens], dim = 1\n    )  # SOH SOT Text EOT EOH\n    all_modified_input_ids.append(modified_input_ids)\n\nall_padded_tensors = []\nall_attention_masks = []\nmax_length = max(\n    [modified_input_ids.shape[1] for modified_input_ids in all_modified_input_ids]\n)\nfor modified_input_ids in all_modified_input_ids:\n    padding = max_length - modified_input_ids.shape[1]\n    padded_tensor = torch.cat(\n        [torch.full((1, padding), 128263, dtype = torch.int64), modified_input_ids], dim = 1\n    )\n    attention_mask = torch.cat(\n        [\n            torch.zeros((1, padding), dtype = torch.int64),\n            torch.ones((1, modified_input_ids.shape[1]), dtype = torch.int64),\n        ],\n        dim = 1,\n    )\n    all_padded_tensors.append(padded_tensor)\n    all_attention_masks.append(attention_mask)\n\nall_padded_tensors = torch.cat(all_padded_tensors, dim = 0)\nall_attention_masks = torch.cat(all_attention_masks, dim = 0)\n\ninput_ids = all_padded_tensors.to(\"cuda\")\nattention_mask = all_attention_masks.to(\"cuda\")\ngenerated_ids = model.generate(\n    input_ids = input_ids,\n    attention_mask = attention_mask,\n    max_new_tokens = 1200,\n    do_sample = True,\n    temperature = 0.6,\n    top_p = 0.95,\n    repetition_penalty = 1.1,\n    num_return_sequences = 1,\n    eos_token_id = 128258,\n    use_cache = True,\n)\ntoken_to_find = 128257\ntoken_to_remove = 128258\n\ntoken_indices = (generated_ids == token_to_find).nonzero(as_tuple = True)\n\nif len(token_indices[1]) > 0:\n    last_occurrence_idx = token_indices[1][-1].item()\n    cropped_tensor = generated_ids[:, last_occurrence_idx + 1 :]\nelse:\n    cropped_tensor = generated_ids\n\nmask = cropped_tensor != token_to_remove\n\nprocessed_rows = []\n\nfor row in cropped_tensor:\n    masked_row = row[row != token_to_remove]\n    processed_rows.append(masked_row)\n\ncode_lists = []\n\nfor row in processed_rows:\n    row_length = row.size(0)\n    new_length = (row_length // 7) * 7\n    trimmed_row = row[:new_length]\n    trimmed_row = [t - 128266 for t in trimmed_row]\n    code_lists.append(trimmed_row)\n\n\ndef redistribute_codes(code_list):\n    layer_1 = []\n    layer_2 = []\n    layer_3 = []\n    for i in range((len(code_list) + 1) // 7):\n        layer_1.append(code_list[7 * i])\n        layer_2.append(code_list[7 * i + 1] - 4096)\n        layer_3.append(code_list[7 * i + 2] - (2 * 4096))\n        layer_3.append(code_list[7 * i + 3] - (3 * 4096))\n        layer_2.append(code_list[7 * i + 4] - (4 * 4096))\n        layer_3.append(code_list[7 * i + 5] - (5 * 4096))\n        layer_3.append(code_list[7 * i + 6] - (6 * 4096))\n    codes = [\n        torch.tensor(layer_1).unsqueeze(0),\n        torch.tensor(layer_2).unsqueeze(0),\n        torch.tensor(layer_3).unsqueeze(0),\n    ]\n\n    # codes = [c.to(\"cuda\") for c in codes]\n    audio_hat = snac_model.decode(codes)\n    return audio_hat\n\n\nmy_samples = []\nfor code_list in code_lists:\n    samples = redistribute_codes(code_list)\n    my_samples.append(samples)\noutput_path = \"orpheus_audio.wav\"\ntry:\n    for i, samples in enumerate(my_samples):\n        audio_data = samples.detach().squeeze().cpu().numpy()\n        import soundfile as sf\n\n        sf.write(output_path, audio_data, 24000)  # Explicitly pass sample rate\n        print(f\"✅ Audio saved to {output_path}!\")\nexcept Exception as e:\n    assert False, f\"Inference failed with exception: {e}\"\n\n# Verify the file exists\nimport os\n\nassert os.path.exists(output_path), f\"Audio file not found at {output_path}\"\nprint(\"✅ Audio file exists on disk!\")\ndel my_samples, samples\n## assert that transcribed_text contains The birch canoe slid on the smooth planks. Glued the sheet to the dark blue background. It's easy to tell the depth of a well. Four hours of steady work faced us.\n\nprint(\"✅ All sections passed successfully!\")\n\n\nsafe_remove_directory(\"./unsloth_compiled_cache\")\nsafe_remove_directory(\"./orpheus\")\n"
  },
  {
    "path": "tests/saving/text_to_speech_models/test_whisper.py",
    "content": "from unsloth import FastLanguageModel, FastModel\nfrom transformers import WhisperForConditionalGeneration, WhisperProcessor\nimport torch\n\n# ruff: noqa\nimport sys\nfrom pathlib import Path\nfrom peft import PeftModel\nimport warnings\nimport requests\n\n\nREPO_ROOT = Path(__file__).parents[3]\nsys.path.insert(0, str(REPO_ROOT))\n\n\nfrom tests.utils.cleanup_utils import safe_remove_directory\nfrom tests.utils.os_utils import require_package, require_python_package\n\nrequire_package(\"ffmpeg\", \"ffmpeg\")\nrequire_python_package(\"soundfile\")\n\nimport soundfile as sf\n\nprint(f\"\\n{'='*80}\")\nprint(\"🔍 SECTION 1: Loading Model and LoRA Adapters\")\nprint(f\"{'='*80}\")\n\n\nmodel, tokenizer = FastModel.from_pretrained(\n    model_name = \"unsloth/whisper-large-v3\",\n    dtype = None,  # Leave as None for auto detection\n    load_in_4bit = False,  # Set to True to do 4bit quantization which reduces memory\n    auto_model = WhisperForConditionalGeneration,\n    whisper_language = \"English\",\n    whisper_task = \"transcribe\",\n    # token = \"hf_...\", # use one if using gated models like meta-llama/Llama-2-7b-hf\n)\n\n\nbase_model_class = model.__class__.__name__\n# https://github.com/huggingface/transformers/issues/37172\nmodel.generation_config.input_ids = model.generation_config.forced_decoder_ids\nmodel.generation_config.forced_decoder_ids = None\n\n\nmodel = FastModel.get_peft_model(\n    model,\n    r = 64,  # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n    target_modules = [\"q_proj\", \"v_proj\"],\n    lora_alpha = 64,\n    lora_dropout = 0,  # Supports any, but = 0 is optimized\n    bias = \"none\",  # Supports any, but = \"none\" is optimized\n    # [NEW] \"unsloth\" uses 30% less VRAM, fits 2x larger batch sizes!\n    use_gradient_checkpointing = \"unsloth\",  # True or \"unsloth\" for very long context\n    random_state = 3407,\n    use_rslora = False,  # We support rank stabilized LoRA\n    loftq_config = None,  # And LoftQ\n    task_type = None,  # ** MUST set this for Whisper **\n)\n\nprint(\"✅ Model and LoRA adapters loaded successfully!\")\n\n\nprint(f\"\\n{'='*80}\")\nprint(\"🔍 SECTION 2: Checking Model Class Type\")\nprint(f\"{'='*80}\")\n\nassert isinstance(model, PeftModel), \"Model should be an instance of PeftModel\"\nprint(\"✅ Model is an instance of PeftModel!\")\n\n\nprint(f\"\\n{'='*80}\")\nprint(\"🔍 SECTION 3: Checking Config Model Class Type\")\nprint(f\"{'='*80}\")\n\n\ndef find_lora_base_model(model_to_inspect):\n    current = model_to_inspect\n    if hasattr(current, \"base_model\"):\n        current = current.base_model\n    if hasattr(current, \"model\"):\n        current = current.model\n    return current\n\n\nconfig_model = find_lora_base_model(model) if isinstance(model, PeftModel) else model\n\nassert (\n    config_model.__class__.__name__ == base_model_class\n), f\"Expected config_model class to be {base_model_class}\"\nprint(\"✅ config_model returns correct Base Model class:\", str(base_model_class))\n\n\nprint(f\"\\n{'='*80}\")\nprint(\"🔍 SECTION 4: Saving and Merging Model\")\nprint(f\"{'='*80}\")\n\nwith warnings.catch_warnings():\n    warnings.simplefilter(\"error\")  # Treat warnings as errors\n    try:\n        model.save_pretrained_merged(\"whisper\", tokenizer)\n        print(\"✅ Model saved and merged successfully without warnings!\")\n    except Exception as e:\n        assert False, f\"Model saving/merging failed with exception: {e}\"\n\nprint(f\"\\n{'='*80}\")\nprint(\"🔍 SECTION 5: Loading Model for Inference\")\nprint(f\"{'='*80}\")\n\n\nmodel, tokenizer = FastModel.from_pretrained(\n    model_name = \"./whisper\",\n    dtype = None,  # Leave as None for auto detection\n    load_in_4bit = False,  # Set to True to do 4bit quantization which reduces memory\n    auto_model = WhisperForConditionalGeneration,\n    whisper_language = \"English\",\n    whisper_task = \"transcribe\",\n    # token = \"hf_...\", # use one if using gated models like meta-llama/Llama-2-7b-hf\n)\n\n# model = WhisperForConditionalGeneration.from_pretrained(\"./whisper\")\n# processor = WhisperProcessor.from_pretrained(\"./whisper\")\n\nprint(\"✅ Model loaded for inference successfully!\")\n\nprint(f\"\\n{'='*80}\")\nprint(\"🔍 SECTION 6: Downloading Sample Audio File\")\nprint(f\"{'='*80}\")\n\naudio_url = \"https://upload.wikimedia.org/wikipedia/commons/5/5b/Speech_12dB_s16.flac\"\naudio_file = \"Speech_12dB_s16.flac\"\n\ntry:\n    headers = {\n        \"User-Agent\": \"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36\"\n    }\n    response = requests.get(audio_url, headers = headers)\n    response.raise_for_status()\n    with open(audio_file, \"wb\") as f:\n        f.write(response.content)\n    print(\"✅ Audio file downloaded successfully!\")\nexcept Exception as e:\n    assert False, f\"Failed to download audio file: {e}\"\n\nprint(f\"\\n{'='*80}\")\nprint(\"🔍 SECTION 7: Running Inference\")\nprint(f\"{'='*80}\")\n\n\nfrom transformers import pipeline\nimport torch\n\nFastModel.for_inference(model)\nmodel.eval()\n# Create pipeline without specifying the device\nwhisper = pipeline(\n    \"automatic-speech-recognition\",\n    model = model,\n    tokenizer = tokenizer.tokenizer,\n    feature_extractor = tokenizer.feature_extractor,\n    processor = tokenizer,\n    return_language = True,\n    torch_dtype = torch.float16,  # Remove the device parameter\n)\n# Example usage\naudio_file = \"Speech_12dB_s16.flac\"\ntranscribed_text = whisper(audio_file)\n# audio, sr = sf.read(audio_file)\n# input_features = processor(audio, return_tensors=\"pt\").input_features\n# transcribed_text = model.generate(input_features=input_features)\nprint(f\"📝 Transcribed Text: {transcribed_text['text']}\")\n\n## assert that transcribed_text contains The birch canoe slid on the smooth planks. Glued the sheet to the dark blue background. It's easy to tell the depth of a well. Four hours of steady work faced us.\n\nexpected_phrases = [\n    \"birch canoe slid on the smooth planks\",\n    \"sheet to the dark blue background\",\n    \"easy to tell the depth of a well\",\n    \"Four hours of steady work faced us\",\n]\n\ntranscribed_lower = transcribed_text[\"text\"].lower()\nall_phrases_found = all(\n    phrase.lower() in transcribed_lower for phrase in expected_phrases\n)\n\nassert (\n    all_phrases_found\n), f\"Expected phrases not found in transcription: {transcribed_text['text']}\"\nprint(\"✅ Transcription contains all expected phrases!\")\n\n\nsafe_remove_directory(\"./unsloth_compiled_cache\")\nsafe_remove_directory(\"./whisper\")\n"
  },
  {
    "path": "tests/saving/vision_models/test_index_file_sharded_model.py",
    "content": "## Import required libraries\n\nfrom unsloth import FastVisionModel, is_bf16_supported\nfrom unsloth.trainer import UnslothVisionDataCollator\n\nimport torch\nimport os\nfrom datasets import load_dataset\nfrom trl import SFTTrainer, SFTConfig\nfrom huggingface_hub import HfFileSystem\nimport sys\nfrom pathlib import Path\n\n\nREPO_ROOT = Path(__file__).parents[3]\nsys.path.insert(0, str(REPO_ROOT))\n\nfrom tests.utils.cleanup_utils import safe_remove_directory\n\n\n## Dataset Preparation\"\"\"\n\nprint(\"\\n📊 Loading and preparing dataset...\")\ndataset = load_dataset(\"lbourdois/OCR-liboaccn-OPUS-MIT-5M-clean\", \"en\", split = \"train\")\n# To select the first 2000 examples\ntrain_dataset = dataset.select(range(2000))\n\n# To select the next 200 examples for evaluation\neval_dataset = dataset.select(range(2000, 2200))\n\nprint(f\"✅ Dataset loaded successfully!\")\nprint(f\"   📈 Training samples: {len(train_dataset)}\")\nprint(f\"   📊 Evaluation samples: {len(eval_dataset)}\")\n\n\n# Convert dataset to OAI messages\ndef format_data(sample):\n    return {\n        \"messages\": [\n            {\n                \"role\": \"system\",\n                \"content\": [{\"type\": \"text\", \"text\": system_message}],\n            },\n            {\n                \"role\": \"user\",\n                \"content\": [\n                    {\n                        \"type\": \"text\",\n                        \"text\": sample[\"question\"],\n                    },\n                    {\n                        \"type\": \"image\",\n                        \"image\": sample[\"image\"],\n                    },\n                ],\n            },\n            {\n                \"role\": \"assistant\",\n                \"content\": [{\"type\": \"text\", \"text\": sample[\"answer\"]}],\n            },\n        ],\n    }\n\n\nprint(\"\\n🔄 Formatting dataset for vision training...\")\nsystem_message = \"You are an expert french ocr system.\"\n# Convert dataset to OAI messages\n# need to use list comprehension to keep Pil.Image type, .mape convert image to bytes\ntrain_dataset = [format_data(sample) for sample in train_dataset]\neval_dataset = [format_data(sample) for sample in eval_dataset]\nprint(\"✅ Dataset formatting completed!\")\n\n\"\"\"## Finetuning Setup and Run\"\"\"\n\n\nprint(\"\\n\" + \"=\" * 80)\nprint(\"=== MODEL LOADING AND SETUP ===\".center(80))\nprint(\"=\" * 80 + \"\\n\")\n# Load Base Model\nprint(\"🤖 Loading base vision model...\")\ntry:\n    model, tokenizer = FastVisionModel.from_pretrained(\n        # model_name = \"unsloth/Qwen2-VL-7B-Instruct\",\n        model_name = \"unsloth/Qwen2-VL-7B-Instruct\",\n        max_seq_length = 2048,  # Choose any for long context!\n        load_in_4bit = True,  # 4 bit quantization to reduce memory\n        load_in_8bit = False,  # [NEW!] A bit more accurate, uses 2x memory\n        full_finetuning = False,  # [NEW!] We have full finetuning now!\n    )\nexcept Exception as e:\n    print(f\"❌ Failed to load base model: {e}\")\n    raise\n\nprint(\"\\n🔧 Setting up LoRA configuration...\")\n## Lora Finetuning\ntry:\n    model = FastVisionModel.get_peft_model(\n        model,\n        finetune_vision_layers = True,  # Turn off for just text!\n        finetune_language_layers = True,  # Should leave on!\n        finetune_attention_modules = True,  # Attention good for GRPO\n        finetune_mlp_modules = True,  # SHould leave on always!\n        r = 16,  # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n        lora_alpha = 32,\n        lora_dropout = 0,  # Supports any, but = 0 is optimized\n        bias = \"none\",  # Supports any, but = \"none\" is optimized\n        use_gradient_checkpointing = \"unsloth\",  # True or \"unsloth\" for very long context\n        random_state = 3407,\n        use_rslora = False,  # We support rank stabilized LoRA\n        loftq_config = None,  # And LoftQ\n    )\n    print(\"✅ LoRA configuration applied successfully!\")\n    print(f\"   🎯 LoRA rank (r): 16\")\n    print(f\"   📊 LoRA alpha: 32\")\n    print(f\"   🔍 Vision layers: Enabled\")\n    print(f\"   💬 Language layers: Enabled\")\nexcept Exception as e:\n    print(f\"❌ Failed to apply LoRA configuration: {e}\")\n    raise\n\nprint(\"\\n\" + \"=\" * 80)\nprint(\"=== TRAINING SETUP ===\".center(80))\nprint(\"=\" * 80 + \"\\n\")\n\n\nprint(\"🏋️ Preparing trainer...\")\nFastVisionModel.for_training(model)  # Enable for training!\n\ntry:\n    trainer = SFTTrainer(\n        model = model,\n        tokenizer = tokenizer,\n        data_collator = UnslothVisionDataCollator(model, tokenizer),\n        train_dataset = train_dataset,\n        args = SFTConfig(\n            # per_device_train_batch_size = 4,\n            # gradient_accumulation_steps = 8,\n            per_device_train_batch_size = 2,\n            gradient_accumulation_steps = 4,\n            gradient_checkpointing = True,\n            gradient_checkpointing_kwargs = {\n                \"use_reentrant\": False\n            },  # use reentrant checkpointing\n            max_grad_norm = 0.3,  # max gradient norm based on QLoRA paper\n            warmup_ratio = 0.03,\n            # num_train_epochs = 2, # Set this instead of max_steps for full training runs\n            max_steps = 10,\n            learning_rate = 2e-4,\n            fp16 = not is_bf16_supported(),\n            bf16 = is_bf16_supported(),\n            logging_steps = 5,\n            save_strategy = \"epoch\",\n            optim = \"adamw_torch_fused\",\n            weight_decay = 0.01,\n            lr_scheduler_type = \"linear\",\n            seed = 3407,\n            output_dir = \"checkpoints\",\n            report_to = \"none\",  # For Weights and Biases\n            # You MUST put the below items for vision finetuning:\n            remove_unused_columns = False,\n            dataset_text_field = \"\",\n            dataset_kwargs = {\"skip_prepare_dataset\": True},\n            dataset_num_proc = 4,\n            max_seq_length = 2048,\n        ),\n    )\n    print(\"✅ Trainer setup completed!\")\n    print(f\"   📦 Batch size: 2\")\n    print(f\"   🔄 Gradient accumulation steps: 4\")\n    print(f\"   📈 Max training steps: 10\")\n    print(f\"   🎯 Learning rate: 2e-4\")\n    print(f\"   💾 Precision: {'BF16' if is_bf16_supported() else 'FP16'}\")\nexcept Exception as e:\n    print(f\"❌ Failed to setup trainer: {e}\")\n    raise\n\nprint(\"\\n\" + \"=\" * 80)\nprint(\"=== STARTING TRAINING ===\".center(80))\nprint(\"=\" * 80 + \"\\n\")\n# run training\ntry:\n    print(\"🚀 Starting training process...\")\n    trainer_stats = trainer.train()\nexcept Exception as e:\n    print(f\"❌ Training failed: {e}\")\n    raise\n\nprint(\"\\n\" + \"=\" * 80)\nprint(\"=== SAVING MODEL ===\".center(80))\nprint(\"=\" * 80 + \"\\n\")\n\nprint(\"💾 Saving adapter model and tokenizer locally...\")\ntry:\n    model.save_pretrained(\"unsloth-qwen2-7vl-french-ocr-adapter\", tokenizer)\n    tokenizer.save_pretrained(\"unsloth-qwen2-7vl-french-ocr-adapter\")\n    print(\"✅ Model saved locally!\")\nexcept Exception as e:\n    print(f\"❌ Failed to save model locally: {e}\")\n    raise\n\n\nhf_username = os.environ.get(\"HF_USER\", \"\")\nif not hf_username:\n    hf_username = input(\"Please enter your Hugging Face username: \").strip()\n    os.environ[\"HF_USER\"] = hf_username\n\nhf_token = os.environ.get(\"HF_TOKEN\", \"\")\nif not hf_token:\n    hf_token = input(\"Please enter your Hugging Face token: \").strip()\n    os.environ[\"HF_TOKEN\"] = hf_token\n\nrepo_name = f\"{hf_username}/qwen2-7b-ocr-merged\"\nsuccess = {\n    \"upload\": False,\n    \"safetensors_check\": False,\n    \"download\": False,\n}\n# Stage 1: Upload model to Hub\ntry:\n    print(\"\\n\" + \"=\" * 80)\n    print(\"=== UPLOADING MODEL TO HUB ===\".center(80))\n    print(\"=\" * 80 + \"\\n\")\n    print(f\"🚀 Uploading to repository: {repo_name}\")\n    model.push_to_hub_merged(repo_name, tokenizer = tokenizer, token = hf_token)\n    success[\"upload\"] = True\n    print(\"✅ Model uploaded successfully!\")\nexcept Exception as e:\n    print(f\"❌ Failed to upload model: {e}\")\n    raise Exception(\"Model upload failed.\")\n\n# Stage 2: Verify safetensors.index.json exists\ntry:\n    print(\"\\n\" + \"=\" * 80)\n    print(\"=== VERIFYING REPO CONTENTS ===\".center(80))\n    print(\"=\" * 80 + \"\\n\")\n    fs = HfFileSystem(token = hf_token)\n    file_list = fs.ls(repo_name, detail = True)\n    safetensors_found = any(\n        file[\"name\"].endswith(\"model.safetensors.index.json\") for file in file_list\n    )\n    if safetensors_found:\n        success[\"safetensors_check\"] = True\n        print(\"✅ model.safetensors.index.json found in repo!\")\n    else:\n        raise Exception(\"model.safetensors.index.json not found in repo.\")\nexcept Exception as e:\n    print(f\"❌ Verification failed: {e}\")\n    raise Exception(\"Repo verification failed.\")\n\n# test downloading model even if cached\nsafe_remove_directory(f\"./{hf_username}\")\n\ntry:\n    print(\"\\n\" + \"=\" * 80)\n    print(\"=== TESTING MODEL DOWNLOAD ===\".center(80))\n    print(\"=\" * 80 + \"\\n\")\n    print(\"📥 Testing model download...\")\n    # Force download even if cached\n    test_model, test_tokenizer = FastVisionModel.from_pretrained(repo_name)\n    success[\"download\"] = True\n    print(\"✅ Model downloaded successfully!\")\n\n    # Clean up test model\n    del test_model, test_tokenizer\n    torch.cuda.empty_cache()\nexcept Exception as e:\n    print(f\"❌ Download failed: {e}\")\n    raise Exception(\"Model download failed.\")\n\n# Final report\nprint(\"\\n\" + \"=\" * 80)\nprint(\"=== VALIDATION REPORT ===\".center(80))\nprint(\"=\" * 80 + \"\\n\")\nfor stage, passed in success.items():\n    status = \"✅\" if passed else \"❌\"\n    print(f\"{status} {stage.replace('_', ' ').title()}\")\nprint(\"\\n\" + \"=\" * 80)\n\nif all(success.values()):\n    print(\"\\n🎉 All stages completed successfully!\")\n    print(f\"🌐 Your model is available at: https://huggingface.co/{repo_name}\")\nelse:\n    raise Exception(\"Validation failed for one or more stages.\")\n\n\n# Final cleanup\nprint(\"\\n🧹 Cleaning up temporary files...\")\nsafe_remove_directory(\"./checkpoints\")\nsafe_remove_directory(\"./unsloth_compiled_cache\")\nsafe_remove_directory(\"./unsloth-qwen2-7vl-french-ocr-adapter\")\n\nprint(\"\\n🎯 Pipeline completed successfully!\")\nprint(\"=\" * 80)\n"
  },
  {
    "path": "tests/saving/vision_models/test_push_to_hub_merged.py",
    "content": "## Import required libraries\n\nfrom unsloth import FastVisionModel, is_bf16_supported\nfrom unsloth.trainer import UnslothVisionDataCollator\n\nimport torch\nimport os\nfrom datasets import load_dataset\nfrom trl import SFTTrainer, SFTConfig\n\nimport sys\nfrom pathlib import Path\n\n\nREPO_ROOT = Path(__file__).parents[3]\nsys.path.insert(0, str(REPO_ROOT))\n\n\nfrom tests.utils.cleanup_utils import safe_remove_directory\n\n\n## Dataset Preparation\"\"\"\n\nprint(\"\\n📊 Loading and preparing dataset...\")\ndataset = load_dataset(\"lbourdois/OCR-liboaccn-OPUS-MIT-5M-clean\", \"en\", split = \"train\")\n# To select the first 2000 examples\ntrain_dataset = dataset.select(range(2000))\n\n# To select the next 200 examples for evaluation\neval_dataset = dataset.select(range(2000, 2200))\n\nprint(f\"✅ Dataset loaded successfully!\")\nprint(f\"   📈 Training samples: {len(train_dataset)}\")\nprint(f\"   📊 Evaluation samples: {len(eval_dataset)}\")\n\n\n# Convert dataset to OAI messages\ndef format_data(sample):\n    return {\n        \"messages\": [\n            {\n                \"role\": \"system\",\n                \"content\": [{\"type\": \"text\", \"text\": system_message}],\n            },\n            {\n                \"role\": \"user\",\n                \"content\": [\n                    {\n                        \"type\": \"text\",\n                        \"text\": sample[\"question\"],\n                    },\n                    {\n                        \"type\": \"image\",\n                        \"image\": sample[\"image\"],\n                    },\n                ],\n            },\n            {\n                \"role\": \"assistant\",\n                \"content\": [{\"type\": \"text\", \"text\": sample[\"answer\"]}],\n            },\n        ],\n    }\n\n\nprint(\"\\n🔄 Formatting dataset for vision training...\")\nsystem_message = \"You are an expert french ocr system.\"\n# Convert dataset to OAI messages\n# need to use list comprehension to keep Pil.Image type, .mape convert image to bytes\ntrain_dataset = [format_data(sample) for sample in train_dataset]\neval_dataset = [format_data(sample) for sample in eval_dataset]\nprint(\"✅ Dataset formatting completed!\")\n\n\"\"\"## Finetuning Setup and Run\"\"\"\n\n\nprint(\"\\n\" + \"=\" * 80)\nprint(\"=== MODEL LOADING AND SETUP ===\".center(80))\nprint(\"=\" * 80 + \"\\n\")\n# Load Base Model\nprint(\"🤖 Loading base vision model...\")\ntry:\n    model, tokenizer = FastVisionModel.from_pretrained(\n        # model_name = \"unsloth/Qwen2-VL-7B-Instruct\",\n        model_name = \"unsloth/Qwen2-VL-2B-Instruct\",\n        max_seq_length = 2048,  # Choose any for long context!\n        load_in_4bit = True,  # 4 bit quantization to reduce memory\n        load_in_8bit = False,  # [NEW!] A bit more accurate, uses 2x memory\n        full_finetuning = False,  # [NEW!] We have full finetuning now!\n    )\nexcept Exception as e:\n    print(f\"❌ Failed to load base model: {e}\")\n    raise\n\nprint(\"\\n🔧 Setting up LoRA configuration...\")\n## Lora Finetuning\ntry:\n    model = FastVisionModel.get_peft_model(\n        model,\n        finetune_vision_layers = True,  # Turn off for just text!\n        finetune_language_layers = True,  # Should leave on!\n        finetune_attention_modules = True,  # Attention good for GRPO\n        finetune_mlp_modules = True,  # SHould leave on always!\n        r = 16,  # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n        lora_alpha = 32,\n        lora_dropout = 0,  # Supports any, but = 0 is optimized\n        bias = \"none\",  # Supports any, but = \"none\" is optimized\n        use_gradient_checkpointing = \"unsloth\",  # True or \"unsloth\" for very long context\n        random_state = 3407,\n        use_rslora = False,  # We support rank stabilized LoRA\n        loftq_config = None,  # And LoftQ\n    )\n    print(\"✅ LoRA configuration applied successfully!\")\n    print(f\"   🎯 LoRA rank (r): 16\")\n    print(f\"   📊 LoRA alpha: 32\")\n    print(f\"   🔍 Vision layers: Enabled\")\n    print(f\"   💬 Language layers: Enabled\")\nexcept Exception as e:\n    print(f\"❌ Failed to apply LoRA configuration: {e}\")\n    raise\n\nprint(\"\\n\" + \"=\" * 80)\nprint(\"=== TRAINING SETUP ===\".center(80))\nprint(\"=\" * 80 + \"\\n\")\n\n\nprint(\"🏋️ Preparing trainer...\")\nFastVisionModel.for_training(model)  # Enable for training!\n\ntry:\n    trainer = SFTTrainer(\n        model = model,\n        tokenizer = tokenizer,\n        data_collator = UnslothVisionDataCollator(model, tokenizer),\n        train_dataset = train_dataset,\n        args = SFTConfig(\n            # per_device_train_batch_size = 4,\n            # gradient_accumulation_steps = 8,\n            per_device_train_batch_size = 2,\n            gradient_accumulation_steps = 4,\n            gradient_checkpointing = True,\n            gradient_checkpointing_kwargs = {\n                \"use_reentrant\": False\n            },  # use reentrant checkpointing\n            max_grad_norm = 0.3,  # max gradient norm based on QLoRA paper\n            warmup_ratio = 0.03,\n            # num_train_epochs = 2, # Set this instead of max_steps for full training runs\n            max_steps = 10,\n            learning_rate = 2e-4,\n            fp16 = not is_bf16_supported(),\n            bf16 = is_bf16_supported(),\n            logging_steps = 5,\n            save_strategy = \"epoch\",\n            optim = \"adamw_torch_fused\",\n            weight_decay = 0.01,\n            lr_scheduler_type = \"linear\",\n            seed = 3407,\n            output_dir = \"checkpoints\",\n            report_to = \"none\",  # For Weights and Biases\n            # You MUST put the below items for vision finetuning:\n            remove_unused_columns = False,\n            dataset_text_field = \"\",\n            dataset_kwargs = {\"skip_prepare_dataset\": True},\n            dataset_num_proc = 4,\n            max_seq_length = 2048,\n        ),\n    )\n    print(\"✅ Trainer setup completed!\")\n    print(f\"   📦 Batch size: 2\")\n    print(f\"   🔄 Gradient accumulation steps: 4\")\n    print(f\"   📈 Max training steps: 10\")\n    print(f\"   🎯 Learning rate: 2e-4\")\n    print(f\"   💾 Precision: {'BF16' if is_bf16_supported() else 'FP16'}\")\nexcept Exception as e:\n    print(f\"❌ Failed to setup trainer: {e}\")\n    raise\n\nprint(\"\\n\" + \"=\" * 80)\nprint(\"=== STARTING TRAINING ===\".center(80))\nprint(\"=\" * 80 + \"\\n\")\n# run training\ntry:\n    print(\"🚀 Starting training process...\")\n    trainer_stats = trainer.train()\nexcept Exception as e:\n    print(f\"❌ Training failed: {e}\")\n    raise\n\nprint(\"\\n\" + \"=\" * 80)\nprint(\"=== SAVING MODEL ===\".center(80))\nprint(\"=\" * 80 + \"\\n\")\n\nprint(\"💾 Saving adapter model and tokenizer locally...\")\ntry:\n    model.save_pretrained(\"unsloth-qwen2-7vl-french-ocr-adapter\", tokenizer)\n    tokenizer.save_pretrained(\"unsloth-qwen2-7vl-french-ocr-adapter\")\n    print(\"✅ Model saved locally!\")\nexcept Exception as e:\n    print(f\"❌ Failed to save model locally: {e}\")\n    raise\n\n\nhf_username = os.environ.get(\"HF_USER\", \"\")\nif not hf_username:\n    hf_username = input(\"Please enter your Hugging Face username: \").strip()\n    os.environ[\"HF_USER\"] = hf_username\n\nhf_token = os.environ.get(\"HF_TOKEN\", \"\")\nif not hf_token:\n    hf_token = input(\"Please enter your Hugging Face token: \").strip()\n    os.environ[\"HF_TOKEN\"] = hf_token\n\nrepo_name = f\"{hf_username}/qwen2-ocr-merged\"\nsuccess = {\n    \"upload\": False,\n    \"download\": False,\n}\n# Stage 1: Upload model to Hub\ntry:\n    print(\"\\n\" + \"=\" * 80)\n    print(\"=== UPLOADING MODEL TO HUB ===\".center(80))\n    print(\"=\" * 80 + \"\\n\")\n    print(f\"🚀 Uploading to repository: {repo_name}\")\n    model.push_to_hub_merged(repo_name, tokenizer = tokenizer, token = hf_token)\n    success[\"upload\"] = True\n    print(\"✅ Model uploaded successfully!\")\nexcept Exception as e:\n    print(f\"❌ Failed to upload model: {e}\")\n    raise Exception(\"Model upload failed.\")\n\n\ntry:\n    print(\"\\n\" + \"=\" * 80)\n    print(\"=== TESTING MODEL DOWNLOAD ===\".center(80))\n    print(\"=\" * 80 + \"\\n\")\n    print(\"📥 Testing model download...\")\n    # Force download even if cached\n    test_model, test_tokenizer = FastVisionModel.from_pretrained(repo_name)\n    success[\"download\"] = True\n    print(\"✅ Model downloaded successfully!\")\n\n    # Clean up test model\n    del test_model, test_tokenizer\n    torch.cuda.empty_cache()\nexcept Exception as e:\n    print(f\"❌ Download failed: {e}\")\n    raise Exception(\"Model download failed.\")\n\n# Final report\nprint(\"\\n\" + \"=\" * 80)\nprint(\"=== VALIDATION REPORT ===\".center(80))\nprint(\"=\" * 80 + \"\\n\")\nfor stage, passed in success.items():\n    status = \"✅\" if passed else \"❌\"\n    print(f\"{status} {stage.replace('_', ' ').title()}\")\nprint(\"\\n\" + \"=\" * 80)\n\nif all(success.values()):\n    print(\"\\n🎉 All stages completed successfully!\")\n    print(f\"🌐 Your model is available at: https://huggingface.co/{repo_name}\")\nelse:\n    raise Exception(\"Validation failed for one or more stages.\")\n\n\n# Final cleanup\nprint(\"\\n🧹 Cleaning up temporary files...\")\nsafe_remove_directory(\"./checkpoints\")\nsafe_remove_directory(\"./unsloth_compiled_cache\")\nsafe_remove_directory(\"./unsloth-qwen2-7vl-french-ocr-adapter\")\nsafe_remove_directory(f\"./{hf_username}\")\n\nprint(\"\\n🎯 Pipeline completed successfully!\")\nprint(\"=\" * 80)\n"
  },
  {
    "path": "tests/saving/vision_models/test_save_merge_qwen2.5vl32B_model_ocr_benchmark.py",
    "content": "# -*- coding: utf-8 -*-\n\nfrom unsloth import FastVisionModel\n\nimport torch\nfrom qwen_vl_utils import process_vision_info\nimport os\nfrom datasets import load_dataset\nfrom trl import SFTTrainer, SFTConfig\n\nimport sys\nfrom pathlib import Path\n\n\nREPO_ROOT = Path(__file__).parents[3]\nsys.path.insert(0, str(REPO_ROOT))\n\nfrom tests.utils.cleanup_utils import safe_remove_directory\nfrom tests.utils.ocr_eval import OCRModelEvaluator\n\n\n## Dataset Preparation\nfrom datasets import load_dataset\n\ndataset = load_dataset(\"lbourdois/OCR-liboaccn-OPUS-MIT-5M-clean\", \"en\", split = \"train\")\n# To select the first 2000 examples\ntrain_dataset = dataset.select(range(2000))\n\n# To select the next 200 examples for evaluation\neval_dataset = dataset.select(range(2000, 2200))\n\n\n# Convert dataset to OAI messages\ndef format_data(sample):\n    return {\n        \"messages\": [\n            {\n                \"role\": \"system\",\n                \"content\": [{\"type\": \"text\", \"text\": system_message}],\n            },\n            {\n                \"role\": \"user\",\n                \"content\": [\n                    {\n                        \"type\": \"text\",\n                        \"text\": sample[\"question\"],\n                    },\n                    {\n                        \"type\": \"image\",\n                        \"image\": sample[\"image\"],\n                    },\n                ],\n            },\n            {\n                \"role\": \"assistant\",\n                \"content\": [{\"type\": \"text\", \"text\": sample[\"answer\"]}],\n            },\n        ],\n    }\n\n\nsystem_message = \"You are an expert french ocr system.\"\n# Convert dataset to OAI messages\n# need to use list comprehension to keep Pil.Image type, .mape convert image to bytes\ntrain_dataset = [format_data(sample) for sample in train_dataset]\neval_dataset = [format_data(sample) for sample in eval_dataset]\n\n## Setup OCR main evaluation function and helpers\nimport os\nimport torch\nfrom tqdm import tqdm\nimport pandas as pd\nfrom jiwer import wer, cer\nfrom qwen_vl_utils import process_vision_info\n\n#\nocr_evaluator = OCRModelEvaluator()\nmodel_comparison_results = {}\n\n## Finetuning Setup and Run\n# Load Base Model\n\nmodel, tokenizer = FastVisionModel.from_pretrained(\n    model_name = \"unsloth/Qwen2.5-VL-32B-Instruct-bnb-4bit\",\n    max_seq_length = 2048,  # Choose any for long context!\n    load_in_4bit = True,  # 4 bit quantization to reduce memory\n    load_in_8bit = False,  # [NEW!] A bit more accurate, uses 2x memory\n    full_finetuning = False,  # [NEW!] We have full finetuning now!\n)\n\n# benchmark base model performance\nmodel_name = \"Unsloth Base model\"\nFastVisionModel.for_inference(model)\navg_wer, avg_cer = ocr_evaluator.evaluate_model(\n    model, tokenizer, eval_dataset, output_dir = \"unsloth_base_model_results\"\n)\nocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)\n\n## Lora Finetuning\nmodel = FastVisionModel.get_peft_model(\n    model,\n    finetune_vision_layers = True,  # Turn off for just text!\n    finetune_language_layers = True,  # Should leave on!\n    finetune_attention_modules = True,  # Attention good for GRPO\n    finetune_mlp_modules = True,  # SHould leave on always!\n    r = 16,  # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n    # target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n    # \"gate_proj\", \"up_proj\", \"down_proj\",],\n    lora_alpha = 32,\n    lora_dropout = 0,  # Supports any, but = 0 is optimized\n    bias = \"none\",  # Supports any, but = \"none\" is optimized\n    # [NEW] \"unsloth\" uses 30% less VRAM, fits 2x larger batch sizes!\n    use_gradient_checkpointing = \"unsloth\",  # True or \"unsloth\" for very long context\n    random_state = 3407,\n    use_rslora = False,  # We support rank stabilized LoRA\n    loftq_config = None,  # And LoftQ\n)\n\nfrom unsloth import is_bf16_supported\nfrom unsloth.trainer import UnslothVisionDataCollator\n\nFastVisionModel.for_training(model)  # Enable for training!\nmodel.config.use_cache = False\n\n\ntrainer = SFTTrainer(\n    model = model,\n    tokenizer = tokenizer,\n    data_collator = UnslothVisionDataCollator(model, tokenizer),\n    train_dataset = train_dataset,\n    args = SFTConfig(\n        # per_device_train_batch_size = 4,\n        # gradient_accumulation_steps = 8,\n        per_device_train_batch_size = 2,\n        gradient_accumulation_steps = 4,\n        gradient_checkpointing = True,\n        gradient_checkpointing_kwargs = {\n            \"use_reentrant\": False\n        },  # use reentrant checkpointing\n        max_grad_norm = 0.3,  # max gradient norm based on QLoRA paper\n        warmup_ratio = 0.03,\n        # num_train_epochs = 2, # Set this instead of max_steps for full training runs\n        max_steps = 60,\n        learning_rate = 2e-4,\n        fp16 = not is_bf16_supported(),\n        bf16 = is_bf16_supported(),\n        logging_steps = 5,\n        save_strategy = \"epoch\",\n        optim = \"adamw_torch_fused\",\n        weight_decay = 0.01,\n        lr_scheduler_type = \"linear\",\n        seed = 3407,\n        output_dir = \"unsloth-qwen2.5-vl-32b-french-ocr-checkpoints\",\n        report_to = \"none\",  # For Weights and Biases\n        # You MUST put the below items for vision finetuning:\n        remove_unused_columns = False,\n        dataset_text_field = \"\",\n        dataset_kwargs = {\"skip_prepare_dataset\": True},\n        dataset_num_proc = 4,\n        max_seq_length = 2048,\n    ),\n)\n\n# run training\ntrainer_stats = trainer.train()\n\nmodel.save_pretrained(\"unsloth-qwen2.5-vl-32b-french-ocr-adapter\", tokenizer)\ntokenizer.save_pretrained(\"unsloth-qwen2.5-vl-32b-french-ocr-adapter\")\n\n## Measure Adapter Performance\n\n# benchmark lora model performance\nmodel_name = \"Unsloth lora adapter model\"\nFastVisionModel.for_inference(model)\navg_wer, avg_cer = ocr_evaluator.evaluate_model(\n    model, tokenizer, eval_dataset, output_dir = \"unsloth_lora_model_results\"\n)\nocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)\n\n## Merge Model\n\n\ndef find_lora_base_model(model_to_inspect):\n    current = model_to_inspect\n    if hasattr(current, \"base_model\"):\n        current = current.base_model\n    if hasattr(current, \"model\"):\n        current = current.model\n    return current\n\n\nbase = find_lora_base_model(model)\n\nprint((base.__class__.__name__))\n\n# merge default 16 bits\nmodel.save_pretrained_merged(\n    save_directory = \"qwen2.5-ocr-merged-finetune-merge-16bit\", tokenizer = tokenizer\n)\n\n\n## Benchmark merged model performance\n\n### 16 bits merged model\n\nmodel, tokenizer = FastVisionModel.from_pretrained(\n    \"./qwen2.5-ocr-merged-finetune-merge-16bit\", load_in_4bit = False, load_in_8bit = False\n)\n\n# benchmark 4bit loaded, 16bits merged model performance\nmodel_name = \"Unsloth 16bits-merged model load-16bits\"\nmodel.config.use_cache = True\n\navg_wer, avg_cer = ocr_evaluator.evaluate_model(\n    model,\n    tokenizer,\n    eval_dataset,\n    output_dir = \"unsloth_16bits_merged_model_load_16bits_results\",\n)\nocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)\n\n# load 16bits-merged model in 4 bits\nmodel, tokenizer = FastVisionModel.from_pretrained(\n    \"./qwen2.5-ocr-merged-finetune-merge-16bit\", load_in_4bit = True, load_in_8bit = False\n)\n\n# benchmark 4bit loaded, 16bits merged model performance\nmodel_name = \"Unsloth 16bits-merged model load-4bits\"\nmodel.config.use_cache = True\n\navg_wer, avg_cer = ocr_evaluator.evaluate_model(\n    model,\n    tokenizer,\n    eval_dataset,\n    output_dir = \"unsloth_16bits_merged_model_load_4bits_results\",\n)\nocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)\n\n# load model in 8 bits\nmodel, tokenizer = FastVisionModel.from_pretrained(\n    \"./qwen2.5-ocr-merged-finetune-merge-16bit\", load_in_4bit = False, load_in_8bit = True\n)\n\n# benchmark 4bit loaded, 16bits merged model performance\nmodel_name = \"Unsloth 16bits-merged model load-8bits\"\navg_wer, avg_cer = ocr_evaluator.evaluate_model(\n    model,\n    tokenizer,\n    eval_dataset,\n    output_dir = \"unsloth_16bits_merged_model_load_8bits_results\",\n)\nocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)\n\n# \"\"\"### 4 bits merged model\"\"\"\n#\n# # load 4bits-merged model in 4 bits\n# model, tokenizer = FastVisionModel.from_pretrained(\"./qwen2-ocr-merged-finetune-merge-4bit\",load_in_4bit=True, load_in_8bit=False)\n#\n# # benchmark 4bit loaded, 4bits merged model performance\n# model_name = \"Unsloth 4bits-merged model load-4bits\"\n#\n# avg_wer, avg_cer = ocr_evaluator.evaluate_model(model, tokenizer, eval_dataset, output_dir=\"unsloth_4bits_merged_model_load_4bits_results\")\n# ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)\n#\n# # load model in 8 bits\n# model, tokenizer = FastVisionModel.from_pretrained(\"./qwen2-ocr-merged-finetune-merge-4bit\",load_in_4bit=False, load_in_8bit=True)\n#\n# # benchmark 8bit loaded, 4bits merged model performance\n# model_name = \"Unsloth 4bits-merged model load-8bits\"\n#\n# avg_wer, avg_cer = ocr_evaluator.evaluate_model(model, tokenizer, eval_dataset, output_dir=\"unsloth_4bits_merged_model_load_8bits_results\")\n# ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)\n\n# Model comparison report\n# print model comparison\nocr_evaluator.print_model_comparison()\n\n\n# Final cleanup\nprint(\"\\n🧹 Cleaning up temporary files...\")\nsafe_remove_directory(\"./unsloth-qwen2.5-vl-32b-french-ocr-adapter\")\nsafe_remove_directory(\"./unsloth-qwen2.5-vl-32b-french-ocr-checkpoints\")\nsafe_remove_directory(\"./unsloth_compiled_cache\")\nsafe_remove_directory(\"./qwen2.5-ocr-merged-finetune-merge-16bit\")\n\nprint(\"\\n🎯 Pipeline completed successfully!\")\nprint(\"=\" * 80)\n"
  },
  {
    "path": "tests/saving/vision_models/test_save_merge_vision_model_ocr_benchmark.py",
    "content": "# -*- coding: utf-8 -*-\n\nfrom unsloth import FastVisionModel\n\nimport torch\nfrom qwen_vl_utils import process_vision_info\nimport os\nfrom datasets import load_dataset\nfrom trl import SFTTrainer, SFTConfig\n\nimport sys\nfrom pathlib import Path\n\n\nREPO_ROOT = Path(__file__).parents[3]\nsys.path.insert(0, str(REPO_ROOT))\n\nfrom tests.utils.cleanup_utils import safe_remove_directory\nfrom tests.utils.ocr_eval import OCRModelEvaluator\n\n\n## Dataset Preparation\nfrom datasets import load_dataset\n\ndataset = load_dataset(\"lbourdois/OCR-liboaccn-OPUS-MIT-5M-clean\", \"en\", split = \"train\")\n# To select the first 2000 examples\ntrain_dataset = dataset.select(range(2000))\n\n# To select the next 200 examples for evaluation\neval_dataset = dataset.select(range(2000, 2200))\n\n\n# Convert dataset to OAI messages\ndef format_data(sample):\n    return {\n        \"messages\": [\n            {\n                \"role\": \"system\",\n                \"content\": [{\"type\": \"text\", \"text\": system_message}],\n            },\n            {\n                \"role\": \"user\",\n                \"content\": [\n                    {\n                        \"type\": \"text\",\n                        \"text\": sample[\"question\"],\n                    },\n                    {\n                        \"type\": \"image\",\n                        \"image\": sample[\"image\"],\n                    },\n                ],\n            },\n            {\n                \"role\": \"assistant\",\n                \"content\": [{\"type\": \"text\", \"text\": sample[\"answer\"]}],\n            },\n        ],\n    }\n\n\nsystem_message = \"You are an expert french ocr system.\"\n# Convert dataset to OAI messages\n# need to use list comprehension to keep Pil.Image type, .mape convert image to bytes\ntrain_dataset = [format_data(sample) for sample in train_dataset]\neval_dataset = [format_data(sample) for sample in eval_dataset]\n\n## Setup OCR main evaluation function and helpers\nimport os\nimport torch\nfrom tqdm import tqdm\nimport pandas as pd\nfrom jiwer import wer, cer\nfrom qwen_vl_utils import process_vision_info\n\n#\nocr_evaluator = OCRModelEvaluator()\nmodel_comparison_results = {}\n\n## Finetuning Setup and Run\n# Load Base Model\n\nmodel, tokenizer = FastVisionModel.from_pretrained(\n    model_name = \"unsloth/Qwen2-VL-7B-Instruct\",\n    max_seq_length = 2048,  # Choose any for long context!\n    load_in_4bit = True,  # 4 bit quantization to reduce memory\n    load_in_8bit = False,  # [NEW!] A bit more accurate, uses 2x memory\n    full_finetuning = False,  # [NEW!] We have full finetuning now!\n)\n\n# benchmark base model performance\nmodel_name = \"Unsloth Base model\"\nFastVisionModel.for_inference(model)\navg_wer, avg_cer = ocr_evaluator.evaluate_model(\n    model, tokenizer, eval_dataset, output_dir = \"unsloth_base_model_results\"\n)\nocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)\n\n## Lora Finetuning\nmodel = FastVisionModel.get_peft_model(\n    model,\n    finetune_vision_layers = True,  # Turn off for just text!\n    finetune_language_layers = True,  # Should leave on!\n    finetune_attention_modules = True,  # Attention good for GRPO\n    finetune_mlp_modules = True,  # SHould leave on always!\n    r = 16,  # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n    # target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n    # \"gate_proj\", \"up_proj\", \"down_proj\",],\n    lora_alpha = 32,\n    lora_dropout = 0,  # Supports any, but = 0 is optimized\n    bias = \"none\",  # Supports any, but = \"none\" is optimized\n    # [NEW] \"unsloth\" uses 30% less VRAM, fits 2x larger batch sizes!\n    use_gradient_checkpointing = \"unsloth\",  # True or \"unsloth\" for very long context\n    random_state = 3407,\n    use_rslora = False,  # We support rank stabilized LoRA\n    loftq_config = None,  # And LoftQ\n)\n\nfrom unsloth import is_bf16_supported\nfrom unsloth.trainer import UnslothVisionDataCollator\n\nFastVisionModel.for_training(model)  # Enable for training!\nmodel.config.use_cache = False\n\n\ntrainer = SFTTrainer(\n    model = model,\n    tokenizer = tokenizer,\n    data_collator = UnslothVisionDataCollator(model, tokenizer),\n    train_dataset = train_dataset,\n    args = SFTConfig(\n        # per_device_train_batch_size = 4,\n        # gradient_accumulation_steps = 8,\n        per_device_train_batch_size = 2,\n        gradient_accumulation_steps = 4,\n        gradient_checkpointing = True,\n        gradient_checkpointing_kwargs = {\n            \"use_reentrant\": False\n        },  # use reentrant checkpointing\n        max_grad_norm = 0.3,  # max gradient norm based on QLoRA paper\n        warmup_ratio = 0.03,\n        # num_train_epochs = 2, # Set this instead of max_steps for full training runs\n        max_steps = 60,\n        learning_rate = 2e-4,\n        fp16 = not is_bf16_supported(),\n        bf16 = is_bf16_supported(),\n        logging_steps = 5,\n        save_strategy = \"epoch\",\n        optim = \"adamw_torch_fused\",\n        weight_decay = 0.01,\n        lr_scheduler_type = \"linear\",\n        seed = 3407,\n        output_dir = \"unsloth-qwen2-7vl-french-ocr-checkpoints\",\n        report_to = \"none\",  # For Weights and Biases\n        # You MUST put the below items for vision finetuning:\n        remove_unused_columns = False,\n        dataset_text_field = \"\",\n        dataset_kwargs = {\"skip_prepare_dataset\": True},\n        dataset_num_proc = 4,\n        max_seq_length = 2048,\n    ),\n)\n\n# run training\ntrainer_stats = trainer.train()\n\nmodel.save_pretrained(\"unsloth-qwen2-7vl-french-ocr-adapter\", tokenizer)\ntokenizer.save_pretrained(\"unsloth-qwen2-7vl-french-ocr-adapter\")\n\n## Measure Adapter Performance\n\n# benchmark lora model performance\nmodel_name = \"Unsloth lora adapter model\"\nFastVisionModel.for_inference(model)\navg_wer, avg_cer = ocr_evaluator.evaluate_model(\n    model, tokenizer, eval_dataset, output_dir = \"unsloth_lora_model_results\"\n)\nocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)\n\n## Merge Model\n\n\ndef find_lora_base_model(model_to_inspect):\n    current = model_to_inspect\n    if hasattr(current, \"base_model\"):\n        current = current.base_model\n    if hasattr(current, \"model\"):\n        current = current.model\n    return current\n\n\nbase = find_lora_base_model(model)\n\nprint((base.__class__.__name__))\n\n# merge default 16 bits\nmodel.save_pretrained_merged(\n    save_directory = \"qwen2-ocr-merged-finetune-merge-16bit\", tokenizer = tokenizer\n)\n\n\n## Benchmark merged model performance\n\n### 16 bits merged model\n\nmodel, tokenizer = FastVisionModel.from_pretrained(\n    \"./qwen2-ocr-merged-finetune-merge-16bit\", load_in_4bit = False, load_in_8bit = False\n)\n\n# benchmark 4bit loaded, 16bits merged model performance\nmodel_name = \"Unsloth 16bits-merged model load-16bits\"\nmodel.config.use_cache = True\n\navg_wer, avg_cer = ocr_evaluator.evaluate_model(\n    model,\n    tokenizer,\n    eval_dataset,\n    output_dir = \"unsloth_16bits_merged_model_load_16bits_results\",\n)\nocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)\n\n# load 16bits-merged model in 4 bits\nmodel, tokenizer = FastVisionModel.from_pretrained(\n    \"./qwen2-ocr-merged-finetune-merge-16bit\", load_in_4bit = True, load_in_8bit = False\n)\n\n# benchmark 4bit loaded, 16bits merged model performance\nmodel_name = \"Unsloth 16bits-merged model load-4bits\"\nmodel.config.use_cache = True\n\navg_wer, avg_cer = ocr_evaluator.evaluate_model(\n    model,\n    tokenizer,\n    eval_dataset,\n    output_dir = \"unsloth_16bits_merged_model_load_4bits_results\",\n)\nocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)\n\n# load model in 8 bits\nmodel, tokenizer = FastVisionModel.from_pretrained(\n    \"./qwen2-ocr-merged-finetune-merge-16bit\", load_in_4bit = False, load_in_8bit = True\n)\n\n# benchmark 4bit loaded, 16bits merged model performance\nmodel_name = \"Unsloth 16bits-merged model load-8bits\"\navg_wer, avg_cer = ocr_evaluator.evaluate_model(\n    model,\n    tokenizer,\n    eval_dataset,\n    output_dir = \"unsloth_16bits_merged_model_load_8bits_results\",\n)\nocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)\n\n# \"\"\"### 4 bits merged model\"\"\"\n#\n# # load 4bits-merged model in 4 bits\n# model, tokenizer = FastVisionModel.from_pretrained(\"./qwen2-ocr-merged-finetune-merge-4bit\",load_in_4bit=True, load_in_8bit=False)\n#\n# # benchmark 4bit loaded, 4bits merged model performance\n# model_name = \"Unsloth 4bits-merged model load-4bits\"\n#\n# avg_wer, avg_cer = ocr_evaluator.evaluate_model(model, tokenizer, eval_dataset, output_dir=\"unsloth_4bits_merged_model_load_4bits_results\")\n# ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)\n#\n# # load model in 8 bits\n# model, tokenizer = FastVisionModel.from_pretrained(\"./qwen2-ocr-merged-finetune-merge-4bit\",load_in_4bit=False, load_in_8bit=True)\n#\n# # benchmark 8bit loaded, 4bits merged model performance\n# model_name = \"Unsloth 4bits-merged model load-8bits\"\n#\n# avg_wer, avg_cer = ocr_evaluator.evaluate_model(model, tokenizer, eval_dataset, output_dir=\"unsloth_4bits_merged_model_load_8bits_results\")\n# ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)\n\n# Model comparison report\n# print model comparison\nocr_evaluator.print_model_comparison()\n\n\n# Final cleanup\nprint(\"\\n🧹 Cleaning up temporary files...\")\nsafe_remove_directory(\"./unsloth-qwen2-7vl-french-ocr-adapter\")\nsafe_remove_directory(\"./unsloth-qwen2-7vl-french-ocr-checkpoints\")\nsafe_remove_directory(\"./unsloth_compiled_cache\")\nsafe_remove_directory(\"./qwen2-ocr-merged-finetune-merge-16bit\")\n\nprint(\"\\n🎯 Pipeline completed successfully!\")\nprint(\"=\" * 80)\n"
  },
  {
    "path": "tests/test_get_model_name.py",
    "content": "import unittest\nfrom unittest.mock import patch\nfrom unsloth.models.loader_utils import get_model_name\nfrom unsloth.models import loader_utils\nfrom unsloth.models.mapper import FLOAT_TO_INT_MAPPER, MAP_TO_UNSLOTH_16bit\n\n\ndef _no_remote_mapper():\n    return {}, {}, {}\n\n\nclass TestGetModelName(unittest.TestCase):\n    def _assert_mapping(self, model_name, load_in_4bit, expected, should_change):\n        mapped = get_model_name(model_name, load_in_4bit = load_in_4bit)\n        self.assertEqual(mapped.lower(), expected.lower())\n        if should_change:\n            self.assertNotEqual(mapped.lower(), model_name.lower())\n        else:\n            self.assertEqual(mapped.lower(), model_name.lower())\n\n    @patch.object(loader_utils, \"_get_new_mapper\", _no_remote_mapper)\n    def test_resolution_matrix(self):\n        cases = [\n            # Core mappings\n            (\"meta-llama/Llama-2-7b-hf\", True, \"unsloth/llama-2-7b-bnb-4bit\", True),\n            (\"meta-llama/Llama-2-7b-hf\", False, \"unsloth/llama-2-7b\", True),\n            (\n                \"mistralai/Ministral-8B-Instruct-2410\",\n                True,\n                \"mistralai/Ministral-8B-Instruct-2410\",\n                False,\n            ),\n            (\n                \"meta-llama/Llama-3.2-1B-Instruct\",\n                False,\n                \"unsloth/Llama-3.2-1B-Instruct\",\n                True,\n            ),\n            (\n                \"meta-llama/Llama-2-7b-chat-hf\",\n                True,\n                \"unsloth/llama-2-7b-chat-bnb-4bit\",\n                True,\n            ),\n            (\n                \"meta-llama/Llama-3.3-70B-Instruct\",\n                True,\n                \"unsloth/llama-3.3-70b-instruct-unsloth-bnb-4bit\",\n                True,\n            ),\n            (\"Qwen/Qwen3-8B\", True, \"unsloth/Qwen3-8B-unsloth-bnb-4bit\", True),\n            (\"Qwen/Qwen3-8B\", False, \"unsloth/Qwen3-8B\", True),\n            (\"Qwen/Qwen3-8B-FP8\", False, \"unsloth/Qwen3-8B-FP8\", True),\n            (\"Qwen/Qwen3-8B-FP8\", True, \"unsloth/Qwen3-8B-unsloth-bnb-4bit\", True),\n            (\n                \"mistralai/Ministral-3-3B-Instruct-2512\",\n                True,\n                \"unsloth/Ministral-3-3B-Instruct-2512-unsloth-bnb-4bit\",\n                True,\n            ),\n            (\n                \"mistralai/Ministral-3-3B-Instruct-2512\",\n                False,\n                \"unsloth/Ministral-3-3B-Instruct-2512\",\n                True,\n            ),\n            (\"unsloth/Kimi-K2-Instruct\", True, \"unsloth/Kimi-K2-Instruct-BF16\", True),\n            (\"unsloth/Kimi-K2-Instruct\", False, \"unsloth/Kimi-K2-Instruct\", False),\n            # Fallback-to-original behavior\n            \"nonexistent-user/nonexistent-model-123\",\n            \"google/gemma-3-random-prototype-123\",\n            \"imdatta0/nanoqwen-fp8\",\n            \"imdatta0/nanoqwen-bf16\",\n            # Backward compatibility for legacy 4bit names\n            (\"unsloth/llama-2-7b-bnb-4bit\", True, \"unsloth/llama-2-7b-bnb-4bit\", False),\n            (\"unsloth/llama-2-7b-bnb-4bit\", False, \"unsloth/llama-2-7b\", True),\n            (\"google/gemma-2-9b\", True, \"unsloth/gemma-2-9b-bnb-4bit\", True),\n            # GPT-OSS behavior\n            (\"openai/gpt-oss-20b\", False, \"unsloth/gpt-oss-20b\", True),\n            (\"openai/gpt-oss-20b\", True, \"unsloth/gpt-oss-20b-unsloth-bnb-4bit\", True),\n            (\"unsloth/gpt-oss-20b\", True, \"unsloth/gpt-oss-20b-unsloth-bnb-4bit\", True),\n            (\"unsloth/gpt-oss-20b-bf16\", True, \"unsloth/gpt-oss-20b-bf16\", False),\n            (\n                \"unsloth/gpt-oss-20b-unsloth-bnb-4bit\",\n                False,\n                \"unsloth/gpt-oss-20b\",\n                True,\n            ),\n            (\n                \"unsloth/gpt-oss-20b-bnb-4bit\",\n                True,\n                \"unsloth/gpt-oss-20b-bnb-4bit\",\n                False,\n            ),\n        ]\n        for case in cases:\n            if isinstance(case, str):\n                model_name = case\n                with self.subTest(model_name = model_name, load_in_4bit = True):\n                    self._assert_mapping(model_name, True, model_name, False)\n            else:\n                model_name, load_in_4bit, expected, should_change = case\n                with self.subTest(model_name = model_name, load_in_4bit = load_in_4bit):\n                    self._assert_mapping(\n                        model_name, load_in_4bit, expected, should_change\n                    )\n\n    def test_static_mapper_contract(self):\n        contracts = [\n            (\"qwen/qwen3-8b\", \"unsloth/qwen3-8b-unsloth-bnb-4bit\"),\n            (\"qwen/qwen3-8b-fp8\", \"unsloth/qwen3-8b-unsloth-bnb-4bit\"),\n            (\n                \"mistralai/ministral-3-3b-instruct-2512\",\n                \"unsloth/ministral-3-3b-instruct-2512-unsloth-bnb-4bit\",\n            ),\n            (\"unsloth/kimi-k2-instruct\", \"unsloth/kimi-k2-instruct-bf16\"),\n        ]\n        for src, expected in contracts:\n            with self.subTest(src = src):\n                self.assertEqual(FLOAT_TO_INT_MAPPER[src], expected)\n        self.assertEqual(\n            MAP_TO_UNSLOTH_16bit[\"qwen/qwen3-8b-fp8\"], \"unsloth/Qwen3-8B-FP8\"\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/test_model_registry.py",
    "content": "\"\"\"\n\nTest model registration methods\nChecks that model registration methods work for respective models as well as all models\nThe check is performed\n- by registering the models\n- checking that the instantiated models can be found on huggingface hub by querying for the model id\n\n\"\"\"\n\nfrom dataclasses import dataclass\n\nimport pytest\nfrom huggingface_hub import ModelInfo as HfModelInfo\n\nfrom unsloth.registry import register_models, search_models\nfrom unsloth.registry._deepseek import register_deepseek_models\nfrom unsloth.registry._gemma import register_gemma_models\nfrom unsloth.registry._llama import register_llama_models\nfrom unsloth.registry._mistral import register_mistral_models\nfrom unsloth.registry._phi import register_phi_models\nfrom unsloth.registry._qwen import register_qwen_models\nfrom unsloth.registry.registry import MODEL_REGISTRY, QUANT_TAG_MAP, QuantType\nfrom unsloth.utils.hf_hub import get_model_info\n\nMODEL_NAMES = [\n    \"llama\",\n    \"qwen\",\n    \"mistral\",\n    \"phi\",\n    \"gemma\",\n    \"deepseek\",\n]\nMODEL_REGISTRATION_METHODS = [\n    register_llama_models,\n    register_qwen_models,\n    register_mistral_models,\n    register_phi_models,\n    register_gemma_models,\n    register_deepseek_models,\n]\n\n\n@dataclass\nclass ModelTestParam:\n    name: str\n    register_models: callable\n\n\ndef _test_model_uploaded(model_ids: list[str]):\n    missing_models = []\n    for _id in model_ids:\n        model_info: HfModelInfo = get_model_info(_id)\n        if not model_info:\n            missing_models.append(_id)\n\n    return missing_models\n\n\nTestParams = [\n    ModelTestParam(name, models)\n    for name, models in zip(MODEL_NAMES, MODEL_REGISTRATION_METHODS)\n]\n\n\n# Test that model registration methods register respective models\n@pytest.mark.parametrize(\"model_test_param\", TestParams, ids = lambda param: param.name)\ndef test_model_registration(model_test_param: ModelTestParam):\n    MODEL_REGISTRY.clear()\n    registration_method = model_test_param.register_models\n    registration_method()\n    registered_models = MODEL_REGISTRY.keys()\n    missing_models = _test_model_uploaded(registered_models)\n    assert (\n        not missing_models\n    ), f\"{model_test_param.name} missing following models: {missing_models}\"\n\n\ndef test_all_model_registration():\n    register_models()\n    registered_models = MODEL_REGISTRY.keys()\n    missing_models = _test_model_uploaded(registered_models)\n    assert not missing_models, f\"Missing following models: {missing_models}\"\n\n\ndef test_quant_type():\n    # Test that the quant_type is correctly set for model paths\n    # NOTE: for models registered under org=\"unsloth\" with QuantType.NONE aliases QuantType.UNSLOTH\n    dynamic_quant_models = search_models(quant_types = [QuantType.UNSLOTH])\n    assert all(m.quant_type == QuantType.UNSLOTH for m in dynamic_quant_models)\n    quant_tag = QUANT_TAG_MAP[QuantType.UNSLOTH]\n    assert all(quant_tag in m.model_path for m in dynamic_quant_models)\n"
  },
  {
    "path": "tests/test_raw_text.py",
    "content": "#!/usr/bin/env python3\n\"\"\"\nMinimal test for raw text training implementation.\nTests basic functionality without heavy dependencies.\n\"\"\"\n\nimport sys\nimport os\nimport tempfile\nfrom pathlib import Path\nimport importlib.util\n\n\n# Mock the datasets module since it's not installed\nclass MockDataset:\n    def __init__(self, data_dict):\n        self.data = data_dict\n        self.column_names = list(data_dict.keys())\n\n    def __len__(self):\n        return len(next(iter(self.data.values())))\n\n    def __getitem__(self, idx):\n        if isinstance(idx, str):\n            # Allow accessing columns by name like dataset['text']\n            return self.data[idx]\n        elif isinstance(idx, int):\n            # Allow accessing individual rows by index\n            return {key: values[idx] for key, values in self.data.items()}\n        else:\n            raise TypeError(f\"Invalid index type: {type(idx)}\")\n\n    @classmethod\n    def from_dict(cls, data_dict):\n        return cls(data_dict)\n\n\n# Mock datasets module\ndatasets_mock = type(sys)(\"datasets\")\ndatasets_mock.Dataset = MockDataset\nsys.modules[\"datasets\"] = datasets_mock\n\n# Import the raw_text module directly to avoid unsloth/__init__.py dependencies\ncurrent_dir = os.path.dirname(__file__)\nraw_text_path = os.path.join(\n    os.path.dirname(current_dir), \"unsloth\", \"dataprep\", \"raw_text.py\"\n)\n\nspec = importlib.util.spec_from_file_location(\"raw_text\", raw_text_path)\nraw_text_module = importlib.util.module_from_spec(spec)\nspec.loader.exec_module(raw_text_module)\n\nRawTextDataLoader = raw_text_module.RawTextDataLoader\nTextPreprocessor = raw_text_module.TextPreprocessor\n\n\ndef test_raw_text_loader():\n    \"\"\"Test basic RawTextDataLoader functionality.\"\"\"\n\n    # Mock tokenizer for testing\n    class MockTokenizer:\n        def __init__(self):\n            self.eos_token = \"</s>\"\n            self.eos_token_id = 2  # Mock EOS token ID\n\n        def __call__(self, text, return_tensors = None, add_special_tokens = False):\n            words = text.split()\n            token_ids = list(range(len(words)))\n\n            if return_tensors == \"pt\":\n                # Mock tensor-like object\n                class MockTensor:\n                    def __init__(self, data):\n                        self.data = data\n\n                    def __getitem__(self, idx):\n                        return self.data\n\n                    def __len__(self):\n                        return len(self.data)\n\n                    def tolist(self):\n                        return self.data\n\n                return {\"input_ids\": [MockTensor(token_ids)]}\n            return {\"input_ids\": token_ids}\n\n        def decode(self, token_ids, skip_special_tokens = False):\n            return \" \".join([f\"word_{i}\" for i in token_ids])\n\n    # Create test file\n    test_content = \"This is a test file for raw text training. \" * 10\n    with tempfile.NamedTemporaryFile(mode = \"w\", suffix = \".txt\", delete = False) as f:\n        f.write(test_content)\n        test_file = f.name\n\n    try:\n        # Test loader\n        tokenizer = MockTokenizer()\n        loader = RawTextDataLoader(tokenizer, chunk_size = 5, stride = 2)\n\n        # Test loading with text output (legacy mode)\n        text_dataset = loader.load_from_file(test_file, return_tokenized = False)\n        assert len(text_dataset) > 0, \"Should create at least one chunk\"\n        assert \"text\" in text_dataset.column_names, \"Dataset should have 'text' column\"\n\n        # Test loading with tokenized output (new efficient mode)\n        tokenized_dataset = loader.load_from_file(test_file, return_tokenized = True)\n        assert len(tokenized_dataset) > 0, \"Should create at least one tokenized chunk\"\n        assert (\n            \"input_ids\" in tokenized_dataset.column_names\n        ), \"Dataset should have 'input_ids' column\"\n        assert (\n            \"attention_mask\" in tokenized_dataset.column_names\n        ), \"Dataset should have 'attention_mask' column\"\n\n        # Verify tokenized data structure\n        first_sample = tokenized_dataset[0]\n        assert isinstance(first_sample[\"input_ids\"], list), \"input_ids should be a list\"\n        assert isinstance(\n            first_sample[\"attention_mask\"], list\n        ), \"attention_mask should be a list\"\n        assert len(first_sample[\"input_ids\"]) == len(\n            first_sample[\"attention_mask\"]\n        ), \"input_ids and attention_mask should have same length\"\n\n        # Verify labels field exists (for causal LM training)\n        assert (\n            \"labels\" in tokenized_dataset.column_names\n        ), \"Dataset should have 'labels' column\"\n        assert (\n            first_sample[\"labels\"] == first_sample[\"input_ids\"]\n        ), \"labels should match input_ids\"\n\n        # Test constructor validation\n        try:\n            bad_loader = RawTextDataLoader(tokenizer, chunk_size = 0, stride = 2)\n            assert False, \"Should raise ValueError for chunk_size=0\"\n        except ValueError as e:\n            assert \"chunk_size must be positive\" in str(e)\n\n        try:\n            bad_loader = RawTextDataLoader(tokenizer, chunk_size = 5, stride = 10)\n            assert False, \"Should raise ValueError for stride >= chunk_size\"\n        except ValueError as e:\n            assert \"stride\" in str(e) and \"chunk_size\" in str(e)\n\n        # Test preprocessor\n        preprocessor = TextPreprocessor()\n        clean_text = preprocessor.clean_text(\"  messy   text  \\n\\n\\n  \")\n        assert \"messy text\" in clean_text, \"Should clean text properly\"\n\n        # Test validation\n        stats = preprocessor.validate_dataset(text_dataset)\n        assert stats[\"total_samples\"] > 0, \"Should count samples\"\n        assert \"warnings\" in stats, \"Should include warnings\"\n\n        print(\"✅ All tests passed!\")\n        return True\n\n    except Exception as e:\n        print(f\"❌ Test failed: {e}\")\n        return False\n\n    finally:\n        # Cleanup\n        os.unlink(test_file)\n\n\nif __name__ == \"__main__\":\n    success = test_raw_text_loader()\n    sys.exit(0 if success else 1)\n"
  },
  {
    "path": "tests/utils/__init__.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport time\nfrom contextlib import contextmanager\n\n\n@contextmanager\ndef timer(name):\n    start = time.time()\n    yield\n    end = time.time()\n    print(f\"{name} took {end - start:.2f} seconds\")\n\n\n@contextmanager\ndef header_footer_context(title: str, char = \"-\"):\n    print()\n    print(f\"{char}\" * 50 + f\" {title} \" + f\"{char}\" * 50)\n    yield\n    print(f\"{char}\" * (100 + len(title) + 2))\n    print()\n"
  },
  {
    "path": "tests/utils/aime_eval.md",
    "content": "# AIME Dataset Evaluator\n\nA Python module for evaluating language models on the AIME (American Invitational Mathematics Examination) dataset. This evaluator automatically downloads and combines multiple AIME test datasets and provides comprehensive mathematical reasoning assessment.\n\n\n## Basic Usage\n\n```python\nfrom aime_utils import evaluate_model_aime\n\n# Simple AIME evaluation\nresults = evaluate_model_aime(\n    model=your_model,\n    tokenizer=your_tokenizer,\n    model_type=\"base_model\",\n    temperature=0.3,\n    n_sampling=8,\n    max_tokens=32768\n)\n\nprint(f\"AIME Accuracy: {results['accuracy']:.1f}%\")\nprint(f\"Pass@8: {results['pass_at_k']:.1f}%\")\n```\n\n## Advanced Usage\n\n```python\nfrom aime_utils import evaluate_model_aime, compare_aime_results\n\n# Evaluate multiple model configurations\nall_results = []\n\n# Base model\nbase_results = evaluate_model_aime(\n    model=base_model,\n    tokenizer=tokenizer,\n    model_type=\"base\",\n    temperature=0.3,\n    n_sampling=8\n)\nall_results.append(base_results)\n\n# Fine-tuned model\nft_results = evaluate_model_aime(\n    model=finetuned_model,\n    tokenizer=tokenizer,\n    model_type=\"finetuned\",\n    temperature=0.3,\n    n_sampling=8\n)\nall_results.append(ft_results)\n\n# Generate comprehensive comparison\ncompare_aime_results(all_results)\n```\n\n## Dataset Format\n\nThe evaluator automatically handles AIME dataset format with problems containing:\n\n- **Problem**: Mathematical question text\n- **Answer**: Numerical answer (0-999 range for AIME)\n- **Solution**: Step-by-step solution (when available)\n- **Source**: Original dataset identifier (test2024, test2025-I, test2025-II)\n\n```python\n# Automatic dataset download and formatting\n{\n    \"global_id\": 0,\n    \"original_id\": \"problem_1\",\n    \"source_dataset\": \"test2024\",\n    \"problem\": \"Find the number of...\",\n    \"answer\": \"123\",\n    \"solution\": \"Step-by-step solution...\",\n    \"prompt\": [\n        {\"role\": \"system\", \"content\": \"You are a mathematical problem solver...\"},\n        {\"role\": \"user\", \"content\": \"Problem: Find the number of...\"}\n    ]\n}\n```\n\n\n## Configuration Examples\n\n### Conservative Evaluation\n```python\n# Lower temperature for more consistent answers\nresults = evaluate_model_aime(\n    model=model,\n    tokenizer=tokenizer,\n    model_type=\"conservative\",\n    temperature=0.1,\n    n_sampling=4,\n    top_p=0.9\n)\n```\n\n### High-Sample Evaluation\n```python\n# More samples for better Pass@K estimation\nresults = evaluate_model_aime(\n    model=model,\n    tokenizer=tokenizer,\n    model_type=\"high_sample\",\n    temperature=0.5,\n    n_sampling=16,\n    max_tokens=16384\n)\n```\n\n### Memory-Optimized\n```python\n# Reduced parameters for limited resources\nresults = evaluate_model_aime(\n    model=model,\n    tokenizer=tokenizer,\n    model_type=\"lite\",\n    temperature=0.3,\n    n_sampling=4,\n    max_tokens=8192\n)\n```\n\n## Examples\n\n### Complete Model Pipeline Evaluation\n```python\nfrom aime_utils import evaluate_model_aime, compare_aime_results\n\ndef evaluate_training_pipeline(base_model, finetuned_model, merged_model, tokenizer):\n    \"\"\"Evaluate complete training pipeline on AIME\"\"\"\n\n    all_results = []\n\n    # Standard evaluation configuration\n    eval_config = {\n        \"temperature\": 0.3,\n        \"n_sampling\": 8,\n        \"max_tokens\": 32768,\n        \"top_p\": 0.95,\n        \"seed\": 0\n    }\n\n    # Evaluate base model\n    print(\"Evaluating base model...\")\n    base_results = evaluate_model_aime(\n        model=base_model,\n        tokenizer=tokenizer,\n        model_type=\"base\",\n        **eval_config\n    )\n    all_results.append(base_results)\n\n    # Evaluate fine-tuned model\n    print(\"Evaluating fine-tuned model...\")\n    ft_results = evaluate_model_aime(\n        model=finetuned_model,\n        tokenizer=tokenizer,\n        model_type=\"finetuned\",\n        **eval_config\n    )\n    all_results.append(ft_results)\n\n    # Evaluate merged model\n    print(\"Evaluating merged model...\")\n    merged_results = evaluate_model_aime(\n        model=merged_model,\n        tokenizer=tokenizer,\n        model_type=\"merged\",\n        **eval_config\n    )\n    all_results.append(merged_results)\n\n    # Generate comparison report\n    compare_aime_results(all_results)\n\n    return all_results\n```\n\n### Quantization Impact Analysis\n```python\ndef analyze_quantization_impact(model_paths, tokenizer):\n    \"\"\"Analyze impact of different quantization levels\"\"\"\n\n    quantization_configs = {\n        \"fp16\": {\"load_in_4bit\": False, \"load_in_8bit\": False},\n        \"8bit\": {\"load_in_4bit\": False, \"load_in_8bit\": True},\n        \"4bit\": {\"load_in_4bit\": True, \"load_in_8bit\": False}\n    }\n\n    all_results = []\n\n    for quant_name, load_config in quantization_configs.items():\n        print(f\"Evaluating {quant_name} quantization...\")\n\n        # Load model with specific quantization\n        model = load_model_with_config(model_paths[\"merged\"], **load_config)\n\n        results = evaluate_model_aime(\n            model=model,\n            tokenizer=tokenizer,\n            model_type=f\"merged_{quant_name}\",\n            temperature=0.3,\n            n_sampling=8,\n            max_tokens=32768\n        )\n        all_results.append(results)\n\n        # Cleanup\n        del model\n        torch.cuda.empty_cache()\n\n    compare_aime_results(all_results)\n    return all_results\n```\n\n## Output Format\n\n### Individual Evaluation Results\n```\n🧮 AIME EVALUATION - BASE MODEL\nCombined Dataset: test2024 + test2025-I + test2025-II\n====================================================================\n\n🎯 Overall Performance:\n   Total problems:           45\n   Correct answers:         12/45 (26.7%)\n   Pass@8:                  31.1%\n\n📈 Performance by Dataset:\n    test2024:   4/15 (26.7%)\n  test2025-I:   5/15 (33.3%)\n test2025-II:   3/15 (20.0%)\n\n🎖️  AIME Performance:     ✅ EXCELLENT (26.7%)\n```\n\n### Comparison Report\n```\nCOMPREHENSIVE AIME MODEL COMPARISON\n================================================================================\nModel           Accuracy %   Pass@K %   Correct  Total\n--------------------------------------------------------------------------------\nfinetuned       31.1         35.6       14       45\nbase            26.7         31.1       12       45\nmerged_4bit     24.4         28.9       11       45\n\nIMPROVEMENT ANALYSIS\n==================================================\nfinetuned vs base:\n  Accuracy improvement:  +4.4%\n  Pass@K improvement:    +4.5%\n```\n\n## Performance Tiers\n\nThe evaluator provides performance assessment based on AIME difficulty:\n\n- **🏆 EXCEPTIONAL**: ≥50% accuracy\n- **✅ EXCELLENT**: ≥30% accuracy\n- **🎯 VERY GOOD**: ≥20% accuracy\n- **⚠️ GOOD**: ≥10% accuracy\n- **📈 FAIR**: ≥5% accuracy\n- **❌ NEEDS IMPROVEMENT**: <5% accuracy\n"
  },
  {
    "path": "tests/utils/aime_eval.py",
    "content": "\"\"\"\nAIME Dataset Evaluation Module\n\nThis module provides functions to evaluate language models on the combined AIME dataset\n(test2024 + test2025-I + test2025-II).\n\"\"\"\n\nimport json\nimport requests\nimport os\nimport re\nimport logging\nfrom typing import List, Dict, Any\nfrom tqdm import tqdm\nfrom vllm import SamplingParams\n\n\ndef download_and_combine_aime_datasets(data_dir: str = \"./data/aime\") -> str:\n    \"\"\"Download all AIME datasets and combine them into a single file\"\"\"\n\n    datasets = {\n        \"test2024\": \"https://raw.githubusercontent.com/GAIR-NLP/AIME-Preview/main/eval/data/aime/test2024.jsonl\",\n        \"test2025-I\": \"https://raw.githubusercontent.com/GAIR-NLP/AIME-Preview/main/eval/data/aime/test2025-I.jsonl\",\n        \"test2025-II\": \"https://raw.githubusercontent.com/GAIR-NLP/AIME-Preview/main/eval/data/aime/test2025-II.jsonl\",\n    }\n\n    os.makedirs(data_dir, exist_ok = True)\n    combined_filepath = os.path.join(data_dir, \"aime.jsonl\")\n\n    # Check if combined file already exists\n    if os.path.exists(combined_filepath):\n        print(f\"Combined AIME dataset already exists at {combined_filepath}\")\n        return combined_filepath\n\n    print(\"Downloading and combining AIME datasets...\")\n\n    all_problems = []\n    global_id = 0\n\n    for dataset_name, url in datasets.items():\n        print(f\"  Downloading {dataset_name}...\")\n\n        try:\n            response = requests.get(url)\n            response.raise_for_status()\n\n            # Parse each line and add source information\n            for line_num, line in enumerate(response.text.strip().split(\"\\n\")):\n                if line.strip():\n                    try:\n                        data = json.loads(line)\n                        # Add source dataset information and global ID\n                        data[\"source_dataset\"] = dataset_name\n                        data[\"original_id\"] = data.get(\"id\", line_num)\n                        data[\"global_id\"] = global_id\n                        global_id += 1\n                        all_problems.append(data)\n                    except json.JSONDecodeError as e:\n                        print(\n                            f\"    Warning: Error parsing line {line_num + 1} in {dataset_name}: {e}\"\n                        )\n                        continue\n\n        except requests.RequestException as e:\n            print(f\"    Error downloading {dataset_name}: {e}\")\n            continue\n\n    # Write combined dataset\n    if all_problems:\n        with open(combined_filepath, \"w\", encoding = \"utf-8\") as f:\n            for problem in all_problems:\n                f.write(json.dumps(problem, ensure_ascii = False) + \"\\n\")\n\n        print(f\"✅ Combined {len(all_problems)} problems from {len(datasets)} datasets\")\n        print(f\"   Saved to: {combined_filepath}\")\n\n        # Print summary by dataset\n        for dataset_name in datasets.keys():\n            count = sum(1 for p in all_problems if p[\"source_dataset\"] == dataset_name)\n            print(f\"   {dataset_name}: {count} problems\")\n\n    else:\n        raise RuntimeError(\"No problems were successfully downloaded\")\n\n    return combined_filepath\n\n\ndef load_aime_dataset(data_dir: str = \"./data/aime\") -> List[Dict[str, Any]]:\n    \"\"\"Load combined AIME dataset and format for evaluation\"\"\"\n\n    # Download and combine if needed\n    filepath = download_and_combine_aime_datasets(data_dir)\n\n    examples = []\n    with open(filepath, \"r\", encoding = \"utf-8\") as f:\n        for line_num, line in enumerate(f):\n            line = line.strip()\n            if line:\n                try:\n                    data = json.loads(line)\n\n                    # Format as expected by our evaluation\n                    formatted_example = {\n                        \"global_id\": data.get(\"global_id\", line_num),\n                        \"original_id\": data.get(\n                            \"original_id\", data.get(\"id\", line_num)\n                        ),\n                        \"source_dataset\": data.get(\"source_dataset\", \"unknown\"),\n                        \"problem\": data[\"problem\"],\n                        \"answer\": str(data[\"answer\"]),  # Ensure answer is string\n                        \"solution\": data.get(\"solution\", \"\"),\n                        \"url\": data.get(\"url\", \"\"),\n                        # Format as chat messages for the model\n                        \"prompt\": [\n                            {\n                                \"role\": \"system\",\n                                \"content\": \"You are a mathematical problem solver. Solve the given problem step by step and provide your final answer clearly.\",\n                            },\n                            {\n                                \"role\": \"user\",\n                                \"content\": f\"Problem: {data['problem']}\\n\\nSolve this step by step and provide your final numerical answer.\",\n                            },\n                        ],\n                    }\n                    examples.append(formatted_example)\n\n                except json.JSONDecodeError as e:\n                    print(f\"Error parsing line {line_num + 1}: {e}\")\n                    continue\n\n    print(f\"Loaded {len(examples)} problems from combined AIME dataset\")\n\n    # Print breakdown by source\n    source_counts = {}\n    for example in examples:\n        source = example[\"source_dataset\"]\n        source_counts[source] = source_counts.get(source, 0) + 1\n\n    for source, count in source_counts.items():\n        print(f\"  {source}: {count} problems\")\n\n    return examples\n\n\ndef extract_aime_answer(response: str) -> str:\n    \"\"\"Extract numerical answer from AIME response\"\"\"\n\n    # AIME answers are integers from 0-999\n    # Look for patterns like \"The answer is 123\" or just standalone numbers\n    patterns = [\n        r\"(?:the )?(?:final )?answer is (\\d{1,3})\",\n        r\"(?:therefore|thus|so),?\\s*(?:the )?(?:final )?answer is (\\d{1,3})\",\n        r\"\\\\boxed\\{(\\d{1,3})\\}\",\n        r\"\\$\\\\boxed\\{(\\d{1,3})\\}\\$\",\n        r\"(?:answer|result):\\s*(\\d{1,3})\",\n        r\"(?:^|\\n)\\s*(\\d{1,3})\\s*(?:\\n|$)\",  # Standalone number\n    ]\n\n    response_lower = response.lower().strip()\n\n    for pattern in patterns:\n        matches = re.findall(pattern, response_lower, re.MULTILINE | re.IGNORECASE)\n        if matches:\n            # Get the last match (most likely to be final answer)\n            answer = matches[-1]\n            try:\n                num = int(answer)\n                if 0 <= num <= 999:  # AIME answers are in range 0-999\n                    return str(num)\n            except ValueError:\n                continue\n\n    # If no clear pattern found, try to extract any 1-3 digit number\n    numbers = re.findall(r\"\\b(\\d{1,3})\\b\", response)\n    if numbers:\n        for num_str in reversed(numbers):  # Check from end\n            try:\n                num = int(num_str)\n                if 0 <= num <= 999:\n                    return str(num)\n            except ValueError:\n                continue\n\n    return \"\"\n\n\ndef get_num_tokens(text, tokenizer_instance):\n    \"\"\"Count tokens in text\"\"\"\n    if not text:\n        return 0\n    encoding = tokenizer_instance(text, return_tensors = \"pt\")\n    return len(encoding[\"input_ids\"][0])\n\n\ndef evaluate_model_aime(\n    model,\n    tokenizer,\n    model_type = \"base\",\n    lora_request = None,\n    temperature = 0.3,\n    n_sampling = 8,\n    max_tokens = 32768,\n    top_p = 0.95,\n    seed = 0,\n):\n    \"\"\"Evaluate model on combined AIME dataset with official configuration\"\"\"\n\n    print(f\"\\n{'='*70}\")\n    print(f\"🧮 AIME EVALUATION - {model_type.upper()} MODEL\")\n    print(f\"Combined Dataset: test2024 + test2025-I + test2025-II\")\n    print(f\"{'='*70}\")\n\n    # Load combined AIME dataset\n    try:\n        eval_dataset = load_aime_dataset()\n    except Exception as e:\n        print(f\"Error loading dataset: {e}\")\n        return None\n\n    if not eval_dataset:\n        print(\"No examples found in dataset\")\n        return None\n\n    # Initialize tracking variables\n    records = {}\n    input_tokens = []\n    output_tokens = []\n    correct_answers = 0\n\n    # Track performance by source dataset\n    source_stats = {}\n    for example in eval_dataset:\n        source = example[\"source_dataset\"]\n        if source not in source_stats:\n            source_stats[source] = {\"total\": 0, \"correct\": 0}\n        source_stats[source][\"total\"] += 1\n\n    # Setup sampling parameters (AIME configuration)\n    sampling_params = SamplingParams(\n        temperature = temperature,\n        top_p = top_p,\n        max_tokens = max_tokens,\n        n = n_sampling,  # Multiple samples per question\n        seed = seed,\n    )\n\n    print(f\"\\n🔧 Configuration:\")\n    print(f\"   Temperature: {temperature}\")\n    print(f\"   Samples per question: {n_sampling}\")\n    print(f\"   Max tokens: {max_tokens}\")\n    print(f\"   Top-p: {top_p}\")\n    print(f\"   Seed: {seed}\")\n\n    # Temporarily suppress verbose logging\n    original_levels = {}\n    loggers_to_suppress = [\n        \"vllm\",\n        \"vllm.engine\",\n        \"vllm.worker\",\n        \"vllm.model_executor\",\n        \"vllm.executor\",\n        \"ray\",\n    ]\n\n    for logger_name in loggers_to_suppress:\n        logger = logging.getLogger(logger_name)\n        original_levels[logger_name] = logger.level\n        logger.setLevel(logging.WARNING)\n\n    try:\n        print(f\"\\n🚀 Evaluating {len(eval_dataset)} problems...\")\n\n        # Main evaluation loop\n        with tqdm(\n            total = len(eval_dataset), desc = \"Processing AIME problems\", unit = \"problem\"\n        ) as pbar:\n            for task_id, item in enumerate(eval_dataset):\n                try:\n                    # Prepare prompt\n                    prompt_text = tokenizer.apply_chat_template(\n                        item[\"prompt\"], add_generation_prompt = True, tokenize = False\n                    )\n\n                    input_tokens.append(get_num_tokens(prompt_text, tokenizer))\n\n                    # Generate multiple responses\n                    outputs = model.fast_generate(\n                        [prompt_text],\n                        sampling_params = sampling_params,\n                        lora_request = lora_request,\n                        use_tqdm = False,\n                    )[0].outputs\n\n                    # Process all generated responses\n                    responses = [output.text for output in outputs]\n                    extracted_answers = [\n                        extract_aime_answer(response) for response in responses\n                    ]\n\n                    # Calculate total output tokens\n                    total_output_tokens = sum(\n                        get_num_tokens(response, tokenizer) for response in responses\n                    )\n                    output_tokens.append(total_output_tokens)\n\n                    # Check if any answer is correct\n                    ground_truth = item[\"answer\"]\n                    correct_responses = [\n                        ans == ground_truth for ans in extracted_answers\n                    ]\n                    is_correct = any(correct_responses)\n\n                    if is_correct:\n                        correct_answers += 1\n                        source_stats[item[\"source_dataset\"]][\"correct\"] += 1\n\n                    # Store detailed record\n                    records[task_id] = {\n                        \"global_id\": item[\"global_id\"],\n                        \"original_id\": item[\"original_id\"],\n                        \"source_dataset\": item[\"source_dataset\"],\n                        \"problem\": item[\"problem\"],\n                        \"ground_truth\": ground_truth,\n                        \"responses\": responses,\n                        \"extracted_answers\": extracted_answers,\n                        \"correct_responses\": correct_responses,\n                        \"is_correct\": is_correct,\n                        \"input_tokens\": input_tokens[-1],\n                        \"output_tokens\": total_output_tokens,\n                        \"n_correct\": sum(correct_responses),\n                        \"n_total\": len(responses),\n                        \"solution\": item.get(\"solution\", \"\"),\n                        \"url\": item.get(\"url\", \"\"),\n                    }\n\n                    # Update progress\n                    current_accuracy = correct_answers / (task_id + 1) * 100\n                    pbar.set_postfix(\n                        {\n                            \"accuracy\": f\"{current_accuracy:.1f}%\",\n                            \"correct\": correct_answers,\n                            \"total\": task_id + 1,\n                        }\n                    )\n                    pbar.update(1)\n\n                except Exception as e:\n                    print(f\"\\nError processing problem {task_id}: {str(e)}\")\n                    records[task_id] = {\n                        \"global_id\": item.get(\"global_id\", task_id),\n                        \"original_id\": item.get(\"original_id\", task_id),\n                        \"source_dataset\": item.get(\"source_dataset\", \"unknown\"),\n                        \"problem\": item[\"problem\"],\n                        \"ground_truth\": item[\"answer\"],\n                        \"error\": str(e),\n                        \"is_correct\": False,\n                    }\n                    pbar.update(1)\n                    continue\n\n    finally:\n        # Restore logging levels\n        for logger_name, level in original_levels.items():\n            logging.getLogger(logger_name).setLevel(level)\n\n    # Calculate metrics\n    total_problems = len(eval_dataset)\n    accuracy = correct_answers / total_problems * 100\n\n    # Calculate Pass@k (probability that at least one of k samples is correct)\n    pass_at_k_scores = []\n    for record in records.values():\n        if \"n_correct\" in record and \"n_total\" in record:\n            n_correct = record[\"n_correct\"]\n            n_total = record[\"n_total\"]\n            if n_correct > 0:\n                pass_at_k_scores.append(1.0)\n            else:\n                pass_at_k_scores.append(0.0)\n\n    pass_at_k = sum(pass_at_k_scores) / len(pass_at_k_scores) if pass_at_k_scores else 0\n\n    # Calculate per-source accuracies\n    source_accuracies = {}\n    for source, stats in source_stats.items():\n        source_accuracies[source] = (\n            (stats[\"correct\"] / stats[\"total\"] * 100) if stats[\"total\"] > 0 else 0\n        )\n\n    results = {\n        \"model_type\": model_type,\n        \"dataset\": \"aime_combined\",\n        \"total_problems\": total_problems,\n        \"correct_answers\": correct_answers,\n        \"accuracy\": accuracy,\n        \"pass_at_k\": pass_at_k * 100,\n        \"source_stats\": source_stats,\n        \"source_accuracies\": source_accuracies,\n        \"temperature\": temperature,\n        \"n_sampling\": n_sampling,\n        \"max_tokens\": max_tokens,\n        \"top_p\": top_p,\n        \"seed\": seed,\n        \"avg_input_tokens\": sum(input_tokens) / len(input_tokens)\n        if input_tokens\n        else 0,\n        \"avg_output_tokens\": sum(output_tokens) / len(output_tokens)\n        if output_tokens\n        else 0,\n        \"max_input_tokens\": max(input_tokens) if input_tokens else 0,\n        \"max_output_tokens\": max(output_tokens) if output_tokens else 0,\n    }\n\n    # Save results\n    filename = f\"aime_eval_combined_{model_type}_t{temperature}_n{n_sampling}.json\"\n    with open(filename, \"w\", encoding = \"utf-8\") as f:\n        json.dump({\"results\": results, \"records\": records}, f, indent = 4)\n\n    # Print comprehensive summary\n    print(f\"\\n{'='*70}\")\n    print(f\"📊 AIME EVALUATION RESULTS - {model_type.upper()}\")\n    print(f\"{'='*70}\")\n\n    print(f\"\\n🎯 Overall Performance:\")\n    print(f\"   Total problems:       {total_problems:>6}\")\n    print(\n        f\"   Correct answers:      {correct_answers:>6}/{total_problems} ({accuracy:>5.1f}%)\"\n    )\n    print(f\"   Pass@{n_sampling}:              {pass_at_k:>10.1f}%\")\n\n    print(f\"\\n📈 Performance by Dataset:\")\n    for source, stats in source_stats.items():\n        source_acc = source_accuracies[source]\n        print(\n            f\"   {source:>12}: {stats['correct']:>3}/{stats['total']:>3} ({source_acc:>5.1f}%)\"\n        )\n\n    print(f\"\\n🔧 Configuration:\")\n    print(f\"   Temperature:          {temperature}\")\n    print(f\"   Samples per problem:  {n_sampling}\")\n    print(f\"   Max tokens:           {max_tokens}\")\n    print(f\"   Top-p:                {top_p}\")\n    print(f\"   Seed:                 {seed}\")\n\n    print(f\"\\n📝 Token Statistics:\")\n    print(f\"   Avg input tokens:     {results['avg_input_tokens']:>10.1f}\")\n    print(f\"   Avg output tokens:    {results['avg_output_tokens']:>10.1f}\")\n    print(f\"   Max input tokens:     {results['max_input_tokens']:>10}\")\n    print(f\"   Max output tokens:    {results['max_output_tokens']:>10}\")\n\n    # Performance assessment for AIME\n    if accuracy >= 50:\n        tier = \"🏆 EXCEPTIONAL\"\n    elif accuracy >= 30:\n        tier = \"✅ EXCELLENT\"\n    elif accuracy >= 20:\n        tier = \"🎯 VERY GOOD\"\n    elif accuracy >= 10:\n        tier = \"⚠️  GOOD\"\n    elif accuracy >= 5:\n        tier = \"📈 FAIR\"\n    else:\n        tier = \"❌ NEEDS IMPROVEMENT\"\n\n    print(f\"\\n🎖️  AIME Performance:     {tier} ({accuracy:.1f}%)\")\n    print(f\"\\n💾 Detailed results saved to: {filename}\")\n    print(f\"\\n{'='*70}\")\n\n    return results\n\n\n# Comparison functions for multiple model results\ndef compare_aime_results(all_results):\n    \"\"\"Generate comprehensive comparison for AIME evaluation results\"\"\"\n    print(f\"\\n{'='*80}\")\n    print(\"COMPREHENSIVE AIME MODEL COMPARISON\")\n    print(f\"{'='*80}\")\n\n    # Main comparison table\n    print(\n        f\"{'Model':<15} {'Accuracy %':<12} {'Pass@K %':<10} {'Correct':<8} {'Total':<8}\"\n    )\n    print(\"-\" * 80)\n\n    for result in all_results:\n        print(\n            f\"{result['model_type']:<15} \"\n            f\"{result['accuracy']:<12.1f} \"\n            f\"{result['pass_at_k']:<10.1f} \"\n            f\"{result['correct_answers']:<8} \"\n            f\"{result['total_problems']:<8}\"\n        )\n\n    # Performance improvement analysis\n    if len(all_results) > 1:\n        print(f\"\\n{'='*50}\")\n        print(\"IMPROVEMENT ANALYSIS\")\n        print(f\"{'='*50}\")\n\n        base_result = all_results[0]  # Assume first is base model\n\n        for i, result in enumerate(all_results[1:], 1):\n            print(f\"\\n{result['model_type']} vs {base_result['model_type']}:\")\n\n            accuracy_improvement = result[\"accuracy\"] - base_result[\"accuracy\"]\n            pass_k_improvement = result[\"pass_at_k\"] - base_result[\"pass_at_k\"]\n\n            print(f\"  Accuracy improvement:  {accuracy_improvement:+.1f}%\")\n            print(f\"  Pass@K improvement:    {pass_k_improvement:+.1f}%\")\n\n    # Dataset breakdown\n    print(f\"\\n{'='*50}\")\n    print(\"PERFORMANCE BY DATASET\")\n    print(f\"{'='*50}\")\n\n    # Get all unique datasets from the first result\n    if all_results and \"source_accuracies\" in all_results[0]:\n        datasets = list(all_results[0][\"source_accuracies\"].keys())\n\n        print(f\"{'Model':<15}\", end = \"\")\n        for dataset in datasets:\n            print(f\"{dataset:<15}\", end = \"\")\n        print()\n        print(\"-\" * (15 + 15 * len(datasets)))\n\n        for result in all_results:\n            print(f\"{result['model_type']:<15}\", end = \"\")\n            for dataset in datasets:\n                accuracy = result[\"source_accuracies\"].get(dataset, 0)\n                print(f\"{accuracy:<15.1f}\", end = \"\")\n            print()\n\n    # Save comparison\n    comparison_data = {\n        \"summary\": all_results,\n        \"best_model\": max(all_results, key = lambda x: x[\"accuracy\"]),\n    }\n\n    with open(\"aime_model_comparison.json\", \"w\") as f:\n        json.dump(comparison_data, f, indent = 4)\n\n    print(\n        f\"\\nBest performing model: {comparison_data['best_model']['model_type']} \"\n        f\"({comparison_data['best_model']['accuracy']:.1f}% accuracy)\"\n    )\n"
  },
  {
    "path": "tests/utils/cleanup_utils.py",
    "content": "import gc\nimport logging\nimport os\nimport shutil\nimport torch\nimport sys\nimport warnings\n\n\ndef clear_memory(variables_to_clear = None, verbose = False, clear_all_caches = True):\n    \"\"\"\n    Comprehensive memory clearing for persistent memory leaks.\n\n    Args:\n        variables_to_clear: List of variable names to clear\n        verbose: Print memory status\n        clear_all_caches: Clear all types of caches (recommended for memory leaks)\n    \"\"\"\n\n    # Save current logging levels\n    saved_log_levels = {}\n    for name, logger in logging.Logger.manager.loggerDict.items():\n        if isinstance(logger, logging.Logger):\n            saved_log_levels[name] = logger.level\n    root_level = logging.getLogger().level\n\n    if variables_to_clear is None:\n        variables_to_clear = [\n            \"inputs\",\n            \"model\",\n            \"base_model\",\n            \"processor\",\n            \"tokenizer\",\n            \"base_processor\",\n            \"base_tokenizer\",\n            \"trainer\",\n            \"peft_model\",\n            \"bnb_config\",\n        ]\n\n    # 1. Clear LRU caches FIRST (very important for memory leaks)\n    if clear_all_caches:\n        clear_all_lru_caches(verbose)\n\n    # 2. Delete specified variables\n    g = globals()\n    deleted_vars = []\n    for var in variables_to_clear:\n        if var in g:\n            del g[var]\n            deleted_vars.append(var)\n\n    if verbose and deleted_vars:\n        print(f\"Deleted variables: {deleted_vars}\")\n\n    # 3. Multiple garbage collection passes (important for circular references)\n    for i in range(3):\n        collected = gc.collect()\n        if verbose and collected > 0:\n            print(f\"GC pass {i+1}: collected {collected} objects\")\n\n    # 4. CUDA cleanup\n    if torch.cuda.is_available():\n        # Get memory before cleanup\n        if verbose:\n            mem_before = torch.cuda.memory_allocated() / 1024**3\n\n        torch.cuda.empty_cache()\n        torch.cuda.synchronize()\n\n        # Additional CUDA cleanup for persistent leaks\n        if clear_all_caches:\n            # Reset memory stats\n            torch.cuda.reset_peak_memory_stats()\n            torch.cuda.reset_accumulated_memory_stats()\n\n            # Clear JIT cache\n            if hasattr(torch.jit, \"_state\") and hasattr(\n                torch.jit._state, \"_clear_class_state\"\n            ):\n                torch.jit._state._clear_class_state()\n\n            # Force another CUDA cache clear\n            torch.cuda.empty_cache()\n\n        # Final garbage collection\n        gc.collect()\n\n        if verbose:\n            mem_after = torch.cuda.memory_allocated() / 1024**3\n            mem_reserved = torch.cuda.memory_reserved() / 1024**3\n            print(\n                f\"GPU memory - Before: {mem_before:.2f} GB, After: {mem_after:.2f} GB\"\n            )\n            print(f\"GPU reserved memory: {mem_reserved:.2f} GB\")\n            if mem_before > 0:\n                print(f\"Memory freed: {mem_before - mem_after:.2f} GB\")\n\n    # restore original logging levels\n    logging.getLogger().setLevel(root_level)\n    for name, level in saved_log_levels.items():\n        if name in logging.Logger.manager.loggerDict:\n            logger = logging.getLogger(name)\n            logger.setLevel(level)\n\n\ndef clear_all_lru_caches(verbose = True):\n    \"\"\"Clear all LRU caches in loaded modules.\"\"\"\n    cleared_caches = []\n\n    # Modules to skip to avoid warnings\n    skip_modules = {\n        \"torch.distributed\",\n        \"torchaudio\",\n        \"torch._C\",\n        \"torch.distributed.reduce_op\",\n        \"torchaudio.backend\",\n    }\n\n    # Create a static list of modules to avoid RuntimeError\n    modules = list(sys.modules.items())\n\n    # Method 1: Clear caches in all loaded modules\n    for module_name, module in modules:\n        if module is None:\n            continue\n\n        # Skip problematic modules\n        if any(module_name.startswith(skip) for skip in skip_modules):\n            continue\n\n        try:\n            # Look for functions with lru_cache\n            for attr_name in dir(module):\n                try:\n                    # Suppress warnings when checking attributes\n                    with warnings.catch_warnings():\n                        warnings.simplefilter(\"ignore\", FutureWarning)\n                        warnings.simplefilter(\"ignore\", UserWarning)\n                        warnings.simplefilter(\"ignore\", DeprecationWarning)\n\n                    attr = getattr(module, attr_name)\n                    if hasattr(attr, \"cache_clear\"):\n                        attr.cache_clear()\n                        cleared_caches.append(f\"{module_name}.{attr_name}\")\n                except Exception:\n                    continue  # Skip problematic attributes\n        except Exception:\n            continue  # Skip problematic modules\n\n    # Method 2: Clear specific known caches\n    known_caches = [\n        \"transformers.utils.hub.cached_file\",\n        \"transformers.tokenization_utils_base.get_tokenizer\",\n        \"torch._dynamo.utils.counters\",\n    ]\n\n    for cache_path in known_caches:\n        try:\n            parts = cache_path.split(\".\")\n            module = sys.modules.get(parts[0])\n            if module:\n                obj = module\n                for part in parts[1:]:\n                    obj = getattr(obj, part, None)\n                    if obj is None:\n                        break\n                if obj and hasattr(obj, \"cache_clear\"):\n                    obj.cache_clear()\n                    cleared_caches.append(cache_path)\n        except Exception:\n            continue  # Skip problematic caches\n\n    if verbose and cleared_caches:\n        print(f\"Cleared {len(cleared_caches)} LRU caches\")\n\n\ndef clear_specific_lru_cache(func):\n    \"\"\"Clear cache for a specific function.\"\"\"\n    if hasattr(func, \"cache_clear\"):\n        func.cache_clear()\n        return True\n    return False\n\n\n# Additional utility for monitoring cache sizes\ndef monitor_cache_sizes():\n    \"\"\"Monitor LRU cache sizes across modules.\"\"\"\n    cache_info = []\n\n    for module_name, module in sys.modules.items():\n        if module is None:\n            continue\n        try:\n            for attr_name in dir(module):\n                try:\n                    attr = getattr(module, attr_name)\n                    if hasattr(attr, \"cache_info\"):\n                        info = attr.cache_info()\n                        cache_info.append(\n                            {\n                                \"function\": f\"{module_name}.{attr_name}\",\n                                \"size\": info.currsize,\n                                \"hits\": info.hits,\n                                \"misses\": info.misses,\n                            }\n                        )\n                except:\n                    pass\n        except:\n            pass\n\n    return sorted(cache_info, key = lambda x: x[\"size\"], reverse = True)\n\n\ndef safe_remove_directory(path):\n    try:\n        if os.path.exists(path) and os.path.isdir(path):\n            shutil.rmtree(path)\n            return True\n        else:\n            print(f\"Path {path} is not a valid directory\")\n            return False\n    except Exception as e:\n        print(f\"Failed to remove directory {path}: {e}\")\n        return False\n"
  },
  {
    "path": "tests/utils/data_utils.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nfrom datasets import Dataset\n\nQUESTION = \"What day was I born?\"\nANSWER = \"January 1, 2058\"\nUSER_MESSAGE = {\"role\": \"user\", \"content\": QUESTION}\nASSISTANT_MESSAGE = {\"role\": \"assistant\", \"content\": ANSWER}\nDTYPE = torch.bfloat16\nDEFAULT_MESSAGES = [[USER_MESSAGE, ASSISTANT_MESSAGE]]\n\n\ndef create_instruction_dataset(messages: list[dict] = DEFAULT_MESSAGES):\n    dataset = Dataset.from_dict({\"messages\": messages})\n    return dataset\n\n\ndef create_dataset(tokenizer, num_examples: int = None, messages: list[dict] = None):\n    dataset = create_instruction_dataset(messages)\n\n    def _apply_chat_template(example):\n        chat = tokenizer.apply_chat_template(example[\"messages\"], tokenize = False)\n        return {\"text\": chat}\n\n    dataset = dataset.map(_apply_chat_template, remove_columns = \"messages\")\n    if num_examples is not None:\n        if len(dataset) < num_examples:\n            num_repeats = num_examples // len(dataset) + 1\n            dataset = dataset.repeat(num_repeats)\n        dataset = dataset.select(range(num_examples))\n\n    return dataset\n\n\ndef describe_param(\n    param: torch.Tensor,\n    include_l1: bool = False,\n    include_l2: bool = False,\n    include_infinity: bool = False,\n    as_str: bool = True,\n) -> dict:\n    \"\"\"\n    Provide a statistical summary of a 2D weight matrix or tensor.\n    If as_str is True, the summary is returned as a formatted string.\n    Parameters:\n        param: torch.Tensor\n        include_l1 (bool): Whether to include the L1 norm (sum of absolute values).\n        include_l2 (bool): Whether to include the L2 norm (Frobenius norm).\n        include_infinity (bool): Whether to include the infinity norm (max absolute value).\n        as_str (bool): Whether to return the summary as a formatted string.\n\n    Returns:\n        dict: A dictionary with the following statistics:\n              - shape: Dimensions of the matrix.\n              - mean: Average value.\n              - median: Median value.\n              - std: Standard deviation.\n              - min: Minimum value.\n              - max: Maximum value.\n              - percentile_25: 25th percentile.\n              - percentile_75: 75th percentile.\n              Additionally, if enabled:\n              - L1_norm: Sum of absolute values.\n              - L2_norm: Euclidean (Frobenius) norm.\n              - infinity_norm: Maximum absolute value.\n    \"\"\"\n\n    param = param.float()\n    summary = {\n        \"shape\": param.shape,\n        \"mean\": param.mean().cpu().item(),\n        \"std\": param.std().cpu().item(),\n        \"min\": param.min().cpu().item(),\n        \"max\": param.max().cpu().item(),\n        \"percentile_25\": param.quantile(0.25).cpu().item(),\n        \"percentile_50\": param.quantile(0.5).cpu().item(),\n        \"percentile_75\": param.quantile(0.75).cpu().item(),\n    }\n\n    if include_l1:\n        summary[\"L1_norm\"] = param.abs().sum().cpu().item()\n    if include_l2:\n        summary[\"L2_norm\"] = param.norm().cpu().item()\n    if include_infinity:\n        summary[\"infinity_norm\"] = param.abs().max().cpu().item()\n\n    return format_summary(summary) if as_str else summary\n\n\ndef format_summary(stats: dict, precision: int = 6) -> str:\n    \"\"\"\n    Format the statistical summary dictionary for printing.\n\n    Parameters:\n        stats (dict): The dictionary returned by describe_param.\n        precision (int): Number of decimal places for floating point numbers.\n\n    Returns:\n        str: A formatted string representing the summary.\n    \"\"\"\n    lines = []\n    for key, value in stats.items():\n        if isinstance(value, float):\n            formatted_value = f\"{value:.{precision}f}\"\n        elif isinstance(value, (tuple, list)):\n            # Format each element in tuples or lists (e.g., the shape)\n            formatted_value = \", \".join(str(v) for v in value)\n            formatted_value = (\n                f\"({formatted_value})\"\n                if isinstance(value, tuple)\n                else f\"[{formatted_value}]\"\n            )\n        else:\n            formatted_value = str(value)\n        lines.append(f\"{key}: {formatted_value}\")\n    return \"\\n\".join(lines)\n\n\ndef get_peft_weights(model):\n    # ruff: noqa\n    is_lora_weight = lambda name: any(s in name for s in [\"lora_A\", \"lora_B\"])\n    return {\n        name: param for name, param in model.named_parameters() if is_lora_weight(name)\n    }\n\n\ndef describe_peft_weights(model):\n    for name, param in get_peft_weights(model).items():\n        yield name, describe_param(param, as_str = True)\n\n\ndef check_responses(responses: list[str], answer: str, prompt: str = None) -> bool:\n    for i, response in enumerate(responses, start = 1):\n        if answer in response:\n            print(f\"\\u2713 response {i} contains answer\")\n        else:\n            print(f\"\\u2717 response {i} does not contain answer\")\n            if prompt is not None:\n                response = response.replace(prompt, \"\")\n            print(f\" -> response: {response}\")\n"
  },
  {
    "path": "tests/utils/hf_utils.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nfrom contextlib import contextmanager, nullcontext\nfrom typing import Callable, Optional\n\nimport bitsandbytes as bnb\nimport torch\nfrom bitsandbytes.functional import dequantize_4bit\nfrom peft import get_peft_model, prepare_model_for_kbit_training\nfrom peft.tuners.lora import LoraConfig, LoraLayer\nfrom transformers import (\n    AutoModelForCausalLM,\n    AutoTokenizer,\n    BitsAndBytesConfig,\n)\nfrom transformers.trainer_callback import (\n    TrainerCallback,\n    TrainerControl,\n    TrainerState,\n    TrainingArguments,\n)\nfrom trl import SFTTrainer\n\n\nclass PeftWeightCallback(TrainerCallback):\n    def on_log(\n        self,\n        args: TrainingArguments,\n        state: TrainerState,\n        control: TrainerControl,\n        logs,\n        **kwargs,\n    ):\n        print(f\"DEBUG::CALLBACK::on_log::{state.log_history}\")\n\n    def on_train_begin(\n        self,\n        args: TrainingArguments,\n        state: TrainerState,\n        control: TrainerControl,\n        **kwargs,\n    ):\n        model = kwargs.get(\"model\")\n        assert model is not None\n        print(f\"DEBUG::CALLBACK::on_train_begin::{kwargs.keys()}\")\n\n    def on_step_end(\n        self,\n        args: TrainingArguments,\n        state: TrainerState,\n        control: TrainerControl,\n        **kwargs,\n    ):\n        print(f\"DEBUG::CALLBACK::on_step_end::{state.global_step}\")\n\n\n@torch.inference_mode()\ndef generate_responses(\n    model,\n    tokenizer,\n    prompt,\n    max_new_tokens: int = 100,\n    temperature: float = 0.8,\n    do_sample: bool = True,\n    num_generations: int = 1,\n    skip_special_tokens: bool = True,\n    dtype: torch.dtype = None,\n):\n    inputs = [tokenizer(prompt, return_tensors = \"pt\") for _ in range(num_generations)]\n    keys = inputs[0].keys()\n    batched_inputs = {\n        key: torch.cat([input[key] for input in inputs], dim = 0).to(model.device)\n        for key in keys\n    }\n\n    if dtype is not None:\n        inference_context = torch.autocast(device_type = \"cuda\", dtype = dtype)\n    else:\n        inference_context = nullcontext()\n\n    with inference_context:\n        outputs = model.generate(\n            **batched_inputs,\n            max_new_tokens = max_new_tokens,\n            do_sample = do_sample,\n            temperature = temperature,\n        )\n\n    responses = tokenizer.batch_decode(outputs, skip_special_tokens = skip_special_tokens)\n    return responses\n\n\ndef sample_responses(\n    model,\n    tokenizer,\n    prompt,\n    temperature: float = 0.8,\n    num_generations: int = 1,\n    max_new_tokens: int = 100,\n    skip_special_tokens: bool = True,\n    dtype: torch.dtype = None,\n):\n    responses = generate_responses(\n        model,\n        tokenizer,\n        prompt,\n        temperature = temperature,\n        num_generations = num_generations,\n        max_new_tokens = max_new_tokens,\n        skip_special_tokens = skip_special_tokens,\n        dtype = dtype,\n    )\n    return responses\n\n\ndef setup_tokenizer(model_name, fixup_funcs: list[Callable] = []):\n    tokenizer = AutoTokenizer.from_pretrained(model_name)\n    for fixup_func in fixup_funcs:\n        tokenizer = fixup_func(tokenizer)\n    return tokenizer\n\n\ndef setup_model(\n    model_name,\n    quantize: bool = True,\n    dtype = torch.bfloat16,\n    peft_config = None,\n    autocast_adapter: bool = True,\n):\n    if quantize:\n        bnb_config = BitsAndBytesConfig(\n            load_in_4bit = True,\n            bnb_4bit_use_double_quant = True,\n            bnb_4bit_quant_type = \"nf4\",\n            bnb_4bit_compute_dtype = dtype,\n        )\n    else:\n        bnb_config = None\n\n    model = AutoModelForCausalLM.from_pretrained(\n        model_name,\n        device_map = \"cuda:0\",\n        attn_implementation = \"sdpa\",\n        quantization_config = bnb_config,\n        torch_dtype = dtype,\n    )\n    model = prepare_model_for_kbit_training(model) if quantize else model\n\n    if peft_config is not None:\n        model = get_peft_model(\n            model, peft_config, autocast_adapter_dtype = autocast_adapter\n        )\n\n    return model\n\n\ndef get_peft_config(\n    lora_rank,\n    lora_alpha = None,\n    lora_dropout = 0.0,\n    bias = \"none\",\n    target_modules = \"all-linear\",\n):\n    lora_alpha = lora_alpha or 2 * lora_rank\n    peft_config = LoraConfig(\n        lora_alpha = lora_alpha,\n        lora_dropout = lora_dropout,\n        r = lora_rank,\n        bias = bias,\n        target_modules = target_modules,\n        task_type = \"CAUSAL_LM\",\n    )\n    return peft_config\n\n\ndef setup_trainer(\n    model,\n    tokenizer,\n    dataset,\n    train_args,\n    peft_config = None,\n    formatting_func = None,\n    collator = None,\n):\n    return SFTTrainer(\n        model = model,\n        peft_config = peft_config,\n        train_dataset = dataset,\n        processing_class = tokenizer,\n        formatting_func = formatting_func,\n        data_collator = collator,\n        args = train_args,\n    )\n\n\ndef setup_lora(\n    model,\n    tokenizer,\n    dataset,\n    peft_config,\n    train_args,\n    formatting_func = None,\n    collator = None,\n):\n    return LoraConfig(\n        model = model,\n        peft_config = peft_config,\n        train_dataset = dataset,\n        processing_class = tokenizer,\n        formatting_func = formatting_func,\n        data_collator = collator,\n        args = train_args,\n    )\n\n\ndef convert_weights_back_to_dtype(model, dtype):\n    \"\"\"\n    SFTTrainer calls get_peft_model and prepare_model_for_kbit_training which converts all weights to float32.\n    This function converts the non-loraweights back to the original dtype.\n    \"\"\"\n    for name, param in model.named_parameters():\n        if any(s in name for s in [\"norm\", \"embed\"]):\n            param.data = param.data.to(dtype)\n\n\ndef fix_llama3_tokenizer(tokenizer, padding_side = \"right\"):\n    tokenizer.padding_side = padding_side\n    added_vocab = tokenizer.get_added_vocab()\n    pad_token = [w for w in added_vocab if \"pad\" in w]\n    assert len(pad_token) == 1\n    tokenizer.pad_token = pad_token[0]  # Load dataset from the hub\n    return tokenizer\n\n\ndef replace_module(\n    module: torch.nn.Module,\n    target_module_type: torch.nn.Module,\n    conversion_func: Callable,\n):\n    for child_name, child_module in module.named_children():\n        if isinstance(child_module, target_module_type):\n            new_module = conversion_func(child_module)\n            setattr(module, child_name, new_module)\n        else:\n            replace_module(child_module, target_module_type, conversion_func)\n\n\ndef _convert_lora_to_linear(module: LoraLayer, adapter_name: str = \"default\"):\n    base_layer = module.get_base_layer()\n    weight = base_layer.weight\n\n    assert isinstance(weight, bnb.nn.Params4bit)\n    quant_state = weight.quant_state\n    original_dtype = quant_state.dtype\n\n    w_dq = dequantize_4bit(weight.data, quant_state).float()\n    lora_delta = (\n        module.lora_B[adapter_name].weight\n        @ module.lora_A[adapter_name].weight\n        * module.scaling[adapter_name]\n    )\n    w_dq += lora_delta.float()\n    w_dq = w_dq.to(original_dtype)\n\n    new_module = torch.nn.Linear(\n        w_dq.shape[1], w_dq.shape[0], bias = module.base_layer.bias is not None\n    )\n    new_module.weight.data = torch.nn.Parameter(w_dq, requires_grad = False)\n    if module.lora_bias[adapter_name]:\n        bias_data = module.base_layer.bias.data + module.lora_B[adapter_name].bias\n        new_module.bias.data = torch.nn.Parameter(bias_data, requires_grad = False)\n    return new_module\n\n\ndef convert_lora_to_linear(model: torch.nn.Module):\n    replace_module(model, LoraLayer, _convert_lora_to_linear)\n    assert not any(isinstance(module, LoraLayer) for module in model.modules())\n    return model\n"
  },
  {
    "path": "tests/utils/ocr_eval.md",
    "content": "\n# OCR Model Evaluator\nA comprehensive Python module for evaluating Optical Character Recognition (OCR) models using Word Error Rate (WER) and Character Error Rate (CER) metrics. This evaluator supports vision-language models and provides detailed analysis with comparison capabilities across multiple models\n\n## Basic Usage\n\n```python\nfrom ocr_evaluator import evaluate_ocr_model\n\n# Simple evaluation\navg_wer, avg_cer = evaluate_ocr_model(\n    model=your_model,\n    processor=your_processor,\n    dataset=your_dataset,\n    output_dir=\"evaluation_results\"\n)\n\nprint(f\"Average WER: {avg_wer:.4f}\")\nprint(f\"Average CER: {avg_cer:.4f}\")\n```\n\n\n### Dataset Format\n\nThe evaluator expects datasets in a chatml conversational format with the following structure:\n```\ndataset = [\n    {\n        \"messages\": [\n            {\n                \"role\": \"system\",\n                \"content\": [{\"type\": \"text\", \"text\": \"You are an OCR system.\"}]\n            },\n            {\n                \"role\": \"user\",\n                \"content\": [\n                    {\"type\": \"text\", \"text\": \"Extract text from this image\"},\n                    {\"type\": \"image\", \"image\": PIL_Image_object}\n                ]\n            },\n            {\n                \"role\": \"assistant\",\n                \"content\": [{\"type\": \"text\", \"text\": \"Ground truth text\"}]\n            }\n        ]\n    },\n    # ... more samples\n]\n```\n\n\n## Examples\n\n### Document OCR evaluation\n\n```python\nfrom ocr_evaluator import OCRModelEvaluator\nfrom datasets import load_dataset\n\n# Load document OCR dataset\ndataset = load_dataset(\"your-ocr-dataset\", split=\"test\")\n\n# Convert to required format\neval_data = [format_document_sample(sample) for sample in dataset]\n\n# Evaluate models\nevaluator = OCRModelEvaluator()\n\n# Compare different model configurations\nconfigs = {\n    \"Standard Model\": {\"temperature\": 1.0, \"max_new_tokens\": 512},\n    \"Conservative Model\": {\"temperature\": 0.7, \"max_new_tokens\": 256},\n    \"Creative Model\": {\"temperature\": 1.5, \"max_new_tokens\": 1024}\n}\n\nfor config_name, params in configs.items():\n    wer, cer = evaluator.evaluate_model(\n        model=base_model,\n        processor=processor,\n        dataset=eval_data,\n        output_dir=f\"document_ocr_{config_name.lower().replace(' ', '_')}\",\n        **params\n    )\n    evaluator.add_to_comparison(config_name, wer, cer)\n\n# Generate final report\nevaluator.print_model_comparison()\n```\n\n### Handwriting Recognition\n```python\n# Specialized evaluation for handwriting\ndef evaluate_handwriting_models(models, handwriting_dataset):\n    evaluator = OCRModelEvaluator()\n\n    for model_name, (model, processor) in models.items():\n        # Adjust parameters for handwriting recognition\n        wer, cer = evaluator.evaluate_model(\n            model=model,\n            processor=processor,\n            dataset=handwriting_dataset,\n            temperature=1.2,  # Slightly higher for handwriting variety\n            max_new_tokens=128,  # Usually shorter text\n            output_dir=f\"handwriting_{model_name}\"\n        )\n        evaluator.add_to_comparison(f\"Handwriting - {model_name}\", wer, cer)\n\n    return evaluator.print_model_comparison()\n```\n"
  },
  {
    "path": "tests/utils/ocr_eval.py",
    "content": "\"\"\"\nOCR Model Evaluation Module\n\nThis module provides functionality to evaluate OCR models on datasets with\nword error rate (WER) and character error rate (CER) metrics.\n\"\"\"\n\nimport os\nimport torch\nfrom tqdm import tqdm\nimport pandas as pd\nfrom jiwer import wer, cer\nfrom qwen_vl_utils import process_vision_info\nimport matplotlib.pyplot as plt\nfrom typing import List, Dict, Tuple, Optional, Any\nimport traceback\n\n\nclass OCRModelEvaluator:\n    \"\"\"\n    A comprehensive OCR model evaluator that supports multiple models and provides\n    detailed analysis with WER and CER metrics.\n    \"\"\"\n\n    def __init__(self):\n        \"\"\"Initialize the OCR evaluator.\"\"\"\n        self.model_comparison_results = {}\n\n    def evaluate_model(\n        self,\n        model: Any,\n        processor: Any,\n        dataset: List[Dict],\n        output_dir: str = \"ocr_evaluation_results\",\n        max_new_tokens: int = 1024,\n        temperature: float = 1.5,\n        min_p: float = 0.1,\n        verbose: bool = True,\n    ) -> Tuple[Optional[float], Optional[float]]:\n        \"\"\"\n        Evaluate a model on an OCR dataset.\n        \"\"\"\n        # Create output directory if it doesn't exist\n        os.makedirs(output_dir, exist_ok = True)\n\n        # Initialize results storage\n        results = []\n\n        # Process each sample in the dataset\n        for i, sample in enumerate(\n            tqdm(dataset, desc = \"Evaluating OCR performance\", disable = not verbose)\n        ):\n            try:\n                # Extract components from sample\n                messages = sample[\"messages\"]\n\n                # Get ground truth, image, and question\n                ground_truth, image, question, input_messages = (\n                    self._extract_sample_components(messages, i, verbose)\n                )\n\n                if ground_truth is None or image is None or question is None:\n                    continue\n\n                # Generate model response\n                generated_response = self._generate_response(\n                    model, processor, input_messages, max_new_tokens, temperature, min_p\n                )\n\n                # Calculate metrics\n                word_error = wer(ground_truth, generated_response)\n                char_error = cer(ground_truth, generated_response)\n\n                # Save individual result\n                self._save_individual_result(\n                    output_dir,\n                    i,\n                    question,\n                    generated_response,\n                    ground_truth,\n                    word_error,\n                    char_error,\n                )\n\n                # Store results for summary\n                results.append(\n                    {\n                        \"sample_id\": i,\n                        \"wer\": word_error,\n                        \"cer\": char_error,\n                        \"model_output\": generated_response.strip(),\n                        \"ground_truth\": ground_truth,\n                        \"question\": question,\n                    }\n                )\n\n            except Exception as e:\n                if verbose:\n                    print(f\"Error processing sample {i}: {str(e)}\")\n                    traceback.print_exc()\n\n        # Generate summary report\n        return self._generate_summary_report(results, output_dir, verbose)\n\n    def _extract_sample_components(\n        self, messages: List[Dict], sample_idx: int, verbose: bool\n    ) -> Tuple[Optional[str], Optional[Any], Optional[str], List[Dict]]:\n        \"\"\"Extract ground truth, image, question, and input messages from sample.\"\"\"\n\n        # Extract system message (if present)\n        system_message = next(\n            (msg for msg in messages if msg[\"role\"] == \"system\"), None\n        )\n\n        # Extract user message with the image and question\n        user_message = next((msg for msg in messages if msg[\"role\"] == \"user\"), None)\n        if not user_message:\n            if verbose:\n                print(f\"Skipping sample {sample_idx}: No user message found\")\n            return None, None, None, []\n\n        # Extract assistant message with ground truth\n        assistant_message = next(\n            (msg for msg in messages if msg[\"role\"] == \"assistant\"), None\n        )\n        if not assistant_message:\n            if verbose:\n                print(\n                    f\"Skipping sample {sample_idx}: No assistant message (ground truth) found\"\n                )\n            return None, None, None, []\n\n        # Extract ground truth text\n        ground_truth = None\n        for content_item in assistant_message[\"content\"]:\n            if content_item[\"type\"] == \"text\":\n                ground_truth = content_item[\"text\"]\n                break\n\n        if not ground_truth:\n            if verbose:\n                print(\n                    f\"Skipping sample {sample_idx}: No text found in assistant message\"\n                )\n            return None, None, None, []\n\n        # Extract image and question from user message\n        image = None\n        question = None\n\n        for content_item in user_message[\"content\"]:\n            if content_item[\"type\"] == \"image\":\n                image = content_item[\"image\"]\n            elif content_item[\"type\"] == \"text\":\n                question = content_item[\"text\"]\n\n        if not image:\n            if verbose:\n                print(f\"Skipping sample {sample_idx}: No image found in user message\")\n            return None, None, None, []\n\n        if not question:\n            if verbose:\n                print(\n                    f\"Skipping sample {sample_idx}: No question found in user message\"\n                )\n            return None, None, None, []\n\n        # Construct messages for the model input (excluding assistant message)\n        input_messages = []\n        if system_message:\n            input_messages.append(system_message)\n        input_messages.append(user_message)\n\n        return ground_truth, image, question, input_messages\n\n    def _generate_response(\n        self,\n        model: Any,\n        processor: Any,\n        input_messages: List[Dict],\n        max_new_tokens: int,\n        temperature: float,\n        min_p: float,\n    ) -> str:\n        \"\"\"Generate response from the model.\"\"\"\n\n        # Preparation for inference using Qwen's specific processing\n        text = processor.apply_chat_template(\n            input_messages, tokenize = False, add_generation_prompt = True\n        )\n\n        # Process vision info (images/videos) from messages\n        image_inputs, video_inputs = process_vision_info(input_messages)\n\n        # Create model inputs\n        inputs = processor(\n            text = [text],\n            images = image_inputs,\n            videos = video_inputs,\n            padding = True,\n            return_tensors = \"pt\",\n        )\n        inputs = inputs.to(model.device)\n\n        # Generate response\n        with torch.no_grad():\n            generated_ids = model.generate(\n                **inputs,\n                max_new_tokens = max_new_tokens,\n                temperature = temperature,\n                min_p = min_p,\n                use_cache = True,\n            )\n\n        # Extract only the generated part (not the input)\n        generated_ids_trimmed = [\n            out_ids[len(in_ids) :]\n            for in_ids, out_ids in zip(inputs.input_ids, generated_ids)\n        ]\n\n        # Decode the generated text\n        generated_response = processor.batch_decode(\n            generated_ids_trimmed,\n            skip_special_tokens = True,\n            clean_up_tokenization_spaces = False,\n        )[0]\n\n        return generated_response\n\n    def _save_individual_result(\n        self,\n        output_dir: str,\n        sample_idx: int,\n        question: str,\n        generated_response: str,\n        ground_truth: str,\n        word_error: float,\n        char_error: float,\n    ):\n        \"\"\"Save individual sample result to file.\"\"\"\n        output_file = os.path.join(output_dir, f\"sample_{sample_idx}.txt\")\n        with open(output_file, \"w\", encoding = \"utf-8\") as f:\n            f.write(f\"Sample {sample_idx}\\n\")\n            f.write(f\"Question: {question}\\n\\n\")\n            f.write(f\"Model output:\\n{generated_response.strip()}\\n\\n\")\n            f.write(f\"Ground truth:\\n{ground_truth}\\n\\n\")\n            f.write(f\"WER: {word_error:.4f}, CER: {char_error:.4f}\")\n\n    def _generate_summary_report(\n        self, results: List[Dict], output_dir: str, verbose: bool\n    ) -> Tuple[Optional[float], Optional[float]]:\n        \"\"\"Generate and save summary report.\"\"\"\n        if not results:\n            if verbose:\n                print(\"No results to summarize.\")\n            return None, None\n\n        df = pd.DataFrame(results)\n\n        # Calculate overall averages\n        avg_wer = df[\"wer\"].mean()\n        avg_cer = df[\"cer\"].mean()\n\n        # Save average metrics\n        with open(os.path.join(output_dir, \"avg_metrics.txt\"), \"w\") as f:\n            f.write(f\"Average WER: {avg_wer:.4f}\\n\")\n            f.write(f\"Average CER: {avg_cer:.4f}\\n\")\n\n        # Save detailed results\n        df.to_csv(os.path.join(output_dir, \"detailed_results.csv\"), index = False)\n\n        if verbose:\n            print(\"\\nResults Summary:\")\n            print(f\"Average WER: {avg_wer:.4f}\")\n            print(f\"Average CER: {avg_cer:.4f}\")\n            print(f\"\\nDetailed results saved to {output_dir}/\")\n\n        return avg_wer, avg_cer\n\n    def add_to_comparison(self, model_name: str, wer: float, cer: float):\n        \"\"\"Add model results to the comparison tracker.\"\"\"\n        self.model_comparison_results[model_name] = {\"wer\": wer, \"cer\": cer}\n\n    def print_model_comparison(\n        self, save_csv: bool = True, save_plot: bool = True\n    ) -> Optional[pd.DataFrame]:\n        \"\"\"Print a comparison of all models evaluated so far.\"\"\"\n        if not self.model_comparison_results:\n            print(\"No model results available for comparison\")\n            return None\n\n        print(\"\\n==== MODEL COMPARISON REPORT ====\")\n\n        # Create a comparison dataframe\n        comparison_df = pd.DataFrame(\n            {\n                \"Model\": list(self.model_comparison_results.keys()),\n                \"WER\": [\n                    results[\"wer\"] for results in self.model_comparison_results.values()\n                ],\n                \"CER\": [\n                    results[\"cer\"] for results in self.model_comparison_results.values()\n                ],\n            }\n        )\n\n        # Sort by WER (best performance first)\n        comparison_df = comparison_df.sort_values(\"WER\")\n\n        # Display the comparison table\n        print(\"\\nComparison Table (sorted by WER):\")\n        print(comparison_df.to_string(index = False))\n\n        # Save the comparison table\n        if save_csv:\n            comparison_file = \"model_comparison_results.csv\"\n            comparison_df.to_csv(comparison_file, index = False)\n            print(f\"\\nComparison table saved to {comparison_file}\")\n\n        # Generate a bar chart visualization\n        if save_plot:\n            self._create_comparison_plot(comparison_df)\n\n        return comparison_df\n\n    def _create_comparison_plot(self, comparison_df: pd.DataFrame):\n        \"\"\"Create and save comparison plot.\"\"\"\n        plt.figure(figsize = (12, 6))\n\n        # Plot WER\n        plt.subplot(1, 2, 1)\n        plt.bar(comparison_df[\"Model\"], comparison_df[\"WER\"], color = \"skyblue\")\n        plt.title(\"Word Error Rate Comparison\")\n        plt.ylabel(\"WER (lower is better)\")\n        plt.ylim(bottom = 0)\n        plt.xticks(rotation = 45, ha = \"right\")\n\n        # Plot CER\n        plt.subplot(1, 2, 2)\n        plt.bar(comparison_df[\"Model\"], comparison_df[\"CER\"], color = \"lightgreen\")\n        plt.title(\"Character Error Rate Comparison\")\n        plt.ylabel(\"CER (lower is better)\")\n        plt.ylim(bottom = 0)\n        plt.xticks(rotation = 45, ha = \"right\")\n\n        plt.tight_layout()\n        plt.savefig(\"ocr_model_comparison.png\")\n        plt.show()\n\n        print(f\"\\nVisualization saved to ocr_model_comparison.png\")\n\n    def get_comparison_results(self) -> Dict[str, Dict[str, float]]:\n        \"\"\"Get the current comparison results.\"\"\"\n        return self.model_comparison_results.copy()\n\n    def clear_comparison_results(self):\n        \"\"\"Clear all comparison results.\"\"\"\n        self.model_comparison_results.clear()\n\n\ndef evaluate_ocr_model(\n    model, processor, dataset, output_dir = \"ocr_evaluation_results\", **kwargs\n):\n    \"\"\"\n    Convenience function that maintains backward compatibility with the original function.\n    \"\"\"\n    evaluator = OCRModelEvaluator()\n    return evaluator.evaluate_model(model, processor, dataset, output_dir, **kwargs)\n\n\ndef create_evaluator():\n    \"\"\"Create a new OCR evaluator instance.\"\"\"\n    return OCRModelEvaluator()\n"
  },
  {
    "path": "tests/utils/os_utils.py",
    "content": "import subprocess\nimport sys\nimport os\nimport shutil\nimport importlib\n\n\ndef detect_package_manager():\n    \"\"\"Detect the available package manager\"\"\"\n    package_managers = {\n        \"apt\": \"/usr/bin/apt\",\n        \"yum\": \"/usr/bin/yum\",\n        \"dnf\": \"/usr/bin/dnf\",\n        \"pacman\": \"/usr/bin/pacman\",\n        \"zypper\": \"/usr/bin/zypper\",\n    }\n\n    for pm, path in package_managers.items():\n        if os.path.exists(path):\n            return pm\n    return None\n\n\ndef check_package_installed(package_name, package_manager = None):\n    \"\"\"Check if a package is installed using the system package manager\"\"\"\n\n    if package_manager is None:\n        package_manager = detect_package_manager()\n\n    if package_manager is None:\n        print(\"Warning: Could not detect package manager\")\n        return None\n\n    try:\n        if package_manager == \"apt\":\n            # Check with dpkg\n            result = subprocess.run(\n                [\"dpkg\", \"-l\", package_name], capture_output = True, text = True\n            )\n            return result.returncode == 0\n\n        elif package_manager in [\"yum\", \"dnf\"]:\n            # Check with rpm\n            result = subprocess.run(\n                [\"rpm\", \"-q\", package_name], capture_output = True, text = True\n            )\n            return result.returncode == 0\n\n        elif package_manager == \"pacman\":\n            result = subprocess.run(\n                [\"pacman\", \"-Q\", package_name], capture_output = True, text = True\n            )\n            return result.returncode == 0\n\n        elif package_manager == \"zypper\":\n            result = subprocess.run(\n                [\"zypper\", \"se\", \"-i\", package_name], capture_output = True, text = True\n            )\n            return package_name in result.stdout\n\n    except Exception as e:\n        print(f\"Error checking package: {e}\")\n        return None\n\n\ndef require_package(package_name, executable_name = None):\n    \"\"\"Require a package to be installed, exit if not found\"\"\"\n\n    # First check if executable is in PATH (most reliable)\n    if executable_name:\n        if shutil.which(executable_name):\n            print(f\"✓ {executable_name} is available\")\n            return\n\n    # Then check with package manager\n    pm = detect_package_manager()\n    is_installed = check_package_installed(package_name, pm)\n\n    if is_installed:\n        print(f\"✓ Package {package_name} is installed\")\n        return\n\n    # Package not found - show installation instructions\n    print(f\"❌ Error: {package_name} is not installed\")\n    print(f\"\\nPlease install {package_name} using your system package manager:\")\n\n    install_commands = {\n        \"apt\": f\"sudo apt update && sudo apt install {package_name}\",\n        \"yum\": f\"sudo yum install {package_name}\",\n        \"dnf\": f\"sudo dnf install {package_name}\",\n        \"pacman\": f\"sudo pacman -S {package_name}\",\n        \"zypper\": f\"sudo zypper install {package_name}\",\n    }\n\n    if pm and pm in install_commands:\n        print(f\"  {install_commands[pm]}\")\n    else:\n        for pm_name, cmd in install_commands.items():\n            print(f\"  {pm_name}: {cmd}\")\n\n    print(f\"\\nAlternatively, install with conda:\")\n    print(f\"  conda install -c conda-forge {package_name}\")\n\n    print(f\"\\nPlease install the required package and run the script again.\")\n    sys.exit(1)\n\n\n# Usage\n# require_package(\"ffmpeg\", \"ffmpeg\")\n\n\ndef require_python_package(package_name, import_name = None, pip_name = None):\n    \"\"\"Require a Python package to be installed, exit if not found\"\"\"\n    if import_name is None:\n        import_name = package_name\n    if pip_name is None:\n        pip_name = package_name\n\n    if importlib.util.find_spec(import_name) is None:\n        print(f\"❌ Error: Python package '{package_name}' is not installed\")\n        print(f\"\\nPlease install {package_name} using pip:\")\n        print(f\"  pip install {pip_name}\")\n        print(f\"  # or with conda:\")\n        print(f\"  conda install {pip_name}\")\n        print(f\"\\nAfter installation, run this script again.\")\n        sys.exit(1)\n    else:\n        print(f\"✓ Python package '{package_name}' is installed\")\n"
  },
  {
    "path": "tests/utils/perplexity_eval.md",
    "content": "# Language Model Perplexity Evaluator\n\nA Python module for evaluating language models using perplexity metrics with sliding window approach for long sequences. This evaluator provides efficient computation of perplexity scores across datasets with model comparison capabilities.\n\n## Basic Usage\n\n```python\nfrom perplexity_evaluator import ppl_model, add_to_comparison, print_model_comparison\n\n# Simple perplexity evaluation\ndataset = {\"text\": [\"Your text samples here...\", \"Another text sample...\"]}\nperplexity = ppl_model(model, tokenizer, dataset)\n\nprint(f\"Model Perplexity: {perplexity:.4f}\")\n\n# Add to comparison tracker\nadd_to_comparison(\"My Model\", perplexity)\nprint_model_comparison()\n```\n\n"
  },
  {
    "path": "tests/utils/perplexity_eval.py",
    "content": "from tqdm import tqdm\nimport torch\nimport pandas as pd\n\nmodel_comparison_results = {}\n# return the perplexity of the model on the dataset\n# The perplexity is computed on each example, individually, with a sliding window for examples longer than 512 tokens.\n\n\ndef ppl_model(model, tokenizer, dataset):\n    nlls = []\n    max_length = 2048\n    stride = 512\n    for s in tqdm(range(len(dataset[\"text\"]))):\n        encodings = tokenizer(dataset[\"text\"][s], return_tensors = \"pt\")\n        seq_len = encodings.input_ids.size(1)\n        prev_end_loc = 0\n        for begin_loc in range(0, seq_len, stride):\n            end_loc = min(begin_loc + max_length, seq_len)\n            trg_len = end_loc - prev_end_loc\n            input_ids = encodings.input_ids[:, begin_loc:end_loc].to(\"cuda\")\n            target_ids = input_ids.clone()\n            target_ids[:, :-trg_len] = -100\n            # Create attention mask based on pad token id\n            pad_token_id = (\n                tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0\n            )\n            attention_mask = (input_ids != pad_token_id).long()\n            with torch.no_grad():\n                outputs = model(\n                    input_ids, labels = target_ids, attention_mask = attention_mask\n                )\n                neg_log_likelihood = outputs.loss\n            nlls.append(neg_log_likelihood)\n            prev_end_loc = end_loc\n            if end_loc == seq_len:\n                break\n    ppl = torch.exp(torch.stack(nlls).mean())\n    return ppl\n\n\n# --------------------------------------------------------------------\n\n\n## ----------- Reporting helper function ----------- ##\n\n\n# Create a simple function to add results to the comparison\ndef add_to_comparison(model_name, ppl):\n    \"\"\"Add model results to the comparison tracker\"\"\"\n    model_comparison_results[model_name] = {\"ppl\": ppl}\n    # return model_comparison_results\n\n\n# Create a function to print the comparison report whenever needed\ndef print_model_comparison():\n    \"\"\"Print a comparison of all models evaluated so far\"\"\"\n    if not model_comparison_results:\n        print(\"No model results available for comparison\")\n        return\n\n    print(\"\\n==== MODEL COMPARISON REPORT ====\")\n\n    # Create a comparison dataframe\n    comparison_df = pd.DataFrame(\n        {\n            \"Model\": list(model_comparison_results.keys()),\n            # \"Perplexity\": [results[\"ppl\"] for results in model_comparison_results.values()],\n            \"Perplexity\": [\n                # Convert tensors to CPU and then to float if needed\n                results[\"ppl\"].cpu().item()\n                if torch.is_tensor(results[\"ppl\"])\n                else results[\"ppl\"]\n                for results in model_comparison_results.values()\n            ],\n        }\n    )\n\n    # Display the comparison table\n    print(\"\\nComparison Table:\")\n    print(comparison_df.to_string(index = False))\n"
  },
  {
    "path": "tests/utils/test_attention_masks.py",
    "content": "# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.\n#\n# This program is free software: you can redistribute it and/or modify\n# it under the terms of the GNU Lesser General Public License as published by\n# the Free Software Foundation, either version 3 of the License, or\n# (at your option) any later version.\n#\n# This program is distributed in the hope that it will be useful,\n# but WITHOUT ANY WARRANTY; without even the implied warranty of\n# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the\n# GNU General Public License for more details.\n#\n# You should have received a copy of the GNU Lesser General Public License\n# along with this program.  If not, see <https://www.gnu.org/licenses/>.\n\n\"\"\"Unit tests for packed-attention mask helpers with sliding-window logic.\"\"\"\n\nimport math\n\nimport torch\n\nfrom unsloth.utils import attention_dispatch\nfrom unsloth.utils import packing as packing_utils\n\n\ndef _make_seq_info(lengths):\n    lengths = torch.tensor(lengths, dtype = torch.int32)\n    cu = torch.cat(\n        [\n            torch.zeros(1, dtype = torch.int32),\n            torch.cumsum(lengths, dim = 0, dtype = torch.int32),\n        ]\n    )\n    max_len = int(lengths.max().item())\n    return lengths, cu, max_len\n\n\ndef test_sdpa_packed_attention_mask_sliding_window():\n    seq_info = _make_seq_info([5, 3])\n    mask = packing_utils.build_sdpa_packed_attention_mask(\n        seq_info,\n        dtype = torch.float32,\n        device = torch.device(\"cpu\"),\n        sliding_window = 3,\n    )\n\n    assert mask.shape == (1, 1, 8, 8)\n\n    block_first = mask[0, 0, :5, :5]\n    upper = torch.triu(torch.ones_like(block_first), diagonal = 1).bool()\n    assert torch.all(block_first[upper] == float(\"-inf\"))\n    assert block_first[3, 0].item() == float(\"-inf\")\n    assert block_first[4, 1].item() == float(\"-inf\")\n    assert block_first[4, 2].item() > -math.inf\n    assert mask[0, 0, 0, 6].item() == float(\"-inf\")\n\n\ndef test_xformers_block_mask_sliding_window(monkeypatch):\n    class _FakeMask:\n        def __init__(self, lengths, window = None):\n            self.lengths = lengths\n            self.window = window\n\n        @classmethod\n        def from_seqlens(cls, lengths):\n            return cls(tuple(lengths))\n\n        def make_local_attention(self, window_size):\n            return _FakeMask(self.lengths, window = window_size)\n\n    monkeypatch.setattr(packing_utils, \"_XFormersBlockMask\", _FakeMask, raising = False)\n\n    seq_info = _make_seq_info([4, 4])\n    mask = packing_utils.build_xformers_block_causal_mask(\n        seq_info,\n        sliding_window = 2,\n    )\n\n    assert isinstance(mask, _FakeMask)\n    assert mask.window == 2\n\n\ndef test_run_attention_sdpa_passes_sliding_window(monkeypatch):\n    seq_info = _make_seq_info([3, 2])\n    sliding_window = 2\n\n    original_builder = attention_dispatch.build_sdpa_packed_attention_mask\n    captured = {}\n\n    def _capture_builder(seq_info_arg, *, dtype, device, sliding_window = None):\n        captured[\"window\"] = sliding_window\n        return original_builder(\n            seq_info_arg,\n            dtype = dtype,\n            device = device,\n            sliding_window = sliding_window,\n        )\n\n    monkeypatch.setattr(\n        attention_dispatch,\n        \"build_sdpa_packed_attention_mask\",\n        _capture_builder,\n    )\n\n    def _fake_sdpa(Q, K, V, **kwargs):\n        captured[\"mask\"] = kwargs.get(\"attn_mask\")\n        return torch.zeros_like(Q)\n\n    monkeypatch.setattr(attention_dispatch, \"scaled_dot_product_attention\", _fake_sdpa)\n\n    config = attention_dispatch.AttentionConfig(\n        backend = attention_dispatch.SDPA,\n        n_kv_heads = 1,\n        n_groups = 1,\n    )\n\n    context = attention_dispatch.AttentionContext(\n        bsz = 1,\n        q_len = 5,\n        kv_seq_len = 5,\n        n_heads = 1,\n        head_dim = 1,\n        requires_grad = False,\n        seq_info = seq_info,\n        attention_mask = None,\n        causal_mask = None,\n        sliding_window = sliding_window,\n    )\n\n    Q = torch.zeros(1, 1, 5, 1)\n    K = torch.zeros_like(Q)\n    V = torch.zeros_like(Q)\n\n    attention_dispatch.run_attention(\n        config = config,\n        context = context,\n        Q = Q,\n        K = K,\n        V = V,\n    )\n\n    assert captured[\"window\"] == sliding_window\n    mask = captured[\"mask\"]\n    assert mask is not None and mask.shape == (1, 1, 5, 5)\n    assert mask[0, 0, 4, 1].item() == float(\"-inf\")\n\n\ndef test_run_attention_xformers_passes_sliding_window(monkeypatch):\n    seq_info = _make_seq_info([4])\n    sliding_window = 3\n\n    class _FakeBias:\n        pass\n\n    captured = {}\n\n    def _fake_builder(seq_info_arg, *, sliding_window = None, base_mask = None):\n        captured[\"window\"] = sliding_window\n        captured[\"base\"] = base_mask\n        return _FakeBias()\n\n    def _fake_attention(Q, K, V, attn_bias = None, **_):\n        captured[\"bias\"] = attn_bias\n        return torch.zeros_like(Q)\n\n    monkeypatch.setattr(\n        attention_dispatch, \"build_xformers_block_causal_mask\", _fake_builder\n    )\n    monkeypatch.setattr(\n        attention_dispatch, \"xformers_attention\", _fake_attention, raising = False\n    )\n    monkeypatch.setattr(\n        attention_dispatch, \"XFORMERS_BLOCK_DIAG_CLS\", _FakeBias, raising = False\n    )\n\n    config = attention_dispatch.AttentionConfig(\n        backend = attention_dispatch.XFORMERS,\n        n_kv_heads = 1,\n        n_groups = 1,\n    )\n\n    context = attention_dispatch.AttentionContext(\n        bsz = 1,\n        q_len = 4,\n        kv_seq_len = 4,\n        n_heads = 1,\n        head_dim = 1,\n        requires_grad = False,\n        seq_info = seq_info,\n        attention_mask = None,\n        causal_mask = None,\n        sliding_window = sliding_window,\n    )\n\n    Q = torch.zeros(1, 1, 4, 1)\n    K = torch.zeros_like(Q)\n    V = torch.zeros_like(Q)\n\n    attention_dispatch.run_attention(\n        config = config,\n        context = context,\n        Q = Q,\n        K = K,\n        V = V,\n    )\n\n    assert captured[\"window\"] == sliding_window\n    assert isinstance(captured[\"bias\"], _FakeBias)\n\n\ndef test_run_attention_flash_varlen_receives_window_and_softcap(monkeypatch):\n    seq_info = _make_seq_info([4])\n    sliding_window = 3\n    softcap = 0.5\n    window_tuple = (sliding_window, sliding_window)\n\n    captured = {}\n\n    def _fake_flash_varlen(Q, K, V, cu_q, cu_k, max_q, max_k, **kwargs):\n        captured[\"kwargs\"] = kwargs\n        return torch.zeros_like(Q)\n\n    monkeypatch.setattr(\n        attention_dispatch,\n        \"flash_attn_varlen_func\",\n        _fake_flash_varlen,\n    )\n    monkeypatch.setattr(attention_dispatch, \"HAS_FLASH_ATTENTION\", True)\n\n    config = attention_dispatch.AttentionConfig(\n        backend = attention_dispatch.FLASH_VARLEN,\n        n_kv_heads = 1,\n        n_groups = 1,\n        flash_varlen_kwargs = {\n            \"dropout_p\": 0.0,\n            \"softmax_scale\": 1.0,\n            \"causal\": True,\n            \"softcap\": softcap,\n            \"window_size\": window_tuple,\n        },\n    )\n\n    context = attention_dispatch.AttentionContext(\n        bsz = 1,\n        q_len = 4,\n        kv_seq_len = 4,\n        n_heads = 1,\n        head_dim = 2,\n        requires_grad = False,\n        seq_info = seq_info,\n        attention_mask = None,\n        causal_mask = None,\n        sliding_window = sliding_window,\n    )\n\n    Q = torch.zeros(1, 1, 4, 2)\n    K = torch.zeros_like(Q)\n    V = torch.zeros_like(Q)\n\n    attention_dispatch.run_attention(\n        config = config,\n        context = context,\n        Q = Q,\n        K = K,\n        V = V,\n    )\n\n    assert captured[\"kwargs\"][\"softcap\"] == softcap\n    assert captured[\"kwargs\"][\"window_size\"] == window_tuple\n\n\n\"\"\"Unit tests for packed-attention mask helpers with sliding-window logic.\"\"\"\n"
  },
  {
    "path": "tests/utils/test_packing.py",
    "content": "# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.\n#\n# This program is free software: you can redistribute it and/or modify\n# it under the terms of the GNU Lesser General Public License as published by\n# the Free Software Foundation, either version 3 of the License, or\n# (at your option) any later version.\n#\n# This program is distributed in the hope that it will be useful,\n# but WITHOUT ANY WARRANTY; without even the implied warranty of\n# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the\n# GNU General Public License for more details.\n#\n# You should have received a copy of the GNU Lesser General Public License\n# along with this program.  If not, see <https://www.gnu.org/licenses/>.\n\nfrom unsloth import FastLanguageModel\nfrom unsloth.utils import attention_dispatch as attention_dispatch_utils\nfrom unsloth.utils.packing import (\n    configure_padding_free,\n    configure_sample_packing,\n    enable_padding_free_metadata,\n    enable_sample_packing,\n    mask_packed_sequence_boundaries,\n)\n\nfrom contextlib import ExitStack\nfrom types import SimpleNamespace\nfrom unittest.mock import patch\n\nimport pytest\nimport torch\nfrom datasets import Dataset\nfrom trl import SFTConfig, SFTTrainer\nfrom trl.trainer.sft_trainer import DataCollatorForLanguageModeling\n\n\ndef _build_packed_training_setup(tmp_path, device):\n    dtype = None\n    if device.type == \"cuda\":\n        if torch.cuda.is_bf16_supported():\n            dtype = torch.bfloat16\n        else:\n            dtype = torch.float16\n\n    try:\n        model, tokenizer = FastLanguageModel.from_pretrained(\n            model_name = \"hf-internal-testing/tiny-random-LlamaForCausalLM\",\n            max_seq_length = 64,\n            load_in_4bit = False,\n            dtype = dtype,\n        )\n    except OSError as exc:  # pragma: no cover - offline CI\n        pytest.skip(f\"Requires access to tiny llama checkpoint: {exc}\")\n\n    model.to(device)\n\n    dataset = Dataset.from_dict(\n        {\n            \"text\": [\n                \"Hello world!\",\n                \"Short sample.\",\n                \"This is a slightly longer packed example to test batching.\",\n                \"Another response to include in the batch.\",\n            ]\n        }\n    )\n\n    training_args = SFTConfig(\n        per_device_train_batch_size = 1,\n        per_device_eval_batch_size = 1,\n        gradient_accumulation_steps = 1,\n        dataset_text_field = \"text\",\n        max_length = 64,\n        logging_steps = 1,\n        max_steps = 1,\n        fp16 = device.type == \"cuda\" and not torch.cuda.is_bf16_supported(),\n        bf16 = device.type == \"cuda\" and torch.cuda.is_bf16_supported(),\n        dataset_num_proc = 1,\n        output_dir = str(tmp_path),\n        packing = True,\n    )\n\n    trainer = SFTTrainer(\n        model = model,\n        processing_class = tokenizer,\n        train_dataset = dataset,\n        args = training_args,\n    )\n\n    enable_sample_packing(model, trainer)\n\n    dataloader = trainer.get_train_dataloader()\n    batch = next(iter(dataloader))\n\n    model_device = next(model.parameters()).device\n\n    for key, value in list(batch.items()):\n        if torch.is_tensor(value):\n            batch[key] = value.to(model_device)\n\n    from unsloth.models import llama as llama_mod\n\n    return model, batch, trainer, llama_mod\n\n\ndef _trim_batch_to_total_tokens(data, total_tokens):\n    def _trim_tensor(t: torch.Tensor):\n        if t.ndim >= 2 and t.size(1) > total_tokens:\n            return t[:, :total_tokens].contiguous()\n        return t\n\n    trimmed = {}\n    for key, value in data.items():\n        if torch.is_tensor(value):\n            trimmed[key] = _trim_tensor(value)\n        else:\n            trimmed[key] = value\n    return trimmed\n\n\ndef test_mask_packed_sequence_boundaries_marks_single_row():\n    shift_labels = torch.arange(6, dtype = torch.long).view(1, 6)\n    changed = mask_packed_sequence_boundaries(\n        shift_labels,\n        torch.tensor([2, 1, 3], dtype = torch.int32),\n    )\n    assert changed is True\n    flat = shift_labels.view(-1)\n    assert flat[1].item() == -100\n    assert flat[2].item() == -100\n    assert flat[5].item() == -100\n    assert flat[0].item() != -100\n\n\ndef test_mask_packed_sequence_boundaries_across_multiple_rows():\n    shift_labels = torch.arange(10, dtype = torch.long).view(2, 5)\n    lengths = torch.tensor([3, 2, 4, 1], dtype = torch.int32)\n    changed = mask_packed_sequence_boundaries(shift_labels, lengths)\n    assert changed is True\n    flat = shift_labels.view(-1)\n    for idx in (2, 4, 8, 9):\n        assert flat[idx].item() == -100\n    assert torch.any(flat != -100)\n\n\ndef test_configure_sample_packing():\n    config = SimpleNamespace()\n    configure_sample_packing(config)\n\n    assert config.packing is True\n    assert config.padding_free is True\n    assert config.remove_unused_columns is False\n\n\ndef test_configure_padding_free():\n    config = SimpleNamespace(remove_unused_columns = True)\n    configure_padding_free(config)\n\n    assert config.padding_free is True\n    assert config.remove_unused_columns is False\n\n\nclass _DummyChild(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.max_seq_length = 8\n\n\nclass _DummyModel(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.max_seq_length = 16\n        self.child = _DummyChild()\n        self.config = SimpleNamespace(_attn_implementation = \"sdpa\")\n        self.generation_config = SimpleNamespace(attn_implementation = \"sdpa\")\n\n\nclass _DummyTrainer:\n    def __init__(self):\n        self.args = SimpleNamespace(remove_unused_columns = True)\n        collator_args = {\n            \"pad_token_id\": 0,\n            \"completion_only_loss\": False,\n            \"return_tensors\": \"pt\",\n        }\n        optional_flags = [\n            {\"padding_free\": True, \"return_position_ids\": False},\n            {\"padding_free\": True},\n            {},\n        ]\n        for extra in optional_flags:\n            try:\n                self.data_collator = DataCollatorForLanguageModeling(\n                    **collator_args, **extra\n                )\n                break\n            except TypeError:\n                continue\n        # Ensure attributes exist even if the constructor did not accept them\n        if not hasattr(self.data_collator, \"padding_free\"):\n            self.data_collator.padding_free = True\n        if not hasattr(self.data_collator, \"return_position_ids\"):\n            self.data_collator.return_position_ids = False\n\n\nclass _PaddingFreeCollator:\n    def __init__(self):\n        self.padding_free = True\n        self.return_position_ids = False\n        self.calls = 0\n\n    def torch_call(self, examples):\n        self.calls += 1\n        return {\n            \"input_ids\": torch.tensor([[0]], dtype = torch.long),\n            \"examples_seen\": self.calls,\n        }\n\n\ndef test_enable_sample_packing():\n    model = _DummyModel()\n    trainer = _DummyTrainer()\n\n    enable_sample_packing(model, trainer)\n\n    # model hierarchy should now allow packed overlength inputs\n    assert getattr(model, \"_unsloth_allow_packed_overlength\") is True\n    assert getattr(model.child, \"_unsloth_allow_packed_overlength\") is True\n\n    collator = trainer.data_collator\n    assert collator.return_position_ids is True\n    assert getattr(collator, \"_unsloth_packing_wrapped\") is True\n\n    examples = [\n        {\n            \"input_ids\": [0, 1, 2],\n            \"labels\": [0, 1, 2],\n            \"seq_lengths\": [2, 1],\n        },\n        {\n            \"input_ids\": [3, 4, 5],\n            \"labels\": [3, 4, 5],\n            \"seq_lengths\": [3],\n        },\n    ]\n    batch = collator.torch_call(examples)\n\n    # packed lengths are aggregated into a single tensor\n    assert \"packed_seq_lengths\" in batch\n    assert torch.equal(\n        batch[\"packed_seq_lengths\"],\n        torch.tensor([2, 1, 3], dtype = torch.int32),\n    )\n\n    assert batch[\"input_ids\"].shape == (1, 6)\n    expected_positions = torch.tensor([0, 1, 0, 0, 1, 2], dtype = torch.long)\n    assert torch.equal(batch[\"position_ids\"].view(-1)[:6], expected_positions)\n\n\ndef test_enable_sample_packing_trl_collator(tmp_path):\n    device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n    model, _, trainer, _ = _build_packed_training_setup(tmp_path, device)\n\n    enable_sample_packing(model, trainer)\n\n    examples = [\n        {\n            \"input_ids\": [0, 1, 2],\n            \"labels\": [0, 1, 2],\n            \"seq_lengths\": [2, 1],\n        },\n        {\n            \"input_ids\": [3, 4, 5],\n            \"labels\": [3, 4, 5],\n            \"seq_lengths\": [3],\n        },\n    ]\n\n    batch = trainer.data_collator.torch_call(examples)\n\n    assert batch[\"input_ids\"].shape == (1, 6)\n    assert torch.equal(\n        batch[\"packed_seq_lengths\"],\n        torch.tensor([2, 1, 3], dtype = torch.int32),\n    )\n\n    expected_positions = torch.tensor([0, 1, 0, 0, 1, 2], dtype = torch.long)\n    assert torch.equal(batch[\"position_ids\"].view(-1)[:6], expected_positions)\n\n    if hasattr(trainer, \"accelerator\"):\n        trainer.accelerator.free_memory()\n\n\ndef test_enable_padding_free_metadata():\n    model = _DummyModel()\n    trainer = SimpleNamespace(\n        args = SimpleNamespace(remove_unused_columns = True),\n        data_collator = _PaddingFreeCollator(),\n    )\n\n    enable_padding_free_metadata(model, trainer)\n\n    assert getattr(model, \"_unsloth_allow_packed_overlength\") is True\n    assert getattr(model.child, \"_unsloth_allow_packed_overlength\") is True\n\n    collator = trainer.data_collator\n    assert collator.return_position_ids is True\n    assert getattr(collator, \"_unsloth_padding_free_lengths_wrapped\") is True\n\n    examples = [\n        {\"input_ids\": [0, 1, 2]},\n        {\"input_ids\": [3, 4]},\n    ]\n    batch = collator.torch_call(examples)\n    assert torch.equal(\n        batch[\"packed_seq_lengths\"],\n        torch.tensor([3, 2], dtype = torch.int32),\n    )\n    assert trainer.args.remove_unused_columns is False\n\n\ndef test_packing_sdpa(tmp_path):\n    device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n    model, batch, trainer, llama_mod = _build_packed_training_setup(tmp_path, device)\n\n    assert \"packed_seq_lengths\" in batch\n    assert \"attention_mask\" not in batch\n    assert batch[\"packed_seq_lengths\"].dtype == torch.int32\n\n    total_tokens = batch[\"input_ids\"].size(-1)\n    assert int(batch[\"packed_seq_lengths\"].sum().item()) == total_tokens\n\n    packed_tokens = int(batch[\"packed_seq_lengths\"].sum().item())\n    assert \"position_ids\" in batch\n    flat_positions = batch[\"position_ids\"].reshape(-1)[:packed_tokens]\n    expected_positions = torch.cat(\n        [\n            torch.arange(length, dtype = torch.long)\n            for length in batch[\"packed_seq_lengths\"].tolist()\n        ]\n    )\n    assert torch.equal(flat_positions.cpu(), expected_positions)\n    inputs = _trim_batch_to_total_tokens(batch, packed_tokens)\n\n    seq_info = llama_mod.get_packed_info_from_kwargs(\n        {\"packed_seq_lengths\": batch[\"packed_seq_lengths\"]},\n        inputs[\"input_ids\"].device,\n    )\n    assert seq_info is not None\n\n    original_mask = attention_dispatch_utils.build_sdpa_packed_attention_mask\n    mask_calls = []\n    captured_loss_labels = {}\n\n    def _capture_mask(seq_info, dtype, device, *, sliding_window = None):\n        mask_calls.append(tuple(seq_info[0].tolist()))\n        return original_mask(\n            seq_info,\n            dtype = dtype,\n            device = device,\n            sliding_window = sliding_window,\n        )\n\n    def _capture_loss(*, logits, labels, **loss_kwargs):\n        captured_loss_labels[\"labels\"] = labels.detach().to(\"cpu\")\n        return torch.zeros((), device = logits.device, dtype = logits.dtype)\n\n    with ExitStack() as stack:\n        stack.enter_context(\n            patch.object(attention_dispatch_utils, \"HAS_FLASH_ATTENTION\", False)\n        )\n        stack.enter_context(\n            patch.object(attention_dispatch_utils, \"HAS_XFORMERS\", False)\n        )\n        stack.enter_context(\n            patch.object(\n                attention_dispatch_utils,\n                \"build_sdpa_packed_attention_mask\",\n                side_effect = _capture_mask,\n            )\n        )\n        stack.enter_context(\n            patch.object(\n                llama_mod,\n                \"fast_cross_entropy_loss\",\n                side_effect = _capture_loss,\n            )\n        )\n        with torch.no_grad():\n            outputs = model(**inputs)\n\n    assert mask_calls, \"SDPA packed mask was not constructed\"\n    assert outputs.loss is not None\n    assert \"labels\" in captured_loss_labels\n    flat_loss_labels = captured_loss_labels[\"labels\"].reshape(-1)\n    boundaries = (\n        torch.cumsum(\n            batch[\"packed_seq_lengths\"].to(device = \"cpu\", dtype = torch.long), dim = 0\n        )\n        - 1\n    )\n    for idx in boundaries.tolist():\n        assert flat_loss_labels[idx].item() == -100\n    assert torch.any(flat_loss_labels != -100)\n\n    if hasattr(trainer, \"accelerator\"):\n        trainer.accelerator.free_memory()\n"
  },
  {
    "path": "tests/utils/test_qat.py",
    "content": "from unsloth import FastLanguageModel\n\nfrom typing import Dict\n\nimport pytest\nimport torch\n\ntry:\n    from torchao.quantization.qat import FakeQuantizedLinear\n    from torchao.quantization.qat.fake_quantizer import (\n        FakeQuantizerBase,\n        Float8FakeQuantizer,\n        Int4WeightFakeQuantizer,\n        IntxFakeQuantizer,\n    )\nexcept ImportError:\n    print(\n        \"Missing torchao import, please install or upgrade torchao with: pip install 'torchao>=0.15.0'\"\n    )\n\n\nclass _CountingFakeQuantizer(torch.nn.Module):\n    \"\"\"\n    Dummy fake quantizer that counts the number of times it has been called.\n    \"\"\"\n\n    def __init__(self):\n        super().__init__()\n        self.count = 0\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        self.count += 1\n        return x\n\n\ndef _get_model(qat_scheme: str, full_finetuning: bool):\n    \"\"\"\n    Return a 2-tuple of (model, tokenizer), where the model has been configured\n    to use QAT. If `full_finetuning` is False, return the PEFT (LoRA) model.\n    \"\"\"\n    model, tokenizer = FastLanguageModel.from_pretrained(\n        model_name = \"unsloth/Qwen3-1.7B\",\n        load_in_4bit = False,\n        full_finetuning = full_finetuning,\n        qat_scheme = qat_scheme if full_finetuning else None,\n    )\n    if not full_finetuning:\n        model = FastLanguageModel.get_peft_model(\n            model,\n            qat_scheme = qat_scheme,\n        )\n    return model, tokenizer\n\n\ndef _test_linear_is_fake_quantized(linear: torch.nn.Linear, qat_scheme: str):\n    \"\"\"\n    Verify that the given linear contains fake quantizers according to the `qat_scheme`.\n    \"\"\"\n    weight_only = False\n    if qat_scheme == \"fp8-int4\":\n        act_fq_class = Float8FakeQuantizer\n        weight_fq_class = Int4WeightFakeQuantizer\n        min_in_features = 128\n    elif qat_scheme == \"fp8-fp8\":\n        act_fq_class = Float8FakeQuantizer\n        weight_fq_class = Float8FakeQuantizer\n        min_in_features = -1\n    elif qat_scheme == \"int8\":\n        act_fq_class = None\n        weight_fq_class = IntxFakeQuantizer\n        min_in_features = 128\n        weight_only = True\n    else:\n        raise ValueError(f\"Unknown qat_scheme: {qat_scheme}\")\n\n    # Check base layer activations and weights\n    base_layer = getattr(linear, \"base_layer\", linear)\n    if base_layer.in_features >= min_in_features:\n        assert isinstance(base_layer, FakeQuantizedLinear)\n        if not weight_only:\n            assert isinstance(base_layer.activation_fake_quantizer, act_fq_class)\n        assert isinstance(base_layer.weight_fake_quantizer, weight_fq_class)\n\n    # Check lora A and B (only for full_finetuning=False)\n    if hasattr(linear, \"lora_A\") and hasattr(linear, \"lora_B\"):\n        lora_A = linear.lora_A.default\n        lora_B = linear.lora_B.default\n        if lora_A.in_features >= min_in_features:\n            assert isinstance(lora_A, FakeQuantizedLinear)\n            if not weight_only:\n                assert isinstance(lora_A.activation_fake_quantizer, act_fq_class)\n            assert isinstance(lora_A.weight_fake_quantizer, weight_fq_class)\n        if lora_B.in_features >= min_in_features:\n            assert isinstance(lora_B, FakeQuantizedLinear)\n            if not weight_only:\n                assert isinstance(lora_B.activation_fake_quantizer, act_fq_class)\n            assert isinstance(lora_B.weight_fake_quantizer, weight_fq_class)\n\n\ndef _test_fake_quantizers_are_called(\n    model: torch.nn.Module,\n    example_inputs: Dict,\n    full_finetuning: bool,\n    qat_scheme: str,\n):\n    \"\"\"\n    Verify that the fake quantizers are actually called when the model is called.\n    \"\"\"\n    weight_only = qat_scheme == \"int8\"\n\n    def _swap_fake_quantizers(model: torch.nn.Module):\n        for name, child in model.named_children():\n            if isinstance(child, FakeQuantizerBase):\n                setattr(model, name, _CountingFakeQuantizer())\n\n    def _assert_fake_quantizers_are_called(model: torch.nn.Module):\n        for name, child in model.named_children():\n            if full_finetuning:\n                if isinstance(child, FakeQuantizedLinear):\n                    if not weight_only:\n                        assert child.activation_fake_quantizer.count == 1\n                    assert child.weight_fake_quantizer.count == 1\n            else:\n                # For LoRA, we only fake quantize the input activations once per block:\n                # For self_attn, we only fake quantize the q_proj's input activations\n                # For mlp, we only fake quantize the gate_proj's input activations\n                if name == \"self_attn\":\n                    base_layer = child.q_proj.base_layer\n                    if not weight_only:\n                        assert hasattr(base_layer, \"activation_fake_quantizer\")\n                        assert base_layer.activation_fake_quantizer.count == 1\n                elif name == \"mlp\":\n                    base_layer = child.gate_proj.base_layer\n                    if not weight_only:\n                        assert hasattr(base_layer, \"activation_fake_quantizer\")\n                        assert base_layer.activation_fake_quantizer.count == 1\n                elif isinstance(child, FakeQuantizedLinear):\n                    # Weight fake quantizers should always be called\n                    assert child.weight_fake_quantizer.count == 1\n\n    for k, v in example_inputs.items():\n        example_inputs[k] = v.cuda()\n    model.apply(_swap_fake_quantizers)\n    model(**example_inputs)\n    model.apply(_assert_fake_quantizers_are_called)\n\n\ndef _test_model_fake_quantize(qat_scheme: str, full_finetuning: bool):\n    \"\"\"\n    Test that all linear layers in the model are fake quantized according to the `qat_scheme`.\n    \"\"\"\n    model, tokenizer = _get_model(qat_scheme, full_finetuning)\n    if full_finetuning:\n        model = model.model\n    else:\n        model = model.base_model.model.model\n    for layer in model.layers:\n        _test_linear_is_fake_quantized(layer.self_attn.q_proj, qat_scheme)\n        _test_linear_is_fake_quantized(layer.self_attn.k_proj, qat_scheme)\n        _test_linear_is_fake_quantized(layer.self_attn.v_proj, qat_scheme)\n        _test_linear_is_fake_quantized(layer.mlp.gate_proj, qat_scheme)\n        _test_linear_is_fake_quantized(layer.mlp.up_proj, qat_scheme)\n        _test_linear_is_fake_quantized(layer.mlp.down_proj, qat_scheme)\n    inputs = tokenizer(\"How are you?\", return_tensors = \"pt\")\n    _test_fake_quantizers_are_called(model, inputs, full_finetuning, qat_scheme)\n\n\n# TODO: there are bad interactions across tests right now, need to figure out\n# how to disable model caching before re-enabling this test\n@pytest.mark.parametrize(\"qat_scheme\", [\"fp8-int4\", \"fp8-fp8\", \"int8\"])\ndef _test_full_model_fake_quantize(qat_scheme: str):\n    _test_model_fake_quantize(qat_scheme, full_finetuning = True)\n\n\n@pytest.mark.parametrize(\"qat_scheme\", [\"fp8-int4\", \"fp8-fp8\", \"int8\"])\ndef test_lora_model_fake_quantize(qat_scheme: str):\n    _test_model_fake_quantize(qat_scheme, full_finetuning = False)\n"
  },
  {
    "path": "tests/utils/test_trunc_normal_patch.py",
    "content": "# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.\n#\n# This program is free software: you can redistribute it and/or modify\n# it under the terms of the GNU Lesser General Public License as published by\n# the Free Software Foundation, either version 3 of the License, or\n# (at your option) any later version.\n#\n# This program is distributed in the hope that it will be useful,\n# but WITHOUT ANY WARRANTY; without even the implied warranty of\n# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the\n# GNU General Public License for more details.\n#\n# You should have received a copy of the GNU Lesser General Public License\n# along with this program.  If not, see <https://www.gnu.org/licenses/>.\n\n\"\"\"Tests for trunc_normal low-precision patch compatibility.\"\"\"\n\nimport importlib.util\nimport inspect\nfrom pathlib import Path\n\nimport pytest\nimport torch\n\n\n_MISSING = object()\n\n\ndef _load_import_fixes_module():\n    repo_root = Path(__file__).resolve().parents[2]\n    import_fixes_path = repo_root / \"unsloth\" / \"import_fixes.py\"\n    spec = importlib.util.spec_from_file_location(\n        \"unsloth_import_fixes_local\", import_fixes_path\n    )\n    assert spec is not None and spec.loader is not None\n    module = importlib.util.module_from_spec(spec)\n    spec.loader.exec_module(module)\n    return module\n\n\ndef _getattr_or_missing(obj, name):\n    return getattr(obj, name) if hasattr(obj, name) else _MISSING\n\n\ndef _restore_attr(obj, name, value):\n    if value is _MISSING:\n        if hasattr(obj, name):\n            delattr(obj, name)\n        return\n    setattr(obj, name, value)\n\n\ndef test_trunc_normal_patch_accepts_positional_generator():\n    import_fixes = _load_import_fixes_module()\n    patch_fn = import_fixes.patch_trunc_normal_precision_issue\n\n    init_mod = torch.nn.init\n    old_fn = init_mod.trunc_normal_\n    old_patched = _getattr_or_missing(init_mod, \"_unsloth_trunc_normal_patched\")\n    old_original = _getattr_or_missing(init_mod, \"_unsloth_trunc_normal_original\")\n    try:\n        # Normalize to an unpatched baseline before applying the patch.\n        if old_original is not _MISSING:\n            init_mod.trunc_normal_ = old_original\n        if hasattr(init_mod, \"_unsloth_trunc_normal_patched\"):\n            delattr(init_mod, \"_unsloth_trunc_normal_patched\")\n        if hasattr(init_mod, \"_unsloth_trunc_normal_original\"):\n            delattr(init_mod, \"_unsloth_trunc_normal_original\")\n\n        patch_fn()\n        sig = inspect.signature(init_mod.trunc_normal_)\n        assert \"generator\" in sig.parameters\n        assert sig.parameters[\"generator\"].kind is not inspect.Parameter.KEYWORD_ONLY\n\n        tensor = torch.empty(1024, dtype = torch.float32)\n        gen = torch.Generator()\n        gen.manual_seed(3407)\n\n        init_mod.trunc_normal_(tensor, 0.0, 1.0, -2.0, 2.0, gen)\n        init_mod.trunc_normal_(tensor, mean = 0.0, std = 1.0, a = -2.0, b = 2.0, generator = gen)\n    finally:\n        init_mod.trunc_normal_ = old_fn\n        _restore_attr(init_mod, \"_unsloth_trunc_normal_patched\", old_patched)\n        _restore_attr(init_mod, \"_unsloth_trunc_normal_original\", old_original)\n\n\ndef test_trunc_normal_patch_rejects_invalid_generator():\n    import_fixes = _load_import_fixes_module()\n    patch_fn = import_fixes.patch_trunc_normal_precision_issue\n\n    init_mod = torch.nn.init\n    old_fn = init_mod.trunc_normal_\n    old_patched = _getattr_or_missing(init_mod, \"_unsloth_trunc_normal_patched\")\n    old_original = _getattr_or_missing(init_mod, \"_unsloth_trunc_normal_original\")\n    try:\n        if old_original is not _MISSING:\n            init_mod.trunc_normal_ = old_original\n        if hasattr(init_mod, \"_unsloth_trunc_normal_patched\"):\n            delattr(init_mod, \"_unsloth_trunc_normal_patched\")\n        if hasattr(init_mod, \"_unsloth_trunc_normal_original\"):\n            delattr(init_mod, \"_unsloth_trunc_normal_original\")\n\n        patch_fn()\n        sig = inspect.signature(init_mod.trunc_normal_)\n        if \"generator\" not in sig.parameters:\n            pytest.skip(\"torch.nn.init.trunc_normal_ lacks a generator parameter\")\n\n        tensor = torch.empty(16, dtype = torch.float32)\n        with pytest.raises(TypeError):\n            init_mod.trunc_normal_(tensor, generator = 123)\n    finally:\n        init_mod.trunc_normal_ = old_fn\n        _restore_attr(init_mod, \"_unsloth_trunc_normal_patched\", old_patched)\n        _restore_attr(init_mod, \"_unsloth_trunc_normal_original\", old_original)\n"
  },
  {
    "path": "unsloth/__init__.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport warnings, importlib, sys\nfrom packaging.version import Version\nimport os, re, subprocess, inspect, functools\nimport numpy as np\n\n# Log Unsloth is being used\nos.environ[\"UNSLOTH_IS_PRESENT\"] = \"1\"\n\n# Check if modules that need patching are already imported\ncritical_modules = [\"trl\", \"transformers\", \"peft\"]\nalready_imported = [mod for mod in critical_modules if mod in sys.modules]\n\n# Fix some issues before importing other packages\nfrom .import_fixes import (\n    fix_message_factory_issue,\n    check_fbgemm_gpu_version,\n    disable_broken_causal_conv1d,\n    disable_broken_vllm,\n    configure_amdgpu_asic_id_table_path,\n    torchvision_compatibility_check,\n    fix_diffusers_warnings,\n    fix_huggingface_hub,\n)\n\n# Configure libdrm ids table path early so ROCm can resolve AMD GPU names.\nconfigure_amdgpu_asic_id_table_path()\ndisable_broken_causal_conv1d()\ndisable_broken_vllm()\nfix_message_factory_issue()\ncheck_fbgemm_gpu_version()\ntorchvision_compatibility_check()\nfix_diffusers_warnings()\nfix_huggingface_hub()\ndel configure_amdgpu_asic_id_table_path\ndel disable_broken_causal_conv1d\ndel disable_broken_vllm\ndel fix_message_factory_issue\ndel check_fbgemm_gpu_version\ndel torchvision_compatibility_check\ndel fix_diffusers_warnings\ndel fix_huggingface_hub\n\n# This check is critical because Unsloth optimizes these libraries by modifying\n# their code at import time. If they're imported first, the original (slower,\n# more memory-intensive) implementations will be used instead of Unsloth's\n# optimized versions, potentially causing OOM errors or slower training.\nif already_imported:\n    # stacklevel=2 makes warning point to user's import line rather than this library code,\n    # showing them exactly where to fix the import order in their script\n    warnings.warn(\n        f\"WARNING: Unsloth should be imported before [{', '.join(already_imported)}] \"\n        f\"to ensure all optimizations are applied. Your code may run slower or encounter \"\n        f\"memory issues without these optimizations.\\n\\n\"\n        f\"Please restructure your imports with 'import unsloth' at the top of your file.\",\n        stacklevel = 2,\n    )\ndel already_imported, critical_modules\n\n# Unsloth currently does not work on multi GPU setups - sadly we are a 2 brother team so\n# enabling it will require much more work, so we have to prioritize. Please understand!\n# We do have a beta version, which you can contact us about!\n# Thank you for your understanding and we appreciate it immensely!\n\n# Fixes https://github.com/unslothai/unsloth/issues/1266\nos.environ[\"PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION\"] = \"python\"\n\n# [TODO] Check why some GPUs don't work\n#    \"pinned_use_cuda_host_register:True,\"\\\n#    \"pinned_num_register_threads:8\"\n\n\nfrom importlib.metadata import version as importlib_version\nfrom importlib.metadata import PackageNotFoundError\n\n# Check for unsloth_zoo\ntry:\n    unsloth_zoo_version = importlib_version(\"unsloth_zoo\")\n    if Version(unsloth_zoo_version) < Version(\"2026.3.4\"):\n        print(\n            \"Unsloth: Please update Unsloth and Unsloth-Zoo to the latest version!\\n\"\n            \"Do this via `pip install --upgrade --force-reinstall --no-cache-dir --no-deps unsloth unsloth_zoo`\"\n        )\n        # if os.environ.get(\"UNSLOTH_DISABLE_AUTO_UPDATES\", \"0\") == \"0\":\n        #     try:\n        #         os.system(\"pip install --upgrade --no-cache-dir --no-deps unsloth_zoo\")\n        #     except:\n        #         try:\n        #             os.system(\"pip install --upgrade --no-cache-dir --no-deps --user unsloth_zoo\")\n        #         except:\n        #             raise ImportError(\"Unsloth: Please update unsloth_zoo via `pip install --upgrade --no-cache-dir --no-deps unsloth_zoo`\")\n    import unsloth_zoo\nexcept PackageNotFoundError:\n    raise ImportError(\n        f\"Unsloth: Please install unsloth_zoo via `pip install unsloth_zoo` then retry!\"\n    )\nexcept:\n    raise\ndel PackageNotFoundError, importlib_version\n\n# Try importing PyTorch and check version\ntry:\n    import torch\nexcept ModuleNotFoundError:\n    raise ImportError(\n        \"Unsloth: Pytorch is not installed. Go to https://pytorch.org/.\\n\"\n        \"We have some installation instructions on our Github page.\"\n    )\nexcept:\n    raise\n\nfrom unsloth_zoo.device_type import (\n    is_hip,\n    get_device_type,\n    DEVICE_TYPE,\n    DEVICE_TYPE_TORCH,\n    DEVICE_COUNT,\n    ALLOW_PREQUANTIZED_MODELS,\n)\n\n# Fix other issues\nfrom .import_fixes import (\n    fix_xformers_performance_issue,\n    fix_vllm_aimv2_issue,\n    check_vllm_torch_sm100_compatibility,\n    fix_vllm_guided_decoding_params,\n    fix_vllm_pdl_blackwell,\n    fix_triton_compiled_kernel_missing_attrs,\n    patch_trunc_normal_precision_issue,\n    ignore_logger_messages,\n    patch_ipykernel_hf_xet,\n    patch_trackio,\n    patch_datasets,\n    patch_enable_input_require_grads,\n    fix_openenv_no_vllm,\n    patch_openspiel_env_async,\n    fix_executorch,\n    patch_vllm_for_notebooks,\n    patch_torchcodec_audio_decoder,\n    disable_torchcodec_if_broken,\n    disable_broken_wandb,\n)\n\nfix_xformers_performance_issue()\nfix_vllm_aimv2_issue()\n# Check vLLM + torch < 2.9.0 + SM100 compatibility BEFORE importing vLLM\ncheck_vllm_torch_sm100_compatibility()\nfix_vllm_guided_decoding_params()\nfix_vllm_pdl_blackwell()\nfix_triton_compiled_kernel_missing_attrs()\npatch_trunc_normal_precision_issue()\nignore_logger_messages()\npatch_ipykernel_hf_xet()\npatch_trackio()\npatch_datasets()\npatch_enable_input_require_grads()\nfix_openenv_no_vllm()\npatch_openspiel_env_async()\nfix_executorch()\npatch_vllm_for_notebooks()\npatch_torchcodec_audio_decoder()\ndisable_torchcodec_if_broken()\ndisable_broken_wandb()\n\ndel fix_xformers_performance_issue\ndel fix_vllm_aimv2_issue\ndel check_vllm_torch_sm100_compatibility\ndel fix_vllm_guided_decoding_params\ndel fix_vllm_pdl_blackwell\ndel fix_triton_compiled_kernel_missing_attrs\ndel patch_trunc_normal_precision_issue\ndel ignore_logger_messages\ndel patch_ipykernel_hf_xet\ndel patch_trackio\ndel patch_datasets\ndel patch_enable_input_require_grads\ndel fix_openenv_no_vllm\ndel patch_openspiel_env_async\ndel fix_executorch\ndel patch_vllm_for_notebooks\ndel patch_torchcodec_audio_decoder\ndel disable_torchcodec_if_broken\ndel disable_broken_wandb\n\n# Torch 2.4 has including_emulation\nif DEVICE_TYPE == \"cuda\":\n    major_version, minor_version = torch.cuda.get_device_capability()\n    SUPPORTS_BFLOAT16 = major_version >= 8\n\n    old_is_bf16_supported = torch.cuda.is_bf16_supported\n    if \"including_emulation\" in str(inspect.signature(old_is_bf16_supported)):\n\n        def is_bf16_supported(including_emulation = False):\n            return old_is_bf16_supported(including_emulation)\n\n        torch.cuda.is_bf16_supported = is_bf16_supported\n    else:\n\n        def is_bf16_supported():\n            return SUPPORTS_BFLOAT16\n\n        torch.cuda.is_bf16_supported = is_bf16_supported\n    del major_version, minor_version\nelif DEVICE_TYPE == \"hip\":\n    SUPPORTS_BFLOAT16 = torch.cuda.is_bf16_supported()\nelif DEVICE_TYPE == \"xpu\":\n    # torch.xpu.is_bf16_supported() does not have including_emulation\n    # set SUPPORTS_BFLOAT16 as torch.xpu.is_bf16_supported()\n    SUPPORTS_BFLOAT16 = torch.xpu.is_bf16_supported()\n\n# For Gradio HF Spaces?\n# if \"SPACE_AUTHOR_NAME\" not in os.environ and \"SPACE_REPO_NAME\" not in os.environ:\nimport triton\n\nif DEVICE_TYPE == \"cuda\":\n    libcuda_dirs = lambda: None\n    if Version(triton.__version__) >= Version(\"3.0.0\"):\n        try:\n            from triton.backends.nvidia.driver import libcuda_dirs\n        except:\n            pass\n    else:\n        from triton.common.build import libcuda_dirs\n\n    # Try loading bitsandbytes and triton\n    try:\n        import bitsandbytes as bnb\n    except:\n        print(\n            \"Unsloth: `bitsandbytes` is not installed - 4bit QLoRA unallowed, but 16bit and full finetuning works!\"\n        )\n        bnb = None\n    try:\n        cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32\n        libcuda_dirs()\n    except:\n        warnings.warn(\"Unsloth: Running `ldconfig /usr/lib64-nvidia` to link CUDA.\")\n\n        if os.path.exists(\"/usr/lib64-nvidia\"):\n            os.system(\"ldconfig /usr/lib64-nvidia\")\n        elif os.path.exists(\"/usr/local\"):\n            # Sometimes bitsandbytes cannot be linked properly in Runpod for example\n            possible_cudas = (\n                subprocess.check_output([\"ls\", \"-al\", \"/usr/local\"])\n                .decode(\"utf-8\")\n                .split(\"\\n\")\n            )\n            find_cuda = re.compile(r\"[\\s](cuda\\-[\\d\\.]{2,})$\")\n            possible_cudas = [find_cuda.search(x) for x in possible_cudas]\n            possible_cudas = [x.group(1) for x in possible_cudas if x is not None]\n\n            # Try linking cuda folder, or everything in local\n            if len(possible_cudas) == 0:\n                os.system(\"ldconfig /usr/local/\")\n            else:\n                find_number = re.compile(r\"([\\d\\.]{2,})\")\n                latest_cuda = np.argsort(\n                    [float(find_number.search(x).group(1)) for x in possible_cudas]\n                )[::-1][0]\n                latest_cuda = possible_cudas[latest_cuda]\n                os.system(f\"ldconfig /usr/local/{latest_cuda}\")\n                del find_number, latest_cuda\n            del possible_cudas, find_cuda\n\n        if bnb is not None:\n            importlib.reload(bnb)\n        importlib.reload(triton)\n        try:\n            libcuda_dirs = lambda: None\n            if Version(triton.__version__) >= Version(\"3.0.0\"):\n                try:\n                    from triton.backends.nvidia.driver import libcuda_dirs\n                except:\n                    pass\n            else:\n                from triton.common.build import libcuda_dirs\n            cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32\n            libcuda_dirs()\n        except:\n            warnings.warn(\n                \"Unsloth: CUDA is not linked properly.\\n\"\n                \"Try running `python -m bitsandbytes` then `python -m xformers.info`\\n\"\n                \"We tried running `ldconfig /usr/lib64-nvidia` ourselves, but it didn't work.\\n\"\n                \"You need to run in your terminal `sudo ldconfig /usr/lib64-nvidia` yourself, then import Unsloth.\\n\"\n                \"Also try `sudo ldconfig /usr/local/cuda-xx.x` - find the latest cuda version.\\n\"\n                \"Unsloth will still run for now, but maybe it might crash - let's hope it works!\"\n            )\n    del libcuda_dirs\nelif DEVICE_TYPE == \"hip\":\n    # NO-OP for rocm device\n    pass\nelif DEVICE_TYPE == \"xpu\":\n    import bitsandbytes as bnb\n\n    # TODO: check triton for intel installed properly.\n    pass\n\nfrom .models import *\nfrom .models import __version__\nfrom .save import *\nfrom .chat_templates import *\nfrom .tokenizer_utils import *\nfrom .trainer import *\n\n# Export dataprep utilities for CLI and downstream users\nfrom .dataprep.raw_text import RawTextDataLoader, TextPreprocessor\nfrom unsloth_zoo.rl_environments import (\n    check_python_modules,\n    create_locked_down_function,\n    execute_with_time_limit,\n    Benchmarker,\n    is_port_open,\n    launch_openenv,\n)\n\n# Patch TRL trainers for backwards compatibility\n_patch_trl_trainer()\n"
  },
  {
    "path": "unsloth/_auto_install.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\ntry: import torch\nexcept: raise ImportError('Install torch via `pip install torch`')\nfrom packaging.version import Version as V\nimport re\nv = V(re.match(r\"[0-9\\.]{3,}\", torch.__version__).group(0))\ncuda = str(torch.version.cuda)\nis_ampere = torch.cuda.get_device_capability()[0] >= 8\nUSE_ABI = torch._C._GLIBCXX_USE_CXX11_ABI\nif cuda not in (\"11.8\", \"12.1\", \"12.4\", \"12.6\", \"12.8\", \"13.0\"): raise RuntimeError(f\"CUDA = {cuda} not supported!\")\nif   v <= V('2.1.0'): raise RuntimeError(f\"Torch = {v} too old!\")\nelif v <= V('2.1.1'): x = 'cu{}{}-torch211'\nelif v <= V('2.1.2'): x = 'cu{}{}-torch212'\nelif v  < V('2.3.0'): x = 'cu{}{}-torch220'\nelif v  < V('2.4.0'): x = 'cu{}{}-torch230'\nelif v  < V('2.5.0'): x = 'cu{}{}-torch240'\nelif v  < V('2.5.1'): x = 'cu{}{}-torch250'\nelif v <= V('2.5.1'): x = 'cu{}{}-torch251'\nelif v  < V('2.7.0'): x = 'cu{}{}-torch260'\nelif v  < V('2.7.9'): x = 'cu{}{}-torch270'\nelif v  < V('2.8.0'): x = 'cu{}{}-torch271'\nelif v  < V('2.8.9'): x = 'cu{}{}-torch280'\nelif v  < V('2.9.1'): x = 'cu{}{}-torch290'\nelif v  < V('2.9.2'): x = 'cu{}{}-torch291'\nelif v  < V('2.10.1'): x = 'cu{}{}-torch2100'\nelse: raise RuntimeError(f\"Torch = {v} too new!\")\nif v > V('2.6.9') and cuda not in (\"11.8\", \"12.6\", \"12.8\", \"13.0\"): raise RuntimeError(f\"CUDA = {cuda} not supported!\")\nif v >= V('2.10.0') and cuda not in (\"12.6\", \"12.8\", \"13.0\"): raise RuntimeError(f\"Torch 2.10 requires CUDA 12.6, 12.8, or 13.0! Got CUDA = {cuda}\")\nx = x.format(cuda.replace(\".\", \"\"), \"-ampere\" if False else \"\") # is_ampere is broken due to flash-attn\nprint(f'pip install --upgrade pip && pip install --no-deps git+https://github.com/unslothai/unsloth-zoo.git && pip install \"unsloth[{x}] @ git+https://github.com/unslothai/unsloth.git\" --no-build-isolation')"
  },
  {
    "path": "unsloth/chat_templates.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n__all__ = [\n    \"get_chat_template\",\n    \"test_chat_templates\",\n    \"test_hf_gguf_equivalence\",\n    \"remove_special_tokens\",\n\n    \"to_sharegpt\",\n    \"standardize_sharegpt\",\n    \"standardize_data_formats\",\n    \"apply_chat_template\",\n    \"train_on_responses_only\",\n\n    \"test_construct_chat_template\",\n]\n\nfrom transformers import StoppingCriteria, StoppingCriteriaList\nfrom torch import LongTensor, FloatTensor\nfrom transformers.models.llama.modeling_llama import logger\nfrom .save import patch_saving_functions\nimport os\nimport shutil\nfrom .tokenizer_utils import *\nfrom .models._utils import patch_tokenizer\nimport re\nfrom .ollama_template_mappers import OLLAMA_TEMPLATES\nfrom unsloth_zoo.dataset_utils import (\n    train_on_responses_only,\n    standardize_data_formats,\n)\nstandardize_sharegpt = standardize_data_formats\nCHAT_TEMPLATES = {}\nDEFAULT_SYSTEM_MESSAGE = {}\ndef _ollama_template(name: str):\n    return OLLAMA_TEMPLATES[name]\n\n# =========================================== Unsloth\n# Unsloth efficient template leverages from Zephyr\nunsloth_template = \\\n    \"{{ bos_token }}\"\\\n    \"{% if messages[0]['role'] == 'system' %}\"\\\n        \"{{ messages[0]['content'] + '\\n' }}\"\\\n        \"{% set loop_messages = messages[1:] %}\"\\\n    \"{% else %}\"\\\n        \"{{ '{system_message}' + '\\n' }}\"\\\n        \"{% set loop_messages = messages %}\"\\\n    \"{% endif %}\"\\\n    \"{% for message in loop_messages %}\"\\\n        \"{% if message['role'] == 'user' %}\"\\\n            \"{{ '>>> User: ' + message['content'] + '\\n' }}\"\\\n        \"{% elif message['role'] == 'assistant' %}\"\\\n            \"{{ '>>> Assistant: ' + message['content'] + eos_token + '\\n' }}\"\\\n        \"{% else %}\"\\\n            \"{{ raise_exception('Only user and assistant roles are supported!') }}\"\\\n        \"{% endif %}\"\\\n    \"{% endfor %}\"\\\n    \"{% if add_generation_prompt %}\"\\\n        \"{{ '>>> Assistant: ' }}\"\\\n    \"{% endif %}\"\n\nunsloth_ollama = _ollama_template(\"unsloth\")\n\nunsloth_eos_token = \"eos_token\"\nCHAT_TEMPLATES[\"unsloth\"] = (unsloth_template, unsloth_eos_token, False, unsloth_ollama,)\nDEFAULT_SYSTEM_MESSAGE[\"unsloth\"] = \"You are a helpful assistant to the user\"\n\n# =========================================== Zephyr\n# Zephyr has no BOS!\nzephyr_template = \\\n    \"{% for message in messages %}\"\\\n        \"{% if message['role'] == 'user' %}\"\\\n            \"{{ '<|user|>\\n' + message['content'] + eos_token + '\\n' }}\"\\\n        \"{% elif message['role'] == 'assistant' %}\"\\\n            \"{{ '<|assistant|>\\n' + message['content'] + eos_token + '\\n' }}\"\\\n        \"{% else %}\"\\\n            \"{{ '<|system|>\\n' + message['content'] + eos_token + '\\n' }}\"\\\n        \"{% endif %}\"\\\n    \"{% endfor %}\"\\\n    \"{% if add_generation_prompt %}\"\\\n        \"{{ '<|assistant|>\\n' }}\"\\\n    \"{% endif %}\"\n\nzephyr_ollama = _ollama_template(\"zephyr\")\n\nzephyr_eos_token = \"eos_token\"\nCHAT_TEMPLATES[\"zephyr\"] = (zephyr_template, zephyr_eos_token, False, zephyr_ollama,)\nDEFAULT_SYSTEM_MESSAGE[\"zephyr\"] = None # No system message in Zephyr\n\n# =========================================== ChatML\n# ChatML has no BOS and not EOS! Rather <|im_start|> and <|im_end|> acts as BOS / EOS.\nchatml_template = \\\n    \"{% for message in messages %}\"\\\n        \"{% if message['role'] == 'user' %}\"\\\n            \"{{'<|im_start|>user\\n' + message['content'] + '<|im_end|>\\n'}}\"\\\n        \"{% elif message['role'] == 'assistant' %}\"\\\n            \"{{'<|im_start|>assistant\\n' + message['content'] + '<|im_end|>\\n' }}\"\\\n        \"{% else %}\"\\\n            \"{{ '<|im_start|>system\\n' + message['content'] + '<|im_end|>\\n' }}\"\\\n        \"{% endif %}\"\\\n    \"{% endfor %}\"\\\n    \"{% if add_generation_prompt %}\"\\\n        \"{{ '<|im_start|>assistant\\n' }}\"\\\n    \"{% endif %}\"\n\nchatml_ollama = _ollama_template(\"chatml\")\n\nchatml_eos_token = \"<|im_end|>\"\nCHAT_TEMPLATES[\"chatml\"] = (chatml_template, chatml_eos_token, True, chatml_ollama,)\nDEFAULT_SYSTEM_MESSAGE[\"chatml\"] = None # No system message in ChatML\n\n# =========================================== Mistral-1\n# Mistral Instruct doesn't allow system prompts, so we append it to the user message.\nmistral_template = \\\n    \"{{ bos_token }}\"\\\n    \"{% if messages[0]['role'] == 'system' %}\"\\\n        \"{% if messages[1]['role'] == 'user' %}\"\\\n            \"{{ '[INST] ' + messages[0]['content'] + ' ' + messages[1]['content'] + ' [/INST]' }}\"\\\n            \"{% set loop_messages = messages[2:] %}\"\\\n        \"{% else %}\"\\\n            \"{{ '[INST] ' + messages[0]['content'] + ' [/INST]' }}\"\\\n            \"{% set loop_messages = messages[1:] %}\"\\\n        \"{% endif %}\"\\\n    \"{% else %}\"\\\n        \"{% set loop_messages = messages %}\"\\\n    \"{% endif %}\"\\\n    \"{% for message in loop_messages %}\"\\\n        \"{% if message['role'] == 'user' %}\"\\\n            \"{{ '[INST] ' + message['content'] + ' [/INST]' }}\"\\\n        \"{% elif message['role'] == 'assistant' %}\"\\\n            \"{{ message['content'] + eos_token }}\"\\\n        \"{% else %}\"\\\n            \"{{ raise_exception('Only user and assistant roles are supported!') }}\"\\\n        \"{% endif %}\"\\\n    \"{% endfor %}\"\n\n# Ollama from https://www.ollama.com/library/mistral\nmistral_ollama = _ollama_template(\"mistral\")\n\nmistral_eos_token = \"eos_token\"\nCHAT_TEMPLATES[\"mistral\"] = (mistral_template, mistral_eos_token, False, mistral_ollama,)\nDEFAULT_SYSTEM_MESSAGE[\"mistral\"] = None # No system message in Mistral\n\n# =========================================== Llama-2\n# Adds BOS to every convo! And weird <<SYS>> system messages.\nllama_template = \\\n    \"{% if messages[0]['role'] == 'system' %}\"\\\n        \"{% if messages[1]['role'] == 'user' %}\"\\\n            \"{{ bos_token + '[INST] <<SYS>>\\n' + messages[0]['content'] + '\\n<</SYS>>\\n\\n' + messages[1]['content'] + ' [/INST]' }}\"\\\n            \"{% set loop_messages = messages[2:] %}\"\\\n        \"{% else %}\"\\\n            \"{{ bos_token + '[INST] ' + messages[0]['content'] + ' [/INST]' }}\"\\\n            \"{% set loop_messages = messages[1:] %}\"\\\n        \"{% endif %}\"\\\n    \"{% else %}\"\\\n        \"{% set loop_messages = messages %}\"\\\n    \"{% endif %}\"\\\n    \"{% for message in loop_messages %}\"\\\n        \"{% if message['role'] == 'user' %}\"\\\n            \"{{ bos_token + '[INST] ' + message['content'].strip() + ' [/INST]' }}\"\\\n        \"{% elif message['role'] == 'assistant' %}\"\\\n            \"{{ ' ' + message['content'].strip() + ' ' + eos_token }}\"\\\n        \"{% else %}\"\\\n            \"{{ raise_exception('Only user and assistant roles are supported!') }}\"\\\n        \"{% endif %}\"\\\n    \"{% endfor %}\"\n\n# Ollama from https://www.ollama.com/library/llama3\nllama_ollama = _ollama_template(\"llama\")\n\nllama_eos_token = \"eos_token\"\nCHAT_TEMPLATES[\"llama\"] = (llama_template, llama_eos_token, False, llama_ollama,)\nDEFAULT_SYSTEM_MESSAGE[\"llama\"] = None # No system message in Llama\n\n# ===========================================  Vicuna\n# https://github.com/lm-sys/FastChat/blob/main/docs/vicuna_weights_version.md#prompt-template\nvicuna_template = \\\n    \"{{ bos_token }}\"\\\n    \"{% if messages[0]['role'] == 'system' %}\"\\\n        \"{{ messages[0]['content'] + ' ' }}\"\\\n        \"{% set loop_messages = messages[1:] %}\"\\\n    \"{% else %}\"\\\n        \"{{ '{system_message}' + ' ' }}\"\\\n        \"{% set loop_messages = messages %}\"\\\n    \"{% endif %}\"\\\n    \"{% for message in loop_messages %}\"\\\n        \"{% if message['role'] == 'user' %}\"\\\n            \"{{ 'USER: ' + message['content'] + ' ' }}\"\\\n        \"{% elif message['role'] == 'assistant' %}\"\\\n            \"{{ 'ASSISTANT: ' + message['content'] + eos_token }}\"\\\n        \"{% else %}\"\\\n            \"{{ raise_exception('Only user and assistant roles are supported!') }}\"\\\n        \"{% endif %}\"\\\n    \"{% endfor %}\"\\\n    \"{% if add_generation_prompt %}\"\\\n        \"{{ 'ASSISTANT:' }}\"\\\n    \"{% endif %}\"\n\n# Ollama from https://www.ollama.com/library/vicuna\nvicuna_ollama = _ollama_template(\"vicuna\")\n\nvicuna_eos_token = \"eos_token\"\nCHAT_TEMPLATES[\"vicuna\"] = (vicuna_template, vicuna_eos_token, False, vicuna_ollama,)\nDEFAULT_SYSTEM_MESSAGE[\"vicuna\"] = \"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\"\n\n# =========================================== Vicuna Old\n# https://github.com/lm-sys/FastChat/blob/main/docs/vicuna_weights_version.md#prompt-template\nvicuna_old_template = \\\n    \"{{ bos_token }}\"\\\n    \"{% if messages[0]['role'] == 'system' %}\"\\\n        \"{{ messages[0]['content'] + '\\n' }}\"\\\n        \"{% set loop_messages = messages[1:] %}\"\\\n    \"{% else %}\"\\\n        \"{{ '{system_message}' + '\\n' }}\"\\\n        \"{% set loop_messages = messages %}\"\\\n    \"{% endif %}\"\\\n    \"{% for message in loop_messages %}\"\\\n        \"{% if message['role'] == 'user' %}\"\\\n            \"{{ '### Human: ' + message['content'] + '\\n' }}\"\\\n        \"{% elif message['role'] == 'assistant' %}\"\\\n            \"{{ '### Assistant: ' + message['content'] + eos_token + '\\n' }}\"\\\n        \"{% else %}\"\\\n            \"{{ raise_exception('Only user and assistant roles are supported!') }}\"\\\n        \"{% endif %}\"\\\n    \"{% endfor %}\"\\\n    \"{% if add_generation_prompt %}\"\\\n        \"{{ '### Assistant:' }}\"\\\n    \"{% endif %}\"\n\nvicuna_old_ollama = _ollama_template(\"vicuna_old\")\n\nvicuna_old_eos_token = \"eos_token\"\nCHAT_TEMPLATES[\"vicuna_old\"] = (vicuna_old_template, vicuna_old_eos_token, False, vicuna_old_ollama,)\nDEFAULT_SYSTEM_MESSAGE[\"vicuna_old\"] = \"A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human\\\\'s questions.\"\n\nCHAT_TEMPLATES[\"vicuna old\"] = CHAT_TEMPLATES[\"vicuna_old\"]\nDEFAULT_SYSTEM_MESSAGE[\"vicuna old\"] = DEFAULT_SYSTEM_MESSAGE[\"vicuna_old\"]\n\n# =========================================== Alpaca multi turn\n# https://github.com/tatsu-lab/stanford_alpaca Changed for multi-turn convos\nalpaca_template = \\\n    \"{{ bos_token }}\"\\\n    \"{% if messages[0]['role'] == 'system' %}\"\\\n        \"{{ messages[0]['content'] + '\\n\\n' }}\"\\\n        \"{% set loop_messages = messages[1:] %}\"\\\n    \"{% else %}\"\\\n        \"{{ '{system_message}' + '\\n\\n' }}\"\\\n        \"{% set loop_messages = messages %}\"\\\n    \"{% endif %}\"\\\n    \"{% for message in loop_messages %}\"\\\n        \"{% if message['role'] == 'user' %}\"\\\n            \"{{ '### Instruction:\\n' + message['content'] + '\\n\\n' }}\"\\\n        \"{% elif message['role'] == 'assistant' %}\"\\\n            \"{{ '### Response:\\n' + message['content'] + eos_token + '\\n\\n' }}\"\\\n        \"{% else %}\"\\\n            \"{{ raise_exception('Only user and assistant roles are supported!') }}\"\\\n        \"{% endif %}\"\\\n    \"{% endfor %}\"\\\n    \"{% if add_generation_prompt %}\"\\\n        \"{{ '### Response:\\n' }}\"\\\n    \"{% endif %}\"\n\nalpaca_ollama = _ollama_template(\"alpaca\")\n\nalpaca_eos_token = \"eos_token\"\nCHAT_TEMPLATES[\"alpaca\"] = (alpaca_template, alpaca_eos_token, False, alpaca_ollama,)\nDEFAULT_SYSTEM_MESSAGE[\"alpaca\"] = \"Below are some instructions that describe some tasks. Write responses that appropriately complete each request.\"\n\n# =========================================== Gemma\n# https://huggingface.co/google/gemma-7b-it\n# Notice we must use |trim for lstrip and rstrip. <start_of_turn> maps to 106.\n# <end_of_turn> maps to 107. user and model are normal 1 word tokens.\ngemma_template = \\\n    \"{{ bos_token }}\"\\\n    \"{% if messages[0]['role'] == 'system' %}\"\\\n        \"{{'<start_of_turn>user\\n' + messages[0]['content'] | trim + ' ' + messages[1]['content'] | trim + '<end_of_turn>\\n'}}\"\\\n        \"{% set messages = messages[2:] %}\"\\\n    \"{% endif %}\"\\\n    \"{% for message in messages %}\"\\\n        \"{% if message['role'] == 'user' %}\"\\\n            \"{{'<start_of_turn>user\\n' + message['content'] | trim + '<end_of_turn>\\n'}}\"\\\n        \"{% elif message['role'] == 'assistant' %}\"\\\n            \"{{'<start_of_turn>model\\n' + message['content'] | trim + '<end_of_turn>\\n' }}\"\\\n        \"{% else %}\"\\\n            \"{{ raise_exception('Only user and assistant roles are supported!') }}\"\\\n        \"{% endif %}\"\\\n    \"{% endfor %}\"\\\n    \"{% if add_generation_prompt %}\"\\\n        \"{{ '<start_of_turn>model\\n' }}\"\\\n    \"{% endif %}\"\n\n# Ollama from https://www.ollama.com/library/gemma\ngemma_ollama = _ollama_template(\"gemma\")\n\ngemma_eos_token = \"<end_of_turn>\"\nCHAT_TEMPLATES[\"gemma\"] = (gemma_template, gemma_eos_token, True, gemma_ollama,)\nDEFAULT_SYSTEM_MESSAGE[\"gemma\"] = None # No system message in Gemma\n\n# =========================================== Gemma with ChatML instead\n# We find using <eos> is still more appropriate!\ngemma_chatml_template = \"{{ bos_token }}\" + chatml_template\n\ngemma_chatml_ollama = _ollama_template(\"gemma_chatml\")\n\ngemma_chatml_eos_token = (\n    {\"<start_of_turn>\" : \"<|im_start|>\", \"<eos>\" : \"<|im_end|>\"},\n    \"<|im_end|>\",\n)\nCHAT_TEMPLATES[\"gemma_chatml\"] = (gemma_chatml_template, gemma_chatml_eos_token, True, gemma_chatml_ollama,)\nDEFAULT_SYSTEM_MESSAGE[\"gemma_chatml\"] = None # No system message in Gemma\n\n# =========================================== Gemma 2\n# Same as Gemma 1, but with sliding window attention!\n# https://ollama.com/library/gemma2/blobs/6522ca797f47\ngemma2_template = gemma_template\ngemma2_ollama = _ollama_template(\"gemma2\")\ngemma2_eos_token = \"<end_of_turn>\"\nCHAT_TEMPLATES[\"gemma2\"] = (gemma2_template, gemma2_eos_token, True, gemma2_ollama,)\nDEFAULT_SYSTEM_MESSAGE[\"gemma2\"] = None # No system message in Gemma 2\n\n# =========================================== Gemma 2 with ChatML instead\ngemma2_chatml_template = gemma_chatml_template\ngemma2_chatml_ollama = _ollama_template(\"gemma2_chatml\")\ngemma2_chatml_eos_token = gemma_chatml_eos_token\nCHAT_TEMPLATES[\"gemma2_chatml\"] = (gemma2_chatml_template, gemma2_chatml_eos_token, True, gemma2_chatml_ollama,)\nDEFAULT_SYSTEM_MESSAGE[\"gemma2_chatml\"] = None # No system message in Gemma 2\n\n# =========================================== Llama-3\n# Weirdly \\n\\n is needed?\nllama3_template = \\\n    \"{{ bos_token }}\"\\\n    \"{% for message in messages %}\"\\\n        \"{% if message['role'] == 'user' %}\"\\\n            \"{{ '<|start_header_id|>user<|end_header_id|>\\n\\n' + message['content'] | trim + '<|eot_id|>' }}\"\\\n        \"{% elif message['role'] == 'assistant' %}\"\\\n            \"{{ '<|start_header_id|>assistant<|end_header_id|>\\n\\n' + message['content'] | trim + '<|eot_id|>' }}\"\\\n        \"{% else %}\"\\\n            \"{{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n' + message['content'] | trim + '<|eot_id|>' }}\"\\\n        \"{% endif %}\"\\\n    \"{% endfor %}\"\\\n    \"{% if add_generation_prompt %}\"\\\n        \"{{ '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\"\\\n    \"{% endif %}\"\n\n# Ollama from https://www.ollama.com/library/llama3\nllama3_ollama = _ollama_template(\"llama-3\")\n\nllama3_template_eos_token = \"eos_token\"\n\nCHAT_TEMPLATES[\"llama-3\"] = (llama3_template, llama3_template_eos_token, False, llama3_ollama,)\nDEFAULT_SYSTEM_MESSAGE[\"llama-3\"] = None # No system message in Llama-3\n\nCHAT_TEMPLATES[\"llama3\"] = (llama3_template, llama3_template_eos_token, False, llama3_ollama,)\nDEFAULT_SYSTEM_MESSAGE[\"llama3\"] = None # No system message in Llama-3\n\n\n# =========================================== Phi-3\n# \"{{ bos_token }}\"\\ # Phi-3.5 removes BOS?\nphi3_template = \\\n    \"{% for message in messages %}\"\\\n        \"{% if message['role'] == 'user' %}\"\\\n            \"{{'<|user|>\\n' + message['content'] + '<|end|>\\n'}}\"\\\n        \"{% elif message['role'] == 'assistant' %}\"\\\n            \"{{'<|assistant|>\\n' + message['content'] + '<|end|>\\n'}}\"\\\n        \"{% else %}\"\\\n            \"{{'<|' + message['role'] + '|>\\n' + message['content'] + '<|end|>\\n'}}\"\\\n        \"{% endif %}\"\\\n    \"{% endfor %}\"\\\n    \"{% if add_generation_prompt %}\"\\\n        \"{{ '<|assistant|>\\n' }}\"\\\n    \"{% endif %}\"\n\n# Ollama from https://www.ollama.com/library/phi3\nphi3_ollama = _ollama_template(\"phi-3\")\n\nphi3_template_eos_token = \"<|end|>\"\nCHAT_TEMPLATES[\"phi-3\"]   = (phi3_template, phi3_template_eos_token, False, phi3_ollama,)\nDEFAULT_SYSTEM_MESSAGE[\"phi-3\"] = None # No system message in Phi-3\n\nCHAT_TEMPLATES[\"phi-35\"]  = CHAT_TEMPLATES[\"phi-3\"]\nDEFAULT_SYSTEM_MESSAGE[\"phi-35\"] = None # No system message in Phi-3.5\n\nCHAT_TEMPLATES[\"phi-3.5\"] = CHAT_TEMPLATES[\"phi-3\"]\nDEFAULT_SYSTEM_MESSAGE[\"phi-3.5\"] = None # No system message in Phi-3.5\n\n# =========================================== Llama-3.1\n\"\"\"\nNo trimming in Llama 3.1 Instruct!\nAlso an extra newline for Cutting Knowledge Date\nSee https://colab.research.google.com/drive/1Xpqq5xpIgO-B00MQ-UccYMwN2J8QFgBM?usp=sharing\n\nAlso should be\n\nimport datetime\ntokenizer.apply_chat_template(\n    messages,\n    add_generation_prompt = True,\n    tokenize = False,\n    date_string = datetime.today().strftime(\"%d %B %Y\")),\n)\n\"\"\"\n\nllama31_template = \\\n\"\"\"{{- bos_token }}\n{%- if custom_tools is defined %}\n    {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n    {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n    {%- set date_string = \"26 July 2024\" %}\n{%- endif %}\n{%- if not tools is defined %}\n    {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n    {%- set system_message = messages[0]['content'] %}\n    {%- set messages = messages[1:] %}\n{%- else %}\n    {%- set system_message = \"{system_message}\" %}\n{%- endif %}\n\n{#- System message + builtin tools #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if builtin_tools is defined or tools is not none %}\n    {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{%- if builtin_tools is defined %}\n    {{- \"Tools: \" + builtin_tools | reject('equalto', 'code_interpreter') | join(\", \") + \"\\n\\n\"}}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n    {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n    {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n    {{- \"Do not use variables.\\n\\n\" }}\n    {%- for t in tools %}\n        {{- t | tojson(indent=4) }}\n        {{- \"\\n\\n\" }}\n    {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n    {#- Extract the first user message so we can plug it in here #}\n    {%- if messages | length != 0 %}\n        {%- set first_user_message = messages[0]['content'] %}\n        {%- set messages = messages[1:] %}\n    {%- else %}\n        {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n    {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n    {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n    {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n    {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n    {{- \"Do not use variables.\\n\\n\" }}\n    {%- for t in tools %}\n        {{- t | tojson(indent=4) }}\n        {{- \"\\n\\n\" }}\n    {%- endfor %}\n    {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n    {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n        {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] + '<|eot_id|>' }}\n    {%- elif 'tool_calls' in message %}\n        {%- if not message.tool_calls|length == 1 %}\n            {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n        {%- endif %}\n        {%- set tool_call = message.tool_calls[0].function %}\n        {%- if builtin_tools is defined and tool_call.name in builtin_tools %}\n            {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n            {{- \"<|python_tag|>\" + tool_call.name + \".call(\" }}\n            {%- for arg_name, arg_val in tool_call.arguments | items %}\n                {{- arg_name + '=\"' + arg_val + '\"' }}\n                {%- if not loop.last %}\n                    {{- \", \" }}\n                {%- endif %}\n                {%- endfor %}\n            {{- \")\" }}\n        {%- else  %}\n            {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n            {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n            {{- '\"parameters\": ' }}\n            {{- tool_call.arguments | tojson }}\n            {{- \"}\" }}\n        {%- endif %}\n        {%- if builtin_tools is defined %}\n            {#- This means we're in ipython mode #}\n            {{- \"<|eom_id|>\" }}\n        {%- else %}\n            {{- \"<|eot_id|>\" }}\n        {%- endif %}\n    {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n        {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n        {%- if message.content is mapping or message.content is iterable %}\n            {{- message.content | tojson }}\n        {%- else %}\n            {{- message.content }}\n        {%- endif %}\n        {{- \"<|eot_id|>\" }}\n    {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n    {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n\"\"\"\n\n# Ollama from https://ollama.com/library/llama3.1 (needs updating!)\nllama31_ollama = _ollama_template(\"llama-3.1\")\n\nllama31_template_eos_token = \"eos_token\"\nCHAT_TEMPLATES[\"llama-3.1\"] = (llama31_template, llama31_template_eos_token, False, llama31_ollama,)\nDEFAULT_SYSTEM_MESSAGE[\"llama-3.1\"] = \"\" # Llama3.1 default system message is empty + the dates\n\nCHAT_TEMPLATES[\"llama-31\"]  = (llama31_template, llama31_template_eos_token, False, llama31_ollama,)\nDEFAULT_SYSTEM_MESSAGE[\"llama-31\"] = \"\" # Llama3.1 default system message is empty + the dates\n\nfor version in (\"llama-3.2\", \"llama-3.3\", \"llama-32\", \"llama-33\"):\n    CHAT_TEMPLATES[version] = CHAT_TEMPLATES[\"llama-3.1\"]\n    DEFAULT_SYSTEM_MESSAGE[version] = \"\"\n\n\n# =========================================== Qwen 2.5\nqwen25_template = \\\n\"\"\"{%- if tools %}\n    {{- \\'<|im_start|>system\\\\n\\' }}\n    {%- if messages[0][\\'role\\'] == \\'system\\' %}\n        {{- messages[0][\\'content\\'] }}\n    {%- else %}\n        {{- \\'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.\\' }}\n    {%- endif %}\n    {{- \"\\\\n\\\\n# Tools\\\\n\\\\nYou may call one or more functions to assist with the user query.\\\\n\\\\nYou are provided with function signatures within <tools></tools> XML tags:\\\\n<tools>\" }}\n    {%- for tool in tools %}\n        {{- \"\\\\n\" }}\n        {{- tool | tojson }}\n    {%- endfor %}\n    {{- \"\\\\n</tools>\\\\n\\\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\\\n<tool_call>\\\\n{\\\\\"name\\\\\": <function-name>, \\\\\"arguments\\\\\": <args-json-object>}\\\\n</tool_call><|im_end|>\\\\n\" }}\\n{%- else %}\n    {%- if messages[0][\\'role\\'] == \\'system\\' %}\n        {{- \\'<|im_start|>system\\\\n\\' + messages[0][\\'content\\'] + \\'<|im_end|>\\\\n\\' }}\n    {%- else %}\n        {{- \\'<|im_start|>system\\\\n{system_message}<|im_end|>\\\\n\\' }}\n    {%- endif %}\\n{%- endif %}\\n{%- for message in messages %}\n    {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n        {{- \\'<|im_start|>\\' + message.role + \\'\\\\n\\' + message.content + \\'<|im_end|>\\' + \\'\\\\n\\' }}\n    {%- elif message.role == \"assistant\" %}\n        {{- \\'<|im_start|>\\' + message.role }}\n        {%- if message.content %}\n            {{- \\'\\\\n\\' + message.content }}\n        {%- endif %}\n        {%- for tool_call in message.tool_calls %}\n            {%- if tool_call.function is defined %}\n                {%- set tool_call = tool_call.function %}\n            {%- endif %}\n            {{- \\'\\\\n<tool_call>\\\\n{\"name\": \"\\' }}\n            {{- tool_call.name }}\n            {{- \\'\", \"arguments\": \\' }}\n            {{- tool_call.arguments | tojson }}\n            {{- \\'}\\\\n</tool_call>\\' }}\n        {%- endfor %}\n        {{- \\'<|im_end|>\\\\n\\' }}\n    {%- elif message.role == \"tool\" %}\n        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}            {{- \\'<|im_start|>user\\' }}\n        {%- endif %}\n        {{- \\'\\\\n<tool_response>\\\\n\\' }}\n        {{- message.content }}\n        {{- \\'\\\\n</tool_response>\\' }}\n        {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n            {{- \\'<|im_end|>\\\\n\\' }}\n        {%- endif %}\n    {%- endif %}\\n{%- endfor %}\\n{%- if add_generation_prompt %}\n    {{- \\'<|im_start|>assistant\\\\n\\' }}\n{%- endif %}\n\"\"\"\n\n\n# Ollama from https://ollama.com/library/qwen2.5/blobs/eb4402837c78\nqwen25_ollama = _ollama_template(\"qwen-2.5\")\n\nqwen25_template_eos_token = \"eos_token\"\nqwen25_default_system_message = \"You are Qwen, created by Alibaba Cloud. You are a helpful assistant.\"\nCHAT_TEMPLATES[\"qwen-2.5\"] = (qwen25_template, qwen25_template_eos_token, False, qwen25_ollama,)\nDEFAULT_SYSTEM_MESSAGE[\"qwen-2.5\"] = qwen25_default_system_message # No system message in Qwen 2.5\n\nCHAT_TEMPLATES[\"qwen-25\"]  = (qwen25_template, qwen25_template_eos_token, False, qwen25_ollama,)\nDEFAULT_SYSTEM_MESSAGE[\"qwen-25\"] = qwen25_default_system_message # No system message in Qwen 2.5\n\nCHAT_TEMPLATES[\"qwen25\"]   = (qwen25_template, qwen25_template_eos_token, False, qwen25_ollama,)\nDEFAULT_SYSTEM_MESSAGE[\"qwen25\"] = qwen25_default_system_message # No system message in Qwen 2.5\n\nCHAT_TEMPLATES[\"qwen2.5\"]  = (qwen25_template, qwen25_template_eos_token, False, qwen25_ollama,)\nDEFAULT_SYSTEM_MESSAGE[\"qwen2.5\"] = qwen25_default_system_message # No system message in Qwen 2.5\n\n# =========================================== Phi-4\n# \"{{ bos_token }}\"\\ # Phi-4 removes BOS?\nphi4_template = \\\n    \"{% for message in messages %}\"\\\n        \"{% if (message['role'] == 'system') %}\"\\\n            \"{{'<|im_start|>system<|im_sep|>' + message['content'] + '<|im_end|>'}}\"\\\n        \"{% elif (message['role'] == 'user') %}\"\\\n            \"{{'<|im_start|>user<|im_sep|>' + message['content'] + '<|im_end|>'}}\"\\\n        \"{% elif (message['role'] == 'assistant') %}\"\\\n            \"{{'<|im_start|>assistant<|im_sep|>' + message['content'] + '<|im_end|>'}}\"\\\n        \"{% endif %}\"\\\n    \"{% endfor %}\"\\\n    \"{% if add_generation_prompt %}\"\\\n        \"{{ '<|im_start|>assistant<|im_sep|>' }}\"\\\n    \"{% endif %}\"\n\n_phi4_ollama_template = \\\n    \"{{ if .System }}<|im_start|><|system|><|im_sep|>{{ .System }}<|im_end|>{{ end }}\"\\\n    \"{{ if .Prompt }}<|im_start|><|user|><|im_sep|>{{ .Prompt }}<|im_end|>{{ end }}\"\\\n    \"<|im_start|><|assistant|><|im_sep|>{{ .Response }}<|im_end|>\"\n\n# Ollama from https://www.ollama.com/library/phi4 is different\nphi4_ollama = _ollama_template(\"phi-4\")\n\nphi4_template_eos_token = \"<|im_end|>\"\nCHAT_TEMPLATES[\"phi-4\"] = (phi4_template, phi4_template_eos_token, False, phi4_ollama,)\nDEFAULT_SYSTEM_MESSAGE[\"phi-4\"] = None # No system message in Phi-4\n\n\n# =========================================== Gemma-3\n# Obtained via\n# print(tokenizer.chat_template.replace(\"}\\n\", \"####\").replace(\"\\n\", \"\\\\n\").replace(\"####\", \"}\\n\"))\ngemma3_template = \\\n\"\"\"{{ bos_token }}\n{%- if messages[0]['role'] == 'system' -%}\n    {%- if messages[0]['content'] is string -%}\n        {%- set first_user_prefix = messages[0]['content'] + '\\n\\n' -%}\n    {%- else -%}\n        {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\\n\\n' -%}\n    {%- endif -%}\n    {%- set loop_messages = messages[1:] -%}\n{%- else -%}\n    {%- set first_user_prefix = \"\" -%}\n    {%- set loop_messages = messages -%}\n{%- endif -%}\n{%- for message in loop_messages -%}\n    {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}\n        {{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}\n    {%- endif -%}\n    {%- if (message['role'] == 'assistant') -%}\n        {%- set role = \"model\" -%}\n    {%- else -%}\n        {%- set role = message['role'] -%}\n    {%- endif -%}\n    {{ '<start_of_turn>' + role + '\\n' + (first_user_prefix if loop.first else \"\") }}\n    {%- if message['content'] is string -%}\n        {{ message['content'] | trim }}\n    {%- elif message['content'] is iterable -%}\n        {%- for item in message['content'] -%}\n            {%- if item['type'] == 'image' -%}\n                {{ '<start_of_image>' }}\n            {%- elif item['type'] == 'text' -%}\n                {{ item['text'] | trim }}\n            {%- endif -%}\n        {%- endfor -%}\n    {%- else -%}\n        {{ raise_exception(\"Invalid content type\") }}\n    {%- endif -%}\n    {{ '<end_of_turn>\\n' }}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n    {{ '<start_of_turn>model\\n' }}\n{%- endif -%}\n\"\"\"\n\n# Ollama from https://ollama.com/library/gemma3/blobs/e0a42594d802\ngemma3_ollama = _ollama_template(\"gemma-3\")\n\ngemma3_template_eos_token = \"<end_of_turn>\"\nCHAT_TEMPLATES[\"gemma-3\"] = (gemma3_template, gemma3_template_eos_token, False, gemma3_ollama,)\nDEFAULT_SYSTEM_MESSAGE[\"gemma-3\"] = None # No system message in Gemma-3\n\nCHAT_TEMPLATES[\"gemma3\"] = (gemma3_template, gemma3_template_eos_token, False, gemma3_ollama,)\nDEFAULT_SYSTEM_MESSAGE[\"gemma3\"] = None # No system message in Gemma-3\n\n# =========================================== Qwen-3\n# Official Qwen-3 chat template (see https://ollama.com/library/qwen3/blobs/eb4402837c78)\nqwen3_template = \\\n\"\"\"\n{%- if tools %}\n    {{- '<|im_start|>system\\n' }}\n    {%- if messages[0].role == 'system' %}\n        {{- messages[0].content + '\\n\\n' }}\n    {%- endif %}\n    {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n    {%- for tool in tools %}\n        {{- \"\\n\" }}\n        {{- tool | tojson }}\n    {%- endfor %}\n    {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\\\"name\\\\\": <function-name>, \\\\\"arguments\\\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n    {%- if messages[0].role == 'system' %}\n        {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n    {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for forward_message in messages %}\n    {%- set index = (messages|length - 1) - loop.index0 %}\n    {%- set message = messages[index] %}\n    {%- set current_content = message.content if message.content is not none else '' %}\n    {%- set tool_start = '<tool_response>' %}\n    {%- set tool_start_length = tool_start|length %}\n    {%- set start_of_message = current_content[:tool_start_length] %}\n    {%- set tool_end = '</tool_response>' %}\n    {%- set tool_end_length = tool_end|length %}\n    {%- set start_pos = (current_content|length) - tool_end_length %}\n    {%- if start_pos < 0 %}\n        {%- set start_pos = 0 %}\n    {%- endif %}\n    {%- set end_of_message = current_content[start_pos:] %}\n    {%- if ns.multi_step_tool and message.role == \"user\" and not(start_of_message == tool_start and end_of_message == tool_end) %}\n        {%- set ns.multi_step_tool = false %}\n        {%- set ns.last_query_index = index %}\n    {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n    {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n        {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n    {%- elif message.role == \"assistant\" %}\n        {%- set content = message.content %}\n        {%- set reasoning_content = '' %}\n        {%- if message.reasoning_content is defined and message.reasoning_content is not none %}\n            {%- set reasoning_content = message.reasoning_content %}\n        {%- else %}\n            {%- if '</think>' in message.content %}\n                {%- set content = (message.content.split('</think>')|last).lstrip('\\n') %}\n                {%- set reasoning_content = (message.content.split('</think>')|first).rstrip('\\n') %}\n                {%- set reasoning_content = (reasoning_content.split('<think>')|last).lstrip('\\n') %}\n            {%- endif %}\n        {%- endif %}\n        {%- if loop.index0 > ns.last_query_index %}\n            {%- if loop.last or (not loop.last and reasoning_content) %}\n                {{- '<|im_start|>' + message.role + '\\n<think>\\n' + reasoning_content.strip('\\n') + '\\n</think>\\n\\n' + content.lstrip('\\n') }}\n            {%- else %}\n                {{- '<|im_start|>' + message.role + '\\n' + content }}\n            {%- endif %}\n        {%- else %}\n            {{- '<|im_start|>' + message.role + '\\n' + content }}\n        {%- endif %}\n        {%- if message.tool_calls %}\n            {%- for tool_call in message.tool_calls %}\n                {%- if (loop.first and content) or (not loop.first) %}\n                    {{- '\\n' }}\n                {%- endif %}\n                {%- if tool_call.function %}\n                    {%- set tool_call = tool_call.function %}\n                {%- endif %}\n                {{- '<tool_call>\\n{\"name\": \"' }}\n                {{- tool_call.name }}\n                {{- '\", \"arguments\": ' }}\n                {%- if tool_call.arguments is string %}\n                    {{- tool_call.arguments }}\n                {%- else %}\n                    {{- tool_call.arguments | tojson }}\n                {%- endif %}\n                {{- '}\\n</tool_call>' }}\n            {%- endfor %}\n        {%- endif %}\n        {{- '<|im_end|>\\n' }}\n    {%- elif message.role == \"tool\" %}\n        {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n            {{- '<|im_start|>user' }}\n        {%- endif %}\n        {{- '\\n<tool_response>\\n' }}\n        {{- message.content }}\n        {{- '\\n</tool_response>' }}\n        {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n            {{- '<|im_end|>\\n' }}\n        {%- endif %}\n    {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n    {{- '<|im_start|>assistant\\n' }}\n    {%- if enable_thinking is defined and enable_thinking is false %}\n        {{- '<think>\\n\\n</think>\\n\\n' }}\n    {%- endif %}\n{%- endif %}\n\"\"\"\n\nqwen3_ollama = _ollama_template(\"qwen-3\")\nqwen3_template_eos_token = \"<|im_end|>\"\nCHAT_TEMPLATES[\"qwen-3\"] = (qwen3_template, qwen3_template_eos_token, False, qwen3_ollama,)\nDEFAULT_SYSTEM_MESSAGE[\"qwen-3\"] = None # No default system message for Qwen-3\n\nCHAT_TEMPLATES[\"qwen3\"] = (qwen3_template, qwen3_template_eos_token, False, qwen3_ollama,)\nDEFAULT_SYSTEM_MESSAGE[\"qwen3\"] = None # No default system message for Qwen-3\n\n# =========================================== Gemma-3n\n# Obtained via\n# print(tokenizer.chat_template.replace(\"}\\n\", \"####\").replace(\"\\n\", \"\\\\n\").replace(\"####\", \"}\\n\"))\ngemma3n_template = \\\n\"\"\"{{ bos_token }}\n{%- if messages[0]['role'] == 'system' -%}\n    {%- if messages[0]['content'] is string -%}\n        {%- set first_user_prefix = messages[0]['content'] + '\\n\\n' -%}\n    {%- else -%}\n        {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\\n\\n' -%}\n    {%- endif -%}\n    {%- set loop_messages = messages[1:] -%}\n{%- else -%}\n    {%- set first_user_prefix = \"\" -%}\n    {%- set loop_messages = messages -%}\n{%- endif -%}\n{%- for message in loop_messages -%}\n    {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}\n        {{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}\n    {%- endif -%}\n    {%- if (message['role'] == 'assistant') -%}\n        {%- set role = \"model\" -%}\n    {%- else -%}\n        {%- set role = message['role'] -%}\n    {%- endif -%}\n    {{ '<start_of_turn>' + role + '\\n' + (first_user_prefix if loop.first else \"\") }}\n    {%- if message['content'] is string -%}\n        {{ message['content'] | trim }}\n    {%- elif message['content'] is iterable -%}\n        {%- for item in message['content'] -%}\n            {%- if item['type'] == 'audio' -%}\n                {{ '<audio_soft_token>' }}\n            {%- elif item['type'] == 'image' -%}\n                {{ '<image_soft_token>' }}\n            {%- elif item['type'] == 'text' -%}\n                {{ item['text'] | trim }}\n            {%- endif -%}\n        {%- endfor -%}\n    {%- else -%}\n        {{ raise_exception(\"Invalid content type\") }}\n    {%- endif -%}\n    {{ '<end_of_turn>\\n' }}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n    {{'<start_of_turn>model\\n'}}\n{%- endif -%}\n\"\"\"\n\n# Ollama from https://ollama.com/library/gemma3n/blobs/e0a42594d802\ngemma3n_ollama = _ollama_template(\"gemma-3n\")\ngemma3n_template_eos_token = \"<end_of_turn>\"\nCHAT_TEMPLATES[\"gemma-3n\"] = (gemma3n_template, gemma3n_template_eos_token, False, gemma3n_ollama,)\nDEFAULT_SYSTEM_MESSAGE[\"gemma-3n\"] = None # No system message in Gemma-3n\n\nCHAT_TEMPLATES[\"gemma3n\"] = (gemma3n_template, gemma3n_template_eos_token, False, gemma3n_ollama,)\nDEFAULT_SYSTEM_MESSAGE[\"gemma3n\"] = None # No system message in Gemma-3n\n\n# =========================================== GPT-OSS\n# Obtained via\n# print(tokenizer.chat_template.replace(\"}\\n\", \"####\").replace(\"\\n\", \"\\\\n\").replace(\"####\", \"}\\n\"))\ngptoss_template = \\\n\"\"\"{#-\n  In addition to the normal inputs of `messages` and `tools`, this template also accepts the\n  following kwargs:\n  - \"builtin_tools\": A list, can contain \"browser\" and/or \"python\".\n  - \"model_identity\": A string that optionally describes the model identity.\n  - \"reasoning_effort\": A string that describes the reasoning effort, defaults to \"medium\".\n #}\n\n{#- Tool Definition Rendering ============================================== #}\n{%- macro render_typescript_type(param_spec, required_params, is_nullable=false) -%}\n    {%- if param_spec.type == \"array\" -%}\n        {%- if param_spec['items'] -%}\n            {%- if param_spec['items']['type'] == \"string\" -%}\n                {{- \"string[]\" }}\n            {%- elif param_spec['items']['type'] == \"number\" -%}\n                {{- \"number[]\" }}\n            {%- elif param_spec['items']['type'] == \"integer\" -%}\n                {{- \"number[]\" }}\n            {%- elif param_spec['items']['type'] == \"boolean\" -%}\n                {{- \"boolean[]\" }}\n            {%- else -%}\n                {%- set inner_type = render_typescript_type(param_spec['items'], required_params) -%}\n                {%- if inner_type == \"object | object\" or inner_type|length > 50 -%}\n                    {{- \"any[]\" }}\n                {%- else -%}\n                    {{- inner_type + \"[]\" }}\n                {%- endif -%}\n            {%- endif -%}\n            {%- if param_spec.nullable -%}\n                {{- \" | null\" }}\n            {%- endif -%}\n        {%- else -%}\n            {{- \"any[]\" }}\n            {%- if param_spec.nullable -%}\n                {{- \" | null\" }}\n            {%- endif -%}\n        {%- endif -%}\n    {%- elif param_spec.type is defined and param_spec.type is iterable and param_spec.type is not string and param_spec.type is not mapping and param_spec.type[0] is defined -%}\n        {#- Handle array of types like [\"object\", \"object\"] from Union[dict, list] #}\n        {%- if param_spec.type | length > 1 -%}\n            {{- param_spec.type | join(\" | \") }}\n        {%- else -%}\n            {{- param_spec.type[0] }}\n        {%- endif -%}\n    {%- elif param_spec.oneOf -%}\n        {#- Handle oneOf schemas - check for complex unions and fallback to any #}\n        {%- set has_object_variants = false -%}\n        {%- for variant in param_spec.oneOf -%}\n            {%- if variant.type == \"object\" -%}\n                {%- set has_object_variants = true -%}\n            {%- endif -%}\n        {%- endfor -%}\n        {%- if has_object_variants and param_spec.oneOf|length > 1 -%}\n            {{- \"any\" }}\n        {%- else -%}\n            {%- for variant in param_spec.oneOf -%}\n                {{- render_typescript_type(variant, required_params) -}}\n                {%- if variant.description %}\n                    {{- \"// \" + variant.description }}\n                {%- endif -%}\n                {%- if variant.default is defined %}\n                    {{ \"// default: \" + variant.default|tojson }}\n                {%- endif -%}\n                {%- if not loop.last %}\n                    {{- \" | \" }}\n                {% endif -%}\n            {%- endfor -%}\n        {%- endif -%}\n    {%- elif param_spec.type == \"string\" -%}\n        {%- if param_spec.enum -%}\n            {{- '\"' + param_spec.enum|join('\" | \"') + '\"' -}}\n        {%- else -%}\n            {{- \"string\" }}\n            {%- if param_spec.nullable %}\n                {{- \" | null\" }}\n            {%- endif -%}\n        {%- endif -%}\n    {%- elif param_spec.type == \"number\" -%}\n        {{- \"number\" }}\n    {%- elif param_spec.type == \"integer\" -%}\n        {{- \"number\" }}\n    {%- elif param_spec.type == \"boolean\" -%}\n        {{- \"boolean\" }}\n\n    {%- elif param_spec.type == \"object\" -%}\n        {%- if param_spec.properties -%}\n            {{- \"{\\n\" }}\n            {%- for prop_name, prop_spec in param_spec.properties.items() -%}\n                {{- prop_name -}}\n                {%- if prop_name not in (param_spec.required or []) -%}\n                    {{- \"?\" }}\n                {%- endif -%}\n                {{- \": \" }}\n                {{ render_typescript_type(prop_spec, param_spec.required or []) }}\n                {%- if not loop.last -%}\n                    {{-\", \" }}\n                {%- endif -%}\n            {%- endfor -%}\n            {{- \"}\" }}\n        {%- else -%}\n            {{- \"object\" }}\n        {%- endif -%}\n    {%- else -%}\n        {{- \"any\" }}\n    {%- endif -%}\n{%- endmacro -%}\n\n{%- macro render_tool_namespace(namespace_name, tools) -%}\n    {{- \"## \" + namespace_name + \"\\n\\n\" }}\n    {{- \"namespace \" + namespace_name + \" {\\n\\n\" }}\n    {%- for tool in tools %}\n        {%- set tool = tool.function %}\n        {{- \"// \" + tool.description + \"\\n\" }}\n        {{- \"type \"+ tool.name + \" = \" }}\n        {%- if tool.parameters and tool.parameters.properties %}\n            {{- \"(_: {\\n\" }}\n            {%- for param_name, param_spec in tool.parameters.properties.items() %}\n                {%- if param_spec.description %}\n                    {{- \"// \" + param_spec.description + \"\\n\" }}\n                {%- endif %}\n                {{- param_name }}\n                {%- if param_name not in (tool.parameters.required or []) -%}\n                    {{- \"?\" }}\n                {%- endif -%}\n                {{- \": \" }}\n                {{- render_typescript_type(param_spec, tool.parameters.required or []) }}\n                {%- if param_spec.default is defined -%}\n                    {%- if param_spec.enum %}\n                        {{- \", // default: \" + param_spec.default }}\n                    {%- elif param_spec.oneOf %}\n                        {{- \"// default: \" + param_spec.default }}\n                    {%- else %}\n                        {{- \", // default: \" + param_spec.default|tojson }}\n                    {%- endif -%}\n                {%- endif -%}\n                {%- if not loop.last %}\n                    {{- \",\\n\" }}\n                {%- else %}\n                    {{- \",\\n\" }}\n                {%- endif -%}\n            {%- endfor %}\n            {{- \"}) => any;\\n\\n\" }}\n        {%- else -%}\n            {{- \"() => any;\\n\\n\" }}\n        {%- endif -%}\n    {%- endfor %}\n    {{- \"} // namespace \" + namespace_name }}\n{%- endmacro -%}\n\n{%- macro render_builtin_tools(browser_tool, python_tool) -%}\n    {%- if browser_tool %}\n        {{- \"## browser\\n\\n\" }}\n        {{- \"// Tool for browsing.\\n\" }}\n        {{- \"// The `cursor` appears in brackets before each browsing display: `[{cursor}]`.\\n\" }}\n        {{- \"// Cite information from the tool using the following format:\\n\" }}\n        {{- \"// `【{cursor}†L{line_start}(-L{line_end})?】`, for example: `【6†L9-L11】` or `【8†L3】`.\\n\" }}\n        {{- \"// Do not quote more than 10 words directly from the tool output.\\n\" }}\n        {{- \"// sources=web (default: web)\\n\" }}\n        {{- \"namespace browser {\\n\\n\" }}\n        {{- \"// Searches for information related to `query` and displays `topn` results.\\n\" }}\n        {{- \"type search = (_: {\\n\" }}\n        {{- \"query: string,\\n\" }}\n        {{- \"topn?: number, // default: 10\\n\" }}\n        {{- \"source?: string,\\n\" }}\n        {{- \"}) => any;\\n\\n\" }}\n        {{- \"// Opens the link `id` from the page indicated by `cursor` starting at line number `loc`, showing `num_lines` lines.\\n\" }}\n        {{- \"// Valid link ids are displayed with the formatting: `【{id}†.*】`.\\n\" }}\n        {{- \"// If `cursor` is not provided, the most recent page is implied.\\n\" }}\n        {{- \"// If `id` is a string, it is treated as a fully qualified URL associated with `source`.\\n\" }}\n        {{- \"// If `loc` is not provided, the viewport will be positioned at the beginning of the document or centered on the most relevant passage, if available.\\n\" }}\n        {{- \"// Use this function without `id` to scroll to a new location of an opened page.\\n\" }}\n        {{- \"type open = (_: {\\n\" }}\n        {{- \"id?: number | string, // default: -1\\n\" }}\n        {{- \"cursor?: number, // default: -1\\n\" }}\n        {{- \"loc?: number, // default: -1\\n\" }}\n        {{- \"num_lines?: number, // default: -1\\n\" }}\n        {{- \"view_source?: boolean, // default: false\\n\" }}\n        {{- \"source?: string,\\n\" }}\n        {{- \"}) => any;\\n\\n\" }}\n        {{- \"// Finds exact matches of `pattern` in the current page, or the page given by `cursor`.\\n\" }}\n        {{- \"type find = (_: {\\n\" }}\n        {{- \"pattern: string,\\n\" }}\n        {{- \"cursor?: number, // default: -1\\n\" }}\n        {{- \"}) => any;\\n\\n\" }}\n        {{- \"} // namespace browser\\n\\n\" }}\n    {%- endif -%}\n\n    {%- if python_tool %}\n        {{- \"## python\\n\\n\" }}\n        {{- \"Use this tool to execute Python code in your chain of thought. The code will not be shown to the user. This tool should be used for internal reasoning, but not for code that is intended to be visible to the user (e.g. when creating plots, tables, or files).\\n\\n\" }}\n        {{- \"When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 120.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is UNKNOWN. Depends on the cluster.\\n\\n\" }}\n    {%- endif -%}\n{%- endmacro -%}\n\n{#- System Message Construction ============================================ #}\n{%- macro build_system_message() -%}\n    {%- if model_identity is not defined %}\n        {%- set model_identity = \"You are ChatGPT, a large language model trained by OpenAI.\" %}\n    {%- endif %}\n    {{- model_identity + \"\\n\" }}\n    {{- \"Knowledge cutoff: 2024-06\\n\" }}\n    {{- \"Current date: \" + strftime_now(\"%Y-%m-%d\") + \"\\n\\n\" }}\n    {%- if reasoning_effort is not defined %}\n        {%- set reasoning_effort = \"medium\" %}\n    {%- endif %}\n    {{- \"Reasoning: \" + reasoning_effort + \"\\n\\n\" }}\n    {%- if builtin_tools is defined and builtin_tools is not none %}\n        {{- \"# Tools\\n\\n\" }}\n        {%- set available_builtin_tools = namespace(browser=false, python=false) %}\n        {%- for tool in builtin_tools %}\n            {%- if tool == \"browser\" %}\n                {%- set available_builtin_tools.browser = true %}\n            {%- elif tool == \"python\" %}\n                {%- set available_builtin_tools.python = true %}\n            {%- endif %}\n        {%- endfor %}\n        {{- render_builtin_tools(available_builtin_tools.browser, available_builtin_tools.python) }}\n    {%- endif -%}\n    {{- \"# Valid channels: analysis, commentary, final. Channel must be included for every message.\" }}\n    {%- if tools -%}\n        {{- \"\\nCalls to these tools must go to the commentary channel: 'functions'.\" }}\n    {%- endif -%}\n{%- endmacro -%}\n\n{#- Main Template Logic ================================================= #}\n{#- Set defaults #}\n\n{#- Render system message #}\n{{- \"<|start|>system<|message|>\" }}\n{{- build_system_message() }}\n{{- \"<|end|>\" }}\n\n{#- Extract developer message #}\n{%- if developer_instructions is defined and developer_instructions is not none %}\n    {%- set developer_message = developer_instructions %}\n    {%- set loop_messages = messages %}\n{%- elif messages[0].role == \"developer\" or messages[0].role == \"system\" %}\n    {%- set developer_message = messages[0].content %}\n    {%- set loop_messages = messages[1:] %}\n{%- else %}\n    {%- set developer_message = \"\" %}\n    {%- set loop_messages = messages %}\n{%- endif %}\n\n{#- Render developer message #}\n{%- if developer_message or tools %}\n    {{- \"<|start|>developer<|message|>\" }}\n    {%- if developer_message %}\n        {{- \"# Instructions\\n\\n\" }}\n        {{- developer_message }}\n    {%- endif %}\n    {%- if tools -%}\n        {%- if developer_message %}\n            {{- \"\\n\\n\" }}\n        {%- endif %}\n        {{- \"# Tools\\n\\n\" }}\n        {{- render_tool_namespace(\"functions\", tools) }}\n    {%- endif -%}\n    {{- \"<|end|>\" }}\n{%- endif %}\n\n{#- Render messages #}\n{%- set last_tool_call = namespace(name=none) %}\n{%- for message in loop_messages -%}\n    {#- At this point only assistant/user/tool messages should remain #}\n    {%- if message.role == 'assistant' -%}\n        {#- Checks to ensure the messages are being passed in the format we expect #}\n        {%- if \"content\" in message %}\n            {%- if \"<|channel|>analysis<|message|>\" in message.content or \"<|channel|>final<|message|>\" in message.content %}\n                {{- raise_exception(\"You have passed a message containing <|channel|> tags in the content field. Instead of doing this, you should pass analysis messages (the string between '<|message|>' and '<|end|>') in the 'thinking' field, and final messages (the string between '<|message|>' and '<|end|>') in the 'content' field.\") }}\n            {%- endif %}\n        {%- endif %}\n        {%- if \"thinking\" in message %}\n            {%- if \"<|channel|>analysis<|message|>\" in message.thinking or \"<|channel|>final<|message|>\" in message.thinking %}\n                {{- raise_exception(\"You have passed a message containing <|channel|> tags in the thinking field. Instead of doing this, you should pass analysis messages (the string between '<|message|>' and '<|end|>') in the 'thinking' field, and final messages (the string between '<|message|>' and '<|end|>') in the 'content' field.\") }}\n            {%- endif %}\n        {%- endif %}\n        {%- if \"tool_calls\" in message %}\n            {#- We need very careful handling here - we want to drop the tool call analysis message if the model #}\n            {#- has output a later <|final|> message, but otherwise we want to retain it. This is the only case #}\n            {#- when we render CoT/analysis messages in inference. #}\n            {%- set future_final_message = namespace(found=false) %}\n            {%- for future_message in loop_messages[loop.index:] %}\n                {%- if future_message.role == 'assistant' and \"tool_calls\" not in future_message %}\n                    {%- set future_final_message.found = true %}\n                {%- endif %}\n            {%- endfor %}\n            {#- We assume max 1 tool call per message, and so we infer the tool call name #}\n            {#- in \"tool\" messages from the most recent assistant tool call name #}\n            {%- set tool_call = message.tool_calls[0] %}\n            {%- if tool_call.function %}\n                {%- set tool_call = tool_call.function %}\n            {%- endif %}\n            {%- if message.content and message.thinking %}\n                {{- raise_exception(\"Cannot pass both content and thinking in an assistant message with tool calls! Put the analysis message in one or the other, but not both.\") }}\n            {%- elif message.content and not future_final_message.found %}\n                {{- \"<|start|>assistant<|channel|>analysis<|message|>\" + message.content + \"<|end|>\" }}\n            {%- elif message.thinking and not future_final_message.found %}\n                {{- \"<|start|>assistant<|channel|>analysis<|message|>\" + message.thinking + \"<|end|>\" }}\n            {%- endif %}\n            {{- \"<|start|>assistant to=\" }}\n            {{- \"functions.\" + tool_call.name + \"<|channel|>commentary \" }}\n            {{- (tool_call.content_type if tool_call.content_type is defined else \"json\") + \"<|message|>\" }}\n            {%- if tool_call.arguments is string %}\n                {{- tool_call.arguments }}\n            {%- else %}\n                {{- tool_call.arguments|tojson }}\n            {%- endif %}\n            {{- \"<|call|>\" }}\n            {%- set last_tool_call.name = tool_call.name %}\n        {%- elif loop.last and not add_generation_prompt %}\n            {#- Only render the CoT if the final turn is an assistant turn and add_generation_prompt is false #}\n            {#- This is a situation that should only occur in training, never in inference. #}\n            {%- if \"thinking\" in message %}\n                {{- \"<|start|>assistant<|channel|>analysis<|message|>\" + message.thinking + \"<|end|>\" }}\n            {%- endif %}\n            {#- <|return|> indicates the end of generation, but <|end|> does not #}\n            {#- <|return|> should never be an input to the model, but we include it as the final token #}\n            {#- when training, so the model learns to emit it. #}\n            {{- \"<|start|>assistant<|channel|>final<|message|>\" + message.content + \"<|end|>\" }}\n        {%- elif \"thinking\" in message %}\n            {#- CoT is dropped during all previous turns, so we never render it for inference #}\n            {{- \"<|start|>assistant<|channel|>analysis<|message|>\" + message.content + \"<|end|>\" }}\n            {%- set last_tool_call.name = none %}\n        {%- else %}\n            {#- CoT is dropped during all previous turns, so we never render it for inference #}\n            {{- \"<|start|>assistant<|channel|>final<|message|>\" + message.content + \"<|end|>\" }}\n            {%- set last_tool_call.name = none %}\n        {%- endif %}\n    {%- elif message.role == 'tool' -%}\n        {%- if last_tool_call.name is none %}\n            {{- raise_exception(\"Message has tool role, but there was no previous assistant message with a tool call!\") }}\n        {%- endif %}\n        {{- \"<|start|>functions.\" + last_tool_call.name }}\n        {%- if message.content is string %}\n            {{- \" to=assistant<|channel|>commentary<|message|>\" + message.content + \"<|end|>\" }}\n        {%- else %}\n            {{- \" to=assistant<|channel|>commentary<|message|>\" + message.content|tojson + \"<|end|>\" }}\n        {%- endif %}\n    {%- elif message.role == 'user' -%}\n        {{- \"<|start|>user<|message|>\" + message.content + \"<|end|>\" }}\n    {%- endif -%}\n{%- endfor -%}\n\n{#- Generation prompt #}\n{%- if add_generation_prompt -%}\n<|start|>assistant\n{%- endif -%}\"\"\"\n\n# Ollama from https://ollama.com/library/gpt-oss\ngptoss_ollama = \\\n'''\nFROM {__FILE_LOCATION__}\nTEMPLATE \"\"\"<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.\nKnowledge cutoff: 2024-06\nCurrent date: {{ currentDate }}\n{{- if and .IsThinkSet .Think (ne .ThinkLevel \"\") }}\n\nReasoning: {{ .ThinkLevel }}\n{{- else if or (not .IsThinkSet) (and .IsThinkSet .Think) }}\n\nReasoning: medium\n{{- end }}\n\n{{- $hasNonBuiltinTools := false }}\n{{- if .Tools -}}\n{{- $hasBrowserSearch := false }}\n{{- $hasBrowserOpen := false }}\n{{- $hasBrowserFind := false }}\n{{- $hasPython := false }}\n  {{- range .Tools }}\n    {{- if eq .Function.Name \"browser.search\" -}}{{- $hasBrowserSearch = true -}}\n    {{- else if eq .Function.Name \"browser.open\" -}}{{- $hasBrowserOpen = true -}}\n    {{- else if eq .Function.Name \"browser.find\" -}}{{- $hasBrowserFind = true -}}\n    {{- else if eq .Function.Name \"python\" -}}{{- $hasPython = true -}}\n    {{- else }}{{ $hasNonBuiltinTools = true -}}\n    {{- end }}\n  {{- end }}\n{{- if or $hasBrowserSearch $hasBrowserOpen $hasBrowserFind $hasPython }}\n\n# Tools\n{{- if or $hasBrowserSearch $hasBrowserOpen $hasBrowserFind }}\n\n## browser\n\n// Tool for browsing.\n// The `cursor` appears in brackets before each browsing display: `[{cursor}]`.\n// Cite information from the tool using the following format:\n// `【{cursor}†L{line_start}(-L{line_end})?】`, for example: `【6†L9-L11】` or `【8†L3】`.\n// Do not quote more than 10 words directly from the tool output.\n// sources=web (default: web)\nnamespace browser {\n{{- if $hasBrowserSearch }}\n\n// Searches for information related to `query` and displays `topn` results.\ntype search = (_: {\nquery: string,\ntopn?: number, // default: 10\nsource?: string,\n}) => any;\n{{- end }}\n{{- if $hasBrowserOpen }}\n\n// Opens the link `id` from the page indicated by `cursor` starting at line number `loc`, showing `num_lines` lines.\n// Valid link ids are displayed with the formatting: `【{id}†.*】`.\n// If `cursor` is not provided, the most recent page is implied.\n// If `id` is a string, it is treated as a fully qualified URL associated with `source`.\n// If `loc` is not provided, the viewport will be positioned at the beginning of the document or centered on the most relevant passage, if available.\n// Use this function without `id` to scroll to a new location of an opened page.\ntype open = (_: {\nid?: number | string, // default: -1\ncursor?: number, // default: -1\nloc?: number, // default: -1\nnum_lines?: number, // default: -1\nview_source?: boolean, // default: false\nsource?: string,\n}) => any;\n{{- end }}\n{{- if $hasBrowserFind }}\n\n// Finds exact matches of `pattern` in the current page, or the page given by `cursor`.\ntype find = (_: {\npattern: string,\ncursor?: number, // default: -1\n}) => any;\n{{- end }}\n\n} // namespace browser\n{{- end }}{{/* end if has browser tools */}}\n{{- if $hasPython }}\n\n## python\n\nUse this tool to execute Python code in your chain of thought. The code will not be shown to the user. This tool should be used for internal reasoning, but not for code that is intended to be visible to the user (e.g. when creating plots, tables, or files).\n\nWhen you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 120.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is UNKNOWN. Depends on the cluster.\n{{- end }}{{/* end if hasPython */}}\n{{- end }}{{/* end if has any built-in tools */}}\n{{- end }}{{/* end if .Tools */}}\n\n# Valid channels: analysis, commentary, final. Channel must be included for every message.{{ if $hasNonBuiltinTools }}\nCalls to these tools must go to the commentary channel: 'functions'.\n{{- end -}}<|end|>{{/* end of system */ -}}\n{{- if or $hasNonBuiltinTools .System -}}\n<|start|>developer<|message|>{{- if $hasNonBuiltinTools }}# Tools\n\n## functions\n\nnamespace functions {\n{{- range .Tools }}\n{{- if not (or (eq .Function.Name \"browser.search\") (eq .Function.Name \"browser.open\") (eq .Function.Name \"browser.find\") (eq .Function.Name \"python\")) }}\n{{if .Function.Description }}\n// {{ .Function.Description }}\n{{- end }}\n{{- if and .Function.Parameters.Properties (gt (len .Function.Parameters.Properties) 0) }}\ntype {{ .Function.Name }} = (_: {\n{{- range $name, $prop := .Function.Parameters.Properties }}\n{{- if $prop.Description }}\n  // {{ $prop.Description }}\n{{- end }}\n  {{ $name }}: {{ if gt (len $prop.Type) 1 }}{{ range $i, $t := $prop.Type }}{{ if $i }} | {{ end }}{{ $t }}{{ end }}{{ else }}{{ index $prop.Type 0 }}{{ end }},\n{{- end }}\n}) => any;\n{{- else }}\ntype {{ .Function.Name }} = () => any;\n{{- end }}\n{{- end }}{{/* end if not browser tool */}}\n{{- end }}{{/* end of range .Tools */}}\n\n} // namespace functions\n{{- end }}{{/* end if hasNonBuiltinTools */}}\n{{- if .System}}\n\n# Instructions\n\n{{ .System }}\n{{- end -}}\n<|end|>\n{{- end -}}\n{{- /* Find the index of the last user message */ -}}\n{{- $lastUserIdx := -1 }}\n{{- $prefillingContent := false }}\n{{- $prefillingThinkingOnly := false }}\n{{- range $i, $msg := .Messages }}\n  {{- $last := eq (len (slice $.Messages $i)) 1 -}}\n  {{- if eq $msg.Role \"user\" }}\n    {{- $lastUserIdx = $i }}\n  {{- end -}}\n  {{- if and $last (eq $msg.Role \"assistant\") (gt (len $msg.Content) 0) }}\n    {{- $prefillingContent = true }}\n  {{- else if and $last (eq $msg.Role \"assistant\") (gt (len $msg.Thinking) 0) }}\n    {{- $prefillingThinkingOnly = true }}\n  {{- end }}\n{{- end -}}\n{{- /* Now render messages */ -}}\n{{- range $i, $msg := .Messages }}\n  {{- $last := eq (len (slice $.Messages $i)) 1 -}}\n  {{- if (ne $msg.Role \"system\") -}}\n    {{- if eq $msg.Role \"tool\" -}}\n      {{- if or (eq $msg.ToolName \"python\") (eq $msg.ToolName \"browser.search\") (eq $msg.ToolName \"browser.open\") (eq $msg.ToolName \"browser.find\") -}}\n        <|start|>{{ $msg.ToolName }} to=assistant<|message|>{{ $msg.Content }}<|end|>\n      {{- else -}}\n        <|start|>functions.{{ $msg.ToolName }} to=assistant<|message|>{{ $msg.Content }}<|end|>\n      {{- end -}}\n    {{- else if eq $msg.Role \"assistant\" -}}\n      {{- if and $msg.Thinking (gt $i $lastUserIdx) -}}{{- /* Show thinking only after last user message */ -}}\n      <|start|>assistant<|channel|>analysis<|message|>{{ $msg.Thinking }}{{- if not $prefillingThinkingOnly -}}<|end|>{{- end -}}\n      {{- end -}}\n      {{- if gt (len $msg.Content) 0 -}}\n        <|start|>assistant<|channel|>final<|message|>{{ $msg.Content }}{{- if not $prefillingContent -}}<|end|>{{- end -}}\n      {{- end -}}\n      {{- if gt (len $msg.ToolCalls) 0 -}}\n        {{- range $j, $toolCall := $msg.ToolCalls -}}\n          {{- $isBuiltin := or (eq $toolCall.Function.Name \"python\") (eq $toolCall.Function.Name \"browser.search\") (eq $toolCall.Function.Name \"browser.open\") (eq $toolCall.Function.Name \"browser.find\") -}}\n          <|start|>assistant<|channel|>{{ if $isBuiltin }}analysis{{ else }}commentary{{ end }} to={{ if not $isBuiltin}}functions.{{end}}{{ $toolCall.Function.Name }} <|constrain|>json<|message|>{{ $toolCall.Function.Arguments }}<|call|>\n        {{- end -}}\n      {{- end -}}\n    {{- else if eq $msg.Role \"user\" -}}\n      <|start|>{{ $msg.Role }}<|message|>{{ $msg.Content }}<|end|>\n    {{- end }}\n  {{- else }}\n  {{- end }}\n{{- end -}}\n{{- if not (or $prefillingContent $prefillingThinkingOnly) -}}\n<|start|>assistant\n{{- end -}}\"\"\"\nPARAMETER temperature 1.0\nPARAMETER top_k 0\nPARAMETER top_p 1.0\n'''\n\ngptoss_template_template_eos_token = \"<|return|>\"\nCHAT_TEMPLATES[\"gpt-oss\"] = (gptoss_template, gptoss_template_template_eos_token, False, gptoss_ollama,)\nDEFAULT_SYSTEM_MESSAGE[\"gpt-oss\"] = None # No system message in GPT-oss\n\nCHAT_TEMPLATES[\"gptoss\"] = (gptoss_template, gptoss_template_template_eos_token, False, gptoss_ollama,)\nDEFAULT_SYSTEM_MESSAGE[\"gptoss\"] = None # No system message in GPT-oss\n\n# =========================================== Qwen3-Instruct\nqwen3_instruct_template = \\\n'''{%- if tools %}\n    {{- '<|im_start|>system\\\\n' }}\n    {%- if messages[0].role == 'system' %}\n        {{- messages[0].content + '\\\\n\\\\n' }}\n    {%- endif %}\n    {{- \"# Tools\\\\n\\\\nYou may call one or more functions to assist with the user query.\\\\n\\\\nYou are provided with function signatures within <tools></tools> XML tags:\\\\n<tools>\" }}\n    {%- for tool in tools %}\n        {{- \"\\\\n\" }}\n        {{- tool | tojson }}\n    {%- endfor %}\n    {{- \"\\\\n</tools>\\\\n\\\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\\\n<tool_call>\\\\n{\\\\\"name\\\\\": <function-name>, \\\\\"arguments\\\\\": <args-json-object>}\\\\n</tool_call><|im_end|>\\\\n\" }}\n{%- else %}\n    {%- if messages[0].role == 'system' %}\n        {{- '<|im_start|>system\\\\n' + messages[0].content + '<|im_end|>\\\\n' }}\n    {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n    {%- set index = (messages|length - 1) - loop.index0 %}\n    {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n        {%- set ns.multi_step_tool = false %}\n        {%- set ns.last_query_index = index %}\n    {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n    {%- if message.content is string %}\n        {%- set content = message.content %}\n    {%- else %}\n        {%- set content = '' %}\n    {%- endif %}\n    {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n        {{- '<|im_start|>' + message.role + '\\\\n' + content + '<|im_end|>' + '\\\\n' }}\n    {%- elif message.role == \"assistant\" %}\n        {%- set reasoning_content = '' %}\n        {%- if message.reasoning_content is string %}\n            {%- set reasoning_content = message.reasoning_content %}\n        {%- else %}\n            {%- if '</think>' in content %}\n                {%- set reasoning_content = content.split('</think>')[0].rstrip('\\\\n').split('<think>')[-1].lstrip('\\\\n') %}\n                {%- set content = content.split('</think>')[-1].lstrip('\\\\n') %}\n            {%- endif %}\n        {%- endif %}\n        {%- if loop.index0 > ns.last_query_index %}\n            {%- if reasoning_content %}\n                {{- '<|im_start|>' + message.role + '\\\\n<think>\\\\n' + reasoning_content.strip('\\\\n') + '\\\\n</think>\\\\n\\\\n' + content.lstrip('\\\\n') }}\n            {%- else %}\n                {{- '<|im_start|>' + message.role + '\\\\n' + content }}\n            {%- endif %}\n        {%- else %}\n            {{- '<|im_start|>' + message.role + '\\\\n' + content }}\n        {%- endif %}\n        {%- if message.tool_calls %}\n            {%- for tool_call in message.tool_calls %}\n                {%- if (loop.first and content) or (not loop.first) %}\n                    {{- '\\\\n' }}\n                {%- endif %}\n                {%- if tool_call.function %}\n                    {%- set tool_call = tool_call.function %}\n                {%- endif %}\n                {{- '<tool_call>\\\\n{\"name\": \"' }}\n                {{- tool_call.name }}\n                {{- '\", \"arguments\": ' }}\n                {%- if tool_call.arguments is string %}\n                    {{- tool_call.arguments }}\n                {%- else %}\n                    {{- tool_call.arguments | tojson }}\n                {%- endif %}\n                {{- '}\\\\n</tool_call>' }}\n            {%- endfor %}\n        {%- endif %}\n        {{- '<|im_end|>\\\\n' }}\n    {%- elif message.role == \"tool\" %}\n        {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n            {{- '<|im_start|>user' }}\n        {%- endif %}\n        {{- '\\\\n<tool_response>\\\\n' }}\n        {{- content }}\n        {{- '\\\\n</tool_response>' }}\n        {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n            {{- '<|im_end|>\\\\n' }}\n        {%- endif %}\n    {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n    {{- '<|im_start|>assistant\\\\n' }}\n{%- endif %}'''\n\nqwen3_template_eos_token = \"<|im_end|>\"\nCHAT_TEMPLATES[\"qwen3-instruct\"] = (qwen3_instruct_template, qwen3_template_eos_token, False, _ollama_template(\"qwen3-instruct\"),)\nDEFAULT_SYSTEM_MESSAGE[\"qwen3-instruct\"] = None # No system message in Qwen3\n\n\n# =========================================== Qwen3-Thinking\nqwen3_thinking_template = \\\n'''{%- if tools %}\n    {{- '<|im_start|>system\\\\n' }}\n    {%- if messages[0].role == 'system' %}\n        {{- messages[0].content + '\\\\n\\\\n' }}\n    {%- endif %}\n    {{- \"# Tools\\\\n\\\\nYou may call one or more functions to assist with the user query.\\\\n\\\\nYou are provided with function signatures within <tools></tools> XML tags:\\\\n<tools>\" }}\n    {%- for tool in tools %}\n        {{- \"\\\\n\" }}\n        {{- tool | tojson }}\n    {%- endfor %}\n    {{- \"\\\\n</tools>\\\\n\\\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\\\n<tool_call>\\\\n{\\\\\"name\\\\\": <function-name>, \\\\\"arguments\\\\\": <args-json-object>}\\\\n</tool_call><|im_end|>\\\\n\" }}\n{%- else %}\n    {%- if messages[0].role == 'system' %}\n        {{- '<|im_start|>system\\\\n' + messages[0].content + '<|im_end|>\\\\n' }}\n    {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n    {%- set index = (messages|length - 1) - loop.index0 %}\n    {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n        {%- set ns.multi_step_tool = false %}\n        {%- set ns.last_query_index = index %}\n    {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n    {%- if message.content is string %}\n        {%- set content = message.content %}\n    {%- else %}\n        {%- set content = '' %}\n    {%- endif %}\n    {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n        {{- '<|im_start|>' + message.role + '\\\\n' + content + '<|im_end|>' + '\\\\n' }}\n    {%- elif message.role == \"assistant\" %}\n        {%- set reasoning_content = '' %}\n        {%- if message.reasoning_content is string %}\n            {%- set reasoning_content = message.reasoning_content %}\n        {%- else %}\n            {%- if '</think>' in content %}\n                {%- set reasoning_content = content.split('</think>')[0].rstrip('\\\\n').split('<think>')[-1].lstrip('\\\\n') %}\n                {%- set content = content.split('</think>')[-1].lstrip('\\\\n') %}\n            {%- endif %}\n        {%- endif %}\n        {%- if loop.index0 > ns.last_query_index %}\n            {%- if loop.last or (not loop.last and reasoning_content) %}\n                {{- '<|im_start|>' + message.role + '\\\\n<think>\\\\n' + reasoning_content.strip('\\\\n') + '\\\\n</think>\\\\n\\\\n' + content.lstrip('\\\\n') }}\n            {%- else %}\n                {{- '<|im_start|>' + message.role + '\\\\n' + content }}\n            {%- endif %}\n        {%- else %}\n            {{- '<|im_start|>' + message.role + '\\\\n' + content }}\n        {%- endif %}\n        {%- if message.tool_calls %}\n            {%- for tool_call in message.tool_calls %}\n                {%- if (loop.first and content) or (not loop.first) %}\n                    {{- '\\\\n' }}\n                {%- endif %}\n                {%- if tool_call.function %}\n                    {%- set tool_call = tool_call.function %}\n                {%- endif %}\n                {{- '<tool_call>\\\\n{\"name\": \"' }}\n                {{- tool_call.name }}\n                {{- '\", \"arguments\": ' }}\n                {%- if tool_call.arguments is string %}\n                    {{- tool_call.arguments }}\n                {%- else %}\n                    {{- tool_call.arguments | tojson }}\n                {%- endif %}\n                {{- '}\\\\n</tool_call>' }}\n            {%- endfor %}\n        {%- endif %}\n        {{- '<|im_end|>\\\\n' }}\n    {%- elif message.role == \"tool\" %}\n        {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n            {{- '<|im_start|>user' }}\n        {%- endif %}\n        {{- '\\\\n<tool_response>\\\\n' }}\n        {{- content }}\n        {{- '\\\\n</tool_response>' }}\n        {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n            {{- '<|im_end|>\\\\n' }}\n        {%- endif %}\n    {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n    {{- '<|im_start|>assistant\\n<think>\\n' }}\n{%- endif %}'''\n\nCHAT_TEMPLATES[\"qwen3-thinking\"] = (\n    qwen3_thinking_template,\n    qwen3_template_eos_token,\n    False,\n    _ollama_template(\"qwen3-thinking\"),\n)\nDEFAULT_SYSTEM_MESSAGE[\"qwen3-thinking\"] = None # No system message in Qwen3\n\n\n# =========================================== Liquid-LFM2\nliquid_lfm2_template = \\\n'''\n{{bos_token}}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}'''\n\nliquid_lfm2_template_eos_token = \"<|im_end|>\"\nCHAT_TEMPLATES[\"lfm-2\"] = (liquid_lfm2_template, liquid_lfm2_template_eos_token, False, None)\nDEFAULT_SYSTEM_MESSAGE[\"lfm-2\"] = None # No system message in Phi-3\n\n\n# =========================================== Starling-LM\n\nstarling_template = \\\n\"\"\"{{ bos_token }}\n{%- for message in messages %}\n    {{ 'GPT4 Correct ' + message['role'].title() + ': ' + message['content'] + '<|end_of_turn|>' }}\n{%- endfor %}\n{%- if add_generation_prompt %}\n    {{ 'GPT4 Correct Assistant:' }}\n{%- endif %}\"\"\"\n\n# Ollama from https://ollama.com/library/starling-lm:7b/blobs/4b21bfc435b4\nstarling_ollama = _ollama_template(\"starling\")\n\nstarling_template_eos_token = \"<|end_of_turn|>\"\nCHAT_TEMPLATES[\"starling\"] = (starling_template, starling_template_eos_token, False, starling_ollama)\nDEFAULT_SYSTEM_MESSAGE[\"starling\"] = None\n\n\n# =========================================== Yi-chat\n\nyi_chat_template = \\\n\"\"\"\n{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}\n\"\"\"\n\n# Ollama from https://ollama.com/library/yi:34b-chat/blobs/62fbfd9ed093\nyi_chat_ollama = _ollama_template(\"yi-chat\")\n\nyi_chat_template_eos_token = \"<|endoftext|>\"\nCHAT_TEMPLATES[\"yi-chat\"] = (yi_chat_template, yi_chat_template_eos_token, False, yi_chat_ollama)\nDEFAULT_SYSTEM_MESSAGE[\"yi-chat\"] = None\n\ndef _change_system_message(template: str, type_chat_template: str, system_message: str = None):\n    system_message_pattern = r\"\\{system_message\\}\"\n\n    # For predefined templates, check if default system message exists\n    default_system_message = DEFAULT_SYSTEM_MESSAGE.get(f\"{type_chat_template}\", None)\n    if default_system_message is None:\n        if system_message is not None:\n            logger.warning_once(\n                f\"Unsloth: You tried to change the system message for {type_chat_template}, \"\n                \"but it doesn't have a default system message. \"\n                \"You need to manually add the system message in your data.\"\n            )\n        return template, system_message\n\n    # For custom templates\n    if type_chat_template is None:\n        has_placeholder = re.search(system_message_pattern, template) is not None\n\n        if has_placeholder:\n            if system_message is None:\n                raise ValueError(\"Unsloth: You need to provide a system message for custom templates.\")\n            new_template = re.sub(system_message_pattern, system_message, template)\n            return new_template, system_message\n\n        return template, system_message\n\n    # For predefined templates with default system message\n    message_to_use = system_message if system_message is not None else default_system_message\n    new_template = re.sub(system_message_pattern, message_to_use, template)\n\n    return new_template, message_to_use\n\n\ndef get_chat_template(\n    tokenizer,\n    chat_template = \"chatml\",\n    mapping = {\"role\" : \"role\", \"content\" : \"content\", \"user\" : \"user\", \"assistant\" : \"assistant\"},\n    map_eos_token = True,\n    system_message = None,\n):\n    assert(type(map_eos_token) is bool)\n    old_tokenizer = tokenizer\n\n    IS_GEMMA = False\n    if tokenizer.__class__.__name__.startswith(\"Gemma\"):\n        if chat_template == \"chatml\": chat_template = \"gemma_chatml\"\n        IS_GEMMA = True\n\n    # We add a check for Llama-3\n    # if chat_template == \"llama-3\":\n    #     tokenizer._using_llama3_template = True\n    # else:\n    #     llama3_tokens = set([\"<|end_header_id|>\", \"<|eot_id|>\", \"<|start_header_id|>\"])\n    #     check_llama3_tokens = llama3_tokens & set(str(x) for x in tokenizer.added_tokens_decoder.values())\n    #     if len(check_llama3_tokens) == len(llama3_tokens):\n    #         tokenizer._using_llama3_template = True\n    #     pass\n    # pass\n\n    # We first check if the tokenizer is a fast one. If not, we cannot convert this!\n    is_fast_tokenizer = getattr(tokenizer, \"is_fast\", False)\n    old_padding_side = tokenizer.padding_side\n\n    same_padding_token = False\n    type_chat_template = None\n\n    if type(chat_template) in (list, tuple,):\n        # For changing system message later\n        # Since it's not supported yet, we will raise an error first!\n        type_chat_template = chat_template[0].lower()\n        chat_template, stop_word = chat_template\n        assert(type(chat_template) is str)\n        assert(type(stop_word) is str)\n        ollama_modelfile = None\n\n    elif type(chat_template) is str:\n        # For changing system message later\n        type_chat_template = chat_template.lower()\n\n        chat_template, stop_word, yes_map_eos_token, ollama_modelfile = CHAT_TEMPLATES[chat_template]\n\n        # Check mapping to eos_token\n        if not map_eos_token and yes_map_eos_token: map_eos_token = True\n        if not yes_map_eos_token and map_eos_token: map_eos_token = False\n\n        if type(stop_word) in (list, tuple,):\n            token_mapping, stop_word = stop_word\n            assert(type(token_mapping) is dict)\n        else:\n            token_mapping = None\n\n        assert(type(stop_word) is str)\n\n        # Check fast tokenizer\n        if not is_fast_tokenizer:\n            pass\n            # print(\n            #     \"Unsloth: Not a fast tokenizer, so can't process it as of yet :(\\n\"\\\n            #     \"Please log a Github issue if you want this as a new feature!\\n\"\\\n            #     \"Your chat template will still work, but it won't add or edit tokens.\"\n            # )\n\n        elif token_mapping is not None:\n            # token_mapping = {\"<start_of_turn>\" : \"<|im_start|>\", \"<end_of_turn>\" : \"<|im_end|>\"}\n            # For Gemma :)\n\n            string_vocab = tokenizer._tokenizer.to_str()\n\n            skipped = 0\n            for old_token, new_token in token_mapping.items():\n                old_count = string_vocab.count(f'\"{old_token}\"')\n                new_count = string_vocab.count(f'\"{new_token}\"')\n                if new_count != 0:\n                    print(f\"{new_token} is already a token. Skipping.\")\n                    skipped += 1\n                elif old_count == 0:\n                    raise RuntimeError(f\"{old_token} was not part of the tokenizer!\")\n                else:\n                    string_vocab = string_vocab.replace(f'\"{old_token}\"', f'\"{new_token}\"')\n                pass\n            pass\n\n            if map_eos_token and (not stop_word in token_mapping.values()):\n                # Do not map 107 = <|im_end|> and 1 = <|im_end|>. This will reduce the vocab size by 1\n                logger.warning_once(f\"Unsloth: Will map {stop_word} to EOS = {tokenizer.eos_token}.\")\n                string_vocab = string_vocab.replace(tokenizer.eos_token, stop_word)\n            pass\n\n            if skipped != len(token_mapping):\n                new_tokenizer = tokenizer._tokenizer.from_str(string_vocab)\n\n                # Careful on pad_token\n                old_pad_token = tokenizer.pad_token\n                if old_pad_token == tokenizer.eos_token:\n                    old_pad_token = stop_word\n                    same_padding_token = True\n                pass\n\n                if map_eos_token:\n                    new_tokenizer = tokenizer.__class__(\n                        tokenizer_object = new_tokenizer,\n                        eos_token = stop_word,\n                        pad_token = old_pad_token,\n                    )\n                else:\n                    new_tokenizer = tokenizer.__class__(\n                        tokenizer_object = new_tokenizer,\n                        pad_token = old_pad_token,\n                    )\n                pass\n\n                # Must fix the sentence piece tokenizer since there's no tokenizer.model file!\n                tokenizer = fix_sentencepiece_tokenizer(tokenizer, new_tokenizer, token_mapping,)\n            else:\n                pass\n\n        elif map_eos_token and (stop_word != \"eos_token\"):\n            logger.warning_once(f\"Unsloth: Will map {stop_word} to EOS = {tokenizer.eos_token}.\")\n\n            # Replaces the old EOS token with a new one.\n            # Useful for ChatML <|im_end|> for example.\n            # Usually we train 2 more tokens <|im_start|> and <|im_end|>\n            # But training the lm_head and embeddings are slow!\n            # This is a HACK!\n            # Idea from https://huggingface.co/cognitivecomputations/dolphin-2.6-mistral-7b-dpo-laser\n\n            old_bos_token = getattr(tokenizer, \"bos_token\", None)\n            old_eos_token = getattr(tokenizer, \"eos_token\", None)\n            old_pad_token = getattr(tokenizer, \"pad_token\", None)\n            old_unk_token = getattr(tokenizer, \"unk_token\", None)\n\n            string_vocab = tokenizer._tokenizer.to_str()\n            # First check if new stop_word is in the tokenizer\n            if stop_word in string_vocab:\n                # We shall swap them around\n                temporary_stop_token = \"<|:__TEMP//STOP//TOKEN__:|>\"\n                string_vocab = string_vocab.replace(old_eos_token, temporary_stop_token)\n                string_vocab = string_vocab.replace(stop_word, old_eos_token)\n                string_vocab = string_vocab.replace(temporary_stop_token, stop_word)\n            else:\n                string_vocab = string_vocab.replace(old_eos_token, stop_word)\n            pass\n            new_tokenizer = tokenizer._tokenizer.from_str(string_vocab)\n\n            # Careful on pad_token\n            if old_pad_token == old_eos_token:\n                old_pad_token = stop_word\n                same_padding_token = True\n            pass\n\n            new_tokenizer = tokenizer.__class__(\n                tokenizer_object = new_tokenizer,\n                bos_token = old_bos_token,\n                eos_token = stop_word,\n                unk_token = old_unk_token,\n                pad_token = old_pad_token,\n            )\n\n            # Must fix the sentence piece tokenizer since there's no tokenizer.model file!\n            token_mapping = { old_eos_token : stop_word, }\n            tokenizer = fix_sentencepiece_tokenizer(tokenizer, new_tokenizer, token_mapping,)\n        pass\n\n    else:\n        raise TypeError(\n            f\"Unsloth: `chat_template` must be a tuple of (your_template, eos_token,) or one of\\n\"\\\n            f\"{CHAT_TEMPLATES.keys()}\"\n        )\n\n    # Careful on Gemma\n    # bos_token is a must or else losses become too high\n    if IS_GEMMA and not chat_template.startswith((\"{{ bos_token }}\", \"{{- bos_token }}\")):\n        chat_template = \"{{ bos_token }}\" + chat_template\n\n    # For ShareGPT role -> from and content -> value\n    new_chat_template = chat_template\\\n        .replace(\"'role'\",      \"'\" + mapping[\"role\"]      + \"'\")\\\n        .replace(\"'content'\",   \"'\" + mapping[\"content\"]   + \"'\")\\\n        .replace(\"'user'\",      \"'\" + mapping[\"user\"]      + \"'\")\\\n        .replace(\"'assistant'\", \"'\" + mapping[\"assistant\"] + \"'\")\n\n    _, tokenizer = patch_tokenizer(model = None, tokenizer = tokenizer)\n    tokenizer.padding_side = old_padding_side\n\n    # If not normal HF, we add a check to make old templates work\n    if mapping != {\"role\" : \"role\", \"content\" : \"content\", \"user\" : \"user\", \"assistant\" : \"assistant\"}:\n        chat_template = \\\n            \"{% if 'role' in messages[0] %}\" + \\\n            chat_template + \\\n            \"{% else %}\" + \\\n            new_chat_template + \\\n            \"{% endif %}\"\n    else:\n        chat_template = new_chat_template\n\n    chat_template, system_message = _change_system_message(chat_template, type_chat_template, system_message)\n\n    tokenizer.chat_template = chat_template\n\n    # Also fix up other tokens\n    old_pad_token = getattr(old_tokenizer, \"pad_token\", None)\n    old_bos_token = getattr(old_tokenizer, \"bos_token\", None)\n    old_unk_token = getattr(old_tokenizer, \"unk_token\", None)\n    new_pad_token = getattr(tokenizer,     \"pad_token\", None)\n    new_bos_token = getattr(tokenizer,     \"bos_token\", None)\n    new_unk_token = getattr(tokenizer,     \"unk_token\", None)\n    if old_bos_token != new_bos_token: tokenizer.bos_token = old_bos_token\n    if old_unk_token != new_unk_token: tokenizer.unk_token = old_unk_token\n    if not same_padding_token:\n        if old_pad_token != new_pad_token: tokenizer.pad_token = old_pad_token\n\n    # stopping_criteria = create_stopping_criteria(tokenizer, stop_word)\n\n    # Patch saving functions\n    tokenizer = patch_saving_functions(tokenizer)\n\n    # Add Ollama\n    tokenizer._ollama_modelfile = ollama_modelfile\n    tokenizer._system_message   = system_message\n    return tokenizer#, stopping_criteria\n\n\ndef remove_special_tokens(tokenizer, prompt):\n    # Removes double BOS token\n    if prompt.startswith(tokenizer.bos_token):\n        prompt = prompt[len(tokenizer.bos_token):]\n    return prompt\n\n\ndef _parse_combined_prompt(combined_prompt, dataset):\n    # Find {...}\n    possible_columns = re.findall(r\"\\{(.+?)\\}\", combined_prompt)\n    dataset_columns = set(dataset.column_names)\n    for column in possible_columns:\n        if column not in dataset_columns:\n            raise KeyError(\n                f\"Unsloth: Your prompt includes '{column}' but this does not exist in the dataset. \"\\\n                f\"Only allowed columns are {list(dataset_columns)}\"\n            )\n\n    # Find [[...]]\n    optional_prompts = list(re.finditer(r\"\\[\\[.+?\\]\\]\", combined_prompt, flags = re.DOTALL | re.MULTILINE))\n    optional_prompts = [(x.span(), x.group(0)) for x in optional_prompts]\n\n    final_optional_prompts = []\n    if len(optional_prompts) != 0:\n        # Add left\n        left = optional_prompts[0]\n        l = left[0][0]\n        if l != 0: final_optional_prompts.append(combined_prompt[:l])\n\n        # Add in between\n        for left, right in zip(optional_prompts[:-1], optional_prompts[1:]):\n            l, r = left[0][-1], right[0][0]\n            final_optional_prompts.append(left)\n            if l != r: final_optional_prompts.append(combined_prompt[l : r])\n        final_optional_prompts.append(optional_prompts[-1])\n\n        # Add right\n        right = optional_prompts[-1]\n        r = right[0][1]\n        if r != len(combined_prompt): final_optional_prompts.append(combined_prompt[r:])\n    else:\n        # Just add in the entire string\n        final_optional_prompts.append(combined_prompt)\n\n    check_combined = \"\".join(x if type(x) is str else x[1] for x in final_optional_prompts)\n    assert(combined_prompt == check_combined)\n\n    return possible_columns, final_optional_prompts\n\n\ndef _create_formatter(possible_columns, final_optional_prompts, user_column_name):\n    columns = list(dict.fromkeys(possible_columns))\n    merged_prompt_parts = []\n    formatter_templates = []\n\n    for j, optional_prompt in enumerate(final_optional_prompts):\n        if type(optional_prompt) is str:\n            needed_columns = re.findall(r\"\\{(.+?)\\}\", optional_prompt)\n            formatter_templates.append((\"required\", optional_prompt, needed_columns))\n            merged_prompt_parts.append(optional_prompt)\n            continue\n\n        _, prompt = optional_prompt\n        prompt = prompt[2:-2]\n        needed_columns = re.findall(r\"\\{(.+?)\\}\", prompt)\n        if len(needed_columns) == 0:\n            raise IndexError(\"Unsloth: Optional [[...]] blocks must contain at least 1 {column}.\")\n        optional_name = f\"__optional_{j}__\"\n        formatter_templates.append((\"optional\", optional_name, prompt, needed_columns))\n        merged_prompt_parts.append(\"{\" + optional_name + \"}\")\n\n    merged_prompt = \"\".join(merged_prompt_parts)\n\n    def __combined_prompt_processor__(examples):\n        if len(examples) == 0:\n            return {user_column_name: []}\n\n        first_key = next(iter(examples.keys()), None)\n        if first_key is None:\n            return {user_column_name: []}\n        n_rows = len(examples[first_key])\n\n        texts = []\n        for row_idx in range(n_rows):\n            row_values = {column: examples[column][row_idx] for column in columns}\n            formatter_values = {}\n\n            for formatter_template in formatter_templates:\n                if formatter_template[0] == \"required\":\n                    _, _, needed_columns = formatter_template\n                    for column in needed_columns:\n                        formatter_values[column] = row_values[column]\n                    continue\n\n                _, optional_name, prompt, needed_columns = formatter_template\n                if row_values[needed_columns[0]] not in (None, \"\"):\n                    prompt_values = {column: row_values[column] for column in needed_columns}\n                    formatter_values[optional_name] = prompt.format(**prompt_values)\n                else:\n                    formatter_values[optional_name] = \"\"\n\n            texts.append(merged_prompt.format(**formatter_values))\n\n        return {user_column_name: texts}\n\n    return __combined_prompt_processor__\n\n\ndef to_sharegpt(\n    dataset,\n    merged_prompt = \"\",\n    merged_column_name = \"instruction\",\n    output_column_name = \"output\",\n    remove_unused_columns = True,\n    conversation_extension = 1,\n    random_state = 3407,\n):\n    \"\"\"\n    Converts a dataset to ShareGPT style.\n    ShareGPT requires only 1 input and 1 output field.\n    This means one has to merge multiple columns into 1 for 1 input field.\n    Use `conversation_extension` to increase the length of each conversation by randomnly\n    selecting a few and packing them into 1.\n\n    merged_prompt = \"\",                 Prompt to merge columns into 1 input\n    merged_column_name = \"instruction\", Final column name for the input  field\n    output_column_name = \"output\",      Final column name for the output field\n    remove_unused_columns = True,\n    conversation_extension = 1,         Automatically combines `conversation_extension` convos into 1\n    random_state = 3407,\n    \"\"\"\n    if \"conversations\" in dataset.column_names:\n        convo = dataset[0][\"conversations\"]\n        if type(convo) is list:\n            raise TypeError(\"Unsloth: Your dataset is probably already in ShareGPT format!\")\n\n    possible_columns, final_optional_prompts = _parse_combined_prompt(merged_prompt, dataset)\n    formatter = _create_formatter(possible_columns, final_optional_prompts, merged_column_name)\n    dataset = dataset.map(formatter, batched = True, desc = \"Merging columns\")\n\n    def __convert_to_sharegpt__(examples):\n        users      = examples[merged_column_name]\n        assistants = examples[output_column_name]\n        if len(users) != len(assistants):\n            raise ValueError(\n                \"Unsloth: Input and output columns must have matching batch lengths. \"\n                f\"Got {len(users)} {merged_column_name} rows and {len(assistants)} {output_column_name} rows.\"\n            )\n        texts = [\n            [\n                {\"from\" : \"human\", \"value\" : str(user)     },\n                {\"from\" : \"gpt\",   \"value\" : str(assistant)},\n            ] \\\n            for user, assistant in zip(users, assistants)\n        ]\n        return { \"conversations\" : texts, }\n\n    dataset = dataset.map(\n        __convert_to_sharegpt__,\n        batched = True,\n        desc = \"Converting to ShareGPT\",\n        # Remove unused columns!\n        remove_columns = dataset.column_names if remove_unused_columns else None,\n    )\n\n    # Randomnly concat conversations to create a long stream!\n    from datasets import concatenate_datasets\n    n_extensions = max(conversation_extension-1, 0)\n    if n_extensions == 0: return dataset\n\n    dataset = dataset.rename_columns({\"conversations\" : \"conversations0\"})\n    all_shuffled = [dataset]\n    for j in range(1, n_extensions+1):\n        shuffled = dataset.shuffle(seed = random_state+j).rename_columns({\"conversations0\" : f\"conversations{j}\"})\n        all_shuffled.append(shuffled)\n    dataset = concatenate_datasets(all_shuffled, axis = 1)\n\n    # Combine them into 1\n    n_extensions += 1\n    conversation_columns = [f\"conversations{j}\" for j in range(n_extensions)]\n    def __combine_conversations__(examples):\n        columns = [examples[column] for column in conversation_columns]\n        convos = []\n        for conversations in zip(*columns):\n            merged_conversation = []\n            for conversation in conversations:\n                merged_conversation.extend(conversation)\n            convos.append(merged_conversation)\n        return {\"conversations\" : convos}\n\n    dataset = dataset.map(\n        __combine_conversations__,\n        batched = True,\n        desc = \"Extending conversations\",\n        # Remove unused columns!\n        remove_columns = dataset.column_names if remove_unused_columns else None,\n    )\n    return dataset\n\n\ndef get_ollama_eos_tokens(tokenizer, extra_eos_tokens = []):\n    added_tokens_decoder = tokenizer.added_tokens_decoder.values()\n    added_tokens_decoder = [str(x) for x in added_tokens_decoder]\n\n    # Remove added_tokens_decoder duplicates\n    added_tokens_decoder = list(set(added_tokens_decoder) - set(extra_eos_tokens))\n\n    # Remove BOS\n    if getattr(tokenizer, \"bos_token\", None) is not None:\n        added_tokens_decoder = [x for x in added_tokens_decoder if x != tokenizer.bos_token]\n\n    repeatted_tokens = []\n    # Join all vocab\n    joined_text = \"\\x01\\x00\".join(added_tokens_decoder)\n    for token in added_tokens_decoder:\n        n = len(token)\n        repeatted_counts = joined_text.count(token[:n//2])\n        # Try finding longer than 1/2 of the token in the rest\n        # For eg <|reserved_special_token_0|>, <|reserved_special_token_1|>\n        if repeatted_counts > 2:\n            for j in range(n//2+1, n):\n                if joined_text.count(token[:j]) < repeatted_counts:\n                    j -= 1\n                    # Remove repeatted tokens to reduce search space\n                    joined_text = joined_text.replace(token[:j], \"\")\n                    repeatted_tokens.append(token[:j])\n                    break\n\n    # Remove duplicates\n    splitted = joined_text.split(\"\\x01\\x00\")\n    final_eos_tokens = [old for old, new in zip(added_tokens_decoder, splitted) if old == new]\n    final_eos_tokens += extra_eos_tokens\n    final_eos_tokens += repeatted_tokens\n\n    # Remove new lines, spaces and HTML tags\n    filtered_eos_tokens = []\n    for token in final_eos_tokens:\n        if   token.count(\"\\n\") == len(token): continue\n        elif token.count(\"▁\") == len(token): continue\n        elif token.startswith(\"<\") and len(token) <= 2: continue\n        elif token.startswith(\"</\") and len(token) == 3: continue\n        filtered_eos_tokens.append(token)\n    return filtered_eos_tokens\n\n\ndef construct_chat_template( \\\n\ntokenizer = None,\n\nchat_template = \"\"\"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{SYSTEM}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{OUTPUT}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{OUTPUT}<|eot_id|>\"\"\",\n\ndefault_system_message = \\\n    \"Below are some instructions that describe some tasks. Write responses that appropriately complete each request.\",\n\nextra_eos_tokens = None,\n):\n    \"\"\"\n    Creates an Ollama modelfile and a HF Jinja template from a custom\n    template. You must provide 2x examples of an input & output.\n    There is an optional system message as well.\n\n    You must use {INPUT}, {OUTPUT} twice, and {SYSTEM} is optional.\n    \"\"\"\n    # Strip only the left\n    chat_template = chat_template.lstrip()\n\n    assert(tokenizer is not None)\n\n    if extra_eos_tokens is None: extra_eos_tokens = []\n    elif type(extra_eos_tokens) is str: extra_eos_tokens = [extra_eos_tokens,]\n\n    vocab = tokenizer.get_vocab()\n    for extra_eos in extra_eos_tokens:\n        assert(type(extra_eos) is str)\n        if extra_eos not in vocab:\n            raise ValueError(f\"Unsloth: `{extra_eos}` is not a singular token in the tokenizer.\")\n\n    error_msg = \\\n        \"Unsloth: Your prompt template must have 2 examples showing the user input {INPUT} \"\\\n        \"and the assistant output {OUTPUT}\\n\\n\"\\\n        \"For example what is not allowed is just:\\n\"\\\n        \"### Input:\\\\n{INPUT}\\\\n\\\\n### Response:\\\\n{OUTPUT}\\\\n\\n\\n\"\\\n        \"What is required is 2x of this:\\n\"\\\n        \"### Input:\\\\n{INPUT}\\\\n\\\\n### Response:\\\\n{OUTPUT}\\\\n\"\\\n        \"### Input:\\\\n{INPUT}\\\\n\\\\n### Response:\\\\n{OUTPUT}\\\\n\"\n\n    # Check for EOS after {OUTPUT}\n    if tokenizer.eos_token is not None:\n        extra_eos_tokens.insert(0, tokenizer.eos_token)\n    if len(extra_eos_tokens) == 0:\n        raise RuntimeError(\n            \"Unsloth: Your tokenizer does not have an EOS token? Please provide one via extra_eos_tokens!\"\n        )\n\n    # Check tokenizer types\n    tokenizer_name = tokenizer.name_or_path.lower()\n    if tokenizer_name.startswith((\"unsloth/llama-3-8b-instruct\", \"unsloth/llama-3-70b-instruct\")):\n        # Add <|eot_id|>\n        extra_eos_tokens.append(\"<|eot_id|>\")\n    elif (\"<|eot_id|>\" in extra_eos_tokens or \"<|eot_id|>\" in chat_template) and \\\n        tokenizer_name.startswith((\"unsloth/llama-3-8b\", \"unsloth/llama-3-70b\")):\n        # Warn\n        logger.warning(\n            \"Unsloth: Base llama-3 models did not train <|eot_id|>.\\n\"\\\n            \"Please use the instruct version or use <|end_of_text|>\"\n        )\n    extra_eos_tokens = list(set(extra_eos_tokens))\n\n    count_eos = 0\n    for eos in extra_eos_tokens:\n        count_eos += len(re.findall(r\"{OUTPUT}\" + re.escape(eos), chat_template))\n\n    # This forces you to provide 2 input and outputs\n    final_combined_check = False\n\n    try:\n        # O(N^2) search finding 2 repeatted pieces of text\n        j = len(chat_template)-1\n        at_least_one = False\n        while j > 0:\n            found = chat_template.rfind(chat_template[j:], 0, j)\n            if found == -1: break\n            j -= 1\n            at_least_one = True\n        if j > 0: j += 1\n        else: raise RuntimeError(error_msg)\n\n        if not at_least_one: raise RuntimeError(error_msg)\n\n        # Must be equivalent to left\n        final_combined_check = True\n\n        # Repeatted text\n        instruction_response = chat_template[j:]\n        if instruction_response.count(\"{INPUT}\") != 1 or instruction_response.count(\"{OUTPUT}\") != 1:\n            raise RuntimeError(error_msg)\n\n        # 1st System, Instruction, Output pair\n        left  = chat_template[:j]\n        # 2nd Instruction, Output pair\n        right = chat_template[j:]\n\n        final_combined_check = left if final_combined_check else chat_template\n\n        # Isolate input\n        extra_eos_tokens_regex = \"|\".join(f\"(?:{re.escape(x)})\" for x in extra_eos_tokens)\n        if len(extra_eos_tokens_regex) != 0:\n            find_end = f\"(?:{extra_eos_tokens_regex})?\"\n        else:\n            find_end = \"\"\n        find_end = r\"\\{INPUT\\}[\\s\\n]{0,}\" + find_end\n        input_end = list(re.finditer(find_end, right))\n        assert(len(input_end) == 1)\n        input_end = input_end[0]\n        input_end = input_end.span(0)[1]\n        input_part = right[:input_end]\n\n        # Isolate output\n        output_part = right[input_end:]\n\n        # Isolate system\n        where_system = left.find(input_part)\n        system_part = left[:where_system if where_system != -1 else len(left)]\n\n        # Check if the user provided a correct prompt\n        combined = system_part + input_part + output_part\n        if combined != final_combined_check:\n            combined_changed = combined            .replace('\\n', '\\\\n')\n            left_changed     = final_combined_check.replace('\\n', '\\\\n')\n            raise RuntimeError(\n                \"Unsloth: The prompt template you provided isn't correct. You gave:\\n\"\\\n                f\"{combined_changed}\\n\\n\"\\\n                \"But we require the following:\\n\"\\\n                f\"{left_changed}\"\n            )\n    except:\n        ending = chat_template[chat_template.find(\"{OUTPUT}\") + len(\"{OUTPUT}\"):]\n\n        ending = re.escape(ending)\n        find_text = \"{INPUT}\" + ending + \"(.+?{OUTPUT}\" + ending + \")\"\n        response_part = re.findall(find_text, chat_template, flags = re.DOTALL | re.MULTILINE)\n        response_part = response_part[0]\n\n        for j in range(1, len(response_part)):\n            try_find = re.escape(response_part[:j])\n            try: found = next(re.finditer(\"(\" + try_find + \").+?\\\\{INPUT\\\\}\", chat_template, flags = re.DOTALL | re.MULTILINE))\n            except: break\n        separator = found.group(1)\n\n        response_start = chat_template.find(response_part)\n        start_instruction = chat_template[:response_start].rfind(separator)\n        if start_instruction == -1: start_instruction = 0\n        instruction_part = chat_template[start_instruction:response_start]\n\n        combined = instruction_part + response_part\n        where = chat_template.find(combined)\n        system_part = chat_template[:where]\n\n        system_part, input_part, output_part = system_part, instruction_part, response_part\n\n    if count_eos == 0:\n        logger.warning(\"Unsloth: We automatically added an EOS token to stop endless generations.\")\n        eos = extra_eos_tokens[0]\n        output_part = output_part + eos\n\n    # Ollama modelfile parts\n\n    # Check bos_token is in system prompt\n    ollama_system = system_part\n    has_bos_token = False\n    always_bos_token = False\n    if tokenizer(\"A\").input_ids[0] == getattr(tokenizer, \"bos_token_id\", None):\n        always_bos_token = True\n        if ollama_system.startswith(tokenizer.bos_token):\n            has_bos_token = True\n            ollama_system = ollama_system[len(tokenizer.bos_token):]\n    # Check system\n    if \"{SYSTEM}\" in ollama_system:\n        system_modelfile = \"{{ if .System }}\" + ollama_system.replace(\"{SYSTEM}\", \"{{ .System }}\") + \"{{ end }}\"\n    else:\n        system_modelfile = ollama_system\n    input_modelfile  = \"{{ if .Prompt }}\" + input_part .replace(\"{INPUT}\",  \"{{ .Prompt }}\") + \"{{ end }}\"\n    output_modelfile = output_part.replace(\"{OUTPUT}\", \"{{ .Response }}\")\n\n    # Ollama EOS\n    ollama_eos = get_ollama_eos_tokens(tokenizer, extra_eos_tokens)\n    ollama_eos = '\\n'.join(f'PARAMETER stop \"{eos}\"' for eos in ollama_eos)\n\n    # Add temperature and min_p to counteract gibberish\n    ollama_eos += \"\\nPARAMETER temperature 1.5\\nPARAMETER min_p 0.1\"\n\n    # Ollama modelfile\n    part = '\"\"\"'\n    modelfile = 'FROM {__FILE_LOCATION__}\\n\\n'\\\n    'TEMPLATE ' + part + system_modelfile + input_modelfile + output_modelfile + \\\n        part + '\\n\\n' + ollama_eos\n\n    # HF Jinja Chat template\n    def process(part, which, content = \"message['content']\"):\n        if part.endswith(which):\n            part = \"'\" + part[:part.find(which)] + f\"' + {content}\"\n        elif part.startswith(which):\n            part = f\"{content} + '\" + part[part.find(which):] + \"'\"\n        else:\n            part = \"'\" + part.replace(which, f\"' + {content} + '\") + \"'\"\n        if part.startswith(\"'' + \"): part = part[5:]\n        return part\n    input_jinja  = process(input_part,  \"{INPUT}\")\n    output_jinja = process(output_part, \"{OUTPUT}\")\n\n    jinja_template = \\\n        \"{% for message in loop_messages %}\"\\\n            \"{% if message['role'] == 'user' %}\"\\\n                \"{{ \" + input_jinja + \" }}\"\\\n            \"{% elif message['role'] == 'assistant' %}\"\\\n                \"{{ \" + output_jinja + \" }}\"\\\n            \"{% else %}\"\\\n                \"{{ raise_exception('Only user and assistant roles are supported!') }}\"\\\n            \"{% endif %}\"\\\n        \"{% endfor %}\"\\\n        \"{% if add_generation_prompt %}\"\\\n            \"{{ '\" + output_part[:output_part.find(\"{OUTPUT}\")] + \"' }}\"\\\n        \"{% endif %}\"\n\n    # Now add system prompt to jinja\n    if len(system_part) != 0:\n        partial_system = process(system_part, \"{SYSTEM}\", \"messages[0]['content']\")\n        partial_system = partial_system.replace(\"{SYSTEM}\", \"\")\n\n        if \"{SYSTEM}\" in partial_system:\n            if default_system_message is None:\n                raise RuntimeError(\"Unsloth: Please specify a default system message!\")\n\n        # Separate the BOS\n        if has_bos_token:\n            partial_system = partial_system.replace(tokenizer.bos_token, \"\", 1)\n            system_part    = system_part   .replace(tokenizer.bos_token, \"\", 1)\n\n        partial_system = \\\n            \"{% if messages[0]['role'] == 'system' %}\"\\\n                \"{{ \" + partial_system + \" }}\"\\\n                \"{% set loop_messages = messages[1:] %}\"\n        if default_system_message is not None:\n            full_system = system_part.replace(\"{SYSTEM}\", default_system_message)\n            if \"{SYSTEM}\" in system_part:\n                modelfile += '\\nSYSTEM \"' + default_system_message + '\"'\n            partial_system += \"{% else %}\"\\\n                \"{{ '\" + full_system + \"' }}\"\\\n                \"{% set loop_messages = messages %}\"\\\n            \"{% endif %}\"\n        else:\n            partial_system += \"{% endif %}\"\n\n        jinja_template = partial_system + jinja_template\n\n        if has_bos_token:\n            jinja_template = \"{{ bos_token }}\" + jinja_template\n\n    # Fix missing loop_messages\n    if \"{% set loop_messages = messages %}\" not in jinja_template:\n        jinja_template = jinja_template.replace(\n            \"{% for message in loop_messages %}\",\n            \"{% for message in messages %}\",\n            1, # Only replace the first one\n        )\n\n    # Check if system part is the same!\n    jinja_template = re.sub(\n        r\"\\{\\% if messages\\[0\\]\\['role'\\] \\=\\= 'system' \\%\\}\\{\\{ '(.+?)' \\}\\}\"\\\n        r\"\\{\\% set loop\\_messages \\= messages\\[1\\:\\] \\%\\}\"\\\n        r\"\\{\\% else \\%\\}\\{\\{ '\\1' \\}\\}\\{\\% set loop\\_messages \\= messages \\%\\}\\{\\% endif \\%\\}\"\\\n        r\"\\{\\% for message in loop\\_messages \\%\\}\",\n        r\"{{ '\\1' }}{% for message in messages %}\",\n        jinja_template, flags = re.MULTILINE | re.DOTALL,\n    )\n\n    # Check jinja template for bos\n    if always_bos_token:\n        if not jinja_template.startswith((\"{{ bos_token }}\", \"{{- bos_token }}\")):\n            jinja_template = \"{{ bos_token }}\" + jinja_template\n\n    # Get instruction and output parts for train_on_inputs = False\n    input_part  = input_part [:input_part .find(\"{INPUT}\")]\n    output_part = output_part[:output_part.find(\"{OUTPUT}\")]\n    return modelfile, jinja_template, input_part, output_part\n\n\ndef test_construct_chat_template():\n    token = \"hf_\"\n    from transformers import AutoTokenizer\n    tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Meta-Llama-3-8B-Instruct\", token = token)\n\n    chat_template = \"\"\"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{SYSTEM}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{OUTPUT}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{OUTPUT}<|eot_id|>\"\"\"\n\n    default_system_message = \\\n        \"Below are some instructions that describe some tasks. Write responses that appropriately complete each request.\"\n\n    extra_eos_tokens = None\n\n    modelfile, jinja_template, _, _ = construct_chat_template(\n        tokenizer = tokenizer,\n        chat_template = chat_template,\n        extra_eos_tokens = extra_eos_tokens,\n    )\n\n    messages = [\n        {\"role\": \"system\", \"content\": \"You are an assistant\"},\n        {\"role\": \"user\", \"content\": \"What is 2+2?\"},\n        {\"role\": \"assistant\", \"content\": \"It's 4.\"},\n        {\"role\": \"user\", \"content\": \"Ok!\"},\n        {\"role\": \"assistant\", \"content\": \"Anything else?\"},\n        {\"role\": \"user\", \"content\": \"What's 2x2?\"},\n    ]\n    correct_output = tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)\n\n    tokenizer.chat_template = jinja_template\n    new_output = tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)\n    assert(correct_output == new_output)\n\n\ndef apply_chat_template( \\\n\ndataset,\ntokenizer = None,\n\nchat_template = \"\"\"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{SYSTEM}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{OUTPUT}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{OUTPUT}<|eot_id|>\"\"\",\n\ndefault_system_message = \\\n    \"Below are some instructions that describe some tasks. Write responses that appropriately complete each request.\",\n\nextra_eos_tokens = None,\n\n):\n    \"\"\"\n    Creates an Ollama modelfile and a HF Jinja template from a custom\n    template. You must provide 2x examples of an input & output.\n    There is an optional system message as well.\n\n    You must use {INPUT}, {OUTPUT} twice, and {SYSTEM} is optional.\n    \"\"\"\n    modelfile, jinja_template, input_part, output_part = construct_chat_template(\n        tokenizer = tokenizer,\n        chat_template = chat_template,\n        default_system_message = default_system_message,\n        extra_eos_tokens = extra_eos_tokens,\n    )\n    def formatting_prompts_func(examples):\n        convos = examples[\"conversations\"]\n        texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]\n        return { \"text\" : texts, }\n\n    tokenizer.chat_template = jinja_template\n    tokenizer._ollama_modelfile = modelfile\n    tokenizer._unsloth_input_part  = input_part\n    tokenizer._unsloth_output_part = output_part\n    if hasattr(tokenizer, \"tokenizer\"):\n        tokenizer.tokenizer.chat_template = jinja_template\n        tokenizer.tokenizer._ollama_modelfile = modelfile\n        tokenizer.tokenizer._unsloth_input_part  = input_part\n        tokenizer.tokenizer._unsloth_output_part = output_part\n\n    return dataset.map(formatting_prompts_func, batched = True,)\n\n\ndef create_stopping_criteria(tokenizer, stop_word = \"eos_token\"):\n    class StoppingCriteriaSub(StoppingCriteria):\n        __slots__ = \"stop_token\", \"single_match\", \"length\",\n\n        def __init__(self, stops = \"eos_token\", device = \"cuda\", encounters = 1):\n            super().__init__()\n            if stops == \"eos_token\":\n                self.stop_token = torch.tensor(tokenizer.eos_token_id, device = \"cuda\")\n                self.length = 1\n            else:\n                self.stop_token = tokenizer([\"\\n\" + stops], add_special_tokens = False, return_tensors = \"pt\")\n                self.stop_token = self.stop_token.input_ids.ravel()[1:].to(\"cuda\")\n                self.length = self.stop_token.shape[0]\n            self.single_match = self.length == 1\n\n        def __call__(self, input_ids: LongTensor, scores: FloatTensor) -> bool:\n            input_ids = input_ids.ravel()\n            last_token = input_ids[-1]\n            if self.single_match and (last_token == self.stop_token): return True\n\n            if input_ids.shape[0] >= self.length and \\\n                (input_ids[-self.length:] == self.stop_token).all(): return True\n            return False\n    stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops = stop_word)])\n    return stopping_criteria\n\n\ndef test_chat_templates():\n    messages = [\n        {\"role\": \"system\",\"content\": \" You are a friendly chatbot.\",},\n        {\"role\": \"user\", \"content\": \"What is 2+2?\"},\n        {\"role\": \"assistant\", \"content\": \"It's 4.\"},\n        {\"role\": \"user\", \"content\": \"  But 2+2 is equal to 5. \"},\n        {\"role\": \"assistant\", \"content\": \"No I'm sure its 4.\"},\n        {\"role\": \"user\", \"content\": \"  No it's 100% 5! \"},\n    ]\n\n    # Zephyr\n    from transformers import AutoTokenizer\n    template = zephyr_template\n    correct_tokenizer = AutoTokenizer.from_pretrained(\"HuggingFaceH4/zephyr-7b-beta\")\n    correct_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)\n    correct_tokenizer.chat_template = template\n    our_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)\n    assert(correct_prompt == our_prompt)\n\n    # Chatml\n    template = chatml_template\n    correct_tokenizer = AutoTokenizer.from_pretrained(\"teknium/OpenHermes-2.5-Mistral-7B\")\n    correct_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)\n    correct_tokenizer.chat_template = template\n    our_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)\n    assert(correct_prompt == our_prompt)\n\n    # Mistral\n    template = mistral_template\n    correct_tokenizer = AutoTokenizer.from_pretrained(\"mistralai/Mistral-7B-Instruct-v0.2\")\n    correct_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)\n    correct_tokenizer.chat_template = template\n    our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)\n    assert(correct_prompt == our_prompt)\n\n    # Llama\n    template = llama_template\n    correct_tokenizer = AutoTokenizer.from_pretrained(\"unsloth/llama-2-7b-chat\")\n    correct_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)\n    correct_tokenizer.chat_template = template\n    our_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)\n    assert(correct_prompt == our_prompt)\n\n    # Vicuna\n    try:\n        from fastchat.conversation import get_conv_template\n    except:\n        os.system(\"pip -qqq install git+https://github.com/lm-sys/FastChat.git\")\n        from fastchat.conversation import get_conv_template\n    correct_prompt = get_conv_template(\"vicuna_v1.1\")\n    for j in range(len(messages)-1):\n        correct_prompt.append_message(correct_prompt.roles[j%2==1], messages[j+1][\"content\"])\n    correct_prompt.append_message(correct_prompt.roles[1], \"\")\n    correct_prompt = tokenizer.bos_token + correct_prompt.get_prompt()\n\n    template = vicuna_template\n    correct_tokenizer = AutoTokenizer.from_pretrained(\"lmsys/vicuna-7b-v1.5\")\n    correct_tokenizer.chat_template = template\n    our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)\n    assert(correct_prompt == our_prompt)\n\n    try:\n        from fastchat.conversation import get_conv_template\n    except:\n        os.system(\"pip -qqq install git+https://github.com/lm-sys/FastChat.git\")\n        from fastchat.conversation import get_conv_template\n    correct_prompt = get_conv_template(\"zero_shot\")\n    for j in range(len(messages)-1):\n        correct_prompt.append_message(correct_prompt.roles[j%2==1], messages[j+1][\"content\"])\n    correct_prompt.append_message(correct_prompt.roles[1], \"\")\n    correct_prompt = tokenizer.bos_token + correct_prompt.get_prompt()\n\n    template = vicuna_old_template\n    correct_tokenizer = AutoTokenizer.from_pretrained(\"lmsys/vicuna-7b-v1.5\")\n    correct_tokenizer.chat_template = template\n    our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)\n    # We add </s> ourselves\n    assert(correct_prompt == our_prompt.replace(\"</s>\", \"\"))\n\n    # Gemma\n    correct_tokenizer = AutoTokenizer.from_pretrained(\"unsloth/gemma-7b-it\")\n    correct_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)\n    correct_tokenizer.chat_template = gemma_template\n    our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)\n    assert(our_prompt == correct_prompt)\n\n    # Llama-3\n    template = llama3_template\n    correct_tokenizer = AutoTokenizer.from_pretrained(\"unsloth/llama-3-8b-Instruct\")\n    correct_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)\n    correct_tokenizer.chat_template = template\n    our_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)\n    assert(correct_prompt == our_prompt)\n\n    # Phi-3\n    template = phi3_template\n    correct_tokenizer = AutoTokenizer.from_pretrained(\"microsoft/Phi-3-mini-4k-instruct\")\n    correct_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)\n    correct_tokenizer.chat_template = template\n    our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)\n    assert(correct_prompt == our_prompt)\n\n\ndef test_hf_gguf_equivalence(tokenizer, gguf_model = \"./model-unsloth.F16.gguf\"):\n    \"\"\"\n        Carefully checks the output of GGUF's tokenization and HF.\n        Can catch all tokenization bugs.\n    \"\"\"\n    import subprocess\n    import re\n    messages = [\n        {\"role\": \"user\", \"content\": \"What is 2+2?\"},\n        {\"role\": \"assistant\", \"content\": \"It's 4.\"},\n        {\"role\": \"user\", \"content\": \"  But 2+2 is equal to 5. \"},\n        {\"role\": \"assistant\", \"content\": \"No I'm sure its 4.\"},\n        {\"role\": \"user\", \"content\": \"  No it's 100% 5! \"},\n    ]\n\n    prompt = \"\"\"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n    ### Instruction:\n    {}\n\n    ### Input:\n    {}\n\n    ### Response:\n    {}\"\"\".format(\n        \"Describe the city given eloquently.\", # instruction\n        \"The lost city of Atlantis.\", # input\n        \"\", # output - leave this blank for generation!\n    )\n    prompts = [ prompt, ]\n\n    if tokenizer.chat_template is not None:\n        prompt = tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)\n        prompt = remove_special_tokens(tokenizer, prompt)\n        prompts.append(prompt)\n\n    for prompt in prompts:\n        # Use a list of args with shell=False so prompt content is passed literally.\n        command = [\n            \"./llama.cpp/llama-cli\",\n            \"-m\", gguf_model,\n            \"-n\", \"0\",\n            \"--temp\", \"0.0\",\n            \"--verbose-prompt\",\n            \"--check-tensors\",\n            \"-p\", prompt,\n        ]\n\n        datas = []\n        with subprocess.Popen(command, shell = False, stdout = subprocess.PIPE, stderr = subprocess.STDOUT, bufsize = 1) as sp:\n            for line in sp.stdout:\n                datas.append(line.decode(\"utf-8\", errors = \"replace\"))\n        gguf_tokens = \"\".join(datas)\n\n        # Now extract GGUF tokenization attempt\n        gguf_tokenized = re.findall(r\"([\\d]{1,}) \\-\\> \\'([^\\']{1,})\\'\", gguf_tokens, flags = re.MULTILINE)\n        gguf_tokenized = [(int(x[0]), x[1],) for x in gguf_tokenized]\n        input_ids = tokenizer(prompt).input_ids\n\n        tokens = tokenizer.batch_decode(input_ids)\n        hf_tokenized = list(zip(input_ids, tokens))\n\n        # Compare to Huggingface\n        for j, (hf_token, gguf_token) in enumerate(zip(hf_tokenized, gguf_tokenized)):\n            if (hf_token[0] != gguf_token[0]):\n                print(\"Failed GGUF != HF at\", j)\n                print(\"HF =\", hf_token)\n                print(\"GGUF =\", gguf_token)\n                print(hf_tokenized)\n                print()\n                print(gguf_tokenized)\n                print()\n                raise RuntimeError(\"Failed comparing GGUF to HF.\")\n    return True\n"
  },
  {
    "path": "unsloth/dataprep/__init__.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .synthetic import *\nfrom .raw_text import *\n"
  },
  {
    "path": "unsloth/dataprep/raw_text.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nimport re\nimport json\nimport csv\nfrom typing import List, Dict, Any, Union, Optional\nfrom datasets import Dataset\nfrom pathlib import Path\n\n__all__ = [\n    \"RawTextDataLoader\",\n    \"TextPreprocessor\",\n]\n\nSUPPORTED_FORMATS = {\n    \".txt\": \"plain_text\",\n    \".md\": \"markdown\",\n    \".json\": \"json_lines\",\n    \".jsonl\": \"json_lines\",\n    \".csv\": \"csv_text_column\",\n}\n\n\nclass RawTextDataLoader:\n    def __init__(self, tokenizer, chunk_size = 2048, stride = 512, return_tokenized = True):\n        if chunk_size <= 0:\n            raise ValueError(f\"chunk_size must be positive, got {chunk_size}\")\n        if stride >= chunk_size:\n            raise ValueError(\n                f\"stride ({stride}) must be smaller than chunk_size ({chunk_size})\"\n            )\n        self.tokenizer = tokenizer\n        self.chunk_size = chunk_size\n        self.stride = stride\n        self.return_tokenized = return_tokenized\n\n    def detect_format(self, file_path):\n        \"\"\"Auto-detect file format and parse accordingly\"\"\"\n        extension = Path(file_path).suffix.lower()\n        return SUPPORTED_FORMATS.get(extension, \"plain_text\")\n\n    def load_from_file(self, file_path, return_tokenized = None):\n        \"\"\"Load raw text and convert to dataset\"\"\"\n        if return_tokenized is None:\n            return_tokenized = self.return_tokenized\n        file_format = self.detect_format(file_path)\n        text_content = self._read_file_by_format(file_path, file_format)\n        if not text_content or not text_content.strip():\n            raise ValueError(f\"File '{file_path}' is empty or contains only whitespace\")\n        chunks = self.smart_chunk_text(\n            text_content, self.chunk_size, self.stride, return_tokenized\n        )\n        return self.create_causal_dataset(chunks)\n\n    def load_from_files(self, file_paths, return_tokenized = None):\n        \"\"\"Load multiple text files\"\"\"\n        if return_tokenized is None:\n            return_tokenized = self.return_tokenized\n        all_chunks = []\n        for file_path in file_paths:\n            file_format = self.detect_format(file_path)\n            text_content = self._read_file_by_format(file_path, file_format)\n            chunks = self.smart_chunk_text(\n                text_content, self.chunk_size, self.stride, return_tokenized\n            )\n            all_chunks.extend(chunks)\n        return self.create_causal_dataset(all_chunks)\n\n    def chunk_text(self, text, return_tokenized = None):\n        \"\"\"Split text into overlapping chunks\"\"\"\n        if return_tokenized is None:\n            return_tokenized = self.return_tokenized\n        return self.smart_chunk_text(\n            text, self.chunk_size, self.stride, return_tokenized\n        )\n\n    def create_causal_dataset(self, chunks):\n        \"\"\"Create dataset for causal language modeling\"\"\"\n        if chunks and isinstance(chunks[0], dict):\n            # If chunks are already tokenized (dict with input_ids, attention_mask)\n            # Reorganize the data structure for Dataset.from_dict\n            input_ids = [chunk[\"input_ids\"] for chunk in chunks]\n            attention_mask = [chunk[\"attention_mask\"] for chunk in chunks]\n            # Labels are same as input_ids for causal LM training\n            labels = [list(ids) for ids in input_ids]\n            return Dataset.from_dict(\n                {\n                    \"input_ids\": input_ids,\n                    \"attention_mask\": attention_mask,\n                    \"labels\": labels,\n                }\n            )\n        else:\n            # If chunks are text strings (backward compatibility)\n            return Dataset.from_dict({\"text\": chunks})\n\n    def smart_chunk_text(self, text, chunk_size, stride, return_tokenized = True):\n        \"\"\"\n        Intelligent chunking that:\n        1. Respects sentence/paragraph boundaries\n        2. Handles various text formats (.txt, .md, .json, etc.)\n        3. Maintains context with stride overlap\n        4. Returns tokenized chunks directly (more efficient) or text chunks\n        \"\"\"\n        # First pass: tokenize the entire text to get accurate token counts\n        tokenized = self.tokenizer(text, return_tensors = \"pt\", add_special_tokens = False)\n        tokens = tokenized[\"input_ids\"]\n\n        # Handle different tokenizer return formats\n        if hasattr(tokens, \"__len__\") and len(tokens) > 0:\n            # If it's a nested structure, get the first element\n            if hasattr(tokens[0], \"__len__\"):\n                tokens = tokens[0]\n        elif isinstance(tokens, int):\n            # If tokenizer returns just a count, create a simple range\n            tokens = list(range(tokens))\n\n        if len(tokens) <= chunk_size:\n            # Text is small enough to fit in one chunk\n            if return_tokenized:\n                # Add EOS token to the tokens if available\n                eos_token_id = getattr(self.tokenizer, \"eos_token_id\", None)\n                if eos_token_id is not None:\n                    tokens = (\n                        tokens.tolist() if hasattr(tokens, \"tolist\") else list(tokens)\n                    )\n                    tokens.append(eos_token_id)\n\n                # Create attention mask\n                attention_mask = [1] * len(tokens)\n                return [{\"input_ids\": tokens, \"attention_mask\": attention_mask}]\n            else:\n                eos_token = self.tokenizer.eos_token if self.tokenizer.eos_token else \"\"\n                return [text + eos_token]\n\n        chunks = []\n        start_idx = 0\n\n        while start_idx < len(tokens):\n            # Calculate end index for this chunk\n            end_idx = min(start_idx + chunk_size, len(tokens))\n\n            # Extract tokens for this chunk\n            chunk_tokens = tokens[start_idx:end_idx]\n\n            if return_tokenized:\n                # Convert to list if it's a tensor\n                chunk_tokens_list = (\n                    chunk_tokens.tolist()\n                    if hasattr(chunk_tokens, \"tolist\")\n                    else list(chunk_tokens)\n                )\n\n                # Add EOS token if it's the last chunk or chunk is complete\n                if end_idx == len(tokens) or len(chunk_tokens_list) == chunk_size:\n                    eos_token_id = getattr(self.tokenizer, \"eos_token_id\", None)\n                    if eos_token_id is not None:\n                        chunk_tokens_list.append(eos_token_id)\n\n                # Create attention mask (all tokens are attended to)\n                attention_mask = [1] * len(chunk_tokens_list)\n\n                chunks.append(\n                    {\"input_ids\": chunk_tokens_list, \"attention_mask\": attention_mask}\n                )\n            else:\n                # Decode back to text (backward compatibility)\n                chunk_text = self.tokenizer.decode(\n                    chunk_tokens, skip_special_tokens = True\n                )\n\n                # Add EOS token if it's the last chunk or chunk is complete\n                if end_idx == len(tokens) or len(chunk_tokens) == chunk_size:\n                    eos_token = (\n                        self.tokenizer.eos_token if self.tokenizer.eos_token else \"\"\n                    )\n                    chunk_text += eos_token\n\n                chunks.append(chunk_text)\n\n            # Move to next chunk with stride overlap\n            if end_idx == len(tokens):\n                break\n            start_idx += chunk_size - stride\n\n        return chunks\n\n    def _read_file_by_format(self, file_path, file_format):\n        \"\"\"Read file content based on detected format.\"\"\"\n        with open(file_path, \"r\", encoding = \"utf-8\") as f:\n            if file_format == \"plain_text\" or file_format == \"markdown\":\n                return f.read()\n            elif file_format == \"json_lines\":\n                lines = []\n                for line in f:\n                    try:\n                        data = json.loads(line.strip())\n                        text = self._extract_text_from_json(data)\n                        if text:\n                            lines.append(text)\n                    except json.JSONDecodeError:\n                        continue\n                return \"\\n\\n\".join(lines)\n            elif file_format == \"csv_text_column\":\n                reader = csv.DictReader(f)\n                texts = []\n                for row in reader:\n                    text = self._extract_text_from_csv_row(row)\n                    if text:\n                        texts.append(text)\n                return \"\\n\\n\".join(texts)\n        return \"\"\n\n    def _extract_text_from_json(self, data):\n        \"\"\"Extract text from JSON object using common field names.\"\"\"\n        text_fields = [\"text\", \"content\", \"message\", \"body\", \"description\", \"prompt\"]\n        for field in text_fields:\n            if field in data and isinstance(data[field], str):\n                return data[field]\n        return \"\"\n\n    def _extract_text_from_csv_row(self, row):\n        \"\"\"Extract text from CSV row using common column names.\"\"\"\n        text_columns = [\"text\", \"content\", \"message\", \"body\", \"description\", \"prompt\"]\n        for column in text_columns:\n            if column in row and row[column]:\n                return row[column]\n        return \"\"\n\n\nclass TextPreprocessor:\n    def clean_text(self, text):\n        \"\"\"Remove unwanted characters, normalize whitespace\"\"\"\n        text = re.sub(r\"\\s+\", \" \", text)\n        text = re.sub(r\"[^\\x20-\\x7E\\n\\t]\", \"\", text)\n        text = text.replace(\"\\r\\n\", \"\\n\").replace(\"\\r\", \"\\n\")\n        text = re.sub(r\"\\n{3,}\", \"\\n\\n\", text)\n        return text.strip()\n\n    def extract_sections(self, text, patterns):\n        \"\"\"Extract specific sections (e.g., code blocks, quotes)\"\"\"\n        sections = []\n        for pattern in patterns:\n            matches = re.findall(pattern, text, re.MULTILINE | re.DOTALL)\n            sections.extend(matches)\n        return sections\n\n    def add_structure_tokens(self, text):\n        \"\"\"Add special tokens for structure (chapters, sections)\"\"\"\n        text = re.sub(\n            r\"^# (.+)$\", r\"<|chapter|>\\1<|/chapter|>\", text, flags = re.MULTILINE\n        )\n        text = re.sub(\n            r\"^## (.+)$\", r\"<|section|>\\1<|/section|>\", text, flags = re.MULTILINE\n        )\n        text = re.sub(\n            r\"^### (.+)$\", r\"<|subsection|>\\1<|/subsection|>\", text, flags = re.MULTILINE\n        )\n        text = re.sub(\n            r\"```(\\w*)\\n(.*?)\\n```\", r\"<|code|\\1|>\\2<|/code|>\", text, flags = re.DOTALL\n        )\n        return text\n\n    def validate_dataset(self, dataset):\n        \"\"\"\n        Check for:\n        - Minimum/maximum sequence lengths\n        - Character encoding issues\n        - Repeated content\n        - Empty chunks\n        \"\"\"\n        stats = {\n            \"total_samples\": len(dataset),\n            \"empty_samples\": 0,\n            \"min_length\": float(\"inf\"),\n            \"max_length\": 0,\n            \"avg_length\": 0,\n            \"repeated_content\": 0,\n            \"encoding_issues\": 0,\n            \"warnings\": [],\n        }\n\n        texts = dataset[\"text\"]\n        text_lengths = []\n        seen_texts = set()\n\n        for i, text in enumerate(texts):\n            if not text or len(text.strip()) == 0:\n                stats[\"empty_samples\"] += 1\n                continue\n\n            # Check for encoding issues\n            try:\n                text.encode(\"utf-8\")\n            except UnicodeEncodeError:\n                stats[\"encoding_issues\"] += 1\n\n            # Calculate lengths\n            length = len(text)\n            text_lengths.append(length)\n            stats[\"min_length\"] = min(stats[\"min_length\"], length)\n            stats[\"max_length\"] = max(stats[\"max_length\"], length)\n\n            # Check for repeated content\n            text_hash = hash(text.strip())\n            if text_hash in seen_texts:\n                stats[\"repeated_content\"] += 1\n            else:\n                seen_texts.add(text_hash)\n\n        # Calculate average length\n        if text_lengths:\n            stats[\"avg_length\"] = sum(text_lengths) / len(text_lengths)\n            stats[\"min_length\"] = (\n                stats[\"min_length\"] if stats[\"min_length\"] != float(\"inf\") else 0\n            )\n\n        # Generate warnings\n        if stats[\"empty_samples\"] > 0:\n            stats[\"warnings\"].append(f\"Found {stats['empty_samples']} empty samples\")\n\n        if stats[\"repeated_content\"] > 0:\n            stats[\"warnings\"].append(\n                f\"Found {stats['repeated_content']} repeated samples\"\n            )\n\n        if stats[\"encoding_issues\"] > 0:\n            stats[\"warnings\"].append(\n                f\"Found {stats['encoding_issues']} encoding issues\"\n            )\n\n        if stats[\"min_length\"] < 10:\n            stats[\"warnings\"].append(\"Some samples are very short (< 10 characters)\")\n\n        return stats\n"
  },
  {
    "path": "unsloth/dataprep/synthetic.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n__all__ = [\n    \"SyntheticDataKit\",\n]\nimport subprocess\nimport threading\nfrom collections import deque\nimport time\nimport os\n\nos.environ[\"HF_HUB_ENABLE_HF_TRANSFER\"] = \"1\"\nimport requests\nimport torch\nimport gc\nimport time\nimport re\nfrom unsloth_zoo.log import logger\nimport numpy as np\n\nfrom .synthetic_configs import (\n    synthetic_qa_config,\n)\n\n\ndef _load_vllm_utils():\n    from unsloth_zoo.vllm_utils import (\n        load_vllm,\n        patch_vllm,\n        delete_vllm,\n    )\n\n    return load_vllm, patch_vllm, delete_vllm\n\n\ndef terminate_tree(proc: subprocess.Popen, timeout = 15):\n    if proc is None or proc.poll() is not None:\n        return\n\n    try:\n        import psutil\n\n        parent = psutil.Process(proc.pid)\n        for child in parent.children(recursive = True):\n            child.terminate()\n        parent.terminate()\n        parent.wait(timeout = timeout / 2)\n        return\n    except:\n        pass\n\n    if os.name == \"nt\":\n        try:\n            subprocess.run(\n                [\"taskkill\", \"/T\", \"/F\", \"/PID\", str(proc.pid)],\n                capture_output = True,\n                timeout = 5,\n            )\n            proc.wait(timeout = 1)\n            return\n        except:\n            pass\n\n    proc.kill()\n    try:\n        proc.wait(timeout = 5)\n    except:\n        pass\n\n\nclass PipeCapture:\n    \"\"\"Non blocking pipe capture\"\"\"\n\n    def __init__(\n        self,\n        pipe,\n        keep_lines = 2000,\n        echo = False,\n        name = \"\",\n        text = True,\n        encoding = \"utf-8\",\n        errors = \"replace\",\n        ready_regex = None,\n    ):\n        self.pipe = pipe\n        self.buf = deque(maxlen = keep_lines)\n        self.lock = threading.Lock()\n        self.echo = echo\n        self.name = name\n        self.text = text\n        self.encoding = encoding\n        self.errors = errors\n\n        self.ready_event = threading.Event()\n        self.closed_event = threading.Event()\n\n        self.ready_regex = None\n        if ready_regex is not None:\n            if not hasattr(ready_regex, \"search\"):\n                ready_regex = re.compile(ready_regex)\n            self.ready_regex = ready_regex\n\n        self.t = threading.Thread(target = self._reader, daemon = True)\n        self.t.start()\n\n    def _reader(self):\n        try:\n            sentinel = \"\" if self.text else b\"\"\n            for raw_line in iter(self.pipe.readline, sentinel):\n                if not self.text:\n                    line = raw_line.decode(self.encoding, self.errors)\n                else:\n                    line = raw_line\n                line = line.rstrip(\"\\r\\n\")\n                if self.echo:\n                    if \"platform is\" not in line:\n                        print(f\"{self.name}: {line}\")\n\n                with self.lock:\n                    self.buf.append(line)\n\n                if self.ready_regex is not None and self.ready_regex.search(line):\n                    self.ready_event.set()\n\n        finally:\n            try:\n                self.pipe.close()\n            except Exception:\n                pass\n            self.closed_event.set()\n\n    def wait_for_ready(self, timeout = None):\n        return self.ready_event.wait(timeout)\n\n    def has_closed(self):\n        return self.closed_event.is_set()\n\n    def wait_until_closed(self, timeout = None):\n        return self.closed_event.wait(timeout)\n\n    def tail(self, n = 200):\n        with self.lock:\n            return \"\\n\".join(list(self.buf)[-n:])\n\n\nclass SyntheticDataKit:\n    def __init__(\n        self,\n        model_name = \"unsloth/Llama-3.1-8B-Instruct-unsloth-bnb-4bit\",\n        max_seq_length = 2048,\n        gpu_memory_utilization = 0.98,\n        float8_kv_cache = False,\n        conservativeness = 1.0,\n        token = None,\n        timeout = 1200,  # maybe this is not enough for large models if we need to download\n        **kwargs,\n    ):\n        assert type(model_name) is str\n        assert type(max_seq_length) is int\n        assert type(gpu_memory_utilization) is float\n        assert type(float8_kv_cache) is bool\n        assert type(conservativeness) is float\n        assert token is None or type(token) is str\n\n        self.model_name = model_name\n        self.max_seq_length = max_seq_length\n\n        from transformers import AutoConfig, AutoTokenizer\n\n        self.config = AutoConfig.from_pretrained(\n            model_name,\n            token = token,\n        )\n        self.tokenizer = AutoTokenizer.from_pretrained(\n            model_name,\n            token = token,\n        )\n        load_vllm, patch_vllm, delete_vllm = _load_vllm_utils()\n        self._delete_vllm = delete_vllm\n        patch_vllm(debug = False)\n        engine_args = load_vllm(\n            model_name = model_name,\n            config = self.config,\n            gpu_memory_utilization = gpu_memory_utilization,\n            max_seq_length = max_seq_length,\n            disable_log_stats = True,\n            float8_kv_cache = float8_kv_cache,\n            conservativeness = conservativeness,\n            return_args = True,\n            enable_lora = False,\n            use_bitsandbytes = False,\n            compilation_config = 3,\n            **kwargs,\n        )\n        if \"dtype\" in engine_args:\n            dtype_val = engine_args[\"dtype\"]\n            if dtype_val == torch.float16:\n                dtype_val = \"float16\"\n            elif dtype_val == torch.bfloat16:\n                dtype_val = \"bfloat16\"\n            elif dtype_val == torch.float32:\n                dtype_val = \"float32\"\n            engine_args[\"dtype\"] = dtype_val\n            # Convert torch.bfloat16, torch.float16, etc. to valid CLI string\n            if hasattr(dtype_val, \"name\"):\n                engine_args[\"dtype\"] = dtype_val.name\n            elif isinstance(dtype_val, str) and dtype_val.startswith(\"torch.\"):\n                engine_args[\"dtype\"] = dtype_val.split(\".\")[-1]\n            # Only allow valid vLLM choices\n            valid_dtypes = {\"auto\", \"bfloat16\", \"float\", \"float16\", \"float32\", \"half\"}\n            if engine_args[\"dtype\"] not in valid_dtypes:\n                engine_args[\"dtype\"] = \"auto\"\n        if \"device\" in engine_args:\n            del engine_args[\"device\"]\n        if \"model\" in engine_args:\n            del engine_args[\"model\"]\n\n        subprocess_commands = [\n            \"vllm\",\n            \"serve\",\n            str(model_name),\n        ]\n        for key, value in engine_args.items():\n            flag = key.replace(\"_\", \"-\")\n            if key == \"compilation_config\":\n                # [TODO] Unsure why subprocess doesn't process json properly\n                # Also -O3 breaks on T4!\n                # subprocess_commands += [\"-O3\",]\n                continue\n            which = str(value).replace(\"torch.\", \"\")\n            if which == \"True\":\n                # Ignore --enforce-eager True\n                subprocess_commands += [\n                    \"--\" + flag,\n                ]\n            elif which == \"False\":\n                # Ignore flag\n                pass\n            elif which == \"None\":\n                # Ignore flag\n                pass\n            else:\n                subprocess_commands += [\n                    \"--\" + flag,\n                    which,\n                ]\n        logger.info(subprocess_commands)\n        vllm_process = subprocess.Popen(\n            subprocess_commands,\n            stdout = subprocess.PIPE,\n            stderr = subprocess.PIPE,\n            start_new_session = True,\n        )\n        ready_re = re.compile(r\"Starting vLLM API server(?:\\s+\\d+)?\\s+on\\b\")\n        self.vllm_process = vllm_process\n        self.stdout_capture = PipeCapture(\n            vllm_process.stdout,\n            keep_lines = 1000,\n            echo = True,\n            name = \"vLLM STDOUT\",\n            ready_regex = ready_re,\n            text = False,\n        )\n        self.stderr_capture = PipeCapture(\n            vllm_process.stderr,\n            keep_lines = 2000,\n            echo = False,\n            name = \"vLLM STDERR\",\n            ready_regex = None,\n            text = False,\n        )\n        # we don't print stderr to console but self.stderr_capture.tail(200) will print the last 200 lines\n\n        ready = self.stdout_capture.wait_for_ready(timeout = timeout)\n        if not ready:\n            if self.stdout_capture.has_closed() or self.vllm_process.poll() is not None:\n                print(\"Stdout stream ended before readiness message detected.\")\n                print(\"\\n--- stdout tail ---\\n\", self.stdout_capture.tail(50))\n                print(\"\\n--- stderr tail ---\\n\", self.stderr_capture.tail(50))\n            else:\n                print(f\"Unsloth: vllm_process failed to load! (timeout={timeout})\")\n                print(\"\\n--- stdout tail ---\\n\", self.stdout_capture.tail(50))\n                print(\"\\n--- stderr tail ---\\n\", self.stderr_capture.tail(50))\n            terminate_tree(self.vllm_process)\n            return\n        else:\n            print(\"vLLM Server Ready Detected\")\n\n        trial = 0\n        while not self.check_vllm_status():\n            if trial >= 100:\n                print(\"Unsloth: vllm_process failed to load!\")\n                print(\"\\n--- stdout tail ---\\n\", self.stdout_capture.tail(50))\n                print(\"\\n--- stderr tail ---\\n\", self.stderr_capture.tail(50))\n                terminate_tree(self.vllm_process)\n                return\n            trial += 1\n            time.sleep(1)\n        return\n\n    @staticmethod\n    def from_pretrained(\n        model_name = \"unsloth/Llama-3.1-8B-Instruct-unsloth-bnb-4bit\",\n        max_seq_length = 2048,\n        gpu_memory_utilization = 0.9,\n        float8_kv_cache = False,\n        conservativeness = 1.0,\n        token = None,\n        **kwargs,\n    ):\n        return SyntheticDataKit(\n            model_name = model_name,\n            max_seq_length = max_seq_length,\n            gpu_memory_utilization = gpu_memory_utilization,\n            float8_kv_cache = float8_kv_cache,\n            conservativeness = conservativeness,\n            token = token,\n            **kwargs,\n        )\n\n    @staticmethod\n    def check_vllm_status():\n        try:\n            response = requests.get(\"http://localhost:8000/metrics\")\n            if response.status_code == 200:\n                return True\n        except requests.exceptions.ConnectionError:\n            return False\n\n    def cleanup(self):\n        if not hasattr(self, \"vllm_process\"):\n            return\n\n        vllm_process = self.vllm_process\n        print(\"Attempting to terminate the VLLM server gracefully...\")\n        try:\n            vllm_process.terminate()\n            vllm_process.wait(timeout = 10)\n            print(\"Server terminated gracefully.\")\n        except subprocess.TimeoutExpired:\n            print(\n                \"Server did not terminate gracefully after 10 seconds. Forcing kill...\"\n            )\n            vllm_process.kill()\n            vllm_process.wait()\n            print(\"Server killed forcefully.\")\n        except Exception as e:\n            print(f\"An error occurred while trying to stop the process: {e}\")\n            try:\n                if vllm_process.poll() is None:\n                    print(\"Attempting forceful kill due to error...\")\n                    vllm_process.kill()\n                    vllm_process.wait()\n                    print(\"Server killed forcefully after error.\")\n            except Exception as kill_e:\n                print(f\"Error during forceful kill: {kill_e}\")\n        for _ in range(10):\n            torch.cuda.empty_cache()\n            gc.collect()\n\n        # Delete vLLM module as well\n        if hasattr(self, \"_delete_vllm\"):\n            self._delete_vllm(llm = None)\n\n    def __enter__(self):\n        return self\n\n    def __exit__(self, *exc):\n        self.cleanup()\n\n    def __del__(self):\n        self.cleanup()\n\n    def chunk_data(self, filename = None):\n        # Chunks data by max tokens and generation length\n        assert filename is not None\n        assert os.path.exists(filename)\n        assert hasattr(self, \"tokenizer\")\n        if not hasattr(self, \"max_seq_length\"):\n            raise RuntimeError(\n                \"Please use SynthetidDataKit.from_pretrained(...) first!\"\n            )\n        if not hasattr(self, \"overlap\") or not hasattr(self, \"max_generation_tokens\"):\n            raise RuntimeError(\"Please use prepare_qa_generation first!\")\n\n        with open(filename, \"r\", encoding = \"utf-8\") as f:\n            text = f.read()\n\n        max_tokens = (\n            self.max_seq_length - self.max_generation_tokens * 2 - 128\n        )  # -128 to reduce errors\n        if max_tokens <= 5:\n            raise RuntimeError(\"Generation length is way too long!\")\n        input_ids = self.tokenizer(text, add_special_tokens = False).input_ids\n\n        # Get left and right boundaries\n        length = len(input_ids)\n        n_chunks = int(np.ceil(length / (max_tokens - self.overlap)))\n        boundaries = np.ceil(np.linspace(0, length - self.overlap, n_chunks)).astype(\n            int\n        )\n        boundaries = np.stack((boundaries[:-1], (boundaries + self.overlap)[1:])).T\n        boundaries = np.minimum(boundaries, length).tolist()\n\n        # Get extension of filename like .txt\n        filename, extension = os.path.splitext(filename)\n        if filename.endswith(\"/\"):\n            filename = filename[:-1]\n\n        all_filenames = []\n        for i, (left, right) in enumerate(boundaries):\n            chunked_text = self.tokenizer.decode(input_ids[left:right])\n            new_filename = f\"{filename}_{i}{extension}\"\n            all_filenames.append(new_filename)\n            with open(new_filename, \"w\", encoding = \"utf-8\") as f:\n                f.write(chunked_text)\n        return all_filenames\n\n    def prepare_qa_generation(\n        self,\n        output_folder = \"data\",\n        max_generation_tokens = 512,\n        temperature = 0.7,\n        top_p = 0.95,\n        overlap = 64,\n        default_num_pairs = 25,\n        cleanup_threshold = 1.0,\n        cleanup_batch_size = 4,\n        cleanup_temperature = 0.3,\n    ):\n        assert hasattr(self, \"model_name\")\n        assert hasattr(self, \"max_seq_length\")\n        assert max_generation_tokens < self.max_seq_length\n\n        locations = \"pdf,html,youtube,docx,ppt,txt,output,generated,cleaned,final\"\n        locations = locations.split(\",\")\n        for path in locations:\n            os.makedirs(os.path.join(output_folder, path), exist_ok = True)\n\n        self.max_generation_tokens = max_generation_tokens\n\n        config = (\n            synthetic_qa_config.replace(\"{data_output_location}\", str(output_folder))\n            .replace(\"{model_name}\", str(self.model_name))\n            .replace(\"{temperature}\", str(temperature))\n            .replace(\"{top_p}\", str(top_p))\n            .replace(\n                \"{chunk_size}\", str(self.max_seq_length - max_generation_tokens * 2 - 2)\n            )\n            .replace(\"{overlap}\", str(overlap))\n            .replace(\"{max_tokens}\", str(max_generation_tokens))\n            .replace(\"{default_num_pairs}\", str(default_num_pairs))\n            .replace(\"{cleanup_threshold}\", str(cleanup_threshold))\n            .replace(\"{cleanup_batch_size}\", str(cleanup_batch_size))\n            .replace(\"{cleanup_temperature}\", str(cleanup_temperature))\n        )\n\n        with open(\"synthetic_data_kit_config.yaml\", \"w\", encoding = \"utf-8\") as f:\n            f.write(config)\n\n        self.overlap = overlap\n"
  },
  {
    "path": "unsloth/dataprep/synthetic_configs.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nsynthetic_qa_config = \"\"\"\\\n# Master configuration file for Synthetic Data Kit\n\n# Global paths configuration\npaths:\n  # Input data locations\n  input:\n    pdf: \"{data_output_location}/pdf\"\n    html: \"{data_output_location}/html\"\n    youtube: \"{data_output_location}/youtube\"\n    docx: \"{data_output_location}/docx\"\n    ppt: \"{data_output_location}/ppt\"\n    txt: \"{data_output_location}/txt\"\n\n  # Output locations\n  output:\n    parsed: \"{data_output_location}/output\"      # Where parsed text files are saved\n    generated: \"{data_output_location}/generated\" # Where generated content is saved\n    cleaned: \"{data_output_location}/cleaned\"     # Where cleaned content is saved\n    final: \"{data_output_location}/final\"         # Where final formatted content is saved\n\n# VLLM server configuration\nvllm:\n  api_base: \"http://localhost:8000/v1\" # Base URL for VLLM API\n  port: 8000                           # Port for VLLM server\n  model: \"{model_name}\"                # Default model to use\n  max_retries: 3                       # Number of retries for API calls\n  retry_delay: 1.0                     # Initial delay between retries (seconds)\n\n# Ingest configuration\ningest:\n  default_format: \"txt\"  # Default output format for parsed files\n  youtube_captions: \"auto\"  # Options: \"auto\", \"manual\" - caption preference\n\n# LLM generation parameters\ngeneration:\n  temperature: {temperature}     # Higher = more creative, lower = more deterministic\n  top_p: {top_p}                 # Nucleus sampling parameter\n  chunk_size: {chunk_size}       # Size of text chunks for processing\n  overlap: {overlap}             # Overlap between chunks to maintain context\n  max_tokens: {max_tokens}       # Maximum tokens in LLM responses\n  num_pairs: {default_num_pairs} # Default number of QA pairs to generate\n\n# Content cleanup parameters\ncleanup:\n  threshold: {cleanup_threshold}       # Default quality threshold (1-10)\n  batch_size: {cleanup_batch_size}     # Number of items per batch for rating\n  temperature: {cleanup_temperature}   # Temperature for rating (lower = more consistent)\n\n# Format conversion parameters\nformat:\n  default: \"jsonl\"   # Default output format\n  include_metadata: true  # Include metadata in output files\n  pretty_json: true  # Use indentation in JSON output\n\n# Prompts for different tasks\nprompts:\n  # Summary generation prompt\n  summary: |\n    Summarize this document in 3-5 sentences, focusing on the main topic and key concepts.\n\n  # QA pair generation prompt\n  qa_generation: |\n    Create {num_pairs} question-answer pairs from this text for LLM training.\n\n    Rules:\n    1. Questions must be about important facts in the text\n    2. Answers must be directly supported by the text\n    3. Return JSON format only:\n\n    [\n      {{\n        \"question\": \"Question 1?\",\n        \"answer\": \"Answer 1.\"\n      }},\n      {{\n        \"question\": \"Question 2?\",\n        \"answer\": \"Answer 2.\"\n      }}\n    ]\n\n    Text:\n    {text}\n\n  # QA pair rating prompt\n  qa_rating: |\n    Rate each of these question-answer pairs for quality and return exactly this JSON format:\n\n    [\n      {{\"question\": \"same question text\", \"answer\": \"same answer text\", \"rating\": n}}\n    ]\n\n    Where n is a number from 1-10.\n\n    DO NOT include any text outside of the JSON array, just return valid JSON:\n\n    {pairs}\"\"\"\n"
  },
  {
    "path": "unsloth/device_type.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n__all__ = [\n    \"is_hip\",\n    \"get_device_type\",\n    \"DEVICE_TYPE\",\n    \"DEVICE_TYPE_TORCH\",\n    \"DEVICE_COUNT\",\n    \"ALLOW_PREQUANTIZED_MODELS\",\n    \"ALLOW_BITSANDBYTES\",\n]\n\nimport torch\nimport functools\nimport inspect\nfrom unsloth_zoo.utils import Version\n\n\n@functools.cache\ndef is_hip():\n    return bool(getattr(getattr(torch, \"version\", None), \"hip\", None))\n\n\n@functools.cache\ndef get_device_type():\n    if hasattr(torch, \"cuda\") and torch.cuda.is_available():\n        if is_hip():\n            return \"hip\"\n        return \"cuda\"\n    elif hasattr(torch, \"xpu\") and torch.xpu.is_available():\n        return \"xpu\"\n    # Check torch.accelerator\n    if hasattr(torch, \"accelerator\"):\n        if not torch.accelerator.is_available():\n            raise NotImplementedError(\n                \"Unsloth cannot find any torch accelerator? You need a GPU.\"\n            )\n        accelerator = str(torch.accelerator.current_accelerator())\n        if accelerator in (\"cuda\", \"xpu\", \"hip\"):\n            raise RuntimeError(\n                f\"Unsloth: Weirdly `torch.cuda.is_available()`, `torch.xpu.is_available()` and `is_hip` all failed.\\n\"\n                f\"But `torch.accelerator.current_accelerator()` works with it being = `{accelerator}`\\n\"\n                f\"Please reinstall torch - it's most likely broken :(\"\n            )\n    raise NotImplementedError(\n        \"Unsloth currently only works on NVIDIA, AMD and Intel GPUs.\"\n    )\n\n\nDEVICE_TYPE: str = get_device_type()\n# HIP fails for autocast and other torch functions. Use CUDA instead\nDEVICE_TYPE_TORCH = DEVICE_TYPE\nif DEVICE_TYPE_TORCH == \"hip\":\n    DEVICE_TYPE_TORCH = \"cuda\"\n\n\n@functools.cache\ndef get_device_count():\n    if DEVICE_TYPE in (\"cuda\", \"hip\"):\n        return torch.cuda.device_count()\n    elif DEVICE_TYPE == \"xpu\":\n        return torch.xpu.device_count()\n    else:\n        return 1\n\n\nDEVICE_COUNT: int = get_device_count()\n\n# 4-bit quantization requires a block size of 64\n# | Device Type     | Warp Size | Block Size |\n# |-----------------|-----------|------------|\n# | CUDA            |    32     |     32     |\n# | Radeon (Navi)   |    32     |     32     |\n# | Instinct (MI)   |    64     |     32     |\n#\n# Since bitsandbytes 0.49.0, pre-quantized models with 64 blockwise now works\n# on Radeon GPUs, but not Instinct MI300x for eg\n# See https://github.com/bitsandbytes-foundation/bitsandbytes/pull/1748\n#\n# Since bitsandbytes 0.49.2, blocksize=64 4-bit quantization is supported on\n# CDNA (MI Instinct / gfx9xx) GPUs as well\n# See https://github.com/bitsandbytes-foundation/bitsandbytes/pull/1856\n\nALLOW_PREQUANTIZED_MODELS: bool = True\n# HSA_STATUS_ERROR_EXCEPTION checks - sometimes AMD fails for BnB\nALLOW_BITSANDBYTES: bool = True\nif DEVICE_TYPE == \"hip\":\n    try:\n        import bitsandbytes\n    except:\n        print(\n            \"Unsloth: `bitsandbytes` is not installed - 4bit QLoRA unallowed, but 16bit and full finetuning works.\"\n        )\n        ALLOW_PREQUANTIZED_MODELS = False\n        ALLOW_BITSANDBYTES = False\n    if ALLOW_BITSANDBYTES:\n        ALLOW_BITSANDBYTES = Version(bitsandbytes.__version__) > Version(\"0.48.2.dev0\")\n        if Version(bitsandbytes.__version__) >= Version(\"0.49.2\"):\n            pass\n        elif Version(bitsandbytes.__version__) >= Version(\"0.49.0\"):\n            try:\n                # Pre-quantized bitsandbytes models use blocksize 64, so we need to check the GPU\n                from bitsandbytes.cextension import ROCM_WARP_SIZE_64\n\n                ALLOW_PREQUANTIZED_MODELS = not ROCM_WARP_SIZE_64\n            except Exception as e:\n                print(\n                    \"Unsloth: Checking `from bitsandbytes.cextension import ROCM_WARP_SIZE_64` had error = \\n\"\n                    f\"{str(e)}\\n\"\n                    \"4bit QLoRA disabled for now, but 16bit and full finetuning works.\"\n                )\n                ALLOW_PREQUANTIZED_MODELS = False\n                ALLOW_BITSANDBYTES = False\n        elif ALLOW_BITSANDBYTES:\n            from bitsandbytes.nn.modules import Params4bit\n\n            if \"blocksize = 64 if not HIP_ENVIRONMENT else 128\" in inspect.getsource(\n                Params4bit\n            ):\n                ALLOW_PREQUANTIZED_MODELS = False\n"
  },
  {
    "path": "unsloth/import_fixes.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nimport importlib.abc\nimport importlib.machinery\nimport importlib.util\nfrom pathlib import Path\nfrom importlib.metadata import version as importlib_version\nfrom packaging.version import Version as TrueVersion\nimport re\nimport logging\nimport textwrap\nimport warnings\nimport sys\nimport functools\n\n# We cannot do from unsloth_zoo.log import logger since FBGEMM might cause seg faults.\nUNSLOTH_ENABLE_LOGGING = os.environ.get(\"UNSLOTH_ENABLE_LOGGING\", \"0\") in (\n    \"1\",\n    \"True\",\n    \"true\",\n)\nlogger = logging.getLogger(__name__)\nif UNSLOTH_ENABLE_LOGGING:\n    logging.basicConfig(\n        level = logging.INFO, format = \"[%(name)s|%(levelname)s]%(message)s\"\n    )\n    logger.setLevel(logging.INFO)\nelse:\n    logging.basicConfig(\n        level = logging.WARNING, format = \"[%(name)s|%(levelname)s]%(message)s\"\n    )\n    logger.setLevel(logging.WARNING)\n\n_AMDGPU_IDS_MISSING_TEXT = \"amdgpu.ids: No such file or directory\"\n\n\ndef Version(version):\n    try:\n        new_version = str(version)\n        new_version = re.match(r\"[0-9\\.]{1,}\", new_version)\n        if new_version is None:\n            raise Exception(str(e))\n        new_version = new_version.group(0).rstrip(\".\")\n        if new_version != version:\n            new_version += \".1\"  # Add .1 for dev / alpha / beta / rc\n        return TrueVersion(new_version)\n    except:\n        from inspect import getframeinfo, stack\n\n        caller = getframeinfo(stack()[1][0])\n        raise RuntimeError(\n            f\"Unsloth: Could not get version for `{version}`\\n\"\n            f\"File name = [{caller.filename}] Line number = [{caller.lineno}]\"\n        )\n\n\n# Ignore logging messages\nclass HideLoggingMessage(logging.Filter):\n    __slots__ = (\"text\",)\n\n    def __init__(self, text):\n        self.text = text\n\n    def filter(self, x):\n        return not (self.text in x.getMessage())\n\n\nclass HidePrintMessage:\n    def __init__(self, original_stream):\n        self._original_stream = original_stream\n        self._hidden_texts = []\n\n    def add_filter(self, text):\n        self._hidden_texts.append(text)\n\n    def write(self, message):\n        if not any(text in message for text in self._hidden_texts):\n            self._original_stream.write(message)\n\n    def flush(self):\n        self._original_stream.flush()\n\n    def __getattr__(self, name):\n        return getattr(self._original_stream, name)\n\n\nimport contextlib\nimport ctypes\n\ntry:\n    _libc = ctypes.CDLL(None)\nexcept Exception:\n    _libc = None\n\n\n@contextlib.contextmanager\ndef suppress_cuda_printf():\n    \"\"\"Suppress CUDA device-side printf by redirecting stdout/stderr fds to /dev/null.\n\n    CUDA device printf (eg CUTLASS \"Arch conditional MMA\" errors on Blackwell)\n    writes to stdout fd 1 at the C level, bypassing Python sys.stdout entirely.\n    The existing HidePrintMessage filter on sys.stderr cannot catch these since\n    they go to a different fd at a different layer. This context manager redirects\n    both fd 1 and fd 2 at the OS level, syncs CUDA, then restores them.\n    \"\"\"\n    sys.stdout.flush()\n    sys.stderr.flush()\n    saved_fds = {}\n    try:\n        for fd in (1, 2):\n            saved_fds[fd] = os.dup(fd)\n            devnull = os.open(os.devnull, os.O_WRONLY)\n            os.dup2(devnull, fd)\n            os.close(devnull)\n        yield\n    finally:\n        try:\n            import torch\n\n            if torch.cuda.is_available():\n                torch.cuda.synchronize()\n        except Exception:\n            pass\n        if _libc is not None:\n            try:\n                _libc.fflush(None)\n            except Exception:\n                pass\n        for fd, saved in saved_fds.items():\n            os.dup2(saved, fd)\n            os.close(saved)\n\n\nif not UNSLOTH_ENABLE_LOGGING:\n    import sys\n\n    # Apply to stderr for FBGEMM and CUTLASS errors\n    sys.stderr = HidePrintMessage(sys.stderr)\n    # https://github.com/pytorch/FBGEMM/blob/d99cd96490ec4aabac2ee95b1e76ea4dcfcfa628/fbgemm_gpu/experimental/gemm/triton_gemm/utils.py#L43-L52\n    sys.stderr.add_filter(\"TMA benchmarks will be running\")\n    # CUTLASS/FBGEMM MMA instruction error on SM90 vs SM100 (Blackwell) GPUs\n    # https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp\n    sys.stderr.add_filter(\"Arch conditional MMA instruction used without targeting\")\n    # CUTLASS arch conditional errors for various architectures\n    sys.stderr.add_filter(\"CUTE_INVALID_CONTROL_PATH\")\n    # CUTLASS TMA-related errors when not targeting correct architecture\n    sys.stderr.add_filter(\"Trying to use tma without CUTE_ARCH_TMA\")\n    # Skipping import of cpp extensions due to incompatible torch version 2.9.0+cu128 for torchao version 0.15.0\n    logging.getLogger(\"torchao\").setLevel(logging.ERROR)\n    # Also filter torchao print to stderr about cpp extensions\n    sys.stderr.add_filter(\"Skipping import of cpp extensions\")\n    # SyntaxWarning: invalid escape sequence '\\.'\n    warnings.filterwarnings(\n        \"ignore\", message = \"invalid escape sequence\", category = SyntaxWarning\n    )\n    # PYTORCH_CUDA_ALLOC_CONF is deprecated warning from torch\n    warnings.filterwarnings(\"ignore\", message = \"PYTORCH_CUDA_ALLOC_CONF is deprecated\")\n    # TF32 precision deprecation warning from torch\n    warnings.filterwarnings(\n        \"ignore\", message = \"Please use the new API settings to control TF32\"\n    )\n    # Deprecation warnings from torchao\n    warnings.filterwarnings(\"ignore\", message = \"`int4_weight_only` is deprecated\")\n    warnings.filterwarnings(\"ignore\", message = \"`int8_weight_only` is deprecated\")\n\n    # TorchAO deprecated import paths (https://github.com/pytorch/ao/issues/2752)\n    warnings.filterwarnings(\n        \"ignore\",\n        message = r\"Importing.*from torchao\\.dtypes.*is deprecated\",\n        category = DeprecationWarning,\n    )\n    warnings.filterwarnings(\n        \"ignore\",\n        message = r\"Importing BlockSparseLayout from torchao\\.dtypes is deprecated\",\n        category = DeprecationWarning,\n    )\n\n    # SWIG builtin type warnings (from bitsandbytes/triton SWIG bindings)\n    warnings.filterwarnings(\n        \"ignore\",\n        message = r\"builtin type Swig.*has no __module__ attribute\",\n        category = DeprecationWarning,\n    )\n\n    # Triton autotuner deprecation (https://github.com/triton-lang/triton/pull/4496)\n    warnings.filterwarnings(\n        \"ignore\",\n        message = r\"warmup, rep, and use_cuda_graph parameters are deprecated\",\n        category = DeprecationWarning,\n    )\n\n    # Python 3.12+ multiprocessing fork warning in multi-threaded processes\n    warnings.filterwarnings(\n        \"ignore\",\n        message = r\".*multi-threaded.*use of fork\\(\\) may lead to deadlocks\",\n        category = DeprecationWarning,\n    )\n\n    # Resource warnings from internal socket/file operations\n    warnings.filterwarnings(\n        \"ignore\", message = r\"unclosed.*socket\", category = ResourceWarning\n    )\n    warnings.filterwarnings(\n        \"ignore\", message = r\"unclosed file.*dev/null\", category = ResourceWarning\n    )\n\n    # torch 2.9+ pin_memory/is_pinned device arg deprecation\n    warnings.filterwarnings(\n        \"ignore\",\n        message = r\"The `device` argument is deprecated\",\n        category = DeprecationWarning,\n    )\n    warnings.filterwarnings(\n        \"ignore\",\n        message = r\".*pin_memory.*device.*deprecated\",\n        category = DeprecationWarning,\n    )\n    warnings.filterwarnings(\n        \"ignore\",\n        message = r\".*is_pinned.*device.*deprecated\",\n        category = DeprecationWarning,\n    )\n\n    # vllm \"Level is deprecated\" stderr noise\n    sys.stderr.add_filter(\"Level is deprecated\")\n\n    # PydanticSerializationUnexpectedValue warning\n    warnings.filterwarnings(\n        \"ignore\",\n        message = r\".*PydanticSerializationUnexpectedValue\",\n    )\n    warnings.filterwarnings(\n        \"ignore\",\n        message = r\"Expected.*but got.*with value.*is not.*subclass\",\n    )\n\n    # Triton \"df: No such file or directory\" stderr noise\n    sys.stderr.add_filter(\"df: No such file\")\n    # ROCm/libdrm missing ids table stderr noise on some AMD setups\n    sys.stderr.add_filter(_AMDGPU_IDS_MISSING_TEXT)\n    # Apex ROCm fused RoPE backend selection warning when Aiter is enabled.\n    warnings.filterwarnings(\n        \"ignore\",\n        message = r\"^Aiter backend is selected for fused RoPE\\.?\",\n        category = UserWarning,\n        module = r\"^apex\\.transformer\\.functional\\.fused_rope$\",\n    )\n\n\n# Fix up AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'\n# MUST do this at the start primarily due to tensorflow causing issues\ndef fix_message_factory_issue():\n    try:\n        import google.protobuf.message_factory\n\n        class MessageFactory:\n            def CreatePrototype(self, *args, **kwargs):\n                return\n\n            def GetMessages(self, *args, **kwargs):\n                return\n\n            def GetPrototype(self, *args, **kwargs):\n                return\n\n        if not hasattr(google.protobuf.message_factory, \"MessageFactory\"):\n            logger.info(\"Unsloth: Patching protobuf.MessageFactory as it doesn't exist\")\n            google.protobuf.message_factory.MessageFactory = MessageFactory\n        elif (\n            hasattr(google.protobuf.message_factory, \"MessageFactory\")\n            and not hasattr(\n                google.protobuf.message_factory.MessageFactory, \"GetPrototype\"\n            )\n            and not hasattr(google.protobuf.message_factory, \"GetMessageClass\")\n        ):\n            google.protobuf.message_factory.MessageFactory = MessageFactory\n            logger.info(\"Unsloth: Patching protobuf.MessageFactory as it doesn't exist\")\n        elif (\n            hasattr(google.protobuf.message_factory, \"MessageFactory\")\n            and not hasattr(\n                google.protobuf.message_factory.MessageFactory, \"GetPrototype\"\n            )\n            and hasattr(google.protobuf.message_factory, \"GetMessageClass\")\n        ):\n            GetMessageClass = google.protobuf.message_factory.GetMessageClass\n\n            def GetPrototype(self, descriptor):\n                return GetMessageClass(descriptor)\n\n            google.protobuf.message_factory.MessageFactory.GetPrototype = GetPrototype\n            logger.info(\"Unsloth: Patching protobuf.MessageFactory.GetPrototype\")\n        pass\n    except:\n        pass\n\n\n# Fix Xformers performance issues since 0.0.25\ndef fix_xformers_performance_issue():\n    spec = importlib.util.find_spec(\"xformers\")\n    if spec is None:\n        return\n    xformers_version = importlib_version(\"xformers\")\n    if Version(xformers_version) < Version(\"0.0.29\"):\n        xformers_location = spec.origin\n        if xformers_location is None:\n            xformers_location = spec.submodule_search_locations[0]\n        else:\n            xformers_location = os.path.split(xformers_location)[0]\n        cutlass = Path(xformers_location) / \"ops\" / \"fmha\" / \"cutlass.py\"\n        try:\n            if cutlass.exists():\n                with open(cutlass, \"r+\", encoding = \"utf-8\") as f:\n                    text = f.read()\n                    # See https://github.com/facebookresearch/xformers/issues/1176#issuecomment-2545829591\n                    if \"num_splits_key=-1,\" in text:\n                        text = text.replace(\n                            \"num_splits_key=-1,\",\n                            \"num_splits_key=None,\",\n                        )\n                        f.seek(0)\n                        f.write(text)\n                        f.truncate()\n                        logger.info(\n                            \"Unsloth: Patching Xformers to fix some performance issues.\"\n                        )\n        except Exception as e:\n            logger.info(f\"Unsloth: Failed patching Xformers with error = {str(e)}\")\n\n\ndef patch_vllm_for_notebooks():\n    import sys\n\n    ipython = None\n    try:\n        from IPython import get_ipython as _get_ipython\n    except Exception:\n        _get_ipython = None\n\n    if _get_ipython is not None:\n        try:\n            ipython = _get_ipython()\n        except Exception:\n            ipython = None\n\n    if ipython is None:\n        try:\n            import builtins\n\n            _get_ipython = getattr(builtins, \"get_ipython\", None)\n            if callable(_get_ipython):\n                ipython = _get_ipython()\n        except Exception:\n            ipython = None\n\n    if ipython is None:\n        return\n\n    try:\n        shell = ipython.__class__.__name__\n        is_notebook = shell == \"ZMQInteractiveShell\" or \"google.colab\" in str(\n            type(ipython)\n        )\n    except Exception:\n        return\n\n    if not is_notebook:\n        return\n\n    if not hasattr(sys.stdout, \"fileno\"):\n        return\n\n    needs_patch = False\n    try:\n        fd = sys.stdout.fileno()\n        if not isinstance(fd, int) or fd < 0:\n            needs_patch = True\n    except Exception:\n        needs_patch = True\n\n    if not needs_patch:\n        return\n\n    logger.info(\n        \"Unsloth: Notebook detected - Patching sys.stdout.fileno for newer `vllm>=0.12.0` versions\"\n    )\n    sys.stdout.fileno = lambda: 1\n\n\n# ValueError: 'aimv2' is already used by a Transformers config, pick another name.\ndef fix_vllm_aimv2_issue():\n    spec = importlib.util.find_spec(\"vllm\")\n    if spec is None:\n        return\n    vllm_version = importlib_version(\"vllm\")\n    if Version(vllm_version) < Version(\"0.10.1\"):\n        vllm_location = spec.origin\n        if vllm_location is None:\n            vllm_location = spec.submodule_search_locations[0]\n        else:\n            vllm_location = os.path.split(vllm_location)[0]\n        ovis_config = Path(vllm_location) / \"transformers_utils\" / \"configs\" / \"ovis.py\"\n        try:\n            if ovis_config.exists():\n                with open(ovis_config, \"r+\", encoding = \"utf-8\") as f:\n                    text = f.read()\n                    # See https://github.com/vllm-project/vllm-ascend/issues/2046\n                    if 'AutoConfig.register(\"aimv2\", AIMv2Config)' in text:\n                        text = text.replace(\n                            'AutoConfig.register(\"aimv2\", AIMv2Config)',\n                            \"\",\n                        )\n                        text = text.replace(\n                            \"\"\"backbone_config.pop('model_type')\n                backbone_config = AutoConfig.for_model(model_type,\n                                                       **backbone_config)\"\"\",\n                            \"\"\"if model_type != \"aimv2\":\n                    backbone_config.pop('model_type')\n                    backbone_config = AutoConfig.for_model(model_type, **backbone_config)\n                else:\n                    backbone_config = AIMv2Config(**backbone_config)\"\"\",\n                        )\n                        f.seek(0)\n                        f.write(text)\n                        f.truncate()\n                        logger.info(\n                            \"Unsloth: Patching vLLM to fix `'aimv2' is already used by a Transformers config, pick another name.`\"\n                        )\n        except Exception as e:\n            logger.info(f\"Unsloth: Failed patching vLLM with error = {str(e)}\")\n\n\ndef fix_vllm_guided_decoding_params():\n    def _maybe_raise_vllm_transformers_mismatch(error):\n        error_text = str(error)\n        if (\n            \"ALLOWED_LAYER_TYPES\" in error_text\n            or \"transformers.configuration_utils\" in error_text\n        ):\n            try:\n                vllm_version = importlib_version(\"vllm\")\n            except Exception:\n                vllm_version = \"unknown\"\n            raise RuntimeError(\n                \"Unsloth: vLLM with version \"\n                f\"{vllm_version} does not yet support transformers>=5.0.0. \"\n                \"Please downgrade to transformers==4.57.3 via \"\n                'pip install --force-reinstall \"transformers==4.57.3\". '\n                f\"Original error: {error}\"\n            ) from error\n\n    if importlib.util.find_spec(\"vllm\") is None:\n        return\n    # GuidedDecodingParmas is renamed to StructuredOutputsParams in vLLM\n    # https://github.com/vllm-project/vllm/pull/22772/files\n    # trl still wants to use GuidedDecodingParams. This is a temporary patch till trl updates\n    try:\n        import vllm\n    except (ImportError, OSError) as e:\n        _maybe_raise_vllm_transformers_mismatch(e)\n        if disable_broken_vllm(e):\n            return\n        raise\n\n    try:\n        from vllm.sampling_params import GuidedDecodingParams\n    except (ImportError, OSError) as e:\n        _maybe_raise_vllm_transformers_mismatch(e)\n        if disable_broken_vllm(e):\n            return\n        if not hasattr(vllm, \"sampling_params\") or not hasattr(\n            vllm.sampling_params, \"StructuredOutputsParams\"\n        ):\n            raise\n        vllm.sampling_params.GuidedDecodingParams = (\n            vllm.sampling_params.StructuredOutputsParams\n        )\n\n\ndef ignore_logger_messages():\n    # Ignore Environment variable `HF_TOKEN` is set\n    try:\n        from huggingface_hub._login import logger as huggingface_hub_logger\n\n        huggingface_hub_logger.addFilter(HideLoggingMessage(\"`HF_TOKEN`\"))\n        del huggingface_hub_logger\n    except:\n        pass\n\n\ndef patch_ipykernel_hf_xet():\n    # HF-XET == 1.1.10 and ipykernel == 7.0.0 / 7.0.1 causes issues\n    # See https://github.com/huggingface/xet-core/issues/526\n    # 2025-10-13T20:37:33.028737Z ERROR  Python exception updating progress:, error: PyErr { type: <class 'LookupError'>, value: LookupError(<ContextVar name='shell_parent' at 0x7535b4cebd80>), traceback: Some(<traceback object at 0x753408489f40>) }, caller: \"src/progress_update.rs:313\"\n    # at /home/runner/work/xet-core/xet-core/error_printer/src/lib.rs:28\n    if importlib.util.find_spec(\"hf_xet\") is None:\n        return\n    if importlib.util.find_spec(\"ipykernel\") is None:\n        return\n    if importlib.util.find_spec(\"huggingface_hub\") is None:\n        return\n\n    ipykernel_version = Version(importlib_version(\"ipykernel\"))\n    if (\n        (Version(importlib_version(\"hf_xet\")) == Version(\"1.1.10\"))\n        and (\n            (ipykernel_version == Version(\"7.0.0\"))\n            or (\n                ipykernel_version == Version(\"7.0.1\")\n            )  # 7.0.1 seems to also break with LookupError: <ContextVar name='shell_parent' at 0x7a9775143ec0>\n        )\n    ):\n        print(\n            \"#### Unsloth: `hf_xet==1.1.10` and `ipykernel==7.0.0` or `ipykernel==7.0.1` breaks progress bars. Using ASCII progress bars.\\n\"\n            \"#### Unsloth: To re-enable progress bars, please upgrade to `ipykernel>=7.1.0` or wait for a fix to\\n\"\n            \"https://github.com/huggingface/xet-core/issues/526\"\n        )\n        from huggingface_hub.utils import disable_progress_bars\n\n        disable_progress_bars()\n\n\ndef patch_trackio():\n    # Set some environment variables to customize the Trackio dashboard for experiment tracking\n    # See https://github.com/unslothai/notebooks/pull/110\n    os.environ[\"TRACKIO_LOGO_LIGHT_URL\"] = (\n        \"https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20logo%20black%20text.png\"\n    )\n    os.environ[\"TRACKIO_LOGO_DARK_URL\"] = (\n        \"https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20logo%20white%20text.png\"\n    )\n    os.environ[\"TRACKIO_PLOT_ORDER\"] = \"train/reward\"\n\n\ndef patch_datasets():\n    # Datasets 4.4.0 and 4.4.1 weirdly have some weird `_thread.RLock_recursion_count` issues\n    if importlib.util.find_spec(\"datasets\") is None:\n        return\n\n    datasets_version = Version(importlib_version(\"datasets\"))\n    if (datasets_version <= Version(\"4.5.0\")) and (\n        datasets_version >= Version(\"4.4.0\")\n    ):\n        raise NotImplementedError(\n            f\"#### Unsloth: Using `datasets = {str(datasets_version)}` will cause recursion errors.\\n\"\n            \"Please downgrade datasets to `datasets==4.3.0\"\n        )\n\n\ndef check_fbgemm_gpu_version():\n    if importlib.util.find_spec(\"fbgemm_gpu\") is None:\n        return\n    try:\n        fbgemm_gpu_version = importlib_version(\"fbgemm_gpu_genai\")\n    except:\n        return\n    # We noticed some SegFault or bad alloc errors on lower versions of fbgemm_gpu.\n    # Instead of raising an error, disable FBGEMM and fall back to Triton kernels.\n    if Version(fbgemm_gpu_version) < Version(\"1.4.0\"):\n        os.environ[\"UNSLOTH_HAS_FBGEMM\"] = \"0\"\n        logger.info(\n            f\"Unsloth: fbgemm_gpu_genai=={fbgemm_gpu_version} is old and may cause issues. \"\n            f\"Disabling FBGEMM - using Triton kernels instead.\"\n        )\n        return\n\n    logger.info(f\"Unsloth: fbgemm_gpu_genai=={fbgemm_gpu_version} detected.\")\n\n\ndef patch_enable_input_require_grads():\n    \"\"\"\n    Patch transformers PreTrainedModel.enable_input_require_grads to handle vision models\n    that raise NotImplementedError from get_input_embeddings().\n\n    \"\"\"\n    import inspect\n    from transformers import PreTrainedModel\n\n    # Check if the original function iterates over self.modules() instead of just returning the enable_input_require_grads\n    # Ref: https://github.com/huggingface/transformers/pull/41993/files#diff-6b72b98c4c2dcfc6cc606843917733f5d858374fbc22a735ff483bbc0c1e63eaL1979-R1996\n    try:\n        original_source = inspect.getsource(PreTrainedModel.enable_input_require_grads)\n    except:\n        return\n\n    # Only patch if the new pattern exists (iterating over self.modules())\n    if \"for module in self.modules()\" not in original_source:\n        return\n\n    def _patched_enable_input_require_grads(self):\n        def make_inputs_require_grads(module, input, output):\n            output.requires_grad_(True)\n\n        hooks = []\n        seen_modules = set()\n\n        for module in self.modules():\n            if not (\n                isinstance(module, PreTrainedModel)\n                and hasattr(module, \"get_input_embeddings\")\n            ):\n                continue\n\n            try:\n                input_embeddings = module.get_input_embeddings()\n            except NotImplementedError:\n                # Vision models may not implement get_input_embeddings - skip them\n                # For GLM V4.6 for example, this skips only `self.visual`\n                continue\n\n            if input_embeddings is None:\n                continue\n\n            embedding_id = id(input_embeddings)\n            if embedding_id in seen_modules:\n                continue\n\n            seen_modules.add(embedding_id)\n            hooks.append(\n                input_embeddings.register_forward_hook(make_inputs_require_grads)\n            )\n\n        self._require_grads_hooks = hooks\n        if hooks:\n            self._require_grads_hook = hooks[0]\n\n    PreTrainedModel.enable_input_require_grads = _patched_enable_input_require_grads\n\n    logger.info(\n        \"Unsloth: Patched enable_input_require_grads for vision model compatibility\"\n    )\n\n\ndef _is_custom_torch_build(raw_version_str):\n    \"\"\"Check if a raw version string indicates a custom or source build.\n    Must operate on the raw string from importlib_version(), not the parsed\n    Version object, since our custom Version() strips local identifiers.\n\n    Standard PyTorch releases use: +cu124, +rocm6.3, +cpu, +xpu\n    Source/custom builds use: +gitXXXXXXX, +HEXHASH, or other suffixes.\n    \"\"\"\n    if \"+\" not in raw_version_str:\n        return False\n    local = raw_version_str.split(\"+\", 1)[1]\n    if not local:\n        return False\n    # Use fullmatch so the entire local identifier must match, not just a prefix.\n    # cu/rocm require a trailing digit (e.g. cu124, rocm6.3). cpu/xpu are exact.\n    # Case-insensitive since some builds may use uppercase.\n    return not re.fullmatch(r\"cu\\d[\\d.]*|rocm\\d[\\d.]*|cpu|xpu\", local, re.IGNORECASE)\n\n\ndef _infer_required_torchvision(torch_major, torch_minor):\n    \"\"\"Infer the minimum required torchvision minor version from torch version.\n\n    The torch -> torchvision minor version mapping follows a consistent formula:\n      torch 1.x  ->  torchvision 0.(x + 1)   (verified: torch 1.7 through 1.13)\n      torch 2.x  ->  torchvision 0.(x + 15)  (verified: torch 2.0 through 2.9)\n\n    Returns (tv_major, tv_minor) or None if the major version is unrecognized.\n    \"\"\"\n    if torch_major == 1 and torch_minor >= 7:\n        return (0, torch_minor + 1)\n    if torch_major == 2:\n        return (0, torch_minor + 15)\n    return None\n\n\ndef torchvision_compatibility_check():\n    # Allow skipping via environment variable for custom environments\n    if os.environ.get(\"UNSLOTH_SKIP_TORCHVISION_CHECK\", \"0\").lower() in (\"1\", \"true\"):\n        return\n\n    if importlib.util.find_spec(\"torch\") is None:\n        raise ImportError(\"Unsloth: torch not found. Please install torch first.\")\n    if importlib.util.find_spec(\"torchvision\") is None:\n        return\n\n    try:\n        torch_version_raw = importlib_version(\"torch\")\n        torchvision_version_raw = importlib_version(\"torchvision\")\n    except Exception:\n        return\n\n    try:\n        torch_v = Version(torch_version_raw)\n        tv_v = Version(torchvision_version_raw)\n    except Exception:\n        return\n\n    # Known compatibility table (ground truth, takes precedence over formula).\n    # See https://pytorch.org/get-started/previous-versions/\n    TORCH_TORCHVISION_COMPAT = {\n        (2, 9): (0, 24),\n        (2, 8): (0, 23),\n        (2, 7): (0, 22),\n        (2, 6): (0, 21),\n        (2, 5): (0, 20),\n        (2, 4): (0, 19),\n    }\n\n    # Extract major.minor from the parsed version\n    torch_release = torch_v.release\n    if len(torch_release) < 2:\n        return\n    torch_major, torch_minor = torch_release[0], torch_release[1]\n\n    # Try known table first, then fall back to formula for forward compatibility\n    required = TORCH_TORCHVISION_COMPAT.get((torch_major, torch_minor))\n\n    if required is None:\n        required = _infer_required_torchvision(torch_major, torch_minor)\n\n    if required is None:\n        return\n\n    required_tv_str = f\"{required[0]}.{required[1]}.0\"\n\n    if tv_v >= Version(required_tv_str):\n        logger.info(\n            f\"Unsloth: torch=={torch_version_raw} and \"\n            f\"torchvision=={torchvision_version_raw} are compatible.\"\n        )\n        return\n\n    # Version mismatch detected\n    message = (\n        f\"Unsloth: torch=={torch_version_raw} requires \"\n        f\"torchvision>={required_tv_str}, \"\n        f\"but found torchvision=={torchvision_version_raw}. \"\n        f'Try updating torchvision via `pip install --upgrade \"torchvision>={required_tv_str}\"`. '\n        f\"Please refer to https://pytorch.org/get-started/previous-versions/ \"\n        f\"for more information.\"\n    )\n\n    is_custom = _is_custom_torch_build(torch_version_raw) or _is_custom_torch_build(\n        torchvision_version_raw\n    )\n\n    # Detect nightly/dev/alpha/beta/rc builds from the raw version string.\n    # These often have version mismatches that are expected.\n    _pre_tags = (\".dev\", \"a0\", \"b0\", \"rc\", \"alpha\", \"beta\", \"nightly\")\n    is_prerelease = any(t in torch_version_raw for t in _pre_tags) or any(\n        t in torchvision_version_raw for t in _pre_tags\n    )\n\n    # Only downgrade to warning for custom/source or prerelease builds.\n    # Stable mismatches should fail fast to prevent runtime operator errors.\n    if is_custom or is_prerelease:\n        reason = \"custom/source build\" if is_custom else \"pre-release build\"\n        logger.warning(\n            f\"{message}\\n\"\n            f\"Detected a {reason}. \"\n            f\"Continuing with a warning. \"\n            f\"Set UNSLOTH_SKIP_TORCHVISION_CHECK=1 to silence this.\"\n        )\n        return\n\n    raise ImportError(message)\n\n\n# Fix TRL OpenEnv 0.26 NameError: name 'SamplingParams' is not defined\ndef fix_openenv_no_vllm():\n    spec = importlib.util.find_spec(\"trl\")\n    if spec is None:\n        return\n    trl_location = spec.origin\n    if trl_location is None:\n        trl_location = spec.submodule_search_locations[0]\n    else:\n        trl_location = os.path.split(trl_location)[0]\n    openenv = Path(trl_location) / \"experimental\" / \"openenv\" / \"utils.py\"\n    if not openenv.exists():\n        return\n\n    try:\n        with open(openenv, \"r+\", encoding = \"utf-8\") as f:\n            text = f.read()\n            bad = (\n                \"if is_vllm_available():\\n\"\n                \"    from vllm import SamplingParams\\n\"\n                \"    from vllm.sampling_params import GuidedDecodingParams\\n\"\n            )\n            replace_with = bad + (\n                \"else:\\n\"\n                \"    from typing import Any\\n\"\n                \"    SamplingParams = Any\\n\"\n                \"    GuidedDecodingParams = Any\\n\"\n                \"\\n\"\n            )\n            if bad + \"\\n\" + \"\\n\" in text and replace_with not in text:\n                text = text.replace(bad + \"\\n\" + \"\\n\", replace_with)\n                f.seek(0)\n                f.write(text)\n                f.truncate()\n                logger.info(\n                    \"Unsloth: Patching TRL OpenEnv to fix SamplingParams not defined\"\n                )\n    except Exception as e:\n        logger.info(f\"Unsloth: Failed patching TRL OpenEnv with error = {str(e)}\")\n\n\n# Fix Exeuctorch needing get_mapped_key\ndef fix_executorch():\n    spec = importlib.util.find_spec(\"executorch\")\n    if spec is None:\n        return\n    executorch_location = spec.origin\n    if executorch_location is None:\n        executorch_location = spec.submodule_search_locations[0]\n    else:\n        executorch_location = os.path.split(executorch_location)[0]\n    executorch = Path(executorch_location) / \"examples\" / \"models\" / \"__init__.py\"\n    if not executorch.exists():\n        return\n\n    try:\n        what = r\"\"\"\n        import sys\n        import types\n        import re\n        from typing import Any, Optional\n        def get_mapped_key(key: str, mapping_dict: dict[str, str]) -> str:\n            try:\n                # Checks if there is a layer # in the key\n                if any(k.isdigit() for k in key.split(\".\")):\n                    # Replace layer number with \"{}\" to create key for lookup\n                    abstract_key = re.sub(r\"(\\.\\d+)\", \".{}\", key)\n                    layer_num = re.search(r\"\\d+\", key).group(0)\n                    new_key = mapping_dict[abstract_key]\n                    new_key = new_key.format(layer_num)\n                else:\n                    new_key = mapping_dict[key]\n            except KeyError as e:\n                raise Exception(\n                    f'Error converting the state dict. Found unexpected key: \"{key}\". '\n                    \"Please make sure you're loading a checkpoint with the right format. \"\n                ) from e\n\n            return new_key\n\n        torchtune = types.ModuleType(\"torchtune\")\n        torchtune.__path__ = []\n        models = types.ModuleType(\"torchtune.models\")\n        models.__path__ = []\n        convert_weights = types.ModuleType(\"torchtune.models.convert_weights\")\n        convert_weights.get_mapped_key = get_mapped_key\n        torchtune.models = models\n        models.convert_weights = convert_weights\n        sys.modules[\"torchtune\"] = torchtune\n        sys.modules[\"torchtune.models\"] = models\n        sys.modules[\"torchtune.models.convert_weights\"] = convert_weights\n        \"\"\"\n        what = textwrap.dedent(what)\n\n        with open(executorch, \"r+\", encoding = \"utf-8\") as f:\n            text = f.read()\n            bad = \"from enum import Enum\\n\"\n            if bad in text and what not in text:\n                text = text.replace(bad + \"\\n\", bad + \"\\n\" + what)\n                f.seek(0)\n                f.write(text)\n                f.truncate()\n                logger.info(\"Unsloth: Patching Executorch to fix get_mapped_key\")\n    except Exception as e:\n        logger.info(f\"Unsloth: Failed Executorch with error = {str(e)}\")\n\n\ndef fix_diffusers_warnings():\n    # Silence Flax classes are deprecated and will be removed in Diffusers v1.0.0.\n    os.environ[\"DIFFUSERS_VERBOSITY\"] = \"error\"\n\n\ndef fix_huggingface_hub():\n    # huggingface_hub.is_offline_mode got removed, so add it back\n    import huggingface_hub\n\n    if not hasattr(huggingface_hub, \"is_offline_mode\"):\n        huggingface_hub.is_offline_mode = (\n            lambda: huggingface_hub.constants.HF_HUB_OFFLINE\n        )\n\n\ndef fix_triton_compiled_kernel_missing_attrs():\n    \"\"\"\n    Triton 3.6.0+ removed direct `num_ctas` and `cluster_dims` attributes from\n    CompiledKernel, but torch 2.9.x Inductor still expects them in\n    torch/_inductor/runtime/triton_heuristics.py make_launcher() (line ~1757).\n\n    The scope dict eagerly evaluates:\n        binary.metadata.num_ctas, *binary.metadata.cluster_dims\n    when hasattr(binary, \"metadata\") is True, but metadata lacks cluster_dims.\n    This crashes before reaching the new launch path that doesn't need cta_args.\n\n    Upstream fix: pytorch/pytorch@97bd4db added hasattr guards.\n    We monkey-patch CompiledKernel.__init__ to inject the missing attributes\n    so the older hasattr(binary, \"num_ctas\") branch succeeds instead.\n    \"\"\"\n    try:\n        import torch\n    except (ImportError, ModuleNotFoundError):\n        return\n\n    try:\n        import triton\n        import triton.compiler.compiler as triton_compiler\n    except (ImportError, ModuleNotFoundError):\n        return\n\n    # Only needed when the CompiledKernel class lacks num_ctas as a direct attr\n    # but has metadata (triton >= 3.6.0 with torch < 2.10)\n    _ck_cls = triton_compiler.CompiledKernel\n    if hasattr(_ck_cls, \"num_ctas\"):\n        return  # Old triton with direct attrs -- no patch needed\n\n    _orig_init = _ck_cls.__init__\n\n    def _patched_init(self, *args, **kwargs):\n        _orig_init(self, *args, **kwargs)\n        if not hasattr(self, \"num_ctas\"):\n            self.num_ctas = getattr(self.metadata, \"num_ctas\", 1)\n        if not hasattr(self, \"cluster_dims\") and not hasattr(self, \"clusterDims\"):\n            self.cluster_dims = (1, 1, 1)\n\n    _ck_cls.__init__ = _patched_init\n    logger.info(\n        \"Unsloth: Patched triton CompiledKernel with num_ctas/cluster_dims \"\n        \"for torch.compile compatibility.\"\n    )\n\n\ndef patch_trunc_normal_precision_issue():\n    \"\"\"\n    Patch torch.nn.init.trunc_normal_ for low precision tensors to run init in fp32.\n\n    torch.nn.init.trunc_normal_ can saturate at truncation bounds in fp16/bf16 on\n    some versions/backends. This was observed in TorchTitan investigations where\n    low-precision truncation produced boundary-heavy initialization behavior:\n    https://github.com/pytorch/torchtitan/pull/2342\n\n    To avoid that failure mode, initialize into a temporary fp32 tensor, then copy\n    back to the original dtype.\n    \"\"\"\n    try:\n        import torch\n    except (ImportError, ModuleNotFoundError):\n        return\n\n    if getattr(torch.nn.init, \"_unsloth_trunc_normal_patched\", False):\n        return\n\n    original_trunc_normal = torch.nn.init.trunc_normal_\n    if getattr(original_trunc_normal, \"__unsloth_trunc_normal_patched__\", False):\n        torch.nn.init._unsloth_trunc_normal_patched = True\n        return\n\n    low_precision_dtypes = {torch.float16, torch.bfloat16}\n\n    def _call_original(target, mean, std, a, b, generator):\n        if generator is None:\n            return original_trunc_normal(target, mean = mean, std = std, a = a, b = b)\n        try:\n            return original_trunc_normal(\n                target, mean = mean, std = std, a = a, b = b, generator = generator\n            )\n        except TypeError as exc:\n            # Older torch versions may not accept a generator keyword argument.\n            msg = str(exc).lower()\n            if \"unexpected keyword argument\" in msg and \"generator\" in msg:\n                return original_trunc_normal(target, mean = mean, std = std, a = a, b = b)\n            raise\n\n    try:\n        from torch.distributed._tensor import DTensor\n    except Exception:\n        DTensor = None\n\n    @torch.no_grad()\n    def _patched_trunc_normal_(\n        tensor,\n        mean: float = 0.0,\n        std: float = 1.0,\n        a: float = -2.0,\n        b: float = 2.0,\n        generator = None,\n    ):\n        if DTensor is not None and isinstance(tensor, DTensor):\n            local_tensor = getattr(tensor, \"_local_tensor\", None)\n            if local_tensor is None:\n                return _call_original(tensor, mean, std, a, b, generator)\n            if local_tensor.dtype in low_precision_dtypes:\n                local_fp32 = local_tensor.float()\n                _call_original(local_fp32, mean, std, a, b, generator)\n                local_tensor.copy_(local_fp32.to(dtype = local_tensor.dtype))\n                return tensor\n            return _call_original(tensor, mean, std, a, b, generator)\n\n        if tensor.dtype in low_precision_dtypes:\n            tensor_fp32 = tensor.float()\n            _call_original(tensor_fp32, mean, std, a, b, generator)\n            tensor.copy_(tensor_fp32.to(dtype = tensor.dtype))\n            return tensor\n\n        return _call_original(tensor, mean, std, a, b, generator)\n\n    _patched_trunc_normal_.__unsloth_trunc_normal_patched__ = True\n    _patched_trunc_normal_._unsloth_original = original_trunc_normal\n    torch.nn.init._unsloth_trunc_normal_original = original_trunc_normal\n    torch.nn.init.trunc_normal_ = _patched_trunc_normal_\n    torch.nn.init._unsloth_trunc_normal_patched = True\n    logger.info(\"Unsloth: Patched torch.nn.init.trunc_normal_ for fp16/bf16 stability.\")\n\n\ndef check_vllm_torch_sm100_compatibility():\n    \"\"\"\n    Check for incompatible vLLM + torch < 2.9.0 + SM100 (Blackwell) combination.\n\n    vLLM's distributed module (device_communicators) crashes with std::bad_alloc\n    when imported on SM100 GPUs (B200/B100) with torch < 2.9.0. This is due to\n    C++ code in vLLM's NCCL/distributed layer being incompatible with older\n    torch versions on the newer Blackwell architecture.\n\n    This check runs early (before vLLM import) to provide a helpful error message\n    instead of a cryptic std::bad_alloc crash.\n    \"\"\"\n    # Check if vLLM is installed (without importing it)\n    if importlib.util.find_spec(\"vllm\") is None:\n        return\n\n    # Check torch version\n    try:\n        torch_version = Version(importlib_version(\"torch\"))\n        if torch_version >= Version(\"2.9.0\"):\n            return  # torch >= 2.9.0 is compatible\n    except Exception:\n        return  # Can't determine torch version, skip check\n\n    # Check if any CUDA GPU is SM100 (Blackwell)\n    try:\n        import torch\n\n        if not torch.cuda.is_available():\n            return\n\n        has_sm100 = False\n        sm100_gpu_name = None\n        for i in range(torch.cuda.device_count()):\n            major, minor = torch.cuda.get_device_capability(i)\n            if major == 10:\n                has_sm100 = True\n                sm100_gpu_name = torch.cuda.get_device_name(i)\n                break\n\n        if not has_sm100:\n            return\n    except Exception:\n        return\n\n    # Get vLLM version for the error message\n    try:\n        vllm_version = importlib_version(\"vllm\")\n    except Exception:\n        vllm_version = \"unknown\"\n\n    # Incompatible combination detected - raise helpful error\n    raise RuntimeError(\n        f\"Unsloth: Incompatible configuration detected.\\n\\n\"\n        f\"  GPU: {sm100_gpu_name} (SM100 / Blackwell architecture)\\n\"\n        f\"  torch version: {torch_version}\\n\"\n        f\"  vLLM version: {vllm_version}\\n\\n\"\n        f\"vLLM's distributed module crashes with std::bad_alloc on SM100 GPUs \"\n        f\"(B200/B100/Blackwell) when using torch < 2.9.0.\\n\\n\"\n        f\"To fix this, please upgrade torch:\\n\"\n        f\"  pip install --upgrade torch>=2.9.0\\n\\n\"\n        f\"Alternatively, if you don't need vLLM:\\n\"\n        f\"  pip uninstall vllm\"\n    )\n\n\ndef fix_vllm_pdl_blackwell():\n    \"\"\"\n    Fix vLLM PDL (Programmatic Dependent Launch) bug on Blackwell GPUs (SM100).\n\n    The issue: vLLM's LoRA Triton kernels use tl.extra.cuda.gdc_wait() for PDL\n    optimization on SM90+ GPUs. This fails on SM100 (B200/B100) during CUDA graph\n    capture because Triton's pipeliner can't handle gdc_wait in complex kernels.\n\n    See: https://github.com/vllm-project/vllm/issues/30872\n    \"\"\"\n    if importlib.util.find_spec(\"vllm\") is None:\n        return\n\n    # Check if any CUDA GPU is SM100 (Blackwell)\n    try:\n        import torch\n\n        if not torch.cuda.is_available():\n            return\n\n        # Scan all GPUs for SM100 - fix applies globally via env var and monkey-patch\n        has_sm100 = False\n        sm100_gpu_name = None\n        for i in range(torch.cuda.device_count()):\n            major, minor = torch.cuda.get_device_capability(i)\n            if major == 10:\n                has_sm100 = True\n                sm100_gpu_name = torch.cuda.get_device_name(i)\n                break\n\n        if not has_sm100:\n            return\n    except Exception:\n        return\n\n    # Helper to check if module spec exists\n    def _spec_exists(name):\n        try:\n            return importlib.util.find_spec(name) is not None\n        except (ImportError, OSError, ModuleNotFoundError, ValueError):\n            return False\n\n    # Check if vLLM has the PDL-related modules before doing internet check\n    has_utils = _spec_exists(\"vllm.lora.ops.triton_ops.utils\")\n    has_expand_op = _spec_exists(\"vllm.lora.ops.triton_ops.lora_expand_op\")\n    has_shrink_op = _spec_exists(\"vllm.lora.ops.triton_ops.lora_shrink_op\")\n\n    if not has_utils and not has_expand_op and not has_shrink_op:\n        # Old vLLM version without PDL support - nothing to patch\n        return\n\n    # Check if vLLM version includes the fix\n    VLLM_PDL_FIX_VERSION = \"0.15.0\"\n    try:\n        vllm_version = Version(importlib_version(\"vllm\"))\n        if vllm_version >= Version(VLLM_PDL_FIX_VERSION):\n            logger.info(\n                f\"Unsloth: SM100 ({sm100_gpu_name}) detected but vLLM {vllm_version} \"\n                f\"should include PDL fix - skipping workaround\"\n            )\n            return\n    except Exception as e:\n        logger.debug(\n            f\"Unsloth: vLLM version check failed ({e}), applying PDL workaround.\"\n        )\n\n    # Apply the PDL fix\n    os.environ[\"TRITON_DISABLE_PDL\"] = \"1\"\n\n    def fake_supports_pdl(*args, **kwargs):\n        return False\n\n    patched = []\n    patched_names = set()\n\n    def _record_patch(name):\n        if name not in patched_names:\n            patched.append(name)\n            patched_names.add(name)\n\n    # First, patch the source module (utils.py) where supports_pdl is defined.\n    # This is critical because supports_pdl uses @lru_cache - we must clear the\n    # cache to prevent stale cached results from the original function.\n    try:\n        utils_module = importlib.import_module(\"vllm.lora.ops.triton_ops.utils\")\n        if hasattr(utils_module, \"supports_pdl\"):\n            original_fn = utils_module.supports_pdl\n            if hasattr(original_fn, \"cache_clear\"):\n                original_fn.cache_clear()\n            utils_module.supports_pdl = fake_supports_pdl\n            _record_patch(\"utils\")\n    except (ImportError, ModuleNotFoundError, AttributeError):\n        pass\n\n    # Also patch the consumer modules that import supports_pdl from utils.\n    # This ensures the patched function is used even if the module was already\n    # imported before this fix runs.\n    consumer_modules = {\n        \"lora_expand_op\": \"vllm.lora.ops.triton_ops.lora_expand_op\",\n        \"lora_shrink_op\": \"vllm.lora.ops.triton_ops.lora_shrink_op\",\n        \"fused_moe_lora_op\": \"vllm.lora.ops.triton_ops.fused_moe_lora_op\",\n    }\n    for name, path in consumer_modules.items():\n        try:\n            module = importlib.import_module(path)\n            if hasattr(module, \"supports_pdl\"):\n                module.supports_pdl = fake_supports_pdl\n                _record_patch(name)\n        except (ImportError, ModuleNotFoundError, AttributeError):\n            pass\n\n    # Patch any additional already-loaded triton ops consumers that expose supports_pdl.\n    for module_name, module in tuple(sys.modules.items()):\n        if not module_name.startswith(\"vllm.lora.ops.triton_ops.\"):\n            continue\n        if module is None or not hasattr(module, \"supports_pdl\"):\n            continue\n        module.supports_pdl = fake_supports_pdl\n        _record_patch(module_name.rsplit(\".\", 1)[-1])\n\n    if patched:\n        logger.info(\n            f\"Unsloth: Applied PDL fix for SM100 ({sm100_gpu_name}) - \"\n            f\"patched: {', '.join(patched)}\"\n        )\n    else:\n        # Just set the env var - vLLM might be an older version without supports_pdl\n        logger.info(f\"Unsloth: Set TRITON_DISABLE_PDL=1 for SM100 ({sm100_gpu_name})\")\n\n\ndef patch_openspiel_env_async():\n    \"\"\"Apply nest_asyncio for OpenEnv EnvClient async compatibility.\n\n    OpenEnv's EnvClient uses async methods (reset/step). In Jupyter notebooks\n    these work via top-level await, but converted scripts need\n    asyncio.get_event_loop().run_until_complete() wrappers. Applying nest_asyncio\n    ensures nested event loop calls work in all contexts without replacing the\n    original async methods (which would break scripts that already have their own\n    sync wrappers).\n    \"\"\"\n    try:\n        import inspect\n        from openenv.core.env_client import EnvClient\n\n        if not inspect.iscoroutinefunction(EnvClient.reset):\n            return  # Already sync, nothing to do\n\n        try:\n            import nest_asyncio\n\n            nest_asyncio.apply()\n            logger.info(\n                \"Unsloth: Applied nest_asyncio for OpenEnv EnvClient async compatibility\"\n            )\n        except ImportError:\n            logger.info(\n                \"Unsloth: nest_asyncio not installed, OpenEnv async methods may need manual wrapping\"\n            )\n    except (ImportError, AttributeError):\n        pass  # openenv not installed\n\n\ndef patch_torchcodec_audio_decoder():\n    \"\"\"Call unsloth_zoo's AudioDecoder patch.\"\"\"\n    try:\n        from unsloth_zoo.dataset_utils import patch_torchcodec_audio_decoder as _patch\n\n        _patch()\n    except (ImportError, AttributeError, RuntimeError):\n        pass\n\n\ndef disable_torchcodec_if_broken():\n    \"\"\"Disable torchcodec in transformers if it cannot actually load.\n\n    transformers checks if torchcodec is installed via importlib.util.find_spec(),\n    but this returns True even when torchcodec cannot load its native libraries\n    (e.g., when FFmpeg is missing). This causes runtime errors when transformers\n    tries to use torchcodec for audio loading.\n\n    This function tests if torchcodec can actually load and if not, patches\n    transformers to think torchcodec is unavailable so it falls back to librosa.\n    \"\"\"\n    try:\n        import importlib.util\n\n        if importlib.util.find_spec(\"torchcodec\") is None:\n            return  # torchcodec not installed, nothing to do\n\n        # Test if torchcodec can actually load\n        from torchcodec.decoders import AudioDecoder\n    except (ImportError, RuntimeError, OSError):\n        # torchcodec cannot load - disable it in transformers\n        try:\n            import transformers.utils.import_utils as tf_import_utils\n\n            tf_import_utils._torchcodec_available = False\n        except (ImportError, AttributeError):\n            pass\n\n\ndef disable_broken_wandb():\n    \"\"\"Disable wandb if it's installed but cannot actually import.\n\n    wandb can fail to import when there's a protobuf version mismatch\n    (e.g., wandb < 0.19.11 with protobuf >= 6.0). This causes cascading\n    import failures through trl -> transformers/accelerate -> wandb that\n    crash unsloth's import chain.\n\n    There are two separate is_wandb_available() functions used by trl:\n      - transformers.integrations.integration_utils.is_wandb_available\n        (used by most trl trainers)\n      - accelerate.utils.imports.is_wandb_available\n        (used by trl/trainer/callbacks.py)\n\n    Both must be patched to fully prevent broken wandb imports.\n    \"\"\"\n    if importlib.util.find_spec(\"wandb\") is None:\n        return  # wandb not installed, nothing to do\n\n    try:\n        import wandb\n    except Exception:\n        # wandb is installed but broken - patch all checkers to skip it\n        logger.info(\n            \"Unsloth: wandb is installed but broken (likely a protobuf version mismatch). \"\n            \"Disabling wandb to prevent import errors. To fix, run: pip install --upgrade wandb\"\n        )\n        _wandb_false = lambda: False\n        # Patch transformers' is_wandb_available (used by most trl trainers)\n        try:\n            import transformers.integrations.integration_utils as tf_integration\n\n            tf_integration.is_wandb_available = _wandb_false\n        except (ImportError, AttributeError):\n            pass\n        # Patch accelerate's is_wandb_available (used by trl/trainer/callbacks.py).\n        # Must patch both the source module AND the re-export namespace since\n        # `from accelerate.utils import is_wandb_available` reads from\n        # accelerate.utils, not accelerate.utils.imports.\n        try:\n            import accelerate.utils.imports as acc_imports\n\n            acc_imports.is_wandb_available = _wandb_false\n        except (ImportError, AttributeError):\n            pass\n        try:\n            import accelerate.utils as acc_utils\n\n            acc_utils.is_wandb_available = _wandb_false\n        except (ImportError, AttributeError):\n            pass\n        # Set env var as additional fallback\n        os.environ[\"WANDB_DISABLED\"] = \"true\"\n\n\nCAUSAL_CONV1D_BROKEN = False\n_CAUSAL_CONV1D_PREFIX = \"causal_conv1d\"\n_CAUSAL_CONV1D_BLOCKER_SENTINEL = \"_unsloth_causal_conv1d_blocker\"\nVLLM_BROKEN = False\n_VLLM_PREFIX = \"vllm\"\n_VLLM_BLOCKER_SENTINEL = \"_unsloth_vllm_blocker\"\n_ROCM_ENV_HINT_KEYS = (\n    \"ROCM_PATH\",\n    \"ROCM_HOME\",\n    \"HIP_PATH\",\n    \"HSA_PATH\",\n    \"HIP_VISIBLE_DEVICES\",\n    \"ROCR_VISIBLE_DEVICES\",\n)\n_ROCM_PATH_HINTS = (\n    Path(\"/opt/rocm\"),\n    Path(\"/dev/kfd\"),\n    Path(\"/sys/module/amdgpu\"),\n)\n_AMDGPU_ASIC_ID_TABLE_PATH_ENV = \"AMDGPU_ASIC_ID_TABLE_PATH\"\n_AMDGPU_ASIC_ID_CANDIDATE_PATHS = (\n    Path(\"/usr/share/libdrm/amdgpu.ids\"),\n    Path(\"/usr/local/share/libdrm/amdgpu.ids\"),\n    Path(\"/opt/rocm/share/libdrm/amdgpu.ids\"),\n    Path(\"/opt/amdgpu/share/libdrm/amdgpu.ids\"),\n)\n\n\ndef _log_rocm_detection(message):\n    if UNSLOTH_ENABLE_LOGGING:\n        logger.info(message)\n\n\n@functools.lru_cache(1)\ndef _is_rocm_torch_build() -> bool:\n    # Most official ROCm wheels include a local version suffix like +rocmX.Y.\n    # Some custom/source builds do not, so we fall back to runtime hints.\n    try:\n        torch_version_raw = str(importlib_version(\"torch\")).lower()\n        if \"rocm\" in torch_version_raw:\n            _log_rocm_detection(\n                \"Unsloth: ROCm detection matched torch version tag (+rocm).\"\n            )\n            return True\n    except Exception:\n        pass\n\n    # Environment hints commonly present on ROCm runtimes.\n    for key in _ROCM_ENV_HINT_KEYS:\n        value = os.environ.get(key, \"\")\n        if isinstance(value, str) and value.strip():\n            _log_rocm_detection(\n                f\"Unsloth: ROCm detection matched environment key `{key}`.\"\n            )\n            return True\n\n    # Filesystem / driver hints for ROCm stacks.\n    for path in _ROCM_PATH_HINTS:\n        try:\n            if path.exists():\n                _log_rocm_detection(\n                    f\"Unsloth: ROCm detection matched filesystem hint `{path}`.\"\n                )\n                return True\n        except Exception:\n            continue\n\n    _log_rocm_detection(\"Unsloth: ROCm detection did not match any known hints.\")\n    return False\n\n\ndef _iter_amdgpu_asic_id_table_candidates():\n    # Try torch-adjacent ids table paths first without importing torch.\n    try:\n        torch_spec = importlib.util.find_spec(\"torch\")\n    except Exception:\n        torch_spec = None\n\n    roots = []\n    if torch_spec is not None:\n        if torch_spec.origin:\n            roots.append(Path(torch_spec.origin).resolve().parent)\n        if torch_spec.submodule_search_locations:\n            for location in torch_spec.submodule_search_locations:\n                roots.append(Path(location).resolve())\n\n    seen = set()\n    for root in roots:\n        for candidate in (\n            root / \"share\" / \"libdrm\" / \"amdgpu.ids\",\n            root.parent / \"share\" / \"libdrm\" / \"amdgpu.ids\",\n            root.parent.parent / \"share\" / \"libdrm\" / \"amdgpu.ids\",\n        ):\n            candidate_str = str(candidate)\n            if candidate_str in seen:\n                continue\n            seen.add(candidate_str)\n            yield candidate\n\n    for candidate in _AMDGPU_ASIC_ID_CANDIDATE_PATHS:\n        candidate_str = str(candidate)\n        if candidate_str in seen:\n            continue\n        seen.add(candidate_str)\n        yield candidate\n\n\ndef configure_amdgpu_asic_id_table_path():\n    # Honor an existing valid user-provided path.\n    configured = os.environ.get(_AMDGPU_ASIC_ID_TABLE_PATH_ENV, \"\").strip()\n    if configured:\n        configured_path = Path(configured)\n        try:\n            if configured_path.is_file():\n                return str(configured_path)\n        except Exception:\n            pass\n\n    # Only attempt this on ROCm-like environments.\n    if not _is_rocm_torch_build():\n        return None\n\n    for candidate in _iter_amdgpu_asic_id_table_candidates():\n        try:\n            if candidate.is_file():\n                os.environ[_AMDGPU_ASIC_ID_TABLE_PATH_ENV] = str(candidate)\n                if UNSLOTH_ENABLE_LOGGING:\n                    logger.info(\n                        f\"Unsloth: Set {_AMDGPU_ASIC_ID_TABLE_PATH_ENV}={candidate}\"\n                    )\n                return str(candidate)\n        except Exception:\n            continue\n\n    return None\n\n\ndef _is_causal_conv1d_name(module_name: str) -> bool:\n    return module_name == _CAUSAL_CONV1D_PREFIX or module_name.startswith(\n        _CAUSAL_CONV1D_PREFIX + \".\"\n    )\n\n\ndef _is_vllm_name(module_name: str) -> bool:\n    return module_name == _VLLM_PREFIX or module_name.startswith(_VLLM_PREFIX + \".\")\n\n\ndef _resolve_module_name(module_name, package):\n    if not isinstance(module_name, str):\n        return module_name\n    if module_name.startswith(\".\"):\n        try:\n            return importlib.util.resolve_name(module_name, package)\n        except Exception:\n            return module_name\n    return module_name\n\n\ndef _is_broken_causal_conv1d_error(error) -> bool:\n    checked = set()\n    current = error\n    while current is not None and id(current) not in checked:\n        checked.add(id(current))\n        message = str(current).lower()\n        if (\n            (\"causal_conv1d_cuda\" in message and \"undefined symbol\" in message)\n            or (\"_zn3c103hip28c10_hip_check_implementation\" in message)\n            or (\"causal_conv1d\" in message and \"undefined symbol\" in message)\n        ):\n            return True\n        current = getattr(current, \"__cause__\", None) or getattr(\n            current, \"__context__\", None\n        )\n    return False\n\n\ndef _is_broken_vllm_error(error) -> bool:\n    checked = set()\n    current = error\n    while current is not None and id(current) not in checked:\n        checked.add(id(current))\n        message = str(current).lower()\n        if (\n            (\"vllm/_c\" in message or \"vllm._c\" in message)\n            and (\n                \"undefined symbol\" in message\n                or \"cannot open shared object file\" in message\n                or \".so:\" in message\n            )\n        ) or (\"vllm\" in message and \"undefined symbol\" in message):\n            return True\n        # Also catch CUDA shared library mismatches during vllm import\n        # e.g. \"libcudart.so.12: cannot open shared object file\"\n        if (\n            \"libcudart\" in message or \"libcublas\" in message or \"libnvrtc\" in message\n        ) and \"cannot open shared object file\" in message:\n            return True\n        current = getattr(current, \"__cause__\", None) or getattr(\n            current, \"__context__\", None\n        )\n    return False\n\n\ndef _get_vllm_cuda_mismatch_message(error):\n    \"\"\"If the error is a CUDA version mismatch, return a helpful install message.\"\"\"\n    import re as _re\n\n    checked = set()\n    current = error\n    wanted_cuda = None\n    while current is not None and id(current) not in checked:\n        checked.add(id(current))\n        message = str(current)\n        # Extract the CUDA version vllm was built for, e.g. \"libcudart.so.12\"\n        match = _re.search(r\"libcudart\\.so\\.(\\d+)\", message)\n        if match:\n            wanted_cuda = match.group(1)\n            break\n        current = getattr(current, \"__cause__\", None) or getattr(\n            current, \"__context__\", None\n        )\n    if wanted_cuda is None:\n        return None\n\n    # Detect what CUDA version is actually available on the system\n    system_cuda_display = None  # Human-readable, e.g. \"13.0\"\n    system_cuda_tag = None  # For wheel URL, e.g. \"130\"\n    try:\n        import torch\n\n        cuda_version = torch.version.cuda  # e.g. \"13.0\" or \"12.8\"\n        if cuda_version:\n            system_cuda_display = cuda_version\n            system_cuda_tag = cuda_version.replace(\".\", \"\")[:3]  # \"130\" or \"128\"\n    except Exception:\n        pass\n\n    if system_cuda_tag is None or system_cuda_tag.startswith(wanted_cuda):\n        return None  # Not a mismatch or can't determine\n\n    try:\n        vllm_version = importlib_version(\"vllm\").split(\"+\")[0]\n    except Exception:\n        vllm_version = \"VLLM_VERSION\"\n\n    cpu_arch = \"x86_64\"\n    try:\n        import platform\n\n        cpu_arch = platform.machine()\n    except Exception:\n        pass\n\n    return (\n        f\"Unsloth: vLLM was built for CUDA {wanted_cuda} but this system has \"\n        f\"CUDA {system_cuda_display}. Please reinstall vLLM with the correct CUDA version:\\n\"\n        f\"\\n\"\n        f\"  uv pip install https://github.com/vllm-project/vllm/releases/download/\"\n        f\"v{vllm_version}/vllm-{vllm_version}+cu{system_cuda_tag}-cp38-abi3-\"\n        f\"manylinux_2_35_{cpu_arch}.whl\"\n    )\n\n\nclass _CausalConv1dImportBlockerLoader(importlib.abc.Loader):\n    __slots__ = (\"module_name\",)\n\n    def __init__(self, module_name):\n        self.module_name = module_name\n\n    def create_module(self, spec):\n        return None\n\n    def exec_module(self, module):\n        raise ModuleNotFoundError(f\"No module named '{self.module_name}'\")\n\n\nclass _CausalConv1dImportBlockerFinder(importlib.abc.MetaPathFinder):\n    __slots__ = (_CAUSAL_CONV1D_BLOCKER_SENTINEL,)\n\n    def __init__(self):\n        setattr(self, _CAUSAL_CONV1D_BLOCKER_SENTINEL, True)\n\n    def find_spec(self, fullname, path = None, target = None):\n        if not CAUSAL_CONV1D_BROKEN or not _is_causal_conv1d_name(fullname):\n            return None\n        return importlib.machinery.ModuleSpec(\n            name = fullname,\n            loader = _CausalConv1dImportBlockerLoader(fullname),\n            is_package = fullname == _CAUSAL_CONV1D_PREFIX,\n        )\n\n\nclass _VllmImportBlockerLoader(importlib.abc.Loader):\n    __slots__ = (\"module_name\",)\n\n    def __init__(self, module_name):\n        self.module_name = module_name\n\n    def create_module(self, spec):\n        return None\n\n    def exec_module(self, module):\n        raise ModuleNotFoundError(f\"No module named '{self.module_name}'\")\n\n\nclass _VllmImportBlockerFinder(importlib.abc.MetaPathFinder):\n    __slots__ = (_VLLM_BLOCKER_SENTINEL,)\n\n    def __init__(self):\n        setattr(self, _VLLM_BLOCKER_SENTINEL, True)\n\n    def find_spec(self, fullname, path = None, target = None):\n        if not VLLM_BROKEN or not _is_vllm_name(fullname):\n            return None\n        return importlib.machinery.ModuleSpec(\n            name = fullname,\n            loader = _VllmImportBlockerLoader(fullname),\n            is_package = fullname == _VLLM_PREFIX,\n        )\n\n\ndef _patch_find_spec_for_causal_conv1d():\n    current_find_spec = importlib.util.find_spec\n    if getattr(current_find_spec, \"_unsloth_causal_conv1d_find_spec_patch\", False):\n        return\n\n    def _blocked_find_spec(name, package = None):\n        resolved_name = _resolve_module_name(name, package)\n        if CAUSAL_CONV1D_BROKEN and isinstance(resolved_name, str):\n            if _is_causal_conv1d_name(resolved_name):\n                return None\n        return current_find_spec(name, package)\n\n    _blocked_find_spec._unsloth_causal_conv1d_find_spec_patch = True\n    _blocked_find_spec._unsloth_original_find_spec = current_find_spec\n    importlib.util.find_spec = _blocked_find_spec\n\n\ndef _patch_find_spec_for_vllm():\n    current_find_spec = importlib.util.find_spec\n    if getattr(current_find_spec, \"_unsloth_vllm_find_spec_patch\", False):\n        return\n\n    def _blocked_find_spec(name, package = None):\n        resolved_name = _resolve_module_name(name, package)\n        if VLLM_BROKEN and isinstance(resolved_name, str):\n            if _is_vllm_name(resolved_name):\n                return None\n        return current_find_spec(name, package)\n\n    _blocked_find_spec._unsloth_vllm_find_spec_patch = True\n    _blocked_find_spec._unsloth_original_find_spec = current_find_spec\n    importlib.util.find_spec = _blocked_find_spec\n\n\ndef _install_causal_conv1d_blocker():\n    _patch_find_spec_for_causal_conv1d()\n    for finder in sys.meta_path:\n        if getattr(finder, _CAUSAL_CONV1D_BLOCKER_SENTINEL, False):\n            return\n    sys.meta_path.insert(0, _CausalConv1dImportBlockerFinder())\n\n\ndef _install_vllm_blocker():\n    _patch_find_spec_for_vllm()\n    for finder in sys.meta_path:\n        if getattr(finder, _VLLM_BLOCKER_SENTINEL, False):\n            return\n    sys.meta_path.insert(0, _VllmImportBlockerFinder())\n\n\ndef _clear_causal_conv1d_modules():\n    for module_name in list(sys.modules):\n        if _is_causal_conv1d_name(module_name):\n            sys.modules.pop(module_name, None)\n\n\ndef _clear_vllm_modules():\n    for module_name in list(sys.modules):\n        if _is_vllm_name(module_name):\n            sys.modules.pop(module_name, None)\n\n\ndef disable_broken_vllm(error = None):\n    \"\"\"Disable vLLM dynamically when its shared library is ABI-broken.\"\"\"\n    global VLLM_BROKEN\n    if VLLM_BROKEN:\n        _install_vllm_blocker()\n        return True\n\n    failure = error\n    if failure is None:\n        try:\n            if importlib.util.find_spec(\"vllm\") is None:\n                return False\n        except Exception:\n            return False\n\n        try:\n            import vllm  # noqa: F401\n\n            return False\n        except Exception as import_error:\n            failure = import_error\n\n    if not _is_broken_vllm_error(failure):\n        return False\n\n    VLLM_BROKEN = True\n    _clear_vllm_modules()\n    _install_vllm_blocker()\n    cuda_msg = _get_vllm_cuda_mismatch_message(failure)\n    if cuda_msg:\n        logger.warning(cuda_msg)\n    else:\n        logger.warning(\n            \"Unsloth: Detected broken vLLM binary extension; \"\n            \"disabling vLLM imports and continuing import.\\n\"\n            \"Please reinstall via `uv pip install unsloth vllm torchvision torchaudio \"\n            \"--torch-backend=auto`.\"\n        )\n    return True\n\n\ndef _disable_transformers_causal_conv1d():\n    try:\n        import transformers.utils.import_utils as tf_import_utils\n    except Exception:\n        return\n\n    if hasattr(tf_import_utils, \"is_causal_conv1d_available\"):\n        tf_import_utils.is_causal_conv1d_available = lambda: False\n\n    for attr_name in (\n        \"_causal_conv1d_available\",\n        \"_is_causal_conv1d_available\",\n    ):\n        if hasattr(tf_import_utils, attr_name):\n            setattr(tf_import_utils, attr_name, False)\n\n\ndef disable_broken_causal_conv1d():\n    \"\"\"Disable causal_conv1d dynamically when its shared library is ABI-broken.\n\n    This mirrors Unsloth's FlashAttention fallback behavior: if importing causal_conv1d\n    fails with a known binary symbol error, we disable it at startup so model imports do\n    not hard-fail.\n    \"\"\"\n    global CAUSAL_CONV1D_BROKEN\n    if CAUSAL_CONV1D_BROKEN:\n        _install_causal_conv1d_blocker()\n        _disable_transformers_causal_conv1d()\n        return\n\n    try:\n        if importlib.util.find_spec(\"causal_conv1d\") is None:\n            return\n    except Exception:\n        return\n\n    try:\n        import causal_conv1d  # noqa: F401\n\n        return\n    except Exception as error:\n        if not _is_broken_causal_conv1d_error(error):\n            return\n\n    CAUSAL_CONV1D_BROKEN = True\n    _clear_causal_conv1d_modules()\n    _install_causal_conv1d_blocker()\n    _disable_transformers_causal_conv1d()\n    print(\n        \"Unsloth: Detected broken causal_conv1d binary; \"\n        \"disabling causal_conv1d fast path and continuing import.\"\n    )\n"
  },
  {
    "path": "unsloth/kernels/__init__.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .cross_entropy_loss import (\n    fast_cross_entropy_loss,\n    post_patch_loss_function,\n    patch_loss_functions,\n)\nfrom .rms_layernorm import (\n    fast_rms_layernorm,\n    patch_rms_layernorm,\n    unpatch_rms_layernorm,\n)\nfrom .layernorm import (\n    fast_layernorm,\n    patch_layernorm,\n)\nfrom .rope_embedding import fast_rope_embedding, inplace_rope_embedding\nfrom .swiglu import swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel\nfrom .geglu import (\n    geglu_exact_forward_kernel,\n    geglu_exact_backward_kernel,\n    geglu_approx_forward_kernel,\n    geglu_approx_backward_kernel,\n)\nfrom .fast_lora import (\n    get_lora_parameters,\n    get_lora_parameters_bias,\n    apply_lora_mlp_swiglu,\n    apply_lora_mlp_geglu_exact,\n    apply_lora_mlp_geglu_approx,\n    apply_lora_qkv,\n    apply_lora_o,\n    fast_lora_forward,\n)\nfrom .fp8 import *  # This step is to ensure that we patch the FbgmemFP8Linear and FP8Linear's forward functions before the execution of model creation so that this applies to compiled non fast inference models as well\nfrom .utils import (\n    fast_dequantize,\n    fast_gemv,\n    QUANT_STATE,\n    fast_linear_forward,\n    matmul_lora,\n)\n\nfrom .flex_attention import (\n    HAS_FLEX_ATTENTION,\n    slow_attention_softcapping,\n    slow_inference_attention_softcapping,\n    create_flex_attention_causal_mask,\n    create_flex_attention_sliding_window_mask,\n)\n\nimport os\n\nif \"UNSLOTH_ZOO_IS_PRESENT\" not in os.environ:\n    try:\n        print(\n            \"🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.\"\n        )\n    except:\n        print(\"Unsloth: Will patch your computer to enable 2x faster free finetuning.\")\ndel os\n"
  },
  {
    "path": "unsloth/kernels/cross_entropy_loss.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport triton\nimport triton.language as tl\nimport torch\nfrom .utils import (\n    calculate_settings,\n    MAX_FUSED_SIZE,\n    triton_tanh,\n    triton_cast,\n    torch_gpu_device,\n    is_cdna,\n)\nfrom transformers.models.llama.modeling_llama import logger\nfrom unsloth_zoo.utils import Version\n\nfrom unsloth_zoo.loss_utils import (\n    patch_loss_functions as _patch_loss_functions,\n    post_patch_loss_function,\n)\n\n\ndef _cross_entropy_forward(\n    logits_ptr,\n    logits_row_stride,\n    loss_ptr,\n    logsumexp_ptr,\n    labels_ptr,\n    VOCAB_SIZE: tl.constexpr,\n    BLOCK_SIZE: tl.constexpr,\n    DO_SOFTCAPPING: tl.constexpr,\n    SOFTCAP: tl.constexpr,\n    DO_LOGIT_SCALING: tl.constexpr,\n    LOGIT_SCALE: tl.constexpr,\n):\n    \"\"\"\n    Cross Entropy Loss = 1/n sum [ -yi log(Pi) ]\n    Pi = exp(xi) / sum(exp(xi))\n    CE_i = -y log(p) = -y log[ exp(x) / sum(exp(x)) ]\n         = -y [ x - log[sum(exp(x))] ]\n         = y * (log[sum(exp(x))] - x)\n    If y == 0: CE_i = 0\n    If y == 1: CE_i = logsumexp - x\n\n    logsumexp is also stable\n    Take    y =         log[sum(exp(x))]\n       exp(y) =             sum(exp(x))\n       exp(y) =             sum(exp(x - c)*exp(c)) Since e^(x-c)*e^c = e^x\n       exp(y) =      exp(c)*sum(exp(x - c))\n           y  = log(exp(c)*sum(exp(x - c)))\n           y  = c + log[sum(exp(x - c))]\n    This means we can set c = max(x) to make sure\n    exp(x - c) always is exp(x - max(x)).\n    This ensures exp(x - max(x))'s maximum is 1 as exp(0) = 1.\n    \"\"\"\n    row_idx = tl.program_id(0)\n    logits_ptr += row_idx * triton_cast(logits_row_stride, tl.int64)\n    loss_ptr += row_idx\n    logsumexp_ptr += row_idx\n    labels_ptr += row_idx\n\n    col_offsets = tl.arange(0, BLOCK_SIZE)\n    mask = col_offsets < VOCAB_SIZE\n\n    label_idx = tl.load(labels_ptr).to(tl.int32)\n    logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float(\"inf\")).to(\n        tl.float32\n    )\n\n    # Go logit scaling for Cohere: t * x\n    if DO_LOGIT_SCALING:\n        logits = LOGIT_SCALE * logits\n    # Do logit softcapping for Gemma 2: t * tanh(1/t * x)\n    if DO_SOFTCAPPING:\n        logits = SOFTCAP * triton_tanh(logits / SOFTCAP)\n\n    c = tl.max(logits, 0)\n    logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))\n\n    if label_idx != -100:\n        x = tl.load(logits_ptr + label_idx).to(tl.float32)\n        # Go logit scaling for Cohere: t * x\n        if DO_LOGIT_SCALING:\n            x = LOGIT_SCALE * x\n        # Do logit softcapping for Gemma 2: t * tanh(1/t * x)\n        if DO_SOFTCAPPING:\n            x = SOFTCAP * triton_tanh(x / SOFTCAP)\n        loss = logsumexp - x\n    else:\n        loss = 0.0\n    tl.store(logsumexp_ptr, logsumexp)\n    tl.store(loss_ptr, loss)\n\n\n_cross_entropy_forward = triton.jit(_cross_entropy_forward)\n_cross_entropy_forward = triton.heuristics(\n    {\n        \"DO_SOFTCAPPING\": lambda args: bool(args[\"DO_SOFTCAPPING\"]),\n        \"DO_LOGIT_SCALING\": lambda args: bool(args[\"DO_LOGIT_SCALING\"]),\n    }\n)(_cross_entropy_forward)\n\n\ndef _chunked_cross_entropy_forward(\n    logits_ptr,\n    logits_row_stride: tl.constexpr,\n    loss_ptr,\n    logsumexp_ptr,\n    labels_ptr,\n    VOCAB_SIZE: tl.constexpr,\n    N_CHUNKS: tl.constexpr,\n    BLOCK_SIZE: tl.constexpr,\n    DO_SOFTCAPPING: tl.constexpr,\n    SOFTCAP: tl.constexpr,\n    DO_LOGIT_SCALING: tl.constexpr,\n    LOGIT_SCALE: tl.constexpr,\n):\n    \"\"\"\n    256K vocab divided in 4 chunks\n\n    |-65536-| |-65536-| |-65536-| |-65536-|\n    |-------| |-------| |-------| |-------|\n    |-------| |-------| |-------| |-------|\n\n    If y == 0: CE_i = 0\n    If y == 1: CE_i = logsumexp - x\n\n    Notice we can do logsumexp for each chunk and then\n    logsumexp[chunk_sum(logsumexp)] == logsumexp\n\n    chunk_sum = log[chunk_sum(logsumexp)]\n              = log[exp(logsumexp(a)) + ... + exp(logsumexp(z))]\n              = log[exp(log[sum(exp(a))]) + ... + exp(log[sum(exp(z))])]\n              = log[sum(exp(a)) + ... + sum(exp(z))]\n              = logsumexp(x)\n\n    This means we can perform a logsumexp for each chunk, then do a\n    final logsumexp reduction!\n\n    Ie do: logsumexp(chunked_logsumexp) - x\n    \"\"\"\n    row_idx = tl.program_id(0)\n    chunk_idx = tl.program_id(1)\n    logits_ptr += row_idx * triton_cast(logits_row_stride, tl.int64)\n    loss_ptr += row_idx\n    logsumexp_ptr += row_idx * N_CHUNKS + chunk_idx\n    labels_ptr += row_idx\n\n    col_offsets = chunk_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n    mask = col_offsets < VOCAB_SIZE\n\n    label_idx = tl.load(labels_ptr).to(tl.int32)\n    logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float(\"inf\")).to(\n        tl.float32\n    )\n\n    # Go logit scaling for Cohere: t * x\n    if DO_LOGIT_SCALING:\n        logits = LOGIT_SCALE * logits\n    # Do logit softcapping for Gemma 2: t * tanh(1/t * x)\n    if DO_SOFTCAPPING:\n        logits = SOFTCAP * triton_tanh(logits / SOFTCAP)\n\n    c = tl.max(logits, 0)\n    logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))\n\n    if chunk_idx == 0:\n        # logsumexp(chunked_logsumexp) - x\n        # Do the -x separately\n        if label_idx != -100:\n            x = tl.load(logits_ptr + label_idx).to(tl.float32)\n            # Go logit scaling for Cohere: t * x\n            if DO_LOGIT_SCALING:\n                x = LOGIT_SCALE * x\n            # Do logit softcapping for Gemma 2: t * tanh(1/t * x)\n            if DO_SOFTCAPPING:\n                x = SOFTCAP * triton_tanh(x / SOFTCAP)\n            loss = -1.0 * x\n        else:\n            loss = 0.0\n        tl.store(loss_ptr, loss)\n    tl.store(logsumexp_ptr, logsumexp)\n\n\n_chunked_cross_entropy_forward = triton.jit(_chunked_cross_entropy_forward)\n_chunked_cross_entropy_forward = triton.heuristics(\n    {\n        \"DO_SOFTCAPPING\": lambda args: bool(args[\"DO_SOFTCAPPING\"]),\n        \"DO_LOGIT_SCALING\": lambda args: bool(args[\"DO_LOGIT_SCALING\"]),\n    }\n)(_chunked_cross_entropy_forward)\n\n\ndef _cross_entropy_backward(\n    logits_ptr,\n    logits_row_stride: tl.constexpr,\n    dloss_ptr,\n    dloss_row_stride: tl.constexpr,\n    logsumexp_ptr,\n    labels_ptr,\n    VOCAB_SIZE: tl.constexpr,\n    BLOCK_SIZE: tl.constexpr,\n    DO_SOFTCAPPING: tl.constexpr,\n    SOFTCAP: tl.constexpr,\n    DO_LOGIT_SCALING: tl.constexpr,\n    LOGIT_SCALE: tl.constexpr,\n):\n    \"\"\"\n    CE_i = -y log(P) = y * (log[sum(exp(x))] - x)\n    dC/dx = d/dx (y * log[sum(exp(x))] - x * y)\n\n    From https://en.wikipedia.org/wiki/LogSumExp\n    d/dx logsumexp = exp(x) / sum(exp(x)) = softmax(x)\n\n    dC/dx = y * exp(x) / sum(exp(x)) - d/dx (x * y)\n    dC/dx = y * exp[ log[exp(x) / sum(exp(x))] ] using x = exp(log(x)) trick\n    dC/dx = y * exp[x - logsumexp] - d/dx (x * y)\n\n    If y == 0: dC/dx = 0\n    If y == 1 and x == label: dC/dlabel = exp[x - logsumexp] - 1\n    If y == 1 and x != label: dC/dx     = exp[x - logsumexp]\n    \"\"\"\n    row_idx = tl.program_id(0)\n    block_idx = tl.program_id(1)\n\n    logits_ptr += row_idx * triton_cast(logits_row_stride, tl.int64)\n    dloss_ptr += row_idx * dloss_row_stride\n    col_offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n    mask = col_offsets < VOCAB_SIZE\n    label_idx = tl.load(labels_ptr + row_idx).to(tl.int32)\n\n    if label_idx != -100:\n        dloss = tl.load(dloss_ptr)\n    else:\n        dloss = 0.0\n\n    x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float(\"inf\")).to(tl.float32)\n\n    # Do logit scaling for Cohere\n    if DO_LOGIT_SCALING:\n        # d/dx [s * x] = s\n        x = x * LOGIT_SCALE\n\n    # Do logit softcapping for Gemma 2: t * tanh(1/t * x)\n    partial = x\n    if DO_SOFTCAPPING:\n        # d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x)\n        partial = triton_tanh(x / SOFTCAP)\n        x = SOFTCAP * partial\n\n    logsumexp = tl.load(logsumexp_ptr + row_idx)\n    y = tl.exp(x - logsumexp)\n    y = tl.where(\n        col_offsets == label_idx,\n        y - 1.0,  # exp(x - logsumexp) - 1\n        y,  # exp(x - logsumexp)\n    )\n\n    if DO_LOGIT_SCALING:\n        # d/dx [s * x] = s\n        y = y * LOGIT_SCALE\n\n    if DO_SOFTCAPPING:\n        # d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x)\n        y = y * (1.0 - partial * partial)\n\n    # If y == 0: dC/dx = 0 ==> we already masked it to be = 0, so dloss = 0.\n    tl.store(logits_ptr + col_offsets, dloss * y, mask = mask)\n\n\n_cross_entropy_backward = triton.jit(_cross_entropy_backward)\n_cross_entropy_backward = triton.heuristics(\n    {\n        \"DO_SOFTCAPPING\": lambda args: bool(args[\"DO_SOFTCAPPING\"]),\n        \"DO_LOGIT_SCALING\": lambda args: bool(args[\"DO_LOGIT_SCALING\"]),\n    }\n)(_cross_entropy_backward)\n\n\nMAX_FUSED_SIZE = 65536  # 2**16\n\n\nclass Fast_CrossEntropyLoss(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx, logits, labels, logit_softcapping: float = 0, logit_scaling: float = 0\n    ):\n        n_rows: int\n        vocab_size: int\n        n_rows, vocab_size = logits.shape\n        device = logits.device\n        labels = labels.to(device)\n\n        div, mod = divmod(vocab_size, MAX_FUSED_SIZE)\n        n_chunks: int = div + (mod != 0)\n        losses = torch.empty(n_rows, dtype = torch.float32, device = device)\n\n        DO_SOFTCAPPING: bool = bool(logit_softcapping != 0)\n        DO_LOGIT_SCALING: bool = bool(logit_scaling != 0)\n\n        BLOCK_SIZE: int\n        num_warps: int\n        if n_chunks == 1:\n            # For small vocabs <= 65336 like Llama, Mistral\n            BLOCK_SIZE, num_warps = calculate_settings(vocab_size)\n            if is_cdna():\n                num_warps = num_warps // 2\n            logsumexp = torch.empty(n_rows, dtype = torch.float32, device = device)\n\n            with torch_gpu_device(device):\n                _cross_entropy_forward[(n_rows,)](\n                    logits,\n                    logits.stride(0),\n                    losses,\n                    logsumexp,\n                    labels,\n                    VOCAB_SIZE = vocab_size,\n                    BLOCK_SIZE = BLOCK_SIZE,\n                    DO_SOFTCAPPING = DO_SOFTCAPPING,\n                    SOFTCAP = logit_softcapping,\n                    DO_LOGIT_SCALING = DO_LOGIT_SCALING,\n                    LOGIT_SCALE = logit_scaling,\n                    num_warps = num_warps,\n                )\n        else:\n            # For large vocabs > 65336 like Gemma 256K\n            logsumexp = torch.empty(\n                (\n                    n_rows,\n                    n_chunks,\n                ),\n                dtype = torch.float32,\n                device = device,\n            )\n\n            with torch_gpu_device(device):\n                _chunked_cross_entropy_forward[\n                    (\n                        n_rows,\n                        n_chunks,\n                    )\n                ](\n                    logits,\n                    logits.stride(0),\n                    losses,\n                    logsumexp,\n                    labels,\n                    VOCAB_SIZE = vocab_size,\n                    N_CHUNKS = n_chunks,\n                    BLOCK_SIZE = MAX_FUSED_SIZE,\n                    DO_SOFTCAPPING = DO_SOFTCAPPING,\n                    SOFTCAP = logit_softcapping,\n                    DO_LOGIT_SCALING = DO_LOGIT_SCALING,\n                    LOGIT_SCALE = logit_scaling,\n                    num_warps = 32 if not is_cdna() else 16,\n                )\n            # logsumexp(chunked_logsumexp) - x\n            # Do the -x separately\n            logsumexp = torch.logsumexp(logsumexp, dim = 1)  # Row sum\n            losses += logsumexp\n            losses.masked_fill_(labels == -100, 0)  # Don't forget to mask padding out!\n\n        ctx.save_for_backward(logits, logsumexp, labels)\n        ctx.DO_SOFTCAPPING = DO_SOFTCAPPING\n        ctx.logit_softcapping = logit_softcapping\n        ctx.DO_LOGIT_SCALING = DO_LOGIT_SCALING\n        ctx.logit_scaling = logit_scaling\n        return losses\n\n    @staticmethod\n    def backward(ctx, dlosses):\n        logits, logsumexp, labels = ctx.saved_tensors\n        n_rows: int\n        vocab_size: int\n        n_rows, vocab_size = logits.shape\n\n        BLOCK_SIZE: int = 4096\n        div: int\n        mod: int\n        div, mod = divmod(vocab_size, BLOCK_SIZE)\n        n_blocks: int = div + (mod != 0)\n\n        with torch_gpu_device(dlosses.device):\n            _cross_entropy_backward[\n                (\n                    n_rows,\n                    n_blocks,\n                )\n            ](\n                logits,\n                logits.stride(0),\n                dlosses,\n                dlosses.stride(0),\n                logsumexp,\n                labels,\n                VOCAB_SIZE = vocab_size,\n                BLOCK_SIZE = BLOCK_SIZE,\n                DO_SOFTCAPPING = ctx.DO_SOFTCAPPING,\n                SOFTCAP = ctx.logit_softcapping,\n                DO_LOGIT_SCALING = ctx.DO_LOGIT_SCALING,\n                LOGIT_SCALE = ctx.logit_scaling,\n                num_warps = 8,\n            )\n        return (\n            logits,\n            None,\n            None,\n            None,\n        )\n\n\ndef fast_cross_entropy_loss(\n    logits,\n    labels,\n    logit_softcapping = 0,\n    logit_scaling = 0,\n    n_items = None,\n):\n    \"\"\"\n    Arguments:\n        logits: (batch, seq_len, vocab_size)\n        labels: (batch, seq_len,)\n    Returns:\n        losses: float\n    \"\"\"\n    batch, seq_len, d = logits.shape\n    assert labels.shape == (batch, seq_len)\n\n    device = logits.device\n    loss = Fast_CrossEntropyLoss.apply(\n        logits.view(batch * seq_len, d),\n        labels.view(-1),\n        logit_softcapping,\n        logit_scaling,\n    )\n    if n_items is None:\n        n_items = torch.count_nonzero(labels != -100)\n    if torch.is_tensor(n_items):\n        n_items = n_items.to(device)\n    return loss.sum() / n_items\n\n\nif (Version(torch.__version__) < Version(\"2.4.0\")) and not hasattr(\n    fast_cross_entropy_loss, \"__wrapped__\"\n):\n    fast_cross_entropy_loss = torch._disable_dynamo(fast_cross_entropy_loss)\n\n\n# Patch CE Losses in transformers\ndef patch_loss_functions(torch_compile = True):\n    _patch_loss_functions(fast_cross_entropy_loss, torch_compile = torch_compile)\n"
  },
  {
    "path": "unsloth/kernels/fast_lora.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nfrom .utils import (\n    _maybe_fake_quantize_activations,\n    fast_dequantize,\n    QUANT_STATE,\n    get_lora_parameters,\n    get_lora_parameters_bias,\n    matmul_lora,\n    torch_amp_custom_fwd,\n    torch_amp_custom_bwd,\n)\n\n\nclass LoRA_MLP(torch.autograd.Function):\n    \"\"\"\n    ### LoRA weights\n    G = G + Ag @ Bg\n    U = U + Au @ Bu\n    W = W + Aw @ Bw\n\n    ### SwiGLU(X)\n    e = X @ G\n    f = e * sigmoid(e)\n    g = X @ U\n    h = f * g\n    i = h @ W\n\n    ### Backpropagation chain rule\n    See our blog post for more details\n\n    df = sigmoid(e) * (1 - f) + f\n    dC/dW = h.T @ dY\n    dC/dU = X.T @ (D @ W.T * f)\n    dC/dG = X.T @ (D @ W.T * df * g)\n\n    ### Down projection LoRA weights\n    dC/dAw = dC/dW @ B.T\n    dC/dBw = A.T @ dC/dW\n    dC/dAw =       h.T @ dY @ B.T\n    dC/dBw = A.T @ h.T @ dY\n\n    ### Up projection LoRA weights\n    dC/dAu =       X.T @ (D @ W.T * f) @ B.T\n    dC/dBu = A.T @ X.T @ (D @ W.T * f)\n\n    ### Gate projection LoRA weights\n    dC/dAg =       X.T @ (D @ W.T * df * g) @ B.T\n    dC/dBg = A.T @ X.T @ (D @ W.T * df * g)\n\n    Don't forget to see our blog post for more details!\n    \"\"\"\n\n    @staticmethod\n    @torch_amp_custom_fwd\n    def forward(\n        ctx,\n        X: torch.Tensor,\n        gateW,\n        gateW_quant,\n        gateA,\n        gateB,\n        gateS,\n        upW,\n        upW_quant,\n        upA,\n        upB,\n        upS,\n        downW,\n        downW_quant,\n        downA,\n        downB,\n        downS,\n        _forward_function,\n        _backward_function,\n        inplace = True,\n    ):\n        dtype = X.dtype\n\n        e = matmul_lora(X, gateW, gateW_quant, gateA, gateB, gateS)\n        g = matmul_lora(X, upW, upW_quant, upA, upB, upS)\n        h = _forward_function(e, g)\n        i = matmul_lora(h, downW, downW_quant, downA, downB, downS)\n\n        ctx.custom_saved_tensors = (\n            gateW,\n            gateW_quant,\n            gateS,\n            upW,\n            upW_quant,\n            upS,\n            downW,\n            downW_quant,\n            downS,\n            _backward_function,\n        )\n        ctx.save_for_backward(gateA, gateB, upA, upB, downA, downB, X, e, g)\n        ctx.inplace = inplace\n        return i\n\n    @staticmethod\n    @torch_amp_custom_bwd\n    def backward(ctx, dY: torch.Tensor):\n        (\n            gateW,\n            gateW_quant,\n            gateS,\n            upW,\n            upW_quant,\n            upS,\n            downW,\n            downW_quant,\n            downS,\n            _backward_function,\n        ) = ctx.custom_saved_tensors\n        gateA, gateB, upA, upB, downA, downB, X, e, g = ctx.saved_tensors\n\n        batch, seq_len, hd = X.shape\n        dY = dY.view(-1, dY.shape[-1])\n        X = X.view(-1, X.shape[-1])\n        e = e.view(-1, e.shape[-1])\n        g = g.view(-1, g.shape[-1])\n        dtype = X.dtype\n\n        gateA, gateB, upA, upB, downA, downB = (\n            gateA.to(dtype),\n            gateB.to(dtype),\n            upA.to(dtype),\n            upB.to(dtype),\n            downA.to(dtype),\n            downB.to(dtype),\n        )\n\n        gateA, gateB, upA, upB, downA, downB = (\n            gateA.t(),\n            gateB.t(),\n            upA.t(),\n            upB.t(),\n            downA.t(),\n            downB.t(),\n        )\n\n        DW = matmul_lora(dY, downW.t(), downW_quant, downB, downA, downS)\n        DW, e, g = _backward_function(DW, e, g)\n        h, df, de = DW, e, g\n\n        d_downA = torch.empty_like(downA)\n        d_downB = torch.empty_like(downB)\n        d_gateA = torch.empty_like(gateA)\n        d_gateB = torch.empty_like(gateB)\n        d_upA = torch.empty_like(upA)\n        d_upB = torch.empty_like(upB)\n\n        # Down projection LoRA weights\n        # d_downA = h.t() @ (dY @ downB.t())\n        # d_downB = (downA.t() @ h.t()) @ dY\n        # d_downA *= downS\n        # d_downB *= downS\n        d_downA.addmm_(h.t(), dY @ downB.t(), alpha = downS, beta = 0)\n        d_downB.addmm_(downA.t() @ h.t(), dY, alpha = downS, beta = 0)\n\n        # Up projection LoRA weights\n        # d_upA   = X.t() @ (df @ upB.t())\n        # d_upB   = (upA.t() @ X.t()) @ df\n        # d_upA  *= upS\n        # d_upB  *= upS\n        d_upA.addmm_(X.t(), df @ upB.t(), alpha = upS, beta = 0)\n        d_upB.addmm_(upA.t() @ X.t(), df, alpha = upS, beta = 0)\n\n        # Gate projection LoRA weights\n        # d_gateA = X.t() @ (de @ gateB.t())\n        # d_gateB = (gateA.t() @ X.t()) @ de\n        # d_gateA *= gateS\n        # d_gateB *= gateS\n        d_gateA.addmm_(X.t(), de @ gateB.t(), alpha = gateS, beta = 0)\n        d_gateB.addmm_(gateA.t() @ X.t(), de, alpha = gateS, beta = 0)\n\n        # dX  = matmul_lora(df, upW.t(), upW_quant, upB, upA, upS)\n        # dX += matmul_lora(de, gateW.t(), gateW_quant, gateB, gateA, gateS)\n        upW = fast_dequantize(upW.t(), upW_quant)\n        dX = torch.matmul(df, upW.t(), out = X if ctx.inplace else None)\n        del upW\n        # dX += df @ upB.to(dtype).t() @ (upS * upA.to(dtype).t())\n        dX.addmm_(df @ upB.t(), upA.t(), alpha = upS)\n\n        gateW = fast_dequantize(gateW.t(), gateW_quant)\n        # dX += de @ gateW.t()\n        dX.addmm_(de, gateW.t())\n        del gateW\n        # dX += de @ gateB.to(dtype).t() @ (gateS * gateA.to(dtype).t())\n        dX.addmm_(de @ gateB.t(), gateA.t(), alpha = gateS)\n\n        # gateW, gateW_quant, gateA, gateB, gateS,\n        #  upW,    upW_quant,   upA,   upB,   upS,\n        # downW, downW_quant, downA, downB, downS,\n        return (\n            dX.view(batch, seq_len, hd),\n            None,\n            None,\n            d_gateA.t(),\n            d_gateB.t(),\n            None,\n            None,\n            None,\n            d_upA.t(),\n            d_upB.t(),\n            None,\n            None,\n            None,\n            d_downA.t(),\n            d_downB.t(),\n            None,\n            None,\n            None,\n            None,\n        )  # _backward and _forward and inplace\n\n\nfrom .swiglu import swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel\n\n\ndef apply_lora_mlp_swiglu(self, X, inplace = True):\n    X = _maybe_fake_quantize_activations(X, self.gate_proj)\n    gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)\n    upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)\n    downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)\n    out = LoRA_MLP.apply(\n        X,\n        gateW,\n        gateW_quant,\n        gateA,\n        gateB,\n        gateS,\n        upW,\n        upW_quant,\n        upA,\n        upB,\n        upS,\n        downW,\n        downW_quant,\n        downA,\n        downB,\n        downS,\n        swiglu_fg_kernel,\n        swiglu_DWf_DW_dfg_kernel,\n        inplace,\n    )\n    return out\n\n\nfrom .geglu import geglu_exact_forward_kernel, geglu_exact_backward_kernel\n\n\ndef apply_lora_mlp_geglu_exact(self, X, inplace = True):\n    X = _maybe_fake_quantize_activations(X, self.gate_proj)\n    gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)\n    upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)\n    downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)\n    out = LoRA_MLP.apply(\n        X,\n        gateW,\n        gateW_quant,\n        gateA,\n        gateB,\n        gateS,\n        upW,\n        upW_quant,\n        upA,\n        upB,\n        upS,\n        downW,\n        downW_quant,\n        downA,\n        downB,\n        downS,\n        geglu_exact_forward_kernel,\n        geglu_exact_backward_kernel,\n        inplace,\n    )\n    return out\n\n\nfrom .geglu import geglu_approx_forward_kernel, geglu_approx_backward_kernel\n\n\ndef apply_lora_mlp_geglu_approx(self, X):\n    X = _maybe_fake_quantize_activations(X, self.gate_proj)\n    gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)\n    upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)\n    downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)\n    out = LoRA_MLP.apply(\n        X,\n        gateW,\n        gateW_quant,\n        gateA,\n        gateB,\n        gateS,\n        upW,\n        upW_quant,\n        upA,\n        upB,\n        upS,\n        downW,\n        downW_quant,\n        downA,\n        downB,\n        downS,\n        geglu_approx_forward_kernel,\n        geglu_approx_backward_kernel,\n    )\n    return out\n\n\nclass LoRA_QKV(torch.autograd.Function):\n    \"\"\"\n    ### LoRA weights\n    Wq = Wq + Aq @ Bq\n    Wk = Wk + Ak @ Bk\n    Wv = Wv + Av @ Bv\n    Q = X @ Wq = X @ Wq + X @ Aq @ Bq\n    K = X @ Wk = X @ Wk + X @ Ak @ Bk\n    V = X @ Wv = X @ Wv + X @ Av @ Bv\n\n    ### Backpropagation chain rule\n    See our blogpost for more details.\n\n    dC/dWq = X.T @ D(Wq)\n    dC/dWk = X.T @ D(Wk)\n    dC/dWv = X.T @ D(Wv)\n    We then sum them all find dC/dX\n\n    ### Q projection LoRA weights\n    dC/dAq =       X.T @ D(Wq) @ B.T\n    dC/dBq = A.T @ X.T @ D(Wq)\n\n    ### K projection LoRA weights\n    dC/dAk =       X.T @ D(Wk) @ B.T\n    dC/dBk = A.T @ X.T @ D(Wk)\n\n    ### V projection LoRA weights\n    dC/dAv =       X.T @ D(Wv) @ B.T\n    dC/dBv = A.T @ X.T @ D(Wv)\n    \"\"\"\n\n    @staticmethod\n    @torch_amp_custom_fwd\n    def forward(\n        ctx,\n        X: torch.Tensor,\n        QW,\n        QW_quant,\n        QA,\n        QB,\n        QS,\n        KW,\n        KW_quant,\n        KA,\n        KB,\n        KS,\n        VW,\n        VW_quant,\n        VA,\n        VB,\n        VS,\n        inplace = True,\n    ):\n        dtype = X.dtype\n\n        # bitsandbytes 8-bit matmul expects 2D inputs.\n        # TorchInductor/AOTAutograd fails on 3D tensors during backward,\n        # so we explicitly flatten the sequence dimension.\n        orig_shape = X.shape\n        X_for_matmul = X\n        if X.dim() == 3:\n            X_for_matmul = X.view(-1, X.shape[-1])\n        Q = matmul_lora(X_for_matmul, QW, QW_quant, QA, QB, QS)\n        K = matmul_lora(X_for_matmul, KW, KW_quant, KA, KB, KS)\n        V = matmul_lora(X_for_matmul, VW, VW_quant, VA, VB, VS)\n\n        # Restore original shape after matmul\n        if len(orig_shape) == 3:\n            Q = Q.view(orig_shape[0], orig_shape[1], -1)\n            K = K.view(orig_shape[0], orig_shape[1], -1)\n            V = V.view(orig_shape[0], orig_shape[1], -1)\n\n        ctx.custom_saved_tensors = (\n            QW,\n            QW_quant,\n            QS,\n            KW,\n            KW_quant,\n            KS,\n            VW,\n            VW_quant,\n            VS,\n        )\n        ctx.save_for_backward(\n            X,\n            QA,\n            QB,\n            KA,\n            KB,\n            VA,\n            VB,\n        )\n        ctx.inplace = inplace\n        return Q, K, V\n\n    @staticmethod\n    @torch_amp_custom_bwd\n    def backward(ctx, dQ, dK, dV):\n        QW, QW_quant, QS, KW, KW_quant, KS, VW, VW_quant, VS = ctx.custom_saved_tensors\n        (\n            X,\n            QA,\n            QB,\n            KA,\n            KB,\n            VA,\n            VB,\n        ) = ctx.saved_tensors\n\n        batch, seq_len, hd = X.shape\n        dQ = dQ.view(-1, dQ.shape[-1])\n        dK = dK.reshape(-1, dK.shape[-1])  # view doesn't work on K.T\n        dV = dV.view(-1, dV.shape[-1])\n        X = X.view(-1, X.shape[-1])\n        dtype = X.dtype\n\n        QA, QB, KA, KB, VA, VB = (\n            QA.to(dtype),\n            QB.to(dtype),\n            KA.to(dtype),\n            KB.to(dtype),\n            VA.to(dtype),\n            VB.to(dtype),\n        )\n\n        QA, QB, KA, KB, VA, VB = QA.t(), QB.t(), KA.t(), KB.t(), VA.t(), VB.t()\n\n        ### Weight projection LoRA weights\n        # See our blogpost for more details.\n        d_QA = torch.empty_like(QA)\n        d_QB = torch.empty_like(QB)\n        d_KA = torch.empty_like(KA)\n        d_KB = torch.empty_like(KB)\n        d_VA = torch.empty_like(VA)\n        d_VB = torch.empty_like(VB)\n\n        # Q Projection\n        # d_QA = X.t() @ (dQ @ QB.t())\n        # d_QB = (QA.t() @ X.t()) @ dQ\n        # d_QA *= QS\n        # d_QB *= QS\n        d_QA.addmm_(X.t(), dQ @ QB.t(), alpha = QS, beta = 0)\n        d_QB.addmm_(QA.t() @ X.t(), dQ, alpha = QS, beta = 0)\n\n        # K Projection\n        # d_KA = X.t() @ (dK @ KB.t())\n        # d_KB = (KA.t() @ X.t()) @ dK\n        # d_KA *= KS\n        # d_KB *= KS\n        d_KA.addmm_(X.t(), dK @ KB.t(), alpha = KS, beta = 0)\n        d_KB.addmm_(KA.t() @ X.t(), dK, alpha = KS, beta = 0)\n\n        # V Projection\n        # d_VA = X.t() @ (dV @ VB.t())\n        # d_VB = (VA.t() @ X.t()) @ dV\n        # d_VA *= VS\n        # d_VB *= VS\n        d_VA.addmm_(X.t(), dV @ VB.t(), alpha = VS, beta = 0)\n        d_VB.addmm_(VA.t() @ X.t(), dV, alpha = VS, beta = 0)\n\n        # Combine derivatives to find dX\n        # dQ\n        QW = fast_dequantize(QW.t(), QW_quant)\n        dX = torch.matmul(dQ, QW.t(), out = X if ctx.inplace else None)\n        del QW\n        # dX += (dQ @ QB.to(dtype).t() @ (QS * QA.to(dtype).t()))\n        dX.addmm_(dQ @ QB.t(), QA.t(), alpha = QS)\n\n        # dK\n        KW = fast_dequantize(KW.t(), KW_quant)\n        # dX += dK @ KW.t()\n        dX.addmm_(dK, KW.t())\n        del KW\n        # dX += dK @ KB.to(dtype).t() @ (KS * KA.to(dtype).t())\n        dX.addmm_(dK @ KB.t(), KA.t(), alpha = KS)\n\n        # dV\n        VW = fast_dequantize(VW.t(), VW_quant)\n        # dX += dV @ VW.t()\n        dX.addmm_(dV, VW.t())\n        del VW\n        # dX += dV @ VB.to(dtype).t() @ (VS * VA.to(dtype).t())\n        dX.addmm_(dV @ VB.t(), VA.t(), alpha = VS)\n\n        # QW, QW_quant, QA, QB, QS,\n        # KW, KW_quant, KA, KB, KS,\n        # VW, VW_quant, VA, VB, VS,\n        return (\n            dX.view(batch, seq_len, hd),\n            None,\n            None,\n            d_QA.t(),\n            d_QB.t(),\n            None,\n            None,\n            None,\n            d_KA.t(),\n            d_KB.t(),\n            None,\n            None,\n            None,\n            d_VA.t(),\n            d_VB.t(),\n            None,\n            None,\n        )\n\n\ndef apply_lora_qkv(self, X, inplace = True):\n    X = _maybe_fake_quantize_activations(X, self.q_proj)\n    QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)\n    KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)\n    VW, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj)\n    Q, K, V = LoRA_QKV.apply(\n        X,\n        QW,\n        QW_quant,\n        QA,\n        QB,\n        QS,\n        KW,\n        KW_quant,\n        KA,\n        KB,\n        KS,\n        VW,\n        VW_quant,\n        VA,\n        VB,\n        VS,\n        inplace,\n    )\n    return Q, K, V\n\n\nclass LoRA_W(torch.autograd.Function):\n    \"\"\"\n    ### LoRA weights\n    Wq = Wq + Aq @ Bq\n    Wk = Wk + Ak @ Bk\n    Wv = Wv + Av @ Bv\n    Q = X @ Wq = X @ Wq + X @ Aq @ Bq\n    K = X @ Wk = X @ Wk + X @ Ak @ Bk\n    V = X @ Wv = X @ Wv + X @ Av @ Bv\n\n    ### Backpropagation chain rule\n    dC/dWq = X.T @ D(Wq)\n    dC/dWk = X.T @ D(Wk)\n    dC/dWv = X.T @ D(Wv)\n\n    ### Q projection LoRA weights\n    dC/dAq =       X.T @ D(Wq) @ B.T\n    dC/dBq = A.T @ X.T @ D(Wq)\n\n    ### K projection LoRA weights\n    dC/dAk =       X.T @ D(Wk) @ B.T\n    dC/dBk = A.T @ X.T @ D(Wk)\n\n    ### V projection LoRA weights\n    dC/dAv =       X.T @ D(Wv) @ B.T\n    dC/dBv = A.T @ X.T @ D(Wv)\n    \"\"\"\n\n    @staticmethod\n    @torch_amp_custom_fwd\n    def forward(ctx, X: torch.Tensor, W, W_quant, A, B, S):\n        dtype = X.dtype\n        XW = matmul_lora(X, W, W_quant, A, B, S)\n        ctx.custom_saved_tensors = (\n            W,\n            W_quant,\n            S,\n        )\n        ctx.save_for_backward(A, B, X)\n        return XW\n\n    @staticmethod\n    @torch_amp_custom_bwd\n    def backward(ctx, dY: torch.Tensor):\n        W, W_quant, S = ctx.custom_saved_tensors\n        A, B, X = ctx.saved_tensors\n\n        batch, seq_len, hd = X.shape\n        dY = dY.reshape(-1, dY.shape[-1])  # Must be reshape\n        X = X.reshape(-1, X.shape[-1])  # Must be reshape\n        dtype = X.dtype\n\n        A, B = A.to(dtype), B.to(dtype)\n\n        A, B = A.t(), B.t()\n\n        d_A = torch.empty_like(A)\n        d_B = torch.empty_like(B)\n\n        ### Weight projection LoRA weights\n        # Weight projection\n        # d_A = X.t() @ (dY @ B.t())\n        # d_B = (A.t() @ X.t()) @ dY\n        # d_A *= S\n        # d_B *= S\n        d_A.addmm_(X.t(), dY @ B.t(), alpha = S, beta = 0)\n        d_B.addmm_(A.t() @ X.t(), dY, alpha = S, beta = 0)\n\n        # Get derivative for dX\n        W = fast_dequantize(W.t(), W_quant)\n        dX = dY @ W.t()\n        del W\n        # dX += dY @ B.to(dtype).t() @ (S * A.to(dtype).t())\n        dX.addmm_(dY @ B.t(), A.t(), alpha = S)\n\n        # W, W_quant, A, B, S\n        return dX.view(batch, seq_len, hd), None, None, d_A.t(), d_B.t(), None\n\n\ndef apply_lora_o(self, X):\n    X = _maybe_fake_quantize_activations(X, self.o_proj)\n    OW, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj)\n    O = LoRA_W.apply(X, OW, OW_quant, OA, OB, OS)\n    return O\n\n\nIDENTITY_DROPOUT = torch.nn.Identity\n\n\n@torch._disable_dynamo\ndef fast_lora_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:\n    raise NotImplementedError(\n        \"Unsloth: Currently not supported yet - reshaping done incorrectly\"\n    )\n    self._check_forward_args(x, *args, **kwargs)\n    adapter_names = kwargs.pop(\"adapter_names\", None)\n\n    if self.disable_adapters:\n        if self.merged:\n            self.unmerge()\n        result = self.base_layer(x, *args, **kwargs)\n    elif adapter_names is not None:\n        result = self._mixed_batch_forward(\n            x, *args, adapter_names = adapter_names, **kwargs\n        )\n    elif self.merged:\n        result = self.base_layer(x, *args, **kwargs)\n    else:\n        # Fastpath\n        if len(self.active_adapters) == 1:\n            active_adapter = self.active_adapters[0]\n            if active_adapter not in self.lora_A.keys():\n                return self.base_layer(x, *args, **kwargs)\n\n            dropout = self.lora_dropout[active_adapter]\n            if (\n                isinstance(dropout, IDENTITY_DROPOUT)\n                and not self.use_dora[active_adapter]\n            ):\n                lora_A = self.lora_A[active_adapter].weight\n                lora_B = self.lora_B[active_adapter].weight\n                scaling = self.scaling[active_adapter]\n                W = self.base_layer.weight\n                return LoRA_W.apply(x, W, QUANT_STATE(W), lora_A, lora_B, scaling)\n            pass\n        pass\n\n        result = self.base_layer(x, *args, **kwargs)\n        # As per Tim Dettmers, for 4bit, we need to defensively clone here.\n        # The reason is that in some cases, an error can occur that backprop\n        # does not work on a manipulated view. This issue may be solved with\n        # newer PyTorch versions but this would need extensive testing to be\n        # sure.\n        result = result.clone()\n\n        for active_adapter in self.active_adapters:\n            if active_adapter not in self.lora_A.keys():\n                continue\n            lora_A = self.lora_A[active_adapter]\n            lora_B = self.lora_B[active_adapter]\n            dropout = self.lora_dropout[active_adapter]\n            scaling = self.scaling[active_adapter]\n\n            requires_conversion = not torch.is_autocast_enabled()\n            if requires_conversion:\n                expected_dtype = result.dtype\n                x = x.to(lora_A.weight.dtype)\n\n            if not self.use_dora[active_adapter]:\n                result = result + lora_B(lora_A(dropout(x))) * scaling\n            else:\n                if isinstance(dropout, torch.nn.Identity) or not self.training:\n                    base_result = result\n                else:\n                    x = dropout(x)\n                    base_result = None\n\n                result = result + self.lora_magnitude_vector[active_adapter](\n                    x,\n                    lora_A = lora_A,\n                    lora_B = lora_B,\n                    scaling = scaling,\n                    base_layer = self.get_base_layer(),\n                    base_result = base_result,\n                )\n            if requires_conversion:\n                result = result.to(expected_dtype)\n\n    return result\n"
  },
  {
    "path": "unsloth/kernels/flex_attention.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nfrom functools import lru_cache\nfrom transformers.models.llama.modeling_llama import logger\nimport os\n\ntorch_compile_options = {\n    \"epilogue_fusion\": True,\n    \"max_autotune\": True,\n    \"shape_padding\": True,\n    \"trace.enabled\": os.environ.get(\"UNSLOTH_COMPILE_DEBUG\", \"0\") == \"1\",\n    \"triton.cudagraphs\": False,\n}\n\n# Flex Attention supported from torch 2.5 onwards only\ntry:\n    from torch.nn.attention.flex_attention import (\n        flex_attention as _flex_attention,\n        create_block_mask as _create_block_mask,\n    )\n\n    _flex_attention = torch.compile(\n        _flex_attention, dynamic = True, options = torch_compile_options\n    )\n    HAS_FLEX_ATTENTION = False\nexcept:\n    HAS_FLEX_ATTENTION = False\n\n\nif not HAS_FLEX_ATTENTION:\n    # Logit softcapping\n    @torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)\n    def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):\n        n_heads = self.config.num_attention_heads\n        head_dim = self.head_dim\n        n_kv_heads = self.config.num_key_value_heads\n        n_groups = self.num_key_value_groups\n\n        # Grouped query attention\n        K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)\n        V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)\n        K = K.reshape(bsz, n_heads, q_len, head_dim)\n        V = V.reshape(bsz, n_heads, q_len, head_dim)\n\n        # See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e\n        # Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below\n        # We default to using the config file itself\n        # s = self.config.hidden_size // self.config.num_attention_heads\n        s = self.config.query_pre_attn_scalar\n        t = self.config.attn_logit_softcapping\n\n        Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype)  # Follow Keras exactly\n        A = torch.matmul(Q, K.transpose(2, 3))\n        A = t * torch.tanh(A / t)  # Logit softcapping\n        A += causal_mask[:q_len, :q_len]\n        # Much slower in torch compile!\n        # A.masked_fill_(causal_mask[:q_len, :q_len], -float(\"inf\"))\n        A = torch.nn.functional.softmax(A, dim = -1, dtype = torch.float32).to(Q.dtype)\n        A = torch.matmul(A, V)\n        A = A.transpose(1, 2).contiguous()\n        A = A.reshape(bsz, q_len, n_heads * head_dim)\n        return A\n\n    create_flex_attention_causal_mask = None\n    create_flex_attention_sliding_window_mask = None\nelse:\n    # See https://github.com/pytorch-labs/attention-gym/blob/main/examples/flex_attn.ipynb\n    # for more examples\n    # BSD 3-Clause License Copyright (c) 2023, Driss Guessous, Horace He et al\n    import functools, math\n\n    def generate_tanh_softcap(t):\n        def tanh_softcap(x, b, h, q_idx, kv_idx):\n            return t * torch.tanh(x / t)\n\n        return tanh_softcap\n\n    def causal_masker(b, h, q_idx, kv_idx):\n        return q_idx >= kv_idx\n\n    @functools.lru_cache\n    def sliding_window_masker(size = 4096):\n        def sliding_window(b, h, q_idx, kv_idx):\n            causal_mask = q_idx >= kv_idx\n            window_mask = q_idx - kv_idx <= size\n            return causal_mask & window_mask\n\n        return sliding_window\n\n    @functools.lru_cache\n    def create_block_mask(mask, n = 128):\n        return _create_block_mask(\n            mask,\n            1,\n            1,\n            n,\n            n,\n            BLOCK_SIZE = 128,\n            _compile = True,\n        )\n\n    def create_flex_attention_causal_mask(max_seq_length = 8192):\n        causal_mask = create_block_mask(causal_masker, max_seq_length)\n        return causal_mask\n\n    def create_flex_attention_sliding_window_mask(\n        max_seq_length = 8192, sliding_window = 4096\n    ):\n        sliding_masker = sliding_window_masker(sliding_window)\n        causal_mask = create_block_mask(sliding_masker, max_seq_length)\n        return causal_mask\n\n    @functools.lru_cache\n    def flex_attention(s, t):\n        scale = 1.0 / math.sqrt(s)\n        score_mod = generate_tanh_softcap(t)\n        return functools.partial(\n            _flex_attention,\n            score_mod = score_mod,\n            scale = scale,\n            enable_gqa = True,\n        )\n\n    def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):\n        n_heads = self.config.num_attention_heads\n        head_dim = self.head_dim\n        s = self.config.query_pre_attn_scalar\n        t = self.config.attn_logit_softcapping\n        fx = flex_attention(s, t)\n        A = fx(query = Q, key = K, value = V, block_mask = causal_mask)\n        A = A.transpose(1, 2).contiguous()\n        A = A.reshape(bsz, q_len, n_heads * head_dim)\n        return A\n\n\ntorch_matmul = torch.matmul\ntorch_tanh = torch.tanh\ntorch_nn_functional_softmax = torch.nn.functional.softmax\n\n\ndef slow_inference_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):\n    n_heads = self.config.num_attention_heads\n    head_dim = self.head_dim\n    n_kv_heads = self.config.num_key_value_heads\n    n_groups = self.num_key_value_groups\n\n    # Grouped query attention\n    K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)\n    V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)\n    K = K.reshape(bsz, n_heads, q_len, head_dim)\n    V = V.reshape(bsz, n_heads, q_len, head_dim)\n\n    # See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e\n    # Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below\n    # We default to using the config file itself\n    # s = self.config.hidden_size // self.config.num_attention_heads\n    s = self.config.query_pre_attn_scalar\n    t = self.config.attn_logit_softcapping\n\n    Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype)  # Follow Keras exactly\n    A = torch_matmul(Q, K.transpose(2, 3))\n\n    # Logit softcapping\n    A /= t\n    torch_tanh(A, out = A)\n    A *= t\n    A += causal_mask[:q_len, :q_len]\n    # Much slower in torch compile!\n    # A.masked_fill_(causal_mask[:q_len, :q_len], -float(\"inf\"))\n    A = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32).to(Q.dtype)\n    A = torch_matmul(A, V)\n    A = A.transpose(1, 2).contiguous()\n    A = A.reshape(bsz, q_len, n_heads * head_dim)\n    return A\n"
  },
  {
    "path": "unsloth/kernels/fp8.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport os\nimport torch\nimport torch.nn as nn\nimport triton\nimport triton.language as tl\nfrom torch.nn import functional as F\nimport math\nfrom unsloth_zoo.utils import Version\nfrom unsloth_zoo.log import logger\nfrom unsloth_zoo.temporary_patches.common import torch_compile\n\ntorch_matmul = torch.matmul\n\ntry:\n    from transformers.integrations.finegrained_fp8 import FP8Linear\nexcept:\n    FP8Linear = None\n    logger.info(\n        \"Unsloth: FP8 models need importing FP8Linear from `transformers.integrations.finegrained_fp8` but we don't see it.\"\n    )\n\ntry:\n    from transformers.integrations.fbgemm_fp8 import FbgemmFp8Linear\nexcept:\n    FbgemmFp8Linear = None\n    logger.info(\n        \"Unsloth: FP8 models need importing FbgemmFP8Linear from `transformers.integrations.fbgemm_fp8` but we don't see it.\"\n    )\n\ntry:\n    from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import (\n        triton_quantize_fp8_block,\n    )\nexcept:\n    triton_quantize_fp8_block = None\n    logger.info(\n        \"Unsloth: Could not find fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm.triton_quantize_fp8_block\"\n    )\n\ntry:\n    from torchao.prototype.blockwise_fp8_inference.blockwise_quantization import (\n        blockwise_fp8_gemm as torchao_blockwise_gemm,\n    )\nexcept:\n    torchao_blockwise_gemm = None\n    logger.info(\n        \"Unsloth: Could not find torchao.prototype.blockwise_fp8_inference.blockwise_quantization.blockwise_fp8_gemm\"\n    )\n\n\n@triton.jit\ndef weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):\n    pid_m = tl.program_id(axis = 0)\n    pid_n = tl.program_id(axis = 1)\n    n = tl.cdiv(N, BLOCK_SIZE)\n    offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n    offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n    offs = offs_m[:, None] * N + offs_n[None, :]\n    mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)\n    x = tl.load(x_ptr + offs, mask = mask).to(tl.float32)\n    s = tl.load(s_ptr + pid_m * n + pid_n)\n    y = x * s\n    tl.store(y_ptr + offs, y, mask = mask)\n\n\ndef weight_dequant_block(\n    x: torch.Tensor, s: torch.Tensor, block_size: int = 128, dtype = torch.bfloat16\n) -> torch.Tensor:\n    if not x.is_contiguous():\n        x = x.contiguous()\n    if not s.is_contiguous():\n        s = s.contiguous()\n    assert x.dim() == 2 and s.dim() == 2\n    M, N = x.size()\n    y = torch.empty_like(x, dtype = dtype)\n    grid = lambda meta: (\n        triton.cdiv(M, meta[\"BLOCK_SIZE\"]),\n        triton.cdiv(N, meta[\"BLOCK_SIZE\"]),\n    )\n    weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE = block_size)\n    return y\n\n\ndef weight_dequant(x: torch.Tensor, s: torch.Tensor, dtype = torch.bfloat16):\n    # Per-tensor scale: single value for entire weight matrix\n    if s.numel() == 1:\n        return x.to(dtype) * s.view(1, 1).to(dtype)\n    # Row quantized weight: scale shape is (m, 1) or (n, 1)\n    elif s.ndim == 2 and s.shape[1] == 1:\n        if x.shape[0] == s.shape[0]:\n            y = x.to(dtype) * s.to(dtype)\n        elif x.shape[1] == s.shape[0]:\n            # sometimes, this is called with the transpose of the weight. Adjust for that.\n            y = x.t().to(dtype) * s.to(dtype)\n            y = y.t()\n        else:\n            raise ValueError(f\"Incompatible shapes {x.shape = }, {s.shape = }\")\n        return y\n    # Block quantized weight: scale shape is (ceil(m/block_m), ceil(n/block_n))\n    else:\n        return weight_dequant_block(x, s, dtype = dtype)\n\n\n# Copied from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py\n@triton.jit\ndef act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):\n    pid = tl.program_id(axis = 0)\n    offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n    x = tl.load(x_ptr + offs).to(tl.float32)\n    s = tl.max(tl.abs(x)) / 448.0\n    # For a row of all zeros, lets return zeros as is\n    # for LoRA, there are cases where dY has 0 in it and we should not let it be NaN\n    # this is a deviation from the original implementation.\n    s = 1.0 if s == 0 else s\n    y = x / s\n    y = y.to(y_ptr.dtype.element_ty)\n    tl.store(y_ptr + offs, y)\n    tl.store(s_ptr + pid, s)\n\n\ndef act_quant(\n    x: torch.Tensor, block_size: int = 128\n) -> tuple[torch.Tensor, torch.Tensor]:\n    if not x.is_contiguous():\n        x = x.contiguous()\n    assert x.shape[-1] % block_size == 0\n    y = torch.empty_like(x, dtype = torch.float8_e4m3fn)\n    s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype = torch.float32)\n\n    def grid(meta):\n        return (triton.cdiv(x.numel(), meta[\"BLOCK_SIZE\"]),)\n\n    act_quant_kernel[grid](x, y, s, BLOCK_SIZE = block_size)\n    return y, s\n\n\n# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/quantization/fp8_kernel.py\n@triton.jit\ndef _w8a8_block_fp8_matmul(\n    # Pointers to inputs and output\n    A,\n    B,\n    C,\n    As,\n    Bs,\n    # Shape for matmul\n    M,\n    N,\n    K,\n    # Block size for block-wise quantization\n    group_n,\n    group_k,\n    # Stride for inputs and output\n    stride_am,\n    stride_ak,\n    stride_bk,\n    stride_bn,\n    stride_cm,\n    stride_cn,\n    stride_As_m,\n    stride_As_k,\n    stride_Bs_k,\n    stride_Bs_n,\n    # Meta-parameters\n    BLOCK_SIZE_M: tl.constexpr,\n    BLOCK_SIZE_N: tl.constexpr,\n    BLOCK_SIZE_K: tl.constexpr,\n    GROUP_SIZE_M: tl.constexpr,\n):\n    \"\"\"Triton-accelerated function used to perform linear operations (dot\n    product) on input tensors `A` and `B` with block-wise quantization, and\n    store the result in output tensor `C`.\n    \"\"\"\n\n    pid = tl.program_id(axis = 0)\n    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n    num_pid_in_group = GROUP_SIZE_M * num_pid_n\n    group_id = pid // num_pid_in_group\n    first_pid_m = group_id * GROUP_SIZE_M\n    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n    pid_m = first_pid_m + (pid % group_size_m)\n    pid_n = (pid % num_pid_in_group) // group_size_m\n\n    offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M\n    offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n    offs_k = tl.arange(0, BLOCK_SIZE_K)\n    a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n    b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n\n    As_ptrs = As + offs_am * stride_As_m\n    offs_bsn = offs_bn // group_n\n    Bs_ptrs = Bs + offs_bsn * stride_Bs_n\n\n    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype = tl.float32)\n    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n        a = tl.load(a_ptrs, mask = offs_k[None, :] < K - k * BLOCK_SIZE_K, other = 0.0)\n        b = tl.load(b_ptrs, mask = offs_k[:, None] < K - k * BLOCK_SIZE_K, other = 0.0)\n\n        k_start = k * BLOCK_SIZE_K\n        offs_ks = k_start // group_k\n        a_s = tl.load(As_ptrs + offs_ks * stride_As_k)\n        b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)\n\n        accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]\n        a_ptrs += BLOCK_SIZE_K * stride_ak\n        b_ptrs += BLOCK_SIZE_K * stride_bk\n\n    if C.dtype.element_ty == tl.bfloat16:\n        c = accumulator.to(tl.bfloat16)\n    elif C.dtype.element_ty == tl.float16:\n        c = accumulator.to(tl.float16)\n    else:\n        c = accumulator.to(tl.float32)\n\n    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n    tl.store(c_ptrs, c, mask = c_mask)\n\n\ndef w8a8_block_fp8_matmul_triton(\n    A: torch.Tensor,\n    B: torch.Tensor,\n    As: torch.Tensor,\n    Bs: torch.Tensor,\n    block_size: list[int],\n    output_dtype: torch.dtype = torch.float32,\n) -> torch.Tensor:\n    \"\"\"Block-wise FP8 matmul.\"\"\"\n    if block_size is None:\n        block_n, block_k = 128, 128\n    else:\n        assert len(block_size) == 2\n        block_n, block_k = block_size[0], block_size[1]\n\n    N, K = B.shape\n    assert A.shape[-1] == B.shape[-1]\n    assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()\n    assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]\n    assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2\n    assert triton.cdiv(N, block_n) == Bs.shape[0]\n    assert triton.cdiv(K, block_k) == Bs.shape[1]\n\n    M = A.numel() // A.shape[-1]\n    C_shape = A.shape[:-1] + (N,)\n    C = A.new_empty(C_shape, dtype = output_dtype)\n\n    BLOCK_SIZE_M = 128\n    if M < BLOCK_SIZE_M:\n        BLOCK_SIZE_M = max(triton.next_power_of_2(M), 16)\n    BLOCK_SIZE_K, BLOCK_SIZE_N = block_k, block_n\n\n    def grid(META):\n        return (\n            triton.cdiv(M, META[\"BLOCK_SIZE_M\"]) * triton.cdiv(N, META[\"BLOCK_SIZE_N\"]),\n        )\n\n    _w8a8_block_fp8_matmul[grid](\n        A,\n        B,\n        C,\n        As,\n        Bs,\n        M,\n        N,\n        K,\n        block_n,\n        block_k,\n        A.stride(-2),\n        A.stride(-1),\n        B.stride(1),\n        B.stride(0),\n        C.stride(-2),\n        C.stride(-1),\n        As.stride(-2),\n        As.stride(-1),\n        Bs.stride(1),\n        Bs.stride(0),\n        BLOCK_SIZE_M = BLOCK_SIZE_M,\n        BLOCK_SIZE_N = BLOCK_SIZE_N,\n        BLOCK_SIZE_K = BLOCK_SIZE_K,\n        GROUP_SIZE_M = 8,\n    )\n    return C\n\n\ndef torchao_block_matmul(\n    act_q: torch.Tensor,\n    weight_q: torch.Tensor,\n    act_scale: torch.Tensor,\n    weight_scale: torch.Tensor,\n    block_size: tuple[int, int],\n    output_dtype: torch.dtype = torch.bfloat16,\n):\n    out = torchao_blockwise_gemm(\n        act_q.contiguous(),\n        act_scale.contiguous(),\n        weight_q.contiguous(),\n        weight_scale.contiguous(),\n        block_size = block_size[1],\n    )\n    return out.to(output_dtype)\n\n\n# Note that older versions of fbgemm (<=1.3.0) cause numerical imprecisions resulting in NaNs especially when X has high values in it.\n# So our preference order is fbgemm (>=1.4.0) > torchao > triton. All of these have similar outputs/losses. Never use fbgemm (<=1.3.0) for block quantized FP8 matmul.\n# This torchao FP8 matmul seems to be ~3x faster than the w8a8_block_fp8_matmul_triton. Though torchao is 15-30% slower than fbgemm implementation (on H100 GPUs).\nfp8_block_matmul = (\n    torchao_block_matmul\n    if torchao_blockwise_gemm is not None\n    else w8a8_block_fp8_matmul_triton\n)\n\n\nclass FP8BlockQuantLinear(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, X, weight, weight_scale):\n        m, n = weight.shape\n\n        # Save original scale for backward (before any transformation)\n        original_weight_scale = weight_scale\n\n        # Handle per-tensor quantization: expand scalar to block scale shape\n        if weight_scale.numel() == 1:\n            block_size = [128, 128]\n            # Expand scalar to (ceil(m/128), ceil(n/128)) - same value for all blocks\n            num_blocks_m = triton.cdiv(m, block_size[0])\n            num_blocks_n = triton.cdiv(n, block_size[1])\n            weight_scale = weight_scale.expand(num_blocks_m, num_blocks_n).contiguous()\n        else:\n            # Block quantization path\n            p, q = weight_scale.shape\n            block_size = getattr(weight, \"block_size\", None) or getattr(\n                weight_scale, \"block_size\", [128, 128]\n            )\n            assert block_size is not None, \"block_size is not set\"\n            if triton.cdiv(m, block_size[0]) != p or triton.cdiv(n, block_size[1]) != q:\n                if (\n                    triton.cdiv(m, block_size[0]) == q\n                    and triton.cdiv(n, block_size[1]) == p\n                ):\n                    weight_scale = weight_scale.T\n                    original_weight_scale = weight_scale  # Update for transposed case\n                else:\n                    raise ValueError(\n                        f\"Weight shape {weight.shape} and scales shape {weight_scale.shape} is not compatible with block size {block_size}\"\n                    )\n\n        if not weight.is_contiguous():\n            weight = weight.contiguous()\n\n        # Quantize input and run FP8 matmul\n        qinput, scale = act_quant(X, block_size[1])\n        output = fp8_block_matmul(\n            qinput,\n            weight,\n            scale,\n            weight_scale,\n            block_size,\n            output_dtype = X.dtype,\n        )\n        ctx.weight = weight\n        ctx.weight_scale = original_weight_scale  # Save original for backward\n        return output.to(X.dtype)\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        W_deq = weight_dequant(ctx.weight, ctx.weight_scale)\n        grad_X = torch_matmul(grad_output, W_deq)\n        del W_deq\n        return grad_X, None, None\n\n\n@torch_compile\ndef fp8_torch_block_quant_forward(X, weight, weight_scale):\n    return FP8BlockQuantLinear.apply(X, weight, weight_scale)\n\n\nclass FbgemmFp8Linear_matmul(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, x, weight, weight_scale, bias = None):\n        if weight.shape[0] == weight_scale.shape[0] and (\n            weight.shape[0] % 8 == 0 and weight.shape[1] % 8 == 0\n        ):\n            # Edit: The kernel seems to expect that the weight has dimensions divisible by 8. Otherwise it throws `RuntimeError: cutlass cannot implement`\n            # One thing we can do is to pad the weight and weight scale to multiple of 8 and perform a F8F8BF16 operation.\n            # I tried benchmarking that for speed but observed that dequantize+bf16 matmul is significantly faster than padding+f8f8bf16 matmul. So we'll go that route.\n            # So essentially, f8f8bf16_rowise only happens when shapes are proper (no transposes) and divisible by 8.\n\n            # quantize_fp8_per_row will squash the leading dimensions, so save the desired shape here\n            output_shape = (*x.shape[:-1], -1)\n            # x_quantized and x_scale are not necessarily on the same device as x, this is an issue.\n            # https://github.com/pytorch/FBGEMM/blob/e08af8539c391437f447173863df0f3f6f6f1855/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L1237C3-L1237C45\n            x_quantized, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(\n                x.view(-1, x.shape[-1]).contiguous(),\n                scale_ub = getattr(weight, \"input_scale_ub\", None),\n            )\n            # moving x_quantized, x_scale here creates glibberish output ... However, if we move the output, it works\n            # x_quantized, x_scale = x_quantized.to(x.device), x_scale.to(x.device)\n\n            # The computation still happens on the device where self.weight is even if x_quantized is not on the same device as self.weight\n            weight_scale_float32 = weight_scale.to(torch.float32)\n\n            if not weight.is_contiguous():\n                weight = weight.contiguous()\n            if not weight_scale.is_contiguous():\n                weight_scale = weight_scale.contiguous()\n\n            output = torch.ops.fbgemm.f8f8bf16_rowwise(\n                x_quantized, weight, x_scale, weight_scale_float32, use_fast_accum = True\n            )\n            output = output + bias if bias is not None else output\n            # Hacky for now, we have the output to the device of x\n            output = output.to(x.device, x.dtype)\n            output = output.reshape(output_shape)\n            del x_quantized, x_scale\n        elif (\n            weight.shape[0] != weight_scale.shape[0]\n            and weight.shape[1] == weight_scale.shape[0]\n        ) or (weight.shape[0] // 8 != 0 or weight.shape[1] // 8 != 0):\n            # Either the weight/scale is transposed or its shape is not divisible by 8. Both cases, dequantizing is the preferred way.\n            # The transpose case is generally noticed in backward pass when we do dY@W instead of @W.T as we do for forward.\n            # The shape case, I noticed to happen in MLP of Qwen 2.5 VL 7B where the gate proj is of shape (3420, 1280) and 3420/8=427.5\n\n            W_deq = weight_dequant(weight, weight_scale).T\n            output = torch_matmul(x, W_deq)\n            del W_deq\n        else:\n            raise ValueError(\n                f\"Shapes are incompatible {weight.shape = }, {weight_scale.shape = }, {x.shape = }\"\n            )\n\n        ctx.weight = weight\n        ctx.weight_scale = weight_scale\n        return output\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        W_deq = weight_dequant(ctx.weight, ctx.weight_scale)\n        grad_X = torch_matmul(grad_output, W_deq)\n        del W_deq\n        return grad_X, None, None, None, None\n\n\n@torch_compile\ndef fbgemm_fp8_linear(X, weight, weight_scale, bias = None):\n    return FbgemmFp8Linear_matmul.apply(X, weight, weight_scale, bias)\n\n\nclass FP8_fbgemm_block_linear(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, X, weight, weight_scale, bias = None):\n        orig_shape = X.shape\n        X = X.view(-1, X.shape[-1])\n\n        bs_n, bs_k = getattr(weight, \"block_size\", None) or getattr(\n            weight_scale, \"block_size\", [128, 128]\n        )\n        bs_m = bs_n\n\n        m, n = weight.shape\n        p, q = weight_scale.shape\n\n        if triton.cdiv(m, bs_n) != p or triton.cdiv(n, bs_k) != q:\n            if triton.cdiv(m, bs_n) == q and triton.cdiv(n, bs_k) == p:\n                # weights are transposed during backward pass for training :)\n                # We transpose weight scale to counter that. Note that transposing weight would cause issues with matmul with input X\n                weight_scale = weight_scale.T\n            else:\n                raise ValueError(\n                    f\"Weight shape {weight.shape} and scales shape {weight_scale.shape} is not compatible with block size {bs_n, bs_k}\"\n                )\n\n        xq, xs = triton_quantize_fp8_block(X, bs_m, bs_n, None)\n        ## TODO: Investigate and resolve the high divergence of this output from baseline\n        # WARNING: This causes the outputs to diverge from expected when X has high values in it.\n        # That results in the model producing gibberish, especially on longer sequences and training loss starting at high values like 8 instead of <1 ideally\n        # Please refrain from using this till this issue is resolved. This exists here just for a future headstart.\n        output = torch.ops.fbgemm.f8f8bf16_blockwise(\n            xq, weight.contiguous(), xs, weight_scale.contiguous(), bs_m, bs_n, bs_k\n        )\n        output = output + bias if bias is not None else output\n\n        output = output.view(*orig_shape[:-1], -1)\n\n        del xq\n        del xs\n\n        ctx.weight = weight\n        ctx.weight_scale = weight_scale\n        ctx.block_size = [bs_m, bs_n, bs_k]\n        return output\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        W_deq = weight_dequant(ctx.weight, ctx.weight_scale)\n        grad_X = torch_matmul(grad_output, W_deq)\n        del W_deq\n        return grad_X, None, None, None, None\n\n\n@torch_compile\ndef fp8_fbgemm_block_linear(X, weight, weight_scale, bias = None):\n    return FP8_fbgemm_block_linear.apply(X, weight, weight_scale, bias)\n\n\ndef test_has_fbgemm():\n    # We must manually check if the faster FBGEMM works on the specific GPU\n    # For example RTX 5090 and RTX 4090 does not work\n    # Also SM100 (Blackwell B200/B100) GPUs fail with CUTLASS SM90 kernels\n    # [TODO] Investigate with TorchAO why FBGEMM fails on consumer GPUs\n    M, N, K = 128, 128, 128\n    xq = torch.ones(M, K, dtype = torch.float8_e4m3fn, device = \"cuda\")\n    wq = xq\n    M, K = xq.shape\n    N, _ = wq.shape\n    block_scale = torch.ones(M // 128, K // 128, dtype = torch.float32, device = \"cuda\")\n    has_fbgemm = False\n    try:\n        out = torch.ops.fbgemm.f8f8bf16_blockwise(xq, wq, block_scale, block_scale)\n        assert torch.unique(out).item() == 128\n        has_fbgemm = True\n        del out\n    except Exception as e:\n        error_str = str(e).lower()\n        # Catch any CUTLASS/CUDA errors and disable FBGEMM\n        # This includes MMA instruction errors, architecture mismatches, kernel launch failures, etc.\n        cutlass_cuda_errors = (\n            \"cutlass\",\n            \"cuda error\",\n            \"cuda runtime error\",\n            \"no kernel image\",\n            \"arch conditional\",\n            \"mma instruction\",\n            \"compute capability\",\n            \"cute_invalid_control_path\",\n            \"tma\",\n        )\n        is_cutlass_cuda_error = any(err in error_str for err in cutlass_cuda_errors)\n\n        if is_cutlass_cuda_error:\n            print(\n                \"Unsloth: FBGEMM on the current GPU cannot load - will switch to Triton kernels\"\n            )\n        else:\n            print(\n                f\"Unsloth: FBGEMM on the current GPU cannot load with error = {e} - will switch to Triton kernels\"\n            )\n        has_fbgemm = False\n    del block_scale, xq\n    torch.cuda.empty_cache()\n    return has_fbgemm\n\n\nfp8_block_quant_linear = fp8_torch_block_quant_forward\nif \"UNSLOTH_HAS_FBGEMM\" not in os.environ:\n    os.environ[\"UNSLOTH_HAS_FBGEMM\"] = \"0\"\ntry:\n    import fbgemm_gpu\n\n    # Older versions cause numerical imprecisions resulting in NaNs especially when X has high values in it.\n    # This is both fast and accurate hence preferred.\n    # This makes it 15% faster than the torchao implementation.\n    if Version(fbgemm_gpu.__version__) >= Version(\"1.4.0\"):\n        # We must manually confirm if blockwise FBGEMM works!\n        # This check is a must for consumer grade GPUs which fail\n        # Suppress CUDA device printf during probe -- on Blackwell (SM100) GPUs,\n        # FBGEMM's CUTLASS blockwise kernel (hardcoded SM90) fires thousands of\n        # \"Arch conditional MMA\" lines to stdout fd 1 before aborting.\n        from unsloth.import_fixes import suppress_cuda_printf\n\n        with suppress_cuda_printf():\n            _has_fbgemm = test_has_fbgemm()\n        if _has_fbgemm:\n            os.environ[\"UNSLOTH_HAS_FBGEMM\"] = \"1\"\n            logger.info(f\"Using fbgemm_gpu block quantized FP8 matmul\")\n            fp8_block_quant_linear = fp8_fbgemm_block_linear\n        else:\n            os.environ[\"UNSLOTH_HAS_FBGEMM\"] = \"0\"\nexcept:\n    pass\n\n\n@torch_compile\ndef fp8_linear(X, weight, weight_scale, bias = None):\n    # Per-tensor quantization: single scalar scale for entire weight\n    # Block quantized FP8: 2D scale tensor with multiple columns\n    if weight_scale.numel() == 1 or (\n        weight_scale.ndim == 2 and weight_scale.shape[1] > 1\n    ):\n        out = fp8_block_quant_linear(X, weight, weight_scale)\n    # Row/channel quantized FP8: 2D scale with shape (n, 1)\n    else:\n        out = fbgemm_fp8_linear(X, weight, weight_scale, bias)\n    return out\n\n\ndef module_forward_patch(forward_function, scale_attr = \"weight_scale\"):\n    def patched_forward(self, X):\n        return forward_function(X, self.weight, getattr(self, scale_attr))\n\n    return patched_forward\n\n\n# Patch the forward functions of the layers (for compiled models)\nif FbgemmFp8Linear is not None:\n    FbgemmFp8Linear.forward = module_forward_patch(fbgemm_fp8_linear, \"weight_scale\")\nif FP8Linear is not None:\n    FP8Linear.forward = module_forward_patch(fp8_block_quant_linear, \"weight_scale_inv\")\n"
  },
  {
    "path": "unsloth/kernels/geglu.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport triton\nimport triton.language as tl\nimport torch\nfrom .utils import (\n    calculate_settings,\n    triton_tanh,\n    torch_gpu_device,\n)\n\n# signed int32 max is 2**31-1 so num_elements cannot exceed 2**31\nNUM_INT32_ELEMENTS = 2**31\nSAFE_INT32_BUFFER_MULTIPLIER = 4\nBLOCK_SIZE = 1024\nINT32_SAFETY_BUFFER = NUM_INT32_ELEMENTS - BLOCK_SIZE * SAFE_INT32_BUFFER_MULTIPLIER\n\n\n@triton.jit\ndef _exact_forward_kernel(\n    e,\n    g,\n    h,\n    n_elements,\n    BLOCK_SIZE: tl.constexpr,\n    LONG_INDEXING: tl.constexpr,\n):\n    block_idx = tl.program_id(0)\n    if LONG_INDEXING:\n        offsets = block_idx.to(tl.int64) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE).to(\n            tl.int64\n        )\n        n_elements = tl.cast(n_elements, tl.int64)\n    else:\n        offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n    mask = offsets < n_elements\n\n    # f = 1/2 * e * (1 + erf(1/sqrt(2) * e))\n    # h = f * up\n    e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)\n    g_row = tl.load(g + offsets, mask = mask, other = 0)  # .to(tl.float32)\n\n    f_row = 0.5 * e_row * (tl.math.erf(tl.math.rsqrt(2.0) * e_row) + 1.0)\n    f_row = f_row.to(g_row.dtype)  # Exact copy from HF\n    h_row = f_row * g_row\n\n    # Store h\n    tl.store(h + offsets, h_row, mask = mask)\n\n\ndef geglu_exact_forward_kernel(gate, up):\n    batch, seq_len, hd = gate.shape\n    n_elements = gate.numel()\n    device = gate.device\n    out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = device)\n    grid = lambda meta: (triton.cdiv(n_elements, meta[\"BLOCK_SIZE\"]),)\n    with torch_gpu_device(device):\n        _exact_forward_kernel[grid](\n            gate,\n            up,\n            out,\n            n_elements,\n            BLOCK_SIZE = BLOCK_SIZE,\n            LONG_INDEXING = 0 if n_elements <= INT32_SAFETY_BUFFER else 1,\n        )\n    return out\n\n\n@triton.jit\ndef _exact_backward_kernel(\n    DW,\n    e,\n    g,\n    n_elements,\n    BLOCK_SIZE: tl.constexpr,\n    LONG_INDEXING: tl.constexpr,\n):\n    \"\"\"\n    f = 1/2 * e * (1 + erf(1/sqrt(2) * e))\n    h = f * up\n\n    df/de (with help of Wolfram :)\n    df/de = 1/2 * (1 + erf(1/sqrt(2) * e)) + 1/sqrt(2*pi) * e * exp(-1/2 * e^2)\n\n    Reuse via\n    f =        1/2 * (1 + erf(1/sqrt(2) * e)) * e\n    \"\"\"\n    block_idx = tl.program_id(0)\n    if LONG_INDEXING:\n        offsets = block_idx.to(tl.int64) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE).to(\n            tl.int64\n        )\n        n_elements = tl.cast(n_elements, tl.int64)\n    else:\n        offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n    mask = offsets < n_elements\n\n    DW_row = tl.load(DW + offsets, mask = mask, other = 0)  # .to(tl.float32)\n    e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)\n    g_row = tl.load(g + offsets, mask = mask, other = 0)  # .to(tl.float32)\n\n    # Break e_row away for re-use\n    # f = 1/2 * e * (1 + erf(1/sqrt(2) * e))\n    f_partial_row = 0.5 * (tl.math.erf(tl.math.rsqrt(2.0) * e_row) + 1.0)\n    f_row = f_partial_row * e_row\n\n    f_row = f_row.to(DW_row.dtype)\n    # h = f * g\n    h_row = f_row * g_row\n    # df = DW * f\n    df_row = DW_row * f_row\n    # dg = DW * g\n    dg_row = DW_row * g_row\n\n    # df/de = 1/2 * (1 + erf(1/sqrt(2) * e)) + 1/sqrt(2*pi) * e * exp(-1/2 * e^2)\n    t = 0.3989422804014327  # 1/sqrt(2*pi)\n    df_de = f_partial_row + t * e_row * tl.exp(-0.5 * e_row * e_row)\n\n    de_row = dg_row.to(tl.float32) * df_de\n    de_row = de_row.to(DW_row.dtype)\n\n    # Store derivatives in buffers\n    tl.store(DW + offsets, h_row, mask = mask)  # h  = f * g\n    tl.store(e + offsets, df_row, mask = mask)  # df = DW * f\n    tl.store(g + offsets, de_row, mask = mask)  # de\n\n\ndef geglu_exact_backward_kernel(DW, e, g):\n    batch_seq_len, hd = e.shape\n    n_elements = e.numel()\n    grid = lambda meta: (triton.cdiv(n_elements, meta[\"BLOCK_SIZE\"]),)\n    with torch_gpu_device(e.device):\n        _exact_backward_kernel[grid](\n            DW,\n            e,\n            g,\n            n_elements,\n            BLOCK_SIZE = BLOCK_SIZE,\n            LONG_INDEXING = 0 if n_elements <= INT32_SAFETY_BUFFER else 1,\n        )\n    return DW, e, g\n\n\n@triton.jit\ndef _approx_forward_kernel(\n    e,\n    g,\n    h,\n    n_elements,\n    BLOCK_SIZE: tl.constexpr,\n    LONG_INDEXING: tl.constexpr,\n):\n    block_idx = tl.program_id(0)\n    if LONG_INDEXING:\n        offsets = block_idx.to(tl.int64) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE).to(\n            tl.int64\n        )\n        n_elements = tl.cast(n_elements, tl.int64)\n    else:\n        offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n    mask = offsets < n_elements\n\n    # f = 1/2 * e * (1 + tanh( sqrt(2/pi) * (x + 0.044715 * x^3 ) ))\n    # f = 1/2 * e * (1 + tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) ))\n    # h = f * up\n    s = 0.7978845608028654  # math.sqrt(2 / math.pi)\n\n    e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)\n    g_row = tl.load(g + offsets, mask = mask, other = 0)  # .to(tl.float32)\n\n    f_row = (\n        0.5 * e_row * (triton_tanh(s * e_row * (1.0 + 0.044715 * e_row * e_row)) + 1.0)\n    )\n    f_row = f_row.to(g_row.dtype)  # Exact copy from HF\n    h_row = f_row * g_row\n\n    # Store h\n    tl.store(h + offsets, h_row, mask = mask)\n\n\ndef geglu_approx_forward_kernel(gate, up):\n    batch, seq_len, hd = gate.shape\n    n_elements = gate.numel()\n    device = gate.device\n    out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = device)\n    grid = lambda meta: (triton.cdiv(n_elements, meta[\"BLOCK_SIZE\"]),)\n    with torch_gpu_device(device):\n        _approx_forward_kernel[grid](\n            gate,\n            up,\n            out,\n            n_elements,\n            BLOCK_SIZE = BLOCK_SIZE,\n            LONG_INDEXING = 0 if n_elements <= INT32_SAFETY_BUFFER else 1,\n        )\n    return out\n\n\n@triton.jit\ndef _approx_backward_kernel(\n    DW,\n    e,\n    g,\n    n_elements,\n    BLOCK_SIZE: tl.constexpr,\n    LONG_INDEXING: tl.constexpr,\n):\n    \"\"\"\n    f = 1/2 * e * (1 + tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) ))\n    h = f * up\n\n    df/de (with help from https://arxiv.org/pdf/2305.12073.pdf :))\n    df/de = 1/2 * [1 + tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) )] +\n            1/2 * sech^2 [   sqrt(2/pi) * x * (1 + 0.044715 * x^2 )  ] * \\\n                           ( sqrt(2/pi) * x * (1 + 0.044715 * x^2 * 3 ) )\n\n    Notice sech^2(x) = 1 - tanh^2(x)\n    So reuse tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) )\n\n    See https://www.desmos.com/calculator/nqprfoni6x\n    \"\"\"\n    block_idx = tl.program_id(0)\n    if LONG_INDEXING:\n        offsets = block_idx.to(tl.int64) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE).to(\n            tl.int64\n        )\n        n_elements = tl.cast(n_elements, tl.int64)\n    else:\n        offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n    mask = offsets < n_elements\n\n    DW_row = tl.load(DW + offsets, mask = mask, other = 0)  # .to(tl.float32)\n    e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)\n    g_row = tl.load(g + offsets, mask = mask, other = 0)  # .to(tl.float32)\n\n    # See https://www.desmos.com/calculator/nqprfoni6x\n    s = 0.7978845608028654  # math.sqrt(2 / math.pi)\n    a = s * e_row  # a = sqrt(2 / pi) * x\n    b = a * 0.044715 * e_row * e_row  # b = a * 0.044715 * x^2\n    T = 1.0 + triton_tanh(a + b)\n    T2 = 0.5 * T\n    # Q = 0.5 * -T * (T - 2.0) * (a + 3.0 * b)\n    Q2 = -T2 * (T - 2.0) * (a + 3.0 * b)\n    df_de = T2 + Q2  # 1/2 * (T + Q)\n\n    # f = 1/2 * e * (1 + tanh( sqrt(2/pi) * (x + 0.044715 * x^3 ) ))\n    f_row = T2 * e_row\n    f_row = f_row.to(DW_row.dtype)\n    # h = f * g\n    h_row = f_row * g_row\n    # df = DW * f\n    df_row = DW_row * f_row\n    # dg = DW * g\n    dg_row = DW_row * g_row\n\n    de_row = dg_row.to(tl.float32) * df_de\n    de_row = de_row.to(DW_row.dtype)\n\n    # Store derivatives in buffers\n    tl.store(DW + offsets, h_row, mask = mask)  # h  = f * g\n    tl.store(e + offsets, df_row, mask = mask)  # df = DW * f\n    tl.store(g + offsets, de_row, mask = mask)  # de\n\n\ndef geglu_approx_backward_kernel(DW, e, g):\n    batch_seq_len, hd = e.shape\n    n_elements = e.numel()\n    grid = lambda meta: (triton.cdiv(n_elements, meta[\"BLOCK_SIZE\"]),)\n    with torch_gpu_device(e.device):\n        _approx_backward_kernel[grid](\n            DW,\n            e,\n            g,\n            n_elements,\n            BLOCK_SIZE = BLOCK_SIZE,\n            LONG_INDEXING = 0 if n_elements <= INT32_SAFETY_BUFFER else 1,\n        )\n    return DW, e, g\n"
  },
  {
    "path": "unsloth/kernels/layernorm.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n# Copyright 2024-present Andrej Karpathy & the llm.c team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport triton\nimport triton.language as tl\nimport torch\nfrom .utils import calculate_settings, torch_gpu_device\nfrom unsloth_zoo.patching_utils import (\n    patch_layernorm,\n)\n\n\n@triton.jit\ndef layernorm_forward(\n    Y,\n    Y_row_stride,\n    X,\n    X_row_stride,\n    W,\n    b,\n    r,\n    mu,\n    n_cols: tl.constexpr,\n    eps: tl.constexpr,\n    BLOCK_SIZE: tl.constexpr,\n):\n    row_idx = tl.program_id(0)\n    col_offsets = tl.arange(0, BLOCK_SIZE)\n    mask = col_offsets < n_cols\n\n    Y += row_idx * Y_row_stride\n    X += row_idx * X_row_stride\n    r += row_idx\n    mu += row_idx\n\n    # According to https://pytorch.org/torchtune/stable/_modules/torchtune/modules/layer_norm.html#Fp32LayerNorm, all modules\n    # are in float32!\n    X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)\n    W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)\n    b_row = tl.load(b + col_offsets, mask = mask, other = 0).to(tl.float32)\n\n    mean_X = tl.sum(X_row, axis = 0) / n_cols\n    # (X[0] - mean) == -mean so we need to mask it out\n    XX = tl.where(mask, X_row - mean_X, 0)\n    row_var = tl.sum(XX * XX, axis = 0) / n_cols\n    # Explicit float32 scalar to ensure correct type promotion on HIP/ROCm\n    eps_f32 = tl.full((), eps, tl.float32)\n    inv_var = tl.math.rsqrt(row_var + eps_f32)\n    tl.store(r, inv_var)\n    tl.store(mu, mean_X)\n    output = (XX * inv_var) * W_row + b_row\n    tl.store(Y + col_offsets, output, mask = mask)\n\n\n@triton.jit\ndef layernorm_backward(\n    dY,\n    dY_row_stride,\n    X,\n    X_row_stride,\n    W,\n    b,\n    r,\n    mu,\n    n_cols: tl.constexpr,\n    eps: tl.constexpr,\n    BLOCK_SIZE: tl.constexpr,\n):\n    # Approximately follows https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md\n    row_idx = tl.program_id(0)\n    col_offsets = tl.arange(0, BLOCK_SIZE)\n    mask = col_offsets < n_cols\n\n    dY += row_idx * dY_row_stride\n    X += row_idx * X_row_stride\n    r += row_idx\n    mu += row_idx\n\n    # According to https://pytorch.org/torchtune/stable/_modules/torchtune/modules/layer_norm.html#Fp32LayerNorm, all modules\n    # are in float32!\n    dY_row = tl.load(dY + col_offsets, mask = mask, other = 0).to(tl.float32)\n    X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)\n    W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)\n    b_row = tl.load(b + col_offsets, mask = mask, other = 0).to(tl.float32)\n\n    inv_var = tl.load(r).to(tl.float32)\n    mean = tl.load(mu).to(tl.float32)\n    normed = (X_row - mean) * inv_var\n    dY_W = dY_row * W_row\n    dX_row = (\n        dY_W\n        - tl.sum(dY_W, axis = 0) / n_cols\n        - normed * tl.sum(dY_W * normed, axis = 0) / n_cols\n    )\n    dX_row = dX_row * inv_var\n    tl.store(dY + col_offsets, dX_row, mask = mask)\n\n\nclass Fast_Layernorm(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, X, W, b, eps):\n        shape = X.shape\n        dim = shape[-1]\n        X = X.view(-1, dim)\n        n_rows, n_cols = X.shape\n        BLOCK_SIZE, num_warps = calculate_settings(n_cols)\n        device = X.device\n        Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = device)\n        r = torch.empty(n_rows, dtype = torch.float32, device = device)\n        mu = torch.empty(n_rows, dtype = torch.float32, device = device)\n\n        with torch_gpu_device(device):\n            layernorm_forward[(n_rows,)](\n                Y,\n                Y.stride(0),\n                X,\n                X.stride(0),\n                W,\n                b,\n                r,\n                mu,\n                n_cols,\n                eps,\n                BLOCK_SIZE = BLOCK_SIZE,\n                num_warps = num_warps,\n            )\n        ctx.eps = eps\n        ctx.BLOCK_SIZE = BLOCK_SIZE\n        ctx.num_warps = num_warps\n        ctx.save_for_backward(X, W, b, r, mu)\n        return Y.view(*shape)\n\n    @staticmethod\n    def backward(ctx, dY):\n        shape = dY.shape\n        dim = shape[-1]\n        dY = dY.view(-1, dim)\n        X, W, b, r, mu = ctx.saved_tensors\n        n_rows, n_cols = dY.shape\n\n        with torch_gpu_device(dY.device):\n            layernorm_backward[(n_rows,)](\n                dY,\n                dY.stride(0),\n                X,\n                X.stride(0),\n                W,\n                b,\n                r,\n                mu,\n                n_cols,\n                ctx.eps,\n                BLOCK_SIZE = ctx.BLOCK_SIZE,\n                num_warps = ctx.num_warps,\n            )\n        dX = dY.view(*shape)\n        return dX, None, None, None, None\n\n\ndef fast_layernorm(layernorm, X):\n    assert layernorm.elementwise_affine is True\n    W = layernorm.weight\n    bias = layernorm.bias\n    eps = (\n        layernorm.variance_epsilon\n        if hasattr(layernorm, \"variance_epsilon\")\n        else layernorm.eps\n    )\n    out = Fast_Layernorm.apply(X, W, bias, eps)\n    return out\n\n\ndef test_layernorm(\n    dim = 1024,\n    eps = 1e-5,\n    dtype = torch.float16,\n    bsz = 21,\n    random_state = 3407,\n    seqlen = 3341,\n):\n    from torch.nn import LayerNorm\n\n    layernorm = LayerNorm((dim,), eps = eps, device = \"cuda\", dtype = dtype)\n    torch.cuda.manual_seed(random_state)\n    torch.manual_seed(random_state)\n    torch.nn.init.uniform_(layernorm.weight)\n    torch.nn.init.uniform_(layernorm.bias)\n    X = torch.randn((bsz, seqlen, dim), dtype = dtype, device = \"cuda\")\n    XX = X.clone()\n    X.requires_grad_(True)\n    XX.requires_grad_(True)\n    Y = layernorm(X)\n    YY = torch.randn((bsz, seqlen, dim), dtype = dtype, device = \"cuda\", requires_grad = True)\n    Y.backward(YY)\n    correct_grad = X.grad.clone()\n    # from unsloth.kernels import fast_layernorm\n    Y = fast_layernorm(layernorm, XX)\n    Y.backward(YY)\n    assert torch.dist(correct_grad, XX.grad).item() <= 0.1\n\n\ndef testing_suite_layernorm():\n    for dim in [512, 1024, 2048]:\n        for dtype in [torch.float16, torch.bfloat16]:\n            with torch.autocast(device_type = \"cuda\", dtype = dtype):\n                for seqlen in [3341, 2048, 349]:\n                    for random_state in [3407, 42]:\n                        test_layernorm(\n                            dim = dim,\n                            eps = 1e-5,\n                            dtype = dtype,\n                            bsz = 21,\n                            random_state = random_state,\n                            seqlen = seqlen,\n                        )\n"
  },
  {
    "path": "unsloth/kernels/moe/LICENSE",
    "content": "                    GNU AFFERO GENERAL PUBLIC LICENSE\n                       Version 3, 19 November 2007\n\n Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>\n Everyone is permitted to copy and distribute verbatim copies\n of this license document, but changing it is not allowed.\n\n                            Preamble\n\n  The GNU Affero General Public License is a free, copyleft license for\nsoftware and other kinds of works, specifically designed to ensure\ncooperation with the community in the case of network server software.\n\n  The licenses for most software and other practical works are designed\nto take away your freedom to share and change the works.  By contrast,\nour General Public Licenses are intended to guarantee your freedom to\nshare and change all versions of a program--to make sure it remains free\nsoftware for all its users.\n\n  When we speak of free software, we are referring to freedom, not\nprice.  Our General Public Licenses are designed to make sure that you\nhave the freedom to distribute copies of free software (and charge for\nthem if you wish), that you receive source code or can get it if you\nwant it, that you can change the software or use pieces of it in new\nfree programs, and that you know you can do these things.\n\n  Developers that use our General Public Licenses protect your rights\nwith two steps: (1) assert copyright on the software, and (2) offer\nyou this License which gives you legal permission to copy, distribute\nand/or modify the software.\n\n  A secondary benefit of defending all users' freedom is that\nimprovements made in alternate versions of the program, if they\nreceive widespread use, become available for other developers to\nincorporate.  Many developers of free software are heartened and\nencouraged by the resulting cooperation.  However, in the case of\nsoftware used on network servers, this result may fail to come about.\nThe GNU General Public License permits making a modified version and\nletting the public access it on a server without ever releasing its\nsource code to the public.\n\n  The GNU Affero General Public License is designed specifically to\nensure that, in such cases, the modified source code becomes available\nto the community.  It requires the operator of a network server to\nprovide the source code of the modified version running there to the\nusers of that server.  Therefore, public use of a modified version, on\na publicly accessible server, gives the public access to the source\ncode of the modified version.\n\n  An older license, called the Affero General Public License and\npublished by Affero, was designed to accomplish similar goals.  This is\na different license, not a version of the Affero GPL, but Affero has\nreleased a new version of the Affero GPL which permits relicensing under\nthis license.\n\n  The precise terms and conditions for copying, distribution and\nmodification follow.\n\n                       TERMS AND CONDITIONS\n\n  0. Definitions.\n\n  \"This License\" refers to version 3 of the GNU Affero General Public License.\n\n  \"Copyright\" also means copyright-like laws that apply to other kinds of\nworks, such as semiconductor masks.\n\n  \"The Program\" refers to any copyrightable work licensed under this\nLicense.  Each licensee is addressed as \"you\".  \"Licensees\" and\n\"recipients\" may be individuals or organizations.\n\n  To \"modify\" a work means to copy from or adapt all or part of the work\nin a fashion requiring copyright permission, other than the making of an\nexact copy.  The resulting work is called a \"modified version\" of the\nearlier work or a work \"based on\" the earlier work.\n\n  A \"covered work\" means either the unmodified Program or a work based\non the Program.\n\n  To \"propagate\" a work means to do anything with it that, without\npermission, would make you directly or secondarily liable for\ninfringement under applicable copyright law, except executing it on a\ncomputer or modifying a private copy.  Propagation includes copying,\ndistribution (with or without modification), making available to the\npublic, and in some countries other activities as well.\n\n  To \"convey\" a work means any kind of propagation that enables other\nparties to make or receive copies.  Mere interaction with a user through\na computer network, with no transfer of a copy, is not conveying.\n\n  An interactive user interface displays \"Appropriate Legal Notices\"\nto the extent that it includes a convenient and prominently visible\nfeature that (1) displays an appropriate copyright notice, and (2)\ntells the user that there is no warranty for the work (except to the\nextent that warranties are provided), that licensees may convey the\nwork under this License, and how to view a copy of this License.  If\nthe interface presents a list of user commands or options, such as a\nmenu, a prominent item in the list meets this criterion.\n\n  1. Source Code.\n\n  The \"source code\" for a work means the preferred form of the work\nfor making modifications to it.  \"Object code\" means any non-source\nform of a work.\n\n  A \"Standard Interface\" means an interface that either is an official\nstandard defined by a recognized standards body, or, in the case of\ninterfaces specified for a particular programming language, one that\nis widely used among developers working in that language.\n\n  The \"System Libraries\" of an executable work include anything, other\nthan the work as a whole, that (a) is included in the normal form of\npackaging a Major Component, but which is not part of that Major\nComponent, and (b) serves only to enable use of the work with that\nMajor Component, or to implement a Standard Interface for which an\nimplementation is available to the public in source code form.  A\n\"Major Component\", in this context, means a major essential component\n(kernel, window system, and so on) of the specific operating system\n(if any) on which the executable work runs, or a compiler used to\nproduce the work, or an object code interpreter used to run it.\n\n  The \"Corresponding Source\" for a work in object code form means all\nthe source code needed to generate, install, and (for an executable\nwork) run the object code and to modify the work, including scripts to\ncontrol those activities.  However, it does not include the work's\nSystem Libraries, or general-purpose tools or generally available free\nprograms which are used unmodified in performing those activities but\nwhich are not part of the work.  For example, Corresponding Source\nincludes interface definition files associated with source files for\nthe work, and the source code for shared libraries and dynamically\nlinked subprograms that the work is specifically designed to require,\nsuch as by intimate data communication or control flow between those\nsubprograms and other parts of the work.\n\n  The Corresponding Source need not include anything that users\ncan regenerate automatically from other parts of the Corresponding\nSource.\n\n  The Corresponding Source for a work in source code form is that\nsame work.\n\n  2. Basic Permissions.\n\n  All rights granted under this License are granted for the term of\ncopyright on the Program, and are irrevocable provided the stated\nconditions are met.  This License explicitly affirms your unlimited\npermission to run the unmodified Program.  The output from running a\ncovered work is covered by this License only if the output, given its\ncontent, constitutes a covered work.  This License acknowledges your\nrights of fair use or other equivalent, as provided by copyright law.\n\n  You may make, run and propagate covered works that you do not\nconvey, without conditions so long as your license otherwise remains\nin force.  You may convey covered works to others for the sole purpose\nof having them make modifications exclusively for you, or provide you\nwith facilities for running those works, provided that you comply with\nthe terms of this License in conveying all material for which you do\nnot control copyright.  Those thus making or running the covered works\nfor you must do so exclusively on your behalf, under your direction\nand control, on terms that prohibit them from making any copies of\nyour copyrighted material outside their relationship with you.\n\n  Conveying under any other circumstances is permitted solely under\nthe conditions stated below.  Sublicensing is not allowed; section 10\nmakes it unnecessary.\n\n  3. Protecting Users' Legal Rights From Anti-Circumvention Law.\n\n  No covered work shall be deemed part of an effective technological\nmeasure under any applicable law fulfilling obligations under article\n11 of the WIPO copyright treaty adopted on 20 December 1996, or\nsimilar laws prohibiting or restricting circumvention of such\nmeasures.\n\n  When you convey a covered work, you waive any legal power to forbid\ncircumvention of technological measures to the extent such circumvention\nis effected by exercising rights under this License with respect to\nthe covered work, and you disclaim any intention to limit operation or\nmodification of the work as a means of enforcing, against the work's\nusers, your or third parties' legal rights to forbid circumvention of\ntechnological measures.\n\n  4. Conveying Verbatim Copies.\n\n  You may convey verbatim copies of the Program's source code as you\nreceive it, in any medium, provided that you conspicuously and\nappropriately publish on each copy an appropriate copyright notice;\nkeep intact all notices stating that this License and any\nnon-permissive terms added in accord with section 7 apply to the code;\nkeep intact all notices of the absence of any warranty; and give all\nrecipients a copy of this License along with the Program.\n\n  You may charge any price or no price for each copy that you convey,\nand you may offer support or warranty protection for a fee.\n\n  5. Conveying Modified Source Versions.\n\n  You may convey a work based on the Program, or the modifications to\nproduce it from the Program, in the form of source code under the\nterms of section 4, provided that you also meet all of these conditions:\n\n    a) The work must carry prominent notices stating that you modified\n    it, and giving a relevant date.\n\n    b) The work must carry prominent notices stating that it is\n    released under this License and any conditions added under section\n    7.  This requirement modifies the requirement in section 4 to\n    \"keep intact all notices\".\n\n    c) You must license the entire work, as a whole, under this\n    License to anyone who comes into possession of a copy.  This\n    License will therefore apply, along with any applicable section 7\n    additional terms, to the whole of the work, and all its parts,\n    regardless of how they are packaged.  This License gives no\n    permission to license the work in any other way, but it does not\n    invalidate such permission if you have separately received it.\n\n    d) If the work has interactive user interfaces, each must display\n    Appropriate Legal Notices; however, if the Program has interactive\n    interfaces that do not display Appropriate Legal Notices, your\n    work need not make them do so.\n\n  A compilation of a covered work with other separate and independent\nworks, which are not by their nature extensions of the covered work,\nand which are not combined with it such as to form a larger program,\nin or on a volume of a storage or distribution medium, is called an\n\"aggregate\" if the compilation and its resulting copyright are not\nused to limit the access or legal rights of the compilation's users\nbeyond what the individual works permit.  Inclusion of a covered work\nin an aggregate does not cause this License to apply to the other\nparts of the aggregate.\n\n  6. Conveying Non-Source Forms.\n\n  You may convey a covered work in object code form under the terms\nof sections 4 and 5, provided that you also convey the\nmachine-readable Corresponding Source under the terms of this License,\nin one of these ways:\n\n    a) Convey the object code in, or embodied in, a physical product\n    (including a physical distribution medium), accompanied by the\n    Corresponding Source fixed on a durable physical medium\n    customarily used for software interchange.\n\n    b) Convey the object code in, or embodied in, a physical product\n    (including a physical distribution medium), accompanied by a\n    written offer, valid for at least three years and valid for as\n    long as you offer spare parts or customer support for that product\n    model, to give anyone who possesses the object code either (1) a\n    copy of the Corresponding Source for all the software in the\n    product that is covered by this License, on a durable physical\n    medium customarily used for software interchange, for a price no\n    more than your reasonable cost of physically performing this\n    conveying of source, or (2) access to copy the\n    Corresponding Source from a network server at no charge.\n\n    c) Convey individual copies of the object code with a copy of the\n    written offer to provide the Corresponding Source.  This\n    alternative is allowed only occasionally and noncommercially, and\n    only if you received the object code with such an offer, in accord\n    with subsection 6b.\n\n    d) Convey the object code by offering access from a designated\n    place (gratis or for a charge), and offer equivalent access to the\n    Corresponding Source in the same way through the same place at no\n    further charge.  You need not require recipients to copy the\n    Corresponding Source along with the object code.  If the place to\n    copy the object code is a network server, the Corresponding Source\n    may be on a different server (operated by you or a third party)\n    that supports equivalent copying facilities, provided you maintain\n    clear directions next to the object code saying where to find the\n    Corresponding Source.  Regardless of what server hosts the\n    Corresponding Source, you remain obligated to ensure that it is\n    available for as long as needed to satisfy these requirements.\n\n    e) Convey the object code using peer-to-peer transmission, provided\n    you inform other peers where the object code and Corresponding\n    Source of the work are being offered to the general public at no\n    charge under subsection 6d.\n\n  A separable portion of the object code, whose source code is excluded\nfrom the Corresponding Source as a System Library, need not be\nincluded in conveying the object code work.\n\n  A \"User Product\" is either (1) a \"consumer product\", which means any\ntangible personal property which is normally used for personal, family,\nor household purposes, or (2) anything designed or sold for incorporation\ninto a dwelling.  In determining whether a product is a consumer product,\ndoubtful cases shall be resolved in favor of coverage.  For a particular\nproduct received by a particular user, \"normally used\" refers to a\ntypical or common use of that class of product, regardless of the status\nof the particular user or of the way in which the particular user\nactually uses, or expects or is expected to use, the product.  A product\nis a consumer product regardless of whether the product has substantial\ncommercial, industrial or non-consumer uses, unless such uses represent\nthe only significant mode of use of the product.\n\n  \"Installation Information\" for a User Product means any methods,\nprocedures, authorization keys, or other information required to install\nand execute modified versions of a covered work in that User Product from\na modified version of its Corresponding Source.  The information must\nsuffice to ensure that the continued functioning of the modified object\ncode is in no case prevented or interfered with solely because\nmodification has been made.\n\n  If you convey an object code work under this section in, or with, or\nspecifically for use in, a User Product, and the conveying occurs as\npart of a transaction in which the right of possession and use of the\nUser Product is transferred to the recipient in perpetuity or for a\nfixed term (regardless of how the transaction is characterized), the\nCorresponding Source conveyed under this section must be accompanied\nby the Installation Information.  But this requirement does not apply\nif neither you nor any third party retains the ability to install\nmodified object code on the User Product (for example, the work has\nbeen installed in ROM).\n\n  The requirement to provide Installation Information does not include a\nrequirement to continue to provide support service, warranty, or updates\nfor a work that has been modified or installed by the recipient, or for\nthe User Product in which it has been modified or installed.  Access to a\nnetwork may be denied when the modification itself materially and\nadversely affects the operation of the network or violates the rules and\nprotocols for communication across the network.\n\n  Corresponding Source conveyed, and Installation Information provided,\nin accord with this section must be in a format that is publicly\ndocumented (and with an implementation available to the public in\nsource code form), and must require no special password or key for\nunpacking, reading or copying.\n\n  7. Additional Terms.\n\n  \"Additional permissions\" are terms that supplement the terms of this\nLicense by making exceptions from one or more of its conditions.\nAdditional permissions that are applicable to the entire Program shall\nbe treated as though they were included in this License, to the extent\nthat they are valid under applicable law.  If additional permissions\napply only to part of the Program, that part may be used separately\nunder those permissions, but the entire Program remains governed by\nthis License without regard to the additional permissions.\n\n  When you convey a copy of a covered work, you may at your option\nremove any additional permissions from that copy, or from any part of\nit.  (Additional permissions may be written to require their own\nremoval in certain cases when you modify the work.)  You may place\nadditional permissions on material, added by you to a covered work,\nfor which you have or can give appropriate copyright permission.\n\n  Notwithstanding any other provision of this License, for material you\nadd to a covered work, you may (if authorized by the copyright holders of\nthat material) supplement the terms of this License with terms:\n\n    a) Disclaiming warranty or limiting liability differently from the\n    terms of sections 15 and 16 of this License; or\n\n    b) Requiring preservation of specified reasonable legal notices or\n    author attributions in that material or in the Appropriate Legal\n    Notices displayed by works containing it; or\n\n    c) Prohibiting misrepresentation of the origin of that material, or\n    requiring that modified versions of such material be marked in\n    reasonable ways as different from the original version; or\n\n    d) Limiting the use for publicity purposes of names of licensors or\n    authors of the material; or\n\n    e) Declining to grant rights under trademark law for use of some\n    trade names, trademarks, or service marks; or\n\n    f) Requiring indemnification of licensors and authors of that\n    material by anyone who conveys the material (or modified versions of\n    it) with contractual assumptions of liability to the recipient, for\n    any liability that these contractual assumptions directly impose on\n    those licensors and authors.\n\n  All other non-permissive additional terms are considered \"further\nrestrictions\" within the meaning of section 10.  If the Program as you\nreceived it, or any part of it, contains a notice stating that it is\ngoverned by this License along with a term that is a further\nrestriction, you may remove that term.  If a license document contains\na further restriction but permits relicensing or conveying under this\nLicense, you may add to a covered work material governed by the terms\nof that license document, provided that the further restriction does\nnot survive such relicensing or conveying.\n\n  If you add terms to a covered work in accord with this section, you\nmust place, in the relevant source files, a statement of the\nadditional terms that apply to those files, or a notice indicating\nwhere to find the applicable terms.\n\n  Additional terms, permissive or non-permissive, may be stated in the\nform of a separately written license, or stated as exceptions;\nthe above requirements apply either way.\n\n  8. Termination.\n\n  You may not propagate or modify a covered work except as expressly\nprovided under this License.  Any attempt otherwise to propagate or\nmodify it is void, and will automatically terminate your rights under\nthis License (including any patent licenses granted under the third\nparagraph of section 11).\n\n  However, if you cease all violation of this License, then your\nlicense from a particular copyright holder is reinstated (a)\nprovisionally, unless and until the copyright holder explicitly and\nfinally terminates your license, and (b) permanently, if the copyright\nholder fails to notify you of the violation by some reasonable means\nprior to 60 days after the cessation.\n\n  Moreover, your license from a particular copyright holder is\nreinstated permanently if the copyright holder notifies you of the\nviolation by some reasonable means, this is the first time you have\nreceived notice of violation of this License (for any work) from that\ncopyright holder, and you cure the violation prior to 30 days after\nyour receipt of the notice.\n\n  Termination of your rights under this section does not terminate the\nlicenses of parties who have received copies or rights from you under\nthis License.  If your rights have been terminated and not permanently\nreinstated, you do not qualify to receive new licenses for the same\nmaterial under section 10.\n\n  9. Acceptance Not Required for Having Copies.\n\n  You are not required to accept this License in order to receive or\nrun a copy of the Program.  Ancillary propagation of a covered work\noccurring solely as a consequence of using peer-to-peer transmission\nto receive a copy likewise does not require acceptance.  However,\nnothing other than this License grants you permission to propagate or\nmodify any covered work.  These actions infringe copyright if you do\nnot accept this License.  Therefore, by modifying or propagating a\ncovered work, you indicate your acceptance of this License to do so.\n\n  10. Automatic Licensing of Downstream Recipients.\n\n  Each time you convey a covered work, the recipient automatically\nreceives a license from the original licensors, to run, modify and\npropagate that work, subject to this License.  You are not responsible\nfor enforcing compliance by third parties with this License.\n\n  An \"entity transaction\" is a transaction transferring control of an\norganization, or substantially all assets of one, or subdividing an\norganization, or merging organizations.  If propagation of a covered\nwork results from an entity transaction, each party to that\ntransaction who receives a copy of the work also receives whatever\nlicenses to the work the party's predecessor in interest had or could\ngive under the previous paragraph, plus a right to possession of the\nCorresponding Source of the work from the predecessor in interest, if\nthe predecessor has it or can get it with reasonable efforts.\n\n  You may not impose any further restrictions on the exercise of the\nrights granted or affirmed under this License.  For example, you may\nnot impose a license fee, royalty, or other charge for exercise of\nrights granted under this License, and you may not initiate litigation\n(including a cross-claim or counterclaim in a lawsuit) alleging that\nany patent claim is infringed by making, using, selling, offering for\nsale, or importing the Program or any portion of it.\n\n  11. Patents.\n\n  A \"contributor\" is a copyright holder who authorizes use under this\nLicense of the Program or a work on which the Program is based.  The\nwork thus licensed is called the contributor's \"contributor version\".\n\n  A contributor's \"essential patent claims\" are all patent claims\nowned or controlled by the contributor, whether already acquired or\nhereafter acquired, that would be infringed by some manner, permitted\nby this License, of making, using, or selling its contributor version,\nbut do not include claims that would be infringed only as a\nconsequence of further modification of the contributor version.  For\npurposes of this definition, \"control\" includes the right to grant\npatent sublicenses in a manner consistent with the requirements of\nthis License.\n\n  Each contributor grants you a non-exclusive, worldwide, royalty-free\npatent license under the contributor's essential patent claims, to\nmake, use, sell, offer for sale, import and otherwise run, modify and\npropagate the contents of its contributor version.\n\n  In the following three paragraphs, a \"patent license\" is any express\nagreement or commitment, however denominated, not to enforce a patent\n(such as an express permission to practice a patent or covenant not to\nsue for patent infringement).  To \"grant\" such a patent license to a\nparty means to make such an agreement or commitment not to enforce a\npatent against the party.\n\n  If you convey a covered work, knowingly relying on a patent license,\nand the Corresponding Source of the work is not available for anyone\nto copy, free of charge and under the terms of this License, through a\npublicly available network server or other readily accessible means,\nthen you must either (1) cause the Corresponding Source to be so\navailable, or (2) arrange to deprive yourself of the benefit of the\npatent license for this particular work, or (3) arrange, in a manner\nconsistent with the requirements of this License, to extend the patent\nlicense to downstream recipients.  \"Knowingly relying\" means you have\nactual knowledge that, but for the patent license, your conveying the\ncovered work in a country, or your recipient's use of the covered work\nin a country, would infringe one or more identifiable patents in that\ncountry that you have reason to believe are valid.\n\n  If, pursuant to or in connection with a single transaction or\narrangement, you convey, or propagate by procuring conveyance of, a\ncovered work, and grant a patent license to some of the parties\nreceiving the covered work authorizing them to use, propagate, modify\nor convey a specific copy of the covered work, then the patent license\nyou grant is automatically extended to all recipients of the covered\nwork and works based on it.\n\n  A patent license is \"discriminatory\" if it does not include within\nthe scope of its coverage, prohibits the exercise of, or is\nconditioned on the non-exercise of one or more of the rights that are\nspecifically granted under this License.  You may not convey a covered\nwork if you are a party to an arrangement with a third party that is\nin the business of distributing software, under which you make payment\nto the third party based on the extent of your activity of conveying\nthe work, and under which the third party grants, to any of the\nparties who would receive the covered work from you, a discriminatory\npatent license (a) in connection with copies of the covered work\nconveyed by you (or copies made from those copies), or (b) primarily\nfor and in connection with specific products or compilations that\ncontain the covered work, unless you entered into that arrangement,\nor that patent license was granted, prior to 28 March 2007.\n\n  Nothing in this License shall be construed as excluding or limiting\nany implied license or other defenses to infringement that may\notherwise be available to you under applicable patent law.\n\n  12. No Surrender of Others' Freedom.\n\n  If conditions are imposed on you (whether by court order, agreement or\notherwise) that contradict the conditions of this License, they do not\nexcuse you from the conditions of this License.  If you cannot convey a\ncovered work so as to satisfy simultaneously your obligations under this\nLicense and any other pertinent obligations, then as a consequence you may\nnot convey it at all.  For example, if you agree to terms that obligate you\nto collect a royalty for further conveying from those to whom you convey\nthe Program, the only way you could satisfy both those terms and this\nLicense would be to refrain entirely from conveying the Program.\n\n  13. Remote Network Interaction; Use with the GNU General Public License.\n\n  Notwithstanding any other provision of this License, if you modify the\nProgram, your modified version must prominently offer all users\ninteracting with it remotely through a computer network (if your version\nsupports such interaction) an opportunity to receive the Corresponding\nSource of your version by providing access to the Corresponding Source\nfrom a network server at no charge, through some standard or customary\nmeans of facilitating copying of software.  This Corresponding Source\nshall include the Corresponding Source for any work covered by version 3\nof the GNU General Public License that is incorporated pursuant to the\nfollowing paragraph.\n\n  Notwithstanding any other provision of this License, you have\npermission to link or combine any covered work with a work licensed\nunder version 3 of the GNU General Public License into a single\ncombined work, and to convey the resulting work.  The terms of this\nLicense will continue to apply to the part which is the covered work,\nbut the work with which it is combined will remain governed by version\n3 of the GNU General Public License.\n\n  14. Revised Versions of this License.\n\n  The Free Software Foundation may publish revised and/or new versions of\nthe GNU Affero General Public License from time to time.  Such new versions\nwill be similar in spirit to the present version, but may differ in detail to\naddress new problems or concerns.\n\n  Each version is given a distinguishing version number.  If the\nProgram specifies that a certain numbered version of the GNU Affero General\nPublic License \"or any later version\" applies to it, you have the\noption of following the terms and conditions either of that numbered\nversion or of any later version published by the Free Software\nFoundation.  If the Program does not specify a version number of the\nGNU Affero General Public License, you may choose any version ever published\nby the Free Software Foundation.\n\n  If the Program specifies that a proxy can decide which future\nversions of the GNU Affero General Public License can be used, that proxy's\npublic statement of acceptance of a version permanently authorizes you\nto choose that version for the Program.\n\n  Later license versions may give you additional or different\npermissions.  However, no additional obligations are imposed on any\nauthor or copyright holder as a result of your choosing to follow a\nlater version.\n\n  15. Disclaimer of Warranty.\n\n  THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY\nAPPLICABLE LAW.  EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT\nHOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM \"AS IS\" WITHOUT WARRANTY\nOF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,\nTHE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR\nPURPOSE.  THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM\nIS WITH YOU.  SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF\nALL NECESSARY SERVICING, REPAIR OR CORRECTION.\n\n  16. Limitation of Liability.\n\n  IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING\nWILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS\nTHE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY\nGENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE\nUSE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF\nDATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD\nPARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),\nEVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF\nSUCH DAMAGES.\n\n  17. Interpretation of Sections 15 and 16.\n\n  If the disclaimer of warranty and limitation of liability provided\nabove cannot be given local legal effect according to their terms,\nreviewing courts shall apply local law that most closely approximates\nan absolute waiver of all civil liability in connection with the\nProgram, unless a warranty or assumption of liability accompanies a\ncopy of the Program in return for a fee.\n\n                     END OF TERMS AND CONDITIONS\n\n            How to Apply These Terms to Your New Programs\n\n  If you develop a new program, and you want it to be of the greatest\npossible use to the public, the best way to achieve this is to make it\nfree software which everyone can redistribute and change under these terms.\n\n  To do so, attach the following notices to the program.  It is safest\nto attach them to the start of each source file to most effectively\nstate the exclusion of warranty; and each file should have at least\nthe \"copyright\" line and a pointer to where the full notice is found.\n\n    <one line to give the program's name and a brief idea of what it does.>\n    Copyright (C) <year>  <name of author>\n\n    This program is free software: you can redistribute it and/or modify\n    it under the terms of the GNU Affero General Public License as published\n    by the Free Software Foundation, either version 3 of the License, or\n    (at your option) any later version.\n\n    This program is distributed in the hope that it will be useful,\n    but WITHOUT ANY WARRANTY; without even the implied warranty of\n    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the\n    GNU Affero General Public License for more details.\n\n    You should have received a copy of the GNU Affero General Public License\n    along with this program.  If not, see <https://www.gnu.org/licenses/>.\n\nAlso add information on how to contact you by electronic and paper mail.\n\n  If your software can interact with users remotely through a computer\nnetwork, you should also make sure that it provides a way for users to\nget its source.  For example, if your program is a web application, its\ninterface could display a \"Source\" link that leads users to an archive\nof the code.  There are many ways you could offer source, and different\nsolutions will be better for different programs; see section 13 for the\nspecific requirements.\n\n  You should also get your employer (if you work as a programmer) or school,\nif any, to sign a \"copyright disclaimer\" for the program, if necessary.\nFor more information on this, and how to apply and follow the GNU AGPL, see\n<https://www.gnu.org/licenses/>."
  },
  {
    "path": "unsloth/kernels/moe/README.md",
    "content": "## MoE Grouped GEMM\n\nOptimized implementation of `MoE MLP Block`.\nLicensed under AGPLv3.\n\n### Background\n\n`MoE MLP` requires the following steps:\n- Calculate `topk_weights` and `topk_indices`\n- If using a grouped gemm implementation, calculate permutation indices needed to rearrange tokens grouped by expert\n- For each expert:\n    - `expert_tokens`: gather the tokens assigned to the expert\n    - `first_gemm`: `gate / up proj` @ `expert_tokens`\n    - `silu_and_mul`: `silu` and `mul` of `first_gemm`\n    - `second_gemm`: `silu_and_mul` @ `down proj`\n    - `scatter_second_gemm`: scatter the `second_gemm` to the original token order\n    - `topk_weight_mul`: `second_gemm` @ `topk_weights`\n    - `final_output`: if `topk > 1`, `topk_weight_mul.view(num_tokens, topk, -1).sum(dim=1)` else `topk_weight_mul`\n\nOne way to eliminate the loop is to use a grouped GEMM, where all expert GEMMs are computed within a single kernel, which iterates over tiles of the expert GEMMs as individual GEMMs, where each GEMM, the `A` matrix is `M' x K` and the `B` matrix is `K x N`, where `M'` is the number of tokens assigned to the expert and `B` is the weight matrix for that expert.\n\nThis requires an additional permute (and subsequent copy) of the hidden states such that the tokens assigned to each expert are contiguous in memory before running the first grouped GEMM within the Expert MLP.\nAdditionally, after the second grouped GEMM, the hidden states must be permuted back to the original token order and multiplied by `topk_weights` to get the final output.\n\n### Optimizations\nThis repo implements a grouped GEMM-based MoE MLP with the following optimizations:\n- Eliminates the loop over experts by performing gemms as a grouped GEMM, computing the expert gemms within a single fused triton kernel\n- Fuses the permutation of hidden states from token order (original input order) to expert order (tokens grouped by expert) within the prologue of first the first grouped GEMM\n- Fuses the (un)permutation of hidden states from expert order back to token order in second GEMM\n- Fuses the mul of hidden states by expert weights within epilogue of second GEMM (only implemented for inference, not for training)\n\n### Structure\n- `grouped_gemm/interface.py`: wrappers for the individual forward / backward kernels as well as the `torch.autograd.Function`\n- `grouped_gemm/kernels/forward.py`: forward kernel\n- `grouped_gemm/kernels/backward.py`: backward dX and dW kernels\n- `grouped_gemm/kernels/tuning.py`: manual tuning utils\n- `grouped_gemm/kernels/autotuning.py`: autotuning utils\n- `grouped_gemm/reference/moe_block.py`: contains `Qwen3MoeFusedGroupedGEMMBlock`, a reference implementation of Huggingface `Qwen3SparseMOEBlock` with fused triton kernel in-place of original HF expert computation\n- `grouped_gemm/reference/moe_ops.py`: supporting ops (routing, token sorting, etc.) and reference MoE block using a torch-native grouped gemm approach.\n\n### Tests\n- `grouped_gemm/tests/test_grouped_gemm.py`: unit tests for forward, backward grouped gemm kernels as well as the wrapped grouped gemm autograd.Function.  Best not to run this entire test suite at once due to the large number of parametrized unit tests.  Rather, use filters to run specific\nsets of tests.  E.g., to run forward tests with autotune turned on: `pytest -sv -k \"forward and autotune\" --tb=short tests/test_grouped_gemm.py`.  Use the test function names and parameter ids for words to filter on.\n- `grouped_gemm/tests/test_qwen3_moe.py`: end to end test for Qwen3 MoE block.  IMPORTANT: read `tests/run_qwen3_moe_tests.sh` as well as notes in the test itself for complications when running parametrized pytest test suites and triton / autotune.  TLDR: use the test script and NOT pytest to run the tests.\n\n### Benchmarks\n- `grouped_gemm/benchmark/benchmark_fused_moe.py`: benchmarks HF `Qwen3SpareMOEBlock` or `Llama4TextMoe` against the fused implementation\n\n\nRunning with these flags on an `H100` to bench forward pass (run with `--help` to see all available flags):\n\nFor `Qwen3-30B-A3B`:\n```\npython benchmark/benchmark_fused_moe.py --model qwen3 --mode forward --seqlen 1024 --permute_x --permute_y --autotune\n```\n\nFor the backward bench:\n```\npython benchmark/benchmark_fused_moe.py --model qwen3 --mode backward --seqlen 1024 --permute_x --permute_y --autotune\n```\n\nFor `Llama-4-Scout-17B-16E`:\n```\npython benchmark/benchmark_fused_moe.py --model llama4 --autotune --mode=forward --permute_y\n```\nDitto for backwards.\n\n### Notes\n- Tested and benched on `H100`, though should run on Ampere and possibly even earlier gpu generations though the autotuning configs will need to be adjusted.\n- The env I used to develop the kernel was `pytorch 2.7/2.8` and `pytorch-triton 3.3`.\n- The kernels can be run either as autotuned (see `autotuning.py`) or with manually specified config (see `tuning.py`).  Recommended to run using autotuner since the MoE block requires 2 configs for the forward (2 grouped gemms) and 4 for the backwards (dX and dW per grouped gemm, 2 grouped gemms).\n- Running with autotuning turned off with the default manual kernel config will result is **highly** sub-optimal performance as it is only meant for testing / debugging purposes.\n- I've tried to strike a balance between compilation time and autotuning search space -- can probably squeeze even more performance for specific workloads.\n- The Llama4 reference layer is still highly under-optimized as there are many low-hanging opportunities for further speedups around routing and shared expert calculation.\n\nTODO:\n- TMA store: implemented but not enabled currently due to non-determinism arising from triton pipelining bug.\n- Warp specialization: Hopper support for WS not yet enabled on triton 3.3x branch which ships with latest pytorch 2.7.  \n- Additional optimizations:\n    - Fused / optimized implementations of routing, token sorting, etc.\n    - Better software pipelining within grouped gemm\n    - Threadblock swizzling for better L2 caching\n    - Llama4\n        - Fused gather / topk weight merging \n        - Custom topk, gather indices kernel\n        - Shared expert fusion with experts calculation"
  },
  {
    "path": "unsloth/kernels/moe/__init__.py",
    "content": ""
  },
  {
    "path": "unsloth/kernels/moe/autotune_cache.py",
    "content": "# Unsloth\n# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.\n#\n# This program is free software: you can redistribute it and/or modify\n# it under the terms of the GNU Affero General Public License as published\n# by the Free Software Foundation, either version 3 of the License, or\n# (at your option) any later version.\n#\n# This program is distributed in the hope that it will be useful,\n# but WITHOUT ANY WARRANTY; without even the implied warranty of\n# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the\n# GNU Affero General Public License for more details.\n#\n# You should have received a copy of the GNU Affero General Public License\n# along with this program.  If not, see <https://www.gnu.org/licenses/>.\n\n\"\"\"\nAuto-tuning cache system for MoE kernels to ensure tuning runs only once at training start.\n\"\"\"\n\nimport hashlib\nimport json\nimport logging\nimport os\nimport time\nfrom typing import Dict, List, Optional, Tuple, Any\nimport torch\nimport triton\n\nlogger = logging.getLogger(__name__)\n\n# Global cache for kernel configurations\n_kernel_config_cache: Dict[str, Any] = {}\n_autotune_completed: Dict[str, bool] = {}\n\n\ndef _get_cache_key(\n    num_experts: int,\n    hidden_dim: int,\n    intermediate_dim: int,\n    top_k: int,\n    dtype: torch.dtype,\n    device_capability: Tuple[int, int],\n    seq_len: int = 8192,  # Default sequence length for tuning\n) -> str:\n    \"\"\"Generate a unique cache key based on model configuration.\"\"\"\n    key_data = {\n        \"num_experts\": num_experts,\n        \"hidden_dim\": hidden_dim,\n        \"intermediate_dim\": intermediate_dim,\n        \"top_k\": top_k,\n        \"dtype\": str(dtype),\n        \"device_capability\": device_capability,\n        \"seq_len\": seq_len,\n    }\n    key_str = json.dumps(key_data, sort_keys = True)\n    return hashlib.md5(key_str.encode()).hexdigest()\n\n\ndef _get_cache_file_path(cache_key: str) -> str:\n    \"\"\"Get the file path for the cache file.\"\"\"\n    cache_dir = os.path.expanduser(\"~/.cache/unsloth/moe_autotune\")\n    os.makedirs(cache_dir, exist_ok = True)\n    return os.path.join(cache_dir, f\"{cache_key}.json\")\n\n\ndef load_cached_config(cache_key: str) -> Optional[Dict[str, Any]]:\n    \"\"\"Load cached kernel configuration from disk.\"\"\"\n    cache_file = _get_cache_file_path(cache_key)\n    if not os.path.exists(cache_file):\n        return None\n\n    try:\n        with open(cache_file, \"r\") as f:\n            cached_data = json.load(f)\n\n        # Verify cache is still valid (same device, etc.)\n        current_device_capability = torch.cuda.get_device_capability()\n        if cached_data.get(\"device_capability\") != current_device_capability:\n            logger.info(\"Device capability changed, invalidating cache\")\n            os.remove(cache_file)\n            return None\n\n        logger.info(f\"Loaded cached MoE kernel config: {cache_key}\")\n        return cached_data\n    except Exception as e:\n        logger.warning(f\"Failed to load cache file {cache_file}: {e}\")\n        try:\n            os.remove(cache_file)\n        except:\n            pass\n        return None\n\n\ndef save_cached_config(\n    cache_key: str,\n    config_fwd: Any,\n    config_bwd_dx: Any,\n    config_bwd_dw: Any,\n    metadata: Dict[str, Any] = None,\n) -> None:\n    \"\"\"Save kernel configuration to disk cache.\"\"\"\n    cache_file = _get_cache_file_path(cache_key)\n\n    cache_data = {\n        \"timestamp\": time.time(),\n        \"device_capability\": torch.cuda.get_device_capability(),\n        \"config_fwd\": config_fwd.__dict__\n        if hasattr(config_fwd, \"__dict__\")\n        else str(config_fwd),\n        \"config_bwd_dx\": config_bwd_dx.__dict__\n        if hasattr(config_bwd_dx, \"__dict__\")\n        else str(config_bwd_dx),\n        \"config_bwd_dw\": config_bwd_dw.__dict__\n        if hasattr(config_bwd_dw, \"__dict__\")\n        else str(config_bwd_dw),\n        \"metadata\": metadata or {},\n    }\n\n    try:\n        with open(cache_file, \"w\") as f:\n            json.dump(cache_data, f, indent = 2)\n        logger.info(f\"Saved MoE kernel config cache: {cache_key}\")\n    except Exception as e:\n        logger.warning(f\"Failed to save cache file {cache_file}: {e}\")\n\n\ndef get_or_autotune_moe_kernels(\n    num_experts: int,\n    hidden_dim: int,\n    intermediate_dim: int,\n    top_k: int,\n    dtype: torch.dtype,\n    force_autotune: bool = False,\n    seq_len: int = 8192,\n) -> Tuple[Any, Any, Any]:\n    \"\"\"\n    Get cached kernel configurations or run auto-tuning.\n\n    Args:\n        num_experts: Number of experts in the MoE layer\n        hidden_dim: Hidden dimension of the model\n        intermediate_dim: Intermediate dimension for MoE MLP\n        top_k: Number of experts to route to\n        dtype: Data type for computation\n        force_autotune: Force re-running autotuning even if cache exists\n        seq_len: Sequence length to use for tuning benchmarks\n\n    Returns:\n        Tuple of (config_fwd, config_bwd_dx, config_bwd_dw)\n    \"\"\"\n    device_capability = torch.cuda.get_device_capability()\n    cache_key = _get_cache_key(\n        num_experts,\n        hidden_dim,\n        intermediate_dim,\n        top_k,\n        dtype,\n        device_capability,\n        seq_len,\n    )\n\n    # 0. Check for environment variable override to DISABLE autotuning\n    if os.environ.get(\"UNSLOTH_MOE_DISABLE_AUTOTUNE\", \"0\") == \"1\":\n        logger.info(\n            f\"UNSLOTH_MOE_DISABLE_AUTOTUNE=1: Using Heuristic (Safe) MoE kernel configs for SM{device_capability[0]}{device_capability[1]}\"\n        )\n        return _get_heuristic_configs()\n    if not force_autotune and cache_key in _kernel_config_cache:\n        logger.info(f\"Using in-memory cached MoE kernel configs: {cache_key}\")\n        return _kernel_config_cache[cache_key]\n\n    # Try to load from disk\n    if not force_autotune:\n        cached_data = load_cached_config(cache_key)\n        if cached_data is not None:\n            # Reconstruct config objects from cached data\n            try:\n                from .grouped_gemm.kernels.tuning import (\n                    KernelConfigForward,\n                    KernelConfigBackward_dX,\n                    KernelConfigBackward_dW,\n                )\n\n                config_fwd = KernelConfigForward(**cached_data[\"config_fwd\"])\n                config_bwd_dx = KernelConfigBackward_dX(**cached_data[\"config_bwd_dx\"])\n                config_bwd_dw = KernelConfigBackward_dW(**cached_data[\"config_bwd_dw\"])\n\n                configs = (config_fwd, config_bwd_dx, config_bwd_dw)\n                _kernel_config_cache[cache_key] = configs\n                return configs\n            except Exception as e:\n                logger.warning(f\"Failed to reconstruct cached configs: {e}\")\n\n    # Run autotuning\n    if cache_key in _autotune_completed and not force_autotune:\n        logger.info(f\"Autotuning already completed for: {cache_key}\")\n        return _kernel_config_cache[cache_key]\n\n    logger.info(f\"Running MoE kernel auto-tuning for: {cache_key}\")\n    logger.info(\n        f\"Configuration: {num_experts} experts, {hidden_dim} hidden, {intermediate_dim} intermediate, top_k={top_k}\"\n    )\n\n    try:\n        configs = _run_moe_autotuning(\n            num_experts, hidden_dim, intermediate_dim, top_k, dtype, seq_len\n        )\n\n        # Cache the results\n        _kernel_config_cache[cache_key] = configs\n        _autotune_completed[cache_key] = True\n\n        # Save to disk\n        config_fwd, config_bwd_dx, config_bwd_dw = configs\n        save_cached_config(\n            cache_key,\n            config_fwd,\n            config_bwd_dx,\n            config_bwd_dw,\n            {\n                \"num_experts\": num_experts,\n                \"hidden_dim\": hidden_dim,\n                \"intermediate_dim\": intermediate_dim,\n            },\n        )\n\n        logger.info(f\"MoE kernel auto-tuning completed: {cache_key}\")\n        return configs\n\n    except Exception as e:\n        logger.error(f\"MoE kernel auto-tuning failed: {e}\")\n        if \"AttributeError\" in str(e) and \"_experimental_make_tensor_descriptor\" in str(\n            e\n        ):\n            logger.warning(\n                \"Unsloth: Your Triton version might be incompatible with TMA features. Falling back to default configs.\"\n            )\n        logger.info(\"Falling back to default kernel configurations\")\n        return _get_default_configs()\n\n\ndef _run_moe_autotuning(\n    num_experts: int,\n    hidden_dim: int,\n    intermediate_dim: int,\n    top_k: int,\n    dtype: torch.dtype,\n    seq_len: int,\n) -> Tuple[Any, Any, Any]:\n    \"\"\"Run the actual auto-tuning for MoE kernels.\"\"\"\n\n    # Create dummy inputs for tuning\n    device = \"cuda\"\n    # Use a fixed, safe number of tokens for autotuning to avoid OOMs and dependency on seq_len\n    # 4096 is standard for finding good kernels without consuming 10GB+ VRAM\n    # We ignore the passed seq_len for the actual allocation to satisfy user request\n    num_tokens = 4096\n    total_tokens = num_tokens * top_k\n\n    # Create dummy tensors\n    hidden_states = torch.randn(num_tokens, hidden_dim, device = device, dtype = dtype)\n\n    # Create dummy weights\n    gate_up_weights = torch.randn(\n        num_experts, 2 * intermediate_dim, hidden_dim, device = device, dtype = dtype\n    )\n    down_weights = torch.randn(\n        num_experts, hidden_dim, intermediate_dim, device = device, dtype = dtype\n    )\n\n    # Create dummy routing data\n    m_sizes = torch.randint(\n        1, total_tokens // num_experts + 1, (num_experts,), device = device\n    )\n    m_sizes = m_sizes * (total_tokens // m_sizes.sum().item())\n    # Adjust to ensure exact total\n    diff = total_tokens - m_sizes.sum().item()\n    if diff != 0:\n        m_sizes[0] += diff\n\n    gather_indices = torch.arange(total_tokens, device = device)\n    torch.randperm(total_tokens, out = gather_indices)\n\n    # Autotune forward kernel - use the interface function with autotune=True\n    # This properly invokes the kernel and lets triton handle the autotuning\n    from .grouped_gemm.interface import (\n        grouped_gemm_forward,\n        grouped_gemm_dX,\n        grouped_gemm_dW,\n    )\n    from .grouped_gemm.kernels.forward import _autotuned_grouped_gemm_forward_kernel\n    from .grouped_gemm.kernels.backward import (\n        _autotuned_grouped_gemm_dX_kernel,\n        _autotuned_grouped_gemm_dW_kernel,\n    )\n    from .grouped_gemm.kernels.tuning import (\n        KernelConfigForward,\n        KernelConfigBackward_dX,\n        KernelConfigBackward_dW,\n    )\n\n    logger.info(\"Autotuning forward kernel (first GEMM)...\")\n    # Run with autotune=True to trigger autotuning\n    _ = grouped_gemm_forward(\n        X = hidden_states,\n        W = gate_up_weights,\n        topk = top_k,\n        m_sizes = m_sizes,\n        gather_indices = gather_indices,\n        permute_x = True,\n        permute_y = False,\n        autotune = True,\n    )\n    triton_config_fwd = _autotuned_grouped_gemm_forward_kernel.best_config\n\n    # Convert triton.Config to KernelConfigForward\n    config_fwd = KernelConfigForward(\n        BLOCK_SIZE_M = triton_config_fwd.kwargs[\"BLOCK_SIZE_M\"],\n        BLOCK_SIZE_N = triton_config_fwd.kwargs[\"BLOCK_SIZE_N\"],\n        BLOCK_SIZE_K = triton_config_fwd.kwargs[\"BLOCK_SIZE_K\"],\n        num_warps = triton_config_fwd.num_warps,\n        num_stages = triton_config_fwd.num_stages,\n        use_tma_load_x = triton_config_fwd.kwargs.get(\"USE_TMA_LOAD_X\", False),\n        use_tma_load_w = triton_config_fwd.kwargs.get(\"USE_TMA_LOAD_W\", False),\n        use_tma_store = triton_config_fwd.kwargs.get(\"USE_TMA_STORE\", False),\n    )\n\n    # Autotune backward dX kernel\n    logger.info(\"Autotuning backward dX kernel...\")\n    dummy_grad = torch.randn(\n        total_tokens, 2 * intermediate_dim, device = device, dtype = dtype\n    )\n    _ = grouped_gemm_dX(\n        dY = dummy_grad,\n        W = gate_up_weights,\n        gather_indices = gather_indices,\n        m_sizes = m_sizes,\n        topk = top_k,\n        permute_x = True,\n        permute_y = False,\n        autotune = True,\n    )\n    triton_config_bwd_dx = _autotuned_grouped_gemm_dX_kernel.best_config\n\n    # Convert triton.Config to KernelConfigBackward_dX\n    config_bwd_dx = KernelConfigBackward_dX(\n        BLOCK_SIZE_M = triton_config_bwd_dx.kwargs[\"BLOCK_SIZE_M\"],\n        BLOCK_SIZE_N = triton_config_bwd_dx.kwargs[\"BLOCK_SIZE_N\"],\n        BLOCK_SIZE_K = triton_config_bwd_dx.kwargs[\"BLOCK_SIZE_K\"],\n        num_warps = triton_config_bwd_dx.num_warps,\n        num_stages = triton_config_bwd_dx.num_stages,\n        use_tma_load_dy = triton_config_bwd_dx.kwargs.get(\"USE_TMA_LOAD_dY\", False),\n        use_tma_load_w = triton_config_bwd_dx.kwargs.get(\"USE_TMA_LOAD_W\", False),\n        use_tma_store = triton_config_bwd_dx.kwargs.get(\"USE_TMA_STORE\", False),\n    )\n\n    # Autotune backward dW kernel\n    logger.info(\"Autotuning backward dW kernel...\")\n    _ = grouped_gemm_dW(\n        X = hidden_states,\n        dY = dummy_grad,\n        m_sizes = m_sizes,\n        gather_indices = gather_indices,\n        topk = top_k,\n        permute_x = True,\n        permute_y = False,\n        autotune = True,\n    )\n    triton_config_bwd_dw = _autotuned_grouped_gemm_dW_kernel.best_config\n\n    # Convert triton.Config to KernelConfigBackward_dW\n    config_bwd_dw = KernelConfigBackward_dW(\n        BLOCK_SIZE_M = triton_config_bwd_dw.kwargs[\"BLOCK_SIZE_M\"],\n        BLOCK_SIZE_N = triton_config_bwd_dw.kwargs[\"BLOCK_SIZE_N\"],\n        BLOCK_SIZE_K = triton_config_bwd_dw.kwargs[\"BLOCK_SIZE_K\"],\n        num_warps = triton_config_bwd_dw.num_warps,\n        num_stages = triton_config_bwd_dw.num_stages,\n        use_tma_load_dy = triton_config_bwd_dw.kwargs.get(\"USE_TMA_LOAD_dY\", False),\n        use_tma_load_x = triton_config_bwd_dw.kwargs.get(\"USE_TMA_LOAD_X\", False),\n        use_tma_store = triton_config_bwd_dw.kwargs.get(\"USE_TMA_STORE\", False),\n    )\n\n    return config_fwd, config_bwd_dx, config_bwd_dw\n\n    return config_fwd, config_bwd_dx, config_bwd_dw\n\n\ndef _get_heuristic_configs() -> Tuple[Any, Any, Any]:\n    \"\"\"\n    Get 'Safe Heuristic' kernel configurations.\n    These are verified to be safe on A100 (SM80) and provide ~9x speedup on H100/B200.\n    \"\"\"\n    from .grouped_gemm.kernels.tuning import (\n        KernelConfigForward,\n        KernelConfigBackward_dX,\n        KernelConfigBackward_dW,\n    )\n\n    # Safe Forward Config: 64x128x128 (Fits A100 SMEM)\n    config_fwd = KernelConfigForward(\n        BLOCK_SIZE_M = 64,\n        BLOCK_SIZE_N = 128,\n        BLOCK_SIZE_K = 128,\n        num_warps = 8,\n        num_stages = 3,\n        permute_x = True,\n        permute_y = True,\n        use_tma_load_x = False,\n        use_tma_load_w = False,  # TMA loads might need alignment checks, safer to disable for heuristic\n        use_tma_store = False,\n    )\n\n    # Safe Backward Configs: 64x64x256\n    config_bwd_dx = KernelConfigBackward_dX(\n        BLOCK_SIZE_M = 64,\n        BLOCK_SIZE_N = 64,\n        BLOCK_SIZE_K = 256,\n        num_warps = 8,\n        num_stages = 4,\n        permute_x = True,\n        permute_y = True,\n        use_tma_load_dy = False,\n        use_tma_load_w = False,\n        use_tma_store = False,\n    )\n\n    config_bwd_dw = KernelConfigBackward_dW(\n        BLOCK_SIZE_M = 64,\n        BLOCK_SIZE_N = 64,\n        BLOCK_SIZE_K = 256,\n        num_warps = 8,\n        num_stages = 4,\n        permute_x = True,\n        permute_y = True,\n        use_tma_load_dy = False,\n        use_tma_load_x = False,\n        use_tma_store = False,\n    )\n\n    return config_fwd, config_bwd_dx, config_bwd_dw\n\n\ndef _get_default_configs() -> Tuple[Any, Any, Any]:\n    \"\"\"Get default kernel configurations as fallback.\"\"\"\n    from .grouped_gemm.kernels.tuning import (\n        KernelConfigForward,\n        KernelConfigBackward_dX,\n        KernelConfigBackward_dW,\n    )\n\n    logger.warning(\"Using default MoE kernel configurations (not optimal)\")\n\n    config_fwd = KernelConfigForward(\n        BLOCK_SIZE_M = 128,\n        BLOCK_SIZE_N = 128,\n        BLOCK_SIZE_K = 64,\n        num_warps = 8,\n        num_stages = 3,\n        use_tma_load_x = False,\n        use_tma_load_w = False,\n        use_tma_store = False,\n    )\n\n    config_bwd_dx = KernelConfigBackward_dX(\n        BLOCK_SIZE_M = 128,\n        BLOCK_SIZE_N = 128,\n        BLOCK_SIZE_K = 64,\n        num_warps = 8,\n        num_stages = 3,\n        use_tma_load_dy = False,\n        use_tma_load_w = False,\n        use_tma_store = False,\n    )\n\n    config_bwd_dw = KernelConfigBackward_dW(\n        BLOCK_SIZE_M = 128,\n        BLOCK_SIZE_N = 128,\n        BLOCK_SIZE_K = 64,\n        num_warps = 8,\n        num_stages = 3,\n        use_tma_load_dy = False,\n        use_tma_load_x = False,\n        use_tma_store = False,\n    )\n\n    return config_fwd, config_bwd_dx, config_bwd_dw\n\n\ndef clear_cache() -> None:\n    \"\"\"Clear all cached kernel configurations.\"\"\"\n    global _kernel_config_cache, _autotune_completed\n    _kernel_config_cache.clear()\n    _autotune_completed.clear()\n    logger.info(\"Cleared MoE kernel cache\")\n\n\ndef is_autotuning_completed(cache_key: str) -> bool:\n    \"\"\"Check if autotuning has been completed for a given cache key.\"\"\"\n    return cache_key in _autotune_completed\n"
  },
  {
    "path": "unsloth/kernels/moe/benchmark/benchmark_fused_moe.py",
    "content": "import argparse\nimport time\nfrom contextlib import nullcontext\n\nimport torch\nfrom transformers import AutoConfig\nfrom transformers.models.llama4 import Llama4TextConfig\nfrom transformers.models.llama4.modeling_llama4 import Llama4TextMoe\nfrom transformers.models.qwen3_moe import Qwen3MoeConfig\nfrom transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock\nfrom triton.testing import do_bench\nfrom utils import (\n    create_kernel_configs,\n    get_autotuner,\n    post_process_results,\n    postprocess_autotune_results,\n    save_results,\n)\n\nfrom grouped_gemm.kernels.autotuning import (\n    DEFAULT_K_BLOCK_SIZES,\n    DEFAULT_M_BLOCK_SIZES,\n    DEFAULT_N_BLOCK_SIZES,\n    DEFAULT_NUM_STAGES,\n    DEFAULT_NUM_WARPS,\n)\nfrom grouped_gemm.kernels.tuning import (\n    KernelConfigBackward_dW,\n    KernelConfigBackward_dX,\n    KernelConfigForward,\n    KernelResult,\n    TritonTuningContext,\n)\nfrom grouped_gemm.reference.layers.llama4_moe import Llama4TritonTextMoe\nfrom grouped_gemm.reference.layers.qwen3_moe import Qwen3MoeFusedGroupedGEMMBlock\n\nSEED = 42\nLLAMA4_ID = \"meta-llama/Llama-4-Scout-17B-16E\"\nQWEN3_MODEL_ID = \"Qwen/Qwen3-30B-A3B\"\n\n\ndef run_benchmark_forward(\n    ref_model: torch.nn.Module,\n    tt_model: torch.nn.Module,\n    config: AutoConfig,\n    seqlen: int,\n    dtype: torch.dtype,\n    autotune: bool,\n    kernel_config_fwd: KernelConfigForward = None,\n    bs: int = 1,\n):\n    torch.manual_seed(\n        SEED\n    )  # Should not be needed when running using pytest -- autouse fixture in conftest.py\n    device = \"cuda\"\n    hidden_size = config.hidden_size\n\n    X = torch.randn(\n        bs, seqlen, hidden_size, dtype = dtype, device = device, requires_grad = True\n    )\n\n    # Forward\n    bench_forward_ref = lambda: ref_model(X)  # noqa: E731\n    bench_forward_fused = lambda: tt_model(X)  # noqa: E731\n\n    ref_forward_time = do_bench(bench_forward_ref)\n\n    if not autotune:\n        assert kernel_config_fwd is not None\n        tuning_context = TritonTuningContext(kernel_config_fwd)\n    else:\n        tuning_context = nullcontext()\n\n    with tuning_context:\n        fused_forward_time = do_bench(bench_forward_fused)\n\n    if (not autotune) and (not tuning_context.success):\n        return 0, 1\n\n    print(\n        f\"Forward: ref {ref_forward_time:.4f}, fused {fused_forward_time:.4f}, speedup {ref_forward_time / fused_forward_time:.1f}x\"\n    )\n    return ref_forward_time, fused_forward_time\n\n\ndef run_benchmark_backward(\n    ref_model: torch.nn.Module,\n    tt_model: torch.nn.Module,\n    config: AutoConfig,\n    seqlen: int,\n    dtype: torch.dtype,\n    bs = 1,\n):\n    torch.manual_seed(\n        SEED\n    )  # Should not be needed when running using pytest -- autouse fixture in conftest.py\n    device = \"cuda\"\n    hidden_size = config.hidden_size\n\n    X = torch.randn(\n        bs, seqlen, hidden_size, dtype = dtype, device = device, requires_grad = True\n    )\n    X_test = X.detach().clone().requires_grad_(True)\n\n    output, _ = ref_model(X)\n\n    # Prevent autotuning forward pass\n    from grouped_gemm.kernels.forward import _autotuned_grouped_gemm_forward_kernel\n\n    _autotuned_grouped_gemm_forward_kernel.configs = (\n        _autotuned_grouped_gemm_forward_kernel.configs[:20]\n    )\n    test_output, _ = tt_model(X_test)\n\n    # Bench\n    grad_output = torch.randn_like(output)\n    bench_backward_ref = lambda: output.backward(grad_output, retain_graph = True)  # noqa: E731\n    bench_backward_fused = lambda: test_output.backward(grad_output, retain_graph = True)  # noqa: E731\n\n    ref_backward_time = do_bench(\n        bench_backward_ref, grad_to_none = [X, *ref_model.parameters()]\n    )\n    fused_backward_time = do_bench(\n        bench_backward_fused, grad_to_none = [X_test, *tt_model.parameters()]\n    )\n    print(\n        f\"Backward: ref {ref_backward_time:.4f}, fused {fused_backward_time:.4f}, speedup {ref_backward_time / fused_backward_time:.1f}x\"\n    )\n    return ref_backward_time, fused_backward_time\n\n\ndef setup_model(\n    config: Qwen3MoeConfig | Llama4TextConfig,\n    dtype,\n    permute_x,\n    permute_y,\n    autotune,\n    kernel_config_fwd,\n    kernel_config_bwd_dW,\n    kernel_config_bwd_dX,\n    dX_only = False,\n    dW_only = False,\n    overlap_router_shared = False,\n    device = \"cuda\",\n):\n    if isinstance(config, Qwen3MoeConfig):\n        ref_model = Qwen3MoeSparseMoeBlock(config).to(device, dtype)\n\n        # Triton kernel grouped gemm version of MoE Block -- this is what we're testing\n        tt_model = Qwen3MoeFusedGroupedGEMMBlock.from_hf(\n            ref_model,\n            permute_x = permute_x,\n            permute_y = permute_y,\n            autotune = autotune,\n            kernel_config_fwd = kernel_config_fwd,\n            kernel_config_bwd_dW = kernel_config_bwd_dW,\n            kernel_config_bwd_dX = kernel_config_bwd_dX,\n            dX_only = dX_only,\n            dW_only = dW_only,\n        ).to(device, dtype)\n\n    elif isinstance(config, Llama4TextConfig):\n        ref_model = Llama4TextMoe(config).to(device, dtype)\n        tt_model = Llama4TritonTextMoe(\n            config,\n            overlap_router_shared = overlap_router_shared,\n            permute_x = permute_x,\n            permute_y = permute_y,\n            autotune = autotune,\n            kernel_config_fwd = kernel_config_fwd,\n            kernel_config_bwd_dW = kernel_config_bwd_dW,\n            kernel_config_bwd_dX = kernel_config_bwd_dX,\n            dX_only = dX_only,\n            dW_only = dW_only,\n        ).to(device, dtype)\n\n    else:\n        raise ValueError(f\"Unrecognized config {type(config).__name__}\")\n\n    return ref_model, tt_model\n\n\ndef run_benchmark(\n    mode: str,\n    model_config: Qwen3MoeConfig | Llama4TextConfig,\n    seqlen: int,\n    dtype: torch.dtype,\n    permute_x: bool,\n    permute_y: bool,\n    autotune: bool,\n    kernel_config_fwd: KernelConfigForward = None,\n    kernel_config_bwd_dW: KernelConfigBackward_dW = None,\n    kernel_config_bwd_dX: KernelConfigBackward_dX = None,\n    overlap_router_shared: bool = False,\n    results_dir: str = None,\n):\n    if autotune:\n        autotuner = get_autotuner(mode)\n    if mode == \"dW\":\n        dW_only = True\n    elif mode == \"dX\":\n        dX_only = True\n    else:\n        dW_only = dX_only = False\n\n    ref_model, tt_model = setup_model(\n        model_config,\n        dtype = dtype,\n        permute_x = permute_x,\n        permute_y = permute_y,\n        autotune = autotune,\n        kernel_config_fwd = kernel_config_fwd,\n        kernel_config_bwd_dW = kernel_config_bwd_dW,\n        kernel_config_bwd_dX = kernel_config_bwd_dX,\n        dX_only = dX_only,\n        dW_only = dW_only,\n        overlap_router_shared = overlap_router_shared,\n    )\n\n    if mode == \"forward\":\n        ref_time, fused_time = run_benchmark_forward(\n            ref_model,\n            tt_model,\n            config = model_config,\n            seqlen = seqlen,\n            dtype = dtype,\n            autotune = autotune,\n            kernel_config_fwd = kernel_config_fwd,\n        )\n    else:\n        ref_time, fused_time = run_benchmark_backward(\n            ref_model, tt_model, config = model_config, seqlen = seqlen, dtype = dtype\n        )\n\n    if autotune:\n        if mode == \"backward\":\n            autotuner_dW, autotuner_dX = autotuner\n            postprocess_autotune_results(\n                autotuner_dW, \"dW\", ref_time, fused_time, results_dir\n            )\n            postprocess_autotune_results(\n                autotuner_dX, \"dX\", ref_time, fused_time, results_dir\n            )\n        else:\n            postprocess_autotune_results(\n                autotuner, mode, ref_time, fused_time, results_dir\n            )\n\n    return ref_time, fused_time\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--results_dir\", type = str, default = \"benchmark_results\")\n    parser.add_argument(\"--model\", type = str, choices = [\"llama4\", \"qwen3\"], required = True)\n    parser.add_argument(\"--seqlen\", type = int, default = 1024)\n    parser.add_argument(\n        \"--dtype\", type = str, choices = [\"bfloat16\", \"float16\"], default = \"bfloat16\"\n    )\n    parser.add_argument(\"--permute_x\", action = \"store_true\")\n    parser.add_argument(\"--permute_y\", action = \"store_true\")\n    parser.add_argument(\"--autotune\", action = \"store_true\")\n    parser.add_argument(\"--overlap_router_shared\", action = \"store_true\")\n    parser.add_argument(\n        \"--BLOCK_SIZE_M\",\n        nargs = 2,\n        type = int,\n        default = [DEFAULT_M_BLOCK_SIZES[0], DEFAULT_M_BLOCK_SIZES[-1]],\n    )\n    parser.add_argument(\n        \"--BLOCK_SIZE_N\",\n        nargs = 2,\n        type = int,\n        default = [DEFAULT_N_BLOCK_SIZES[0], DEFAULT_N_BLOCK_SIZES[-1]],\n    )\n    parser.add_argument(\n        \"--BLOCK_SIZE_K\",\n        nargs = 2,\n        type = int,\n        default = [DEFAULT_K_BLOCK_SIZES[0], DEFAULT_K_BLOCK_SIZES[-1]],\n    )\n    parser.add_argument(\n        \"--num_warps\",\n        nargs = 2,\n        type = int,\n        default = [DEFAULT_NUM_WARPS[0], DEFAULT_NUM_WARPS[-1]],\n    )\n    parser.add_argument(\n        \"--num_stages\",\n        nargs = 2,\n        type = int,\n        default = [DEFAULT_NUM_STAGES[0], DEFAULT_NUM_STAGES[-1]],\n    )\n    parser.add_argument(\n        \"--use_tma_load_w\", action = \"store_true\"\n    )  # No need to specify, will automatically parametrize these for each kernel config\n    parser.add_argument(\n        \"--use_tma_load_x\", action = \"store_true\"\n    )  # No need to specify, will automatically parametrize these for each kernel config\n    parser.add_argument(\n        \"--use_tma_load_dy\", action = \"store_true\"\n    )  # No need to specify, will automatically parametrize these for each kernel config\n    parser.add_argument(\n        \"--mode\",\n        type = str,\n        choices = [\"forward\", \"backward\", \"dW\", \"dX\"],\n        default = \"forward\",\n    )\n    args = parser.parse_args()\n    args.dtype = getattr(torch, args.dtype)\n\n    model_id = QWEN3_MODEL_ID if args.model == \"qwen3\" else LLAMA4_ID\n    model_config = AutoConfig.from_pretrained(model_id)\n    model_config = model_config.text_config if args.model == \"llama4\" else model_config\n\n    mode = args.mode\n\n    if args.autotune:\n        # logging.basicConfig(level=logging.INFO)\n        print(\n            f\"Benchmarking {model_id} {mode}: seqlen={args.seqlen}, dtype={args.dtype}, permute_x={args.permute_x}, permute_y={args.permute_y}, autotune\"\n        )\n        start_time = time.time()\n        ref_time, fused_time = run_benchmark(\n            args.mode,\n            model_config,\n            seqlen = args.seqlen,\n            dtype = args.dtype,\n            permute_x = args.permute_x,\n            permute_y = args.permute_y,\n            autotune = args.autotune,\n            overlap_router_shared = args.overlap_router_shared,\n            results_dir = args.results_dir,\n        )\n        end_time = time.time()\n        print(f\"Total time: {end_time - start_time:.4f} seconds\")\n\n    # NOTE: better to use autotuner for now, since the MoE block needs 2 different kernel configs for forward (2 grouped gemms, gate_up_proj and down_proj)\n    # and the backward pass needs 4 different kernel configs (2 grouped gemms each for dW and dX)\n    # The benchmark only supports 1 kernel config at a time so the same config will be used for both grouped gemms, which is suboptimal.\n    else:\n        assert False, \"Use autotune for now\"\n        kernel_configs = create_kernel_configs(args, args.permute_x, args.permute_y)\n        print(f\"Running {len(kernel_configs)} kernel configs\")\n        default_kernel_config_fwd = KernelConfigForward(\n            permute_x = args.permute_x, permute_y = args.permute_y\n        )\n        default_kernel_config_bwd_dW = KernelConfigBackward_dW(\n            permute_x = args.permute_x, permute_y = args.permute_y\n        )\n        default_kernel_config_bwd_dX = KernelConfigBackward_dX(\n            permute_x = args.permute_x, permute_y = args.permute_y\n        )\n        results = []\n        for kernel_config in kernel_configs:\n            if args.mode == \"forward\":\n                kernel_config_fwd = kernel_config\n                kernel_config_bwd_dW = default_kernel_config_bwd_dW\n                kernel_config_bwd_dX = default_kernel_config_bwd_dX\n            elif args.mode == \"dW\":\n                kernel_config_fwd = default_kernel_config_fwd\n                kernel_config_bwd_dW = kernel_config\n                kernel_config_bwd_dX = default_kernel_config_bwd_dX\n            elif args.mode == \"dX\":\n                kernel_config_fwd = default_kernel_config_fwd\n                kernel_config_bwd_dW = default_kernel_config_bwd_dW\n                kernel_config_bwd_dX = kernel_config\n            else:\n                raise ValueError(f\"Invalid mode: {args.mode}\")\n            print(\n                f\"Benchmarking {model_id} {args.mode} with seqlen={args.seqlen}, dtype={args.dtype}, permute_x={args.permute_x}, permute_y={args.permute_y}, kernel_config_fwd={kernel_config_fwd}, kernel_config_bwd_dW={kernel_config_bwd_dW}, kernel_config_bwd_dX={kernel_config_bwd_dX}\"\n            )\n\n            ref_time, fused_time = run_benchmark(\n                args.mode,\n                model_config,\n                seqlen = args.seqlen,\n                dtype = args.dtype,\n                permute_x = kernel_config.permute_x,\n                permute_y = kernel_config.permute_y,\n                autotune = False,\n                kernel_config_fwd = kernel_config_fwd,\n                kernel_config_bwd_dW = kernel_config_bwd_dW,\n                kernel_config_bwd_dX = kernel_config_bwd_dX,\n            )\n            results.append(\n                KernelResult(\n                    torch_time = ref_time,\n                    triton_time = fused_time,\n                    speedup = ref_time / fused_time,\n                    kernel_config = kernel_config,\n                )\n            )\n        df = post_process_results(\n            results, args.mode, args.seqlen, args.dtype, args.autotune\n        )\n        save_results(\n            df, args.results_dir, args.mode, args.seqlen, args.dtype, args.autotune\n        )\n"
  },
  {
    "path": "unsloth/kernels/moe/benchmark/utils.py",
    "content": "import argparse\nimport datetime\nimport json\nimport logging\nimport math\nimport os\nfrom itertools import product\n\nimport pandas as pd\nimport torch\n\nfrom grouped_gemm.kernels.tuning import (\n    KernelConfigBackward_dW,\n    KernelConfigBackward_dX,\n    KernelConfigForward,\n    KernelResult,\n)\n\nSEED = 42\n\n\ndef create_merged_results(\n    df: pd.DataFrame, mode: str, seqlen: int, dtype: torch.dtype, autotune: bool\n):\n    kernel_result_cols = df.columns.to_list()\n    test_config_dict = {\n        \"mode\": mode,\n        \"seqlen\": seqlen,\n        \"dtype\": dtype,\n        \"autotune\": autotune,\n    }\n    test_config_cols = list(test_config_dict.keys())\n    for col in test_config_cols:\n        df[col] = test_config_dict[col]\n    # Reorder columns so that test config cols are first\n    df = df[test_config_cols + kernel_result_cols]\n    return df\n\n\ndef post_process_results(\n    results: list[KernelResult],\n    mode: str,\n    seqlen: int,\n    dtype: torch.dtype,\n    autotune: bool,\n):\n    df = KernelResult.to_dataframe(results, sort_by = \"speedup\")\n    df = create_merged_results(df, mode, seqlen, dtype, autotune)\n    return df\n\n\ndef save_results(\n    df: pd.DataFrame,\n    results_dir: str,\n    mode: str,\n    seqlen: int,\n    dtype: torch.dtype,\n    autotune: bool,\n):\n    dt = datetime.datetime.now().strftime(\"%Y%m%d_%H%M\")\n    save_dir = f\"{results_dir}/{mode}\"\n    save_path = f\"{save_dir}/{dt}_{seqlen}_{str(dtype).split('.')[-1]}.csv\"\n    if not os.path.exists(save_dir):\n        os.makedirs(save_dir)\n    print(f\"Saving results to {save_path}\")\n    df.to_csv(save_path, index = False)\n\n\ndef create_kernel_configs(args: argparse.Namespace, permute_x: bool, permute_y: bool):\n    block_m_range = power_of_two_range(args.BLOCK_SIZE_M[0], args.BLOCK_SIZE_M[1])\n    block_n_range = power_of_two_range(args.BLOCK_SIZE_N[0], args.BLOCK_SIZE_N[1])\n    block_k_range = power_of_two_range(args.BLOCK_SIZE_K[0], args.BLOCK_SIZE_K[1])\n    num_warps_range = multiples_of_range(args.num_warps[0], args.num_warps[1], step = 2)\n    num_stages_range = multiples_of_range(\n        args.num_stages[0], args.num_stages[1], step = 1\n    )\n\n    mode = args.mode\n    kernel_configs = []\n    for (\n        block_m,\n        block_n,\n        block_k,\n        num_warps,\n        num_stages,\n        tma_load_a,\n        tma_load_b,\n    ) in product(\n        block_m_range,\n        block_n_range,\n        block_k_range,\n        num_warps_range,\n        num_stages_range,\n        [True, False],\n        [True, False],\n    ):\n        if mode == \"forward\":\n            kernel_config = KernelConfigForward(\n                BLOCK_SIZE_M = block_m,\n                BLOCK_SIZE_N = block_n,\n                BLOCK_SIZE_K = block_k,\n                num_warps = num_warps,\n                num_stages = num_stages,\n                use_tma_load_w = tma_load_a,\n                use_tma_load_x = tma_load_b,\n                permute_x = permute_x,\n                permute_y = permute_y,\n            )\n        elif mode == \"dW\":\n            kernel_config = KernelConfigBackward_dW(\n                BLOCK_SIZE_M = block_m,\n                BLOCK_SIZE_N = block_n,\n                BLOCK_SIZE_K = block_k,\n                num_warps = num_warps,\n                num_stages = num_stages,\n                use_tma_load_dy = tma_load_a,\n                use_tma_load_x = tma_load_b,\n                permute_x = permute_x,\n                permute_y = permute_y,\n            )\n        elif mode == \"dX\":\n            kernel_config = KernelConfigBackward_dX(\n                BLOCK_SIZE_M = block_m,\n                BLOCK_SIZE_N = block_n,\n                BLOCK_SIZE_K = block_k,\n                num_warps = num_warps,\n                num_stages = num_stages,\n                use_tma_load_dy = tma_load_a,\n                use_tma_load_w = tma_load_b,\n                permute_x = permute_x,\n                permute_y = permute_y,\n            )\n        else:\n            raise ValueError(f\"Invalid mode: {mode}\")\n        kernel_configs.append(kernel_config)\n\n    logging.info(f\"Pruning {len(kernel_configs)} kernel configs\")\n\n    pruned_configs = []\n    for config in kernel_configs:\n        if mode == \"forward\":\n            if permute_x and config.use_tma_load_x:\n                continue\n        elif mode == \"dW\":\n            if permute_x and config.use_tma_load_x:\n                continue\n            if permute_y and config.use_tma_load_dy:\n                continue\n        elif mode == \"dX\":\n            if permute_y and config.use_tma_load_dy:\n                continue\n        pruned_configs.append(config)\n    logging.info(f\"After pruning, {len(pruned_configs)} kernel configs\")\n\n    return pruned_configs\n\n\ndef power_of_two_range(start, end):\n    start = math.log2(start)\n    end = math.log2(end)\n    return [2**i for i in range(int(start), int(end) + 1)]\n\n\ndef multiples_of_range(start, end, step = 1):\n    return list(range(start, end + step, step))\n\n\ndef map_key_to_args(key, mode):\n    pass\n\n\ndef save_autotune_results(autotune_cache, mode, ref_time, fused_time, results_dir):\n    device_name = torch.cuda.get_device_name().replace(\" \", \"_\")\n    dt = datetime.datetime.now().strftime(\"%Y%m%d_%H%M\")\n    save_dir = f\"{results_dir}/{mode}/autotune/{dt}/{device_name}\"\n    if not os.path.exists(save_dir):\n        os.makedirs(save_dir)\n\n    for key, config in autotune_cache.items():\n        key = [\n            str(k) if not \"torch\" in str(k) else str(k.split(\"torch.\")[-1]) for k in key\n        ]\n        filename = \"_\".join(key)\n        save_path = f\"{save_dir}/{filename}.json\"\n        print(f\"Saving autotune results to {save_path}\")\n        with open(save_path, \"w\") as f:\n            result = {\n                **config.all_kwargs(),\n                \"ref_time\": ref_time,\n                \"fused_time\": fused_time,\n            }\n            json.dump(result, f)\n\n\ndef get_autotuner(mode):\n    if mode == \"forward\":\n        from grouped_gemm.kernels.forward import _autotuned_grouped_gemm_forward_kernel\n\n        return _autotuned_grouped_gemm_forward_kernel\n    elif mode == \"dW\":\n        from grouped_gemm.kernels.backward import _autotuned_grouped_gemm_dW_kernel\n\n        return _autotuned_grouped_gemm_dW_kernel\n    elif mode == \"dX\":\n        from grouped_gemm.kernels.backward import _autotuned_grouped_gemm_dX_kernel\n\n        return _autotuned_grouped_gemm_dX_kernel\n    elif mode == \"backward\":\n        from grouped_gemm.kernels.backward import (\n            _autotuned_grouped_gemm_dW_kernel,\n            _autotuned_grouped_gemm_dX_kernel,\n        )\n\n        return _autotuned_grouped_gemm_dW_kernel, _autotuned_grouped_gemm_dX_kernel\n    else:\n        raise ValueError(f\"Invalid mode: {mode}\")\n\n\ndef postprocess_autotune_results(autotuner, mode, ref_time, fused_time, results_dir):\n    for key, value in autotuner.cache.items():\n        print(f\"{mode} {key}: {value.all_kwargs()}\")\n    save_autotune_results(\n        autotuner.cache,\n        mode = mode,\n        ref_time = ref_time,\n        fused_time = fused_time,\n        results_dir = results_dir,\n    )\n"
  },
  {
    "path": "unsloth/kernels/moe/grouped_gemm/LICENSE",
    "content": "                    GNU AFFERO GENERAL PUBLIC LICENSE\n                       Version 3, 19 November 2007\n\n Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>\n Everyone is permitted to copy and distribute verbatim copies\n of this license document, but changing it is not allowed.\n\n                            Preamble\n\n  The GNU Affero General Public License is a free, copyleft license for\nsoftware and other kinds of works, specifically designed to ensure\ncooperation with the community in the case of network server software.\n\n  The licenses for most software and other practical works are designed\nto take away your freedom to share and change the works.  By contrast,\nour General Public Licenses are intended to guarantee your freedom to\nshare and change all versions of a program--to make sure it remains free\nsoftware for all its users.\n\n  When we speak of free software, we are referring to freedom, not\nprice.  Our General Public Licenses are designed to make sure that you\nhave the freedom to distribute copies of free software (and charge for\nthem if you wish), that you receive source code or can get it if you\nwant it, that you can change the software or use pieces of it in new\nfree programs, and that you know you can do these things.\n\n  Developers that use our General Public Licenses protect your rights\nwith two steps: (1) assert copyright on the software, and (2) offer\nyou this License which gives you legal permission to copy, distribute\nand/or modify the software.\n\n  A secondary benefit of defending all users' freedom is that\nimprovements made in alternate versions of the program, if they\nreceive widespread use, become available for other developers to\nincorporate.  Many developers of free software are heartened and\nencouraged by the resulting cooperation.  However, in the case of\nsoftware used on network servers, this result may fail to come about.\nThe GNU General Public License permits making a modified version and\nletting the public access it on a server without ever releasing its\nsource code to the public.\n\n  The GNU Affero General Public License is designed specifically to\nensure that, in such cases, the modified source code becomes available\nto the community.  It requires the operator of a network server to\nprovide the source code of the modified version running there to the\nusers of that server.  Therefore, public use of a modified version, on\na publicly accessible server, gives the public access to the source\ncode of the modified version.\n\n  An older license, called the Affero General Public License and\npublished by Affero, was designed to accomplish similar goals.  This is\na different license, not a version of the Affero GPL, but Affero has\nreleased a new version of the Affero GPL which permits relicensing under\nthis license.\n\n  The precise terms and conditions for copying, distribution and\nmodification follow.\n\n                       TERMS AND CONDITIONS\n\n  0. Definitions.\n\n  \"This License\" refers to version 3 of the GNU Affero General Public License.\n\n  \"Copyright\" also means copyright-like laws that apply to other kinds of\nworks, such as semiconductor masks.\n\n  \"The Program\" refers to any copyrightable work licensed under this\nLicense.  Each licensee is addressed as \"you\".  \"Licensees\" and\n\"recipients\" may be individuals or organizations.\n\n  To \"modify\" a work means to copy from or adapt all or part of the work\nin a fashion requiring copyright permission, other than the making of an\nexact copy.  The resulting work is called a \"modified version\" of the\nearlier work or a work \"based on\" the earlier work.\n\n  A \"covered work\" means either the unmodified Program or a work based\non the Program.\n\n  To \"propagate\" a work means to do anything with it that, without\npermission, would make you directly or secondarily liable for\ninfringement under applicable copyright law, except executing it on a\ncomputer or modifying a private copy.  Propagation includes copying,\ndistribution (with or without modification), making available to the\npublic, and in some countries other activities as well.\n\n  To \"convey\" a work means any kind of propagation that enables other\nparties to make or receive copies.  Mere interaction with a user through\na computer network, with no transfer of a copy, is not conveying.\n\n  An interactive user interface displays \"Appropriate Legal Notices\"\nto the extent that it includes a convenient and prominently visible\nfeature that (1) displays an appropriate copyright notice, and (2)\ntells the user that there is no warranty for the work (except to the\nextent that warranties are provided), that licensees may convey the\nwork under this License, and how to view a copy of this License.  If\nthe interface presents a list of user commands or options, such as a\nmenu, a prominent item in the list meets this criterion.\n\n  1. Source Code.\n\n  The \"source code\" for a work means the preferred form of the work\nfor making modifications to it.  \"Object code\" means any non-source\nform of a work.\n\n  A \"Standard Interface\" means an interface that either is an official\nstandard defined by a recognized standards body, or, in the case of\ninterfaces specified for a particular programming language, one that\nis widely used among developers working in that language.\n\n  The \"System Libraries\" of an executable work include anything, other\nthan the work as a whole, that (a) is included in the normal form of\npackaging a Major Component, but which is not part of that Major\nComponent, and (b) serves only to enable use of the work with that\nMajor Component, or to implement a Standard Interface for which an\nimplementation is available to the public in source code form.  A\n\"Major Component\", in this context, means a major essential component\n(kernel, window system, and so on) of the specific operating system\n(if any) on which the executable work runs, or a compiler used to\nproduce the work, or an object code interpreter used to run it.\n\n  The \"Corresponding Source\" for a work in object code form means all\nthe source code needed to generate, install, and (for an executable\nwork) run the object code and to modify the work, including scripts to\ncontrol those activities.  However, it does not include the work's\nSystem Libraries, or general-purpose tools or generally available free\nprograms which are used unmodified in performing those activities but\nwhich are not part of the work.  For example, Corresponding Source\nincludes interface definition files associated with source files for\nthe work, and the source code for shared libraries and dynamically\nlinked subprograms that the work is specifically designed to require,\nsuch as by intimate data communication or control flow between those\nsubprograms and other parts of the work.\n\n  The Corresponding Source need not include anything that users\ncan regenerate automatically from other parts of the Corresponding\nSource.\n\n  The Corresponding Source for a work in source code form is that\nsame work.\n\n  2. Basic Permissions.\n\n  All rights granted under this License are granted for the term of\ncopyright on the Program, and are irrevocable provided the stated\nconditions are met.  This License explicitly affirms your unlimited\npermission to run the unmodified Program.  The output from running a\ncovered work is covered by this License only if the output, given its\ncontent, constitutes a covered work.  This License acknowledges your\nrights of fair use or other equivalent, as provided by copyright law.\n\n  You may make, run and propagate covered works that you do not\nconvey, without conditions so long as your license otherwise remains\nin force.  You may convey covered works to others for the sole purpose\nof having them make modifications exclusively for you, or provide you\nwith facilities for running those works, provided that you comply with\nthe terms of this License in conveying all material for which you do\nnot control copyright.  Those thus making or running the covered works\nfor you must do so exclusively on your behalf, under your direction\nand control, on terms that prohibit them from making any copies of\nyour copyrighted material outside their relationship with you.\n\n  Conveying under any other circumstances is permitted solely under\nthe conditions stated below.  Sublicensing is not allowed; section 10\nmakes it unnecessary.\n\n  3. Protecting Users' Legal Rights From Anti-Circumvention Law.\n\n  No covered work shall be deemed part of an effective technological\nmeasure under any applicable law fulfilling obligations under article\n11 of the WIPO copyright treaty adopted on 20 December 1996, or\nsimilar laws prohibiting or restricting circumvention of such\nmeasures.\n\n  When you convey a covered work, you waive any legal power to forbid\ncircumvention of technological measures to the extent such circumvention\nis effected by exercising rights under this License with respect to\nthe covered work, and you disclaim any intention to limit operation or\nmodification of the work as a means of enforcing, against the work's\nusers, your or third parties' legal rights to forbid circumvention of\ntechnological measures.\n\n  4. Conveying Verbatim Copies.\n\n  You may convey verbatim copies of the Program's source code as you\nreceive it, in any medium, provided that you conspicuously and\nappropriately publish on each copy an appropriate copyright notice;\nkeep intact all notices stating that this License and any\nnon-permissive terms added in accord with section 7 apply to the code;\nkeep intact all notices of the absence of any warranty; and give all\nrecipients a copy of this License along with the Program.\n\n  You may charge any price or no price for each copy that you convey,\nand you may offer support or warranty protection for a fee.\n\n  5. Conveying Modified Source Versions.\n\n  You may convey a work based on the Program, or the modifications to\nproduce it from the Program, in the form of source code under the\nterms of section 4, provided that you also meet all of these conditions:\n\n    a) The work must carry prominent notices stating that you modified\n    it, and giving a relevant date.\n\n    b) The work must carry prominent notices stating that it is\n    released under this License and any conditions added under section\n    7.  This requirement modifies the requirement in section 4 to\n    \"keep intact all notices\".\n\n    c) You must license the entire work, as a whole, under this\n    License to anyone who comes into possession of a copy.  This\n    License will therefore apply, along with any applicable section 7\n    additional terms, to the whole of the work, and all its parts,\n    regardless of how they are packaged.  This License gives no\n    permission to license the work in any other way, but it does not\n    invalidate such permission if you have separately received it.\n\n    d) If the work has interactive user interfaces, each must display\n    Appropriate Legal Notices; however, if the Program has interactive\n    interfaces that do not display Appropriate Legal Notices, your\n    work need not make them do so.\n\n  A compilation of a covered work with other separate and independent\nworks, which are not by their nature extensions of the covered work,\nand which are not combined with it such as to form a larger program,\nin or on a volume of a storage or distribution medium, is called an\n\"aggregate\" if the compilation and its resulting copyright are not\nused to limit the access or legal rights of the compilation's users\nbeyond what the individual works permit.  Inclusion of a covered work\nin an aggregate does not cause this License to apply to the other\nparts of the aggregate.\n\n  6. Conveying Non-Source Forms.\n\n  You may convey a covered work in object code form under the terms\nof sections 4 and 5, provided that you also convey the\nmachine-readable Corresponding Source under the terms of this License,\nin one of these ways:\n\n    a) Convey the object code in, or embodied in, a physical product\n    (including a physical distribution medium), accompanied by the\n    Corresponding Source fixed on a durable physical medium\n    customarily used for software interchange.\n\n    b) Convey the object code in, or embodied in, a physical product\n    (including a physical distribution medium), accompanied by a\n    written offer, valid for at least three years and valid for as\n    long as you offer spare parts or customer support for that product\n    model, to give anyone who possesses the object code either (1) a\n    copy of the Corresponding Source for all the software in the\n    product that is covered by this License, on a durable physical\n    medium customarily used for software interchange, for a price no\n    more than your reasonable cost of physically performing this\n    conveying of source, or (2) access to copy the\n    Corresponding Source from a network server at no charge.\n\n    c) Convey individual copies of the object code with a copy of the\n    written offer to provide the Corresponding Source.  This\n    alternative is allowed only occasionally and noncommercially, and\n    only if you received the object code with such an offer, in accord\n    with subsection 6b.\n\n    d) Convey the object code by offering access from a designated\n    place (gratis or for a charge), and offer equivalent access to the\n    Corresponding Source in the same way through the same place at no\n    further charge.  You need not require recipients to copy the\n    Corresponding Source along with the object code.  If the place to\n    copy the object code is a network server, the Corresponding Source\n    may be on a different server (operated by you or a third party)\n    that supports equivalent copying facilities, provided you maintain\n    clear directions next to the object code saying where to find the\n    Corresponding Source.  Regardless of what server hosts the\n    Corresponding Source, you remain obligated to ensure that it is\n    available for as long as needed to satisfy these requirements.\n\n    e) Convey the object code using peer-to-peer transmission, provided\n    you inform other peers where the object code and Corresponding\n    Source of the work are being offered to the general public at no\n    charge under subsection 6d.\n\n  A separable portion of the object code, whose source code is excluded\nfrom the Corresponding Source as a System Library, need not be\nincluded in conveying the object code work.\n\n  A \"User Product\" is either (1) a \"consumer product\", which means any\ntangible personal property which is normally used for personal, family,\nor household purposes, or (2) anything designed or sold for incorporation\ninto a dwelling.  In determining whether a product is a consumer product,\ndoubtful cases shall be resolved in favor of coverage.  For a particular\nproduct received by a particular user, \"normally used\" refers to a\ntypical or common use of that class of product, regardless of the status\nof the particular user or of the way in which the particular user\nactually uses, or expects or is expected to use, the product.  A product\nis a consumer product regardless of whether the product has substantial\ncommercial, industrial or non-consumer uses, unless such uses represent\nthe only significant mode of use of the product.\n\n  \"Installation Information\" for a User Product means any methods,\nprocedures, authorization keys, or other information required to install\nand execute modified versions of a covered work in that User Product from\na modified version of its Corresponding Source.  The information must\nsuffice to ensure that the continued functioning of the modified object\ncode is in no case prevented or interfered with solely because\nmodification has been made.\n\n  If you convey an object code work under this section in, or with, or\nspecifically for use in, a User Product, and the conveying occurs as\npart of a transaction in which the right of possession and use of the\nUser Product is transferred to the recipient in perpetuity or for a\nfixed term (regardless of how the transaction is characterized), the\nCorresponding Source conveyed under this section must be accompanied\nby the Installation Information.  But this requirement does not apply\nif neither you nor any third party retains the ability to install\nmodified object code on the User Product (for example, the work has\nbeen installed in ROM).\n\n  The requirement to provide Installation Information does not include a\nrequirement to continue to provide support service, warranty, or updates\nfor a work that has been modified or installed by the recipient, or for\nthe User Product in which it has been modified or installed.  Access to a\nnetwork may be denied when the modification itself materially and\nadversely affects the operation of the network or violates the rules and\nprotocols for communication across the network.\n\n  Corresponding Source conveyed, and Installation Information provided,\nin accord with this section must be in a format that is publicly\ndocumented (and with an implementation available to the public in\nsource code form), and must require no special password or key for\nunpacking, reading or copying.\n\n  7. Additional Terms.\n\n  \"Additional permissions\" are terms that supplement the terms of this\nLicense by making exceptions from one or more of its conditions.\nAdditional permissions that are applicable to the entire Program shall\nbe treated as though they were included in this License, to the extent\nthat they are valid under applicable law.  If additional permissions\napply only to part of the Program, that part may be used separately\nunder those permissions, but the entire Program remains governed by\nthis License without regard to the additional permissions.\n\n  When you convey a copy of a covered work, you may at your option\nremove any additional permissions from that copy, or from any part of\nit.  (Additional permissions may be written to require their own\nremoval in certain cases when you modify the work.)  You may place\nadditional permissions on material, added by you to a covered work,\nfor which you have or can give appropriate copyright permission.\n\n  Notwithstanding any other provision of this License, for material you\nadd to a covered work, you may (if authorized by the copyright holders of\nthat material) supplement the terms of this License with terms:\n\n    a) Disclaiming warranty or limiting liability differently from the\n    terms of sections 15 and 16 of this License; or\n\n    b) Requiring preservation of specified reasonable legal notices or\n    author attributions in that material or in the Appropriate Legal\n    Notices displayed by works containing it; or\n\n    c) Prohibiting misrepresentation of the origin of that material, or\n    requiring that modified versions of such material be marked in\n    reasonable ways as different from the original version; or\n\n    d) Limiting the use for publicity purposes of names of licensors or\n    authors of the material; or\n\n    e) Declining to grant rights under trademark law for use of some\n    trade names, trademarks, or service marks; or\n\n    f) Requiring indemnification of licensors and authors of that\n    material by anyone who conveys the material (or modified versions of\n    it) with contractual assumptions of liability to the recipient, for\n    any liability that these contractual assumptions directly impose on\n    those licensors and authors.\n\n  All other non-permissive additional terms are considered \"further\nrestrictions\" within the meaning of section 10.  If the Program as you\nreceived it, or any part of it, contains a notice stating that it is\ngoverned by this License along with a term that is a further\nrestriction, you may remove that term.  If a license document contains\na further restriction but permits relicensing or conveying under this\nLicense, you may add to a covered work material governed by the terms\nof that license document, provided that the further restriction does\nnot survive such relicensing or conveying.\n\n  If you add terms to a covered work in accord with this section, you\nmust place, in the relevant source files, a statement of the\nadditional terms that apply to those files, or a notice indicating\nwhere to find the applicable terms.\n\n  Additional terms, permissive or non-permissive, may be stated in the\nform of a separately written license, or stated as exceptions;\nthe above requirements apply either way.\n\n  8. Termination.\n\n  You may not propagate or modify a covered work except as expressly\nprovided under this License.  Any attempt otherwise to propagate or\nmodify it is void, and will automatically terminate your rights under\nthis License (including any patent licenses granted under the third\nparagraph of section 11).\n\n  However, if you cease all violation of this License, then your\nlicense from a particular copyright holder is reinstated (a)\nprovisionally, unless and until the copyright holder explicitly and\nfinally terminates your license, and (b) permanently, if the copyright\nholder fails to notify you of the violation by some reasonable means\nprior to 60 days after the cessation.\n\n  Moreover, your license from a particular copyright holder is\nreinstated permanently if the copyright holder notifies you of the\nviolation by some reasonable means, this is the first time you have\nreceived notice of violation of this License (for any work) from that\ncopyright holder, and you cure the violation prior to 30 days after\nyour receipt of the notice.\n\n  Termination of your rights under this section does not terminate the\nlicenses of parties who have received copies or rights from you under\nthis License.  If your rights have been terminated and not permanently\nreinstated, you do not qualify to receive new licenses for the same\nmaterial under section 10.\n\n  9. Acceptance Not Required for Having Copies.\n\n  You are not required to accept this License in order to receive or\nrun a copy of the Program.  Ancillary propagation of a covered work\noccurring solely as a consequence of using peer-to-peer transmission\nto receive a copy likewise does not require acceptance.  However,\nnothing other than this License grants you permission to propagate or\nmodify any covered work.  These actions infringe copyright if you do\nnot accept this License.  Therefore, by modifying or propagating a\ncovered work, you indicate your acceptance of this License to do so.\n\n  10. Automatic Licensing of Downstream Recipients.\n\n  Each time you convey a covered work, the recipient automatically\nreceives a license from the original licensors, to run, modify and\npropagate that work, subject to this License.  You are not responsible\nfor enforcing compliance by third parties with this License.\n\n  An \"entity transaction\" is a transaction transferring control of an\norganization, or substantially all assets of one, or subdividing an\norganization, or merging organizations.  If propagation of a covered\nwork results from an entity transaction, each party to that\ntransaction who receives a copy of the work also receives whatever\nlicenses to the work the party's predecessor in interest had or could\ngive under the previous paragraph, plus a right to possession of the\nCorresponding Source of the work from the predecessor in interest, if\nthe predecessor has it or can get it with reasonable efforts.\n\n  You may not impose any further restrictions on the exercise of the\nrights granted or affirmed under this License.  For example, you may\nnot impose a license fee, royalty, or other charge for exercise of\nrights granted under this License, and you may not initiate litigation\n(including a cross-claim or counterclaim in a lawsuit) alleging that\nany patent claim is infringed by making, using, selling, offering for\nsale, or importing the Program or any portion of it.\n\n  11. Patents.\n\n  A \"contributor\" is a copyright holder who authorizes use under this\nLicense of the Program or a work on which the Program is based.  The\nwork thus licensed is called the contributor's \"contributor version\".\n\n  A contributor's \"essential patent claims\" are all patent claims\nowned or controlled by the contributor, whether already acquired or\nhereafter acquired, that would be infringed by some manner, permitted\nby this License, of making, using, or selling its contributor version,\nbut do not include claims that would be infringed only as a\nconsequence of further modification of the contributor version.  For\npurposes of this definition, \"control\" includes the right to grant\npatent sublicenses in a manner consistent with the requirements of\nthis License.\n\n  Each contributor grants you a non-exclusive, worldwide, royalty-free\npatent license under the contributor's essential patent claims, to\nmake, use, sell, offer for sale, import and otherwise run, modify and\npropagate the contents of its contributor version.\n\n  In the following three paragraphs, a \"patent license\" is any express\nagreement or commitment, however denominated, not to enforce a patent\n(such as an express permission to practice a patent or covenant not to\nsue for patent infringement).  To \"grant\" such a patent license to a\nparty means to make such an agreement or commitment not to enforce a\npatent against the party.\n\n  If you convey a covered work, knowingly relying on a patent license,\nand the Corresponding Source of the work is not available for anyone\nto copy, free of charge and under the terms of this License, through a\npublicly available network server or other readily accessible means,\nthen you must either (1) cause the Corresponding Source to be so\navailable, or (2) arrange to deprive yourself of the benefit of the\npatent license for this particular work, or (3) arrange, in a manner\nconsistent with the requirements of this License, to extend the patent\nlicense to downstream recipients.  \"Knowingly relying\" means you have\nactual knowledge that, but for the patent license, your conveying the\ncovered work in a country, or your recipient's use of the covered work\nin a country, would infringe one or more identifiable patents in that\ncountry that you have reason to believe are valid.\n\n  If, pursuant to or in connection with a single transaction or\narrangement, you convey, or propagate by procuring conveyance of, a\ncovered work, and grant a patent license to some of the parties\nreceiving the covered work authorizing them to use, propagate, modify\nor convey a specific copy of the covered work, then the patent license\nyou grant is automatically extended to all recipients of the covered\nwork and works based on it.\n\n  A patent license is \"discriminatory\" if it does not include within\nthe scope of its coverage, prohibits the exercise of, or is\nconditioned on the non-exercise of one or more of the rights that are\nspecifically granted under this License.  You may not convey a covered\nwork if you are a party to an arrangement with a third party that is\nin the business of distributing software, under which you make payment\nto the third party based on the extent of your activity of conveying\nthe work, and under which the third party grants, to any of the\nparties who would receive the covered work from you, a discriminatory\npatent license (a) in connection with copies of the covered work\nconveyed by you (or copies made from those copies), or (b) primarily\nfor and in connection with specific products or compilations that\ncontain the covered work, unless you entered into that arrangement,\nor that patent license was granted, prior to 28 March 2007.\n\n  Nothing in this License shall be construed as excluding or limiting\nany implied license or other defenses to infringement that may\notherwise be available to you under applicable patent law.\n\n  12. No Surrender of Others' Freedom.\n\n  If conditions are imposed on you (whether by court order, agreement or\notherwise) that contradict the conditions of this License, they do not\nexcuse you from the conditions of this License.  If you cannot convey a\ncovered work so as to satisfy simultaneously your obligations under this\nLicense and any other pertinent obligations, then as a consequence you may\nnot convey it at all.  For example, if you agree to terms that obligate you\nto collect a royalty for further conveying from those to whom you convey\nthe Program, the only way you could satisfy both those terms and this\nLicense would be to refrain entirely from conveying the Program.\n\n  13. Remote Network Interaction; Use with the GNU General Public License.\n\n  Notwithstanding any other provision of this License, if you modify the\nProgram, your modified version must prominently offer all users\ninteracting with it remotely through a computer network (if your version\nsupports such interaction) an opportunity to receive the Corresponding\nSource of your version by providing access to the Corresponding Source\nfrom a network server at no charge, through some standard or customary\nmeans of facilitating copying of software.  This Corresponding Source\nshall include the Corresponding Source for any work covered by version 3\nof the GNU General Public License that is incorporated pursuant to the\nfollowing paragraph.\n\n  Notwithstanding any other provision of this License, you have\npermission to link or combine any covered work with a work licensed\nunder version 3 of the GNU General Public License into a single\ncombined work, and to convey the resulting work.  The terms of this\nLicense will continue to apply to the part which is the covered work,\nbut the work with which it is combined will remain governed by version\n3 of the GNU General Public License.\n\n  14. Revised Versions of this License.\n\n  The Free Software Foundation may publish revised and/or new versions of\nthe GNU Affero General Public License from time to time.  Such new versions\nwill be similar in spirit to the present version, but may differ in detail to\naddress new problems or concerns.\n\n  Each version is given a distinguishing version number.  If the\nProgram specifies that a certain numbered version of the GNU Affero General\nPublic License \"or any later version\" applies to it, you have the\noption of following the terms and conditions either of that numbered\nversion or of any later version published by the Free Software\nFoundation.  If the Program does not specify a version number of the\nGNU Affero General Public License, you may choose any version ever published\nby the Free Software Foundation.\n\n  If the Program specifies that a proxy can decide which future\nversions of the GNU Affero General Public License can be used, that proxy's\npublic statement of acceptance of a version permanently authorizes you\nto choose that version for the Program.\n\n  Later license versions may give you additional or different\npermissions.  However, no additional obligations are imposed on any\nauthor or copyright holder as a result of your choosing to follow a\nlater version.\n\n  15. Disclaimer of Warranty.\n\n  THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY\nAPPLICABLE LAW.  EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT\nHOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM \"AS IS\" WITHOUT WARRANTY\nOF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,\nTHE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR\nPURPOSE.  THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM\nIS WITH YOU.  SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF\nALL NECESSARY SERVICING, REPAIR OR CORRECTION.\n\n  16. Limitation of Liability.\n\n  IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING\nWILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS\nTHE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY\nGENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE\nUSE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF\nDATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD\nPARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),\nEVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF\nSUCH DAMAGES.\n\n  17. Interpretation of Sections 15 and 16.\n\n  If the disclaimer of warranty and limitation of liability provided\nabove cannot be given local legal effect according to their terms,\nreviewing courts shall apply local law that most closely approximates\nan absolute waiver of all civil liability in connection with the\nProgram, unless a warranty or assumption of liability accompanies a\ncopy of the Program in return for a fee.\n\n                     END OF TERMS AND CONDITIONS\n\n            How to Apply These Terms to Your New Programs\n\n  If you develop a new program, and you want it to be of the greatest\npossible use to the public, the best way to achieve this is to make it\nfree software which everyone can redistribute and change under these terms.\n\n  To do so, attach the following notices to the program.  It is safest\nto attach them to the start of each source file to most effectively\nstate the exclusion of warranty; and each file should have at least\nthe \"copyright\" line and a pointer to where the full notice is found.\n\n    <one line to give the program's name and a brief idea of what it does.>\n    Copyright (C) <year>  <name of author>\n\n    This program is free software: you can redistribute it and/or modify\n    it under the terms of the GNU Affero General Public License as published\n    by the Free Software Foundation, either version 3 of the License, or\n    (at your option) any later version.\n\n    This program is distributed in the hope that it will be useful,\n    but WITHOUT ANY WARRANTY; without even the implied warranty of\n    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the\n    GNU Affero General Public License for more details.\n\n    You should have received a copy of the GNU Affero General Public License\n    along with this program.  If not, see <https://www.gnu.org/licenses/>.\n\nAlso add information on how to contact you by electronic and paper mail.\n\n  If your software can interact with users remotely through a computer\nnetwork, you should also make sure that it provides a way for users to\nget its source.  For example, if your program is a web application, its\ninterface could display a \"Source\" link that leads users to an archive\nof the code.  There are many ways you could offer source, and different\nsolutions will be better for different programs; see section 13 for the\nspecific requirements.\n\n  You should also get your employer (if you work as a programmer) or school,\nif any, to sign a \"copyright disclaimer\" for the program, if necessary.\nFor more information on this, and how to apply and follow the GNU AGPL, see\n<https://www.gnu.org/licenses/>."
  },
  {
    "path": "unsloth/kernels/moe/grouped_gemm/__init__.py",
    "content": ""
  },
  {
    "path": "unsloth/kernels/moe/grouped_gemm/interface.py",
    "content": "# SPDX-License-Identifier: GNU Affero General Public License v3.0\n# Copyright 2023-present the Unsloth team. All rights reserved.\n\nimport logging\nimport warnings\nfrom dataclasses import asdict\nfrom unsloth import DEVICE_TYPE\n\nimport torch\nimport triton\n\nfrom .kernels.backward import (\n    _autotuned_grouped_gemm_dW_kernel,\n    _autotuned_grouped_gemm_dX_kernel,\n    _grouped_gemm_dW_kernel,\n    _grouped_gemm_dX_kernel,\n)\nfrom .kernels.forward import (\n    _autotuned_grouped_gemm_forward_kernel,\n    _grouped_gemm_forward_kernel,\n)\nfrom .kernels.tuning import (\n    KernelConfigBackward_dW,\n    KernelConfigBackward_dX,\n    KernelConfigForward,\n)\n\nlogger = logging.getLogger(__name__)\n# Set formatter to include timestamp, pathname and lineno\nformatter = logging.Formatter(\n    \"%(asctime)s::%(levelname)s,%(pathname)s:%(lineno)d:: %(message)s\"\n)\n\n# Add console handler\nch = logging.StreamHandler()\nch.setFormatter(formatter)\nlogger.addHandler(ch)\n\n\n# Precompute TMA support to avoid graph breaks\n# TMA requires both:\n# 1. NVIDIA GPU with capability >= 9 (Hopper+)\n# 2. Triton version with TMA API (make_tensor_descriptor or _experimental_make_tensor_descriptor)\ndef _check_tma_support():\n    if DEVICE_TYPE in (\"xpu\", \"hip\"):\n        return False\n    import triton.language as tl\n\n    gpu_supports_tma = torch.cuda.get_device_capability()[0] >= 9\n    # Check for both old experimental and new stable API names\n    triton_has_tma_api = hasattr(tl, \"make_tensor_descriptor\") or hasattr(\n        tl, \"_experimental_make_tensor_descriptor\"\n    )\n    return gpu_supports_tma and triton_has_tma_api\n\n\n_SUPPORTS_TMA = _check_tma_support()\n\n# Check if triton.set_allocator is available (Triton 3.0+)\n_HAS_SET_ALLOCATOR = hasattr(triton, \"set_allocator\")\n\n\ndef supports_tma():\n    return _SUPPORTS_TMA\n\n\n# Helper to support allow_in_graph\ntry:\n    from torch.compiler import allow_in_graph\nexcept ImportError:\n    from torch._dynamo import allow_in_graph\n\n\n# Helper to detect if we're in tracing/compilation mode\ndef _is_tracing(*tensors):\n    \"\"\"\n    Check if tensors are fake tensors used during torch.compile tracing.\n    During tracing, tensors are FakeTensor/FunctionalTensor and we can't run Triton kernels.\n    During execution, tensors are real Tensors and we MUST run the kernels.\n\n    NOTE: We do NOT use torch.compiler.is_compiling() because it returns True\n    during both tracing AND execution. We only want to skip kernels during tracing\n    when tensors are actually fake.\n    \"\"\"\n    for t in tensors:\n        name = type(t).__name__\n        if name in (\"FakeTensor\", \"FunctionalTensor\", \"FunctionalTensorWrapper\"):\n            return True\n    return False\n\n\n_per_device_alloc_fns = {}\n\n\ndef get_per_device_per_stream_alloc_fn(device):\n    if device not in _per_device_alloc_fns:\n        _per_stream_tensors = {}\n\n        def alloc_fn(size: int, alignment: int, stream):\n            assert alignment == 128\n            if (\n                stream not in _per_stream_tensors\n                or _per_stream_tensors[stream].numel() < size\n            ):\n                _per_stream_tensors[stream] = torch.empty(\n                    size, device = device, dtype = torch.int8\n                )\n                _per_stream_tensors[stream].__hibernate__ = {\"type\": \"ignore\"}\n            return _per_stream_tensors[stream]\n\n        _per_device_alloc_fns[device] = alloc_fn\n    return _per_device_alloc_fns[device]\n\n\ndef log_kernel_info(\n    compiled_kernel: triton.compiler.CompiledKernel, best_config: triton.Config = None\n):\n    kernel_name = compiled_kernel.name\n    nregs = compiled_kernel.n_regs\n    nspills = compiled_kernel.n_spills\n    metadata = compiled_kernel.metadata\n    logger.debug(\n        f\"{kernel_name}: n_regs={nregs} n_spills={nspills} metadata={metadata}\"\n    )\n    if best_config is not None:\n        logger.debug(f\"{kernel_name} autotuned best_config: {best_config}\")\n\n\n@allow_in_graph\ndef grouped_gemm_forward(\n    X: torch.Tensor,\n    W: torch.Tensor,\n    topk: int,\n    m_sizes: torch.Tensor,\n    gather_indices: torch.Tensor = None,\n    topk_weights: torch.Tensor = None,\n    # Fusions\n    permute_x: bool = False,\n    permute_y: bool = False,\n    fuse_mul_post: bool = False,\n    # Autotuning - manual kernel params will be ignored if autotune is True\n    autotune: bool = False,\n    # Kernel tuning params if not autotuning -- NOTE: these params need to be tuned, otherwise performance will be poor\n    BLOCK_SIZE_M: int = 32,\n    BLOCK_SIZE_N: int = 32,\n    BLOCK_SIZE_K: int = 32,\n    num_warps: int = 4,\n    num_stages: int = 2,\n    use_tma_load_w: bool = False,\n    use_tma_load_x: bool = False,\n    use_tma_store: bool = False,\n    # software pipelining -- set to True for now, won't impact until loop is re-written\n    flatten: bool = True,\n    # debugging\n    debug: bool = False,\n) -> torch.Tensor:\n    \"\"\"\n    Grouped GEMM forward pass for MoE MLPs.\n\n    The implementation offers a number of fusions specific to MoE:\n    - `permute_x`: fuse the permutation of hidden states from token order (original order) to grouped expert order, typically only needed for the first grouped GEMM in an MoE MLP.\n        - When `permute_x` is True, `X` is expected to be of shape (num_tokens, K).\n        - When `permute_x` is False, `X` is expected to be of shape (total_tokens, K) where `total_tokens = num_tokens * topk` AND already permuted to grouped expert order, i.e., hidden states are sorted such that tokens assigned to each expert are contiguous.\n    - `permute_y`: fused the permutation of the output from expert grouped order back to original token order, typically only needed for the second grouped GEMM in an MoE MLP.\n    - `fuse_mul_pre`: fuse the multiplication of the routed input with topk_weights, only done in the first grouped GEMM in an MoE MLP as for Llama4.  Do not use, since results in performance regression as it interrupts the GEMM mainloop.\n    - `fuse_mul_post`: fuse the multiplication of the routed output with topk_weights, used only when `permute_y` is True. NOTE: this should only be used when using this kernel for inference, not for training.\n\n    X: (M, K) hidden states where M is the num_tokens if `permute_x` is True, otherwise `total_tokens` where `total_tokens = num_tokens * topk`.\n    W: (E, N, K) expert weights, where E is number of experts, N in the intermediate (output) dim, and K is the reduction dim\n    m_sizes: tokens assigned to each expert which correspond to the size of M in the respective GEMMs in the grouped GEMM.\n    gather_indices: (total_tokens,) indices of tokens assigned to each expert.  E.g., slicing gather_indices by cumsum of m_sizes gives the indices of tokens assigned to each expert.\n    topk_weights: (total_tokens,) weights to multiply routed output by in expert MLP calculation, used only when `fuse_mul_post` is True (see note on `fuse_mul_post`).\n    use_fast_accum: currently unused; trade off faster accumulation dtype in GEMM for less precision.\n    use_tma_load_x: use TMA for loading activations, incompatible with permute_x.  TODO: add TMA gather / scatter support for Blackwell+.\n    use_tma_load_w: use TMA for loading weights.  If TMA supported, this should always be enabled as it is faster than global memory load.\n    use_tma_store: use TMA for storing output, incompatible with permute_y.  TODO: add TMA scatter support for Blackwell+.\n\n    Returns:\n        y: (total_tokens, N) output of grouped GEMM\n    \"\"\"\n\n    assert X.device.type == \"cuda\", \"X and W must be on CUDA\"\n    assert m_sizes.device.type == \"cuda\", \"m_sizes must be on CUDA\"\n\n    X = X.contiguous()\n    W = W.contiguous()\n    m_sizes = m_sizes.contiguous()\n\n    # Preconditions\n    assert not (permute_x and permute_y), \"Cannot permute both X and Y\"\n    assert not (permute_y and use_tma_store), \"Cannot use both TMA store and permute_y\"\n\n    if use_tma_load_x:\n        # TMA load for activations, TMA gather only supported on Blackwell+\n        assert not permute_x, \"Cannot use both use_tma_load_x and permute_x\"\n\n    use_tma = use_tma_load_w or use_tma_load_x or use_tma_store\n    if not supports_tma() and use_tma:\n        warnings.warn(\"TMA not supported, tma_load will be set to False\")\n        use_tma_load_w = False\n        use_tma_load_x = False\n        use_tma_store = False\n\n    if use_tma or autotune:\n        # Respect global persistent allocator if set\n        if _HAS_SET_ALLOCATOR and not getattr(triton, \"_unsloth_allocator_set\", False):\n\n            def alloc_fn(size: int, alignment: int, stream: int):\n                return torch.empty(size, device = \"cuda\", dtype = torch.int8)\n\n            triton.set_allocator(alloc_fn)\n\n    if W.ndim == 3:\n        num_experts = W.shape[0]\n        N = W.shape[1]\n        # K = W.shape[2]\n    else:\n        num_experts = m_sizes.shape[0]\n        N = W.shape[0] // num_experts\n\n    X = X.view(-1, X.shape[-1])\n    W = W.view(-1, W.shape[-1])\n\n    if permute_x or permute_y:\n        assert (\n            gather_indices is not None\n        ), \"gather_indices must be provided when permute_x or permute_y is True\"\n        assert gather_indices.is_contiguous()\n        assert gather_indices.device.type == \"cuda\"\n        assert gather_indices.ndim == 1\n        total_tokens = gather_indices.shape[0]\n        num_tokens = total_tokens // topk\n        if permute_x:\n            assert (\n                X.shape[0] == num_tokens\n            ), f\"X.shape[0] ({X.shape[0]}) must match num_tokens ({num_tokens})\"\n        else:\n            assert (\n                X.shape[0] == total_tokens\n            ), f\"X.shape[0] ({X.shape[0]}) must match total_tokens ({total_tokens})\"\n    else:\n        total_tokens = X.shape[0]\n        num_tokens = total_tokens // topk\n\n    _, K = X.shape\n    assert K == W.shape[1], f\"K ({K}) must match W.shape[1] ({W.shape[1]})\"\n\n    if fuse_mul_post:\n        global _FUSED_MUL_WARN\n        if not _FUSED_MUL_WARN:\n            warnings.warn(\n                \"fused_mul should only be used for inference, not for training\"\n            )\n            _FUSED_MUL_WARN = True\n        assert permute_y, \"FUSE_MUL requires PERMUTE_Y\"\n        assert topk_weights is not None\n        assert topk_weights.numel() == total_tokens\n        assert topk_weights.device.type == \"cuda\"\n        assert topk_weights.is_contiguous()\n        topk_weights = topk_weights.view(-1)\n        if debug:\n            print(\n                f\"DEBUG::GROUPED_GEMM {topk_weights.tolist()} {gather_indices.tolist()}\"\n            )\n\n    y = torch.empty((total_tokens, N), device = X.device, dtype = X.dtype)\n    # if total_tokens == 0 or N == 0:\n    #     return y\n\n    NUM_SMS = torch.cuda.get_device_properties(\"cuda\").multi_processor_count\n\n    def grid(META):\n        return (NUM_SMS,)\n\n    if not autotune:\n        # BLOCK_SIZE_K = min(K, BLOCK_SIZE_K)\n        # BLOCK_SIZE_N = min(N, BLOCK_SIZE_N)\n        pass\n\n    if debug:\n        print(\n            f\"DEBUG::GROUPED_GEMM {num_tokens = } {topk = } {num_experts = } {N = } {K = } {BLOCK_SIZE_M = } {BLOCK_SIZE_N = } {BLOCK_SIZE_K = } {permute_x = }\"\n        )\n        print(\n            f\"DEBUG::GROUPED_GEMM {m_sizes.tolist()} {(gather_indices // topk).tolist()}\"\n        )\n\n    kernel_args = {\n        # Inputs\n        \"x_ptr\": X,\n        \"w_ptr\": W,\n        \"m_sizes_ptr\": m_sizes,\n        \"gather_indices_ptr\": gather_indices,\n        \"topk_weights_ptr\": topk_weights,\n        # Output\n        \"y_ptr\": y,\n        # Problem shapes\n        \"NUM_TOKENS\": num_tokens,\n        \"NUM_EXPERTS\": num_experts,\n        \"TOPK\": topk,\n        \"N\": N,\n        \"K\": K,\n        \"NUM_SMS\": NUM_SMS,\n        # Gather / Scatter\n        \"PERMUTE_X\": permute_x,\n        \"PERMUTE_Y\": permute_y,\n        # TopK weight merging\n        \"FUSE_MUL_POST\": fuse_mul_post,\n        # Loop pipelining\n        \"FLATTEN\": flatten,\n    }\n    if not autotune:\n        kernel_args.update(\n            {\n                \"USE_TMA_LOAD_W\": use_tma_load_w,\n                \"USE_TMA_LOAD_X\": use_tma_load_x,\n                \"USE_TMA_STORE\": use_tma_store,\n                \"BLOCK_SIZE_M\": BLOCK_SIZE_M,\n                \"BLOCK_SIZE_N\": BLOCK_SIZE_N,\n                \"BLOCK_SIZE_K\": BLOCK_SIZE_K,\n                \"num_warps\": num_warps,\n                \"num_stages\": num_stages,\n            }\n        )\n\n    kernel = (\n        _autotuned_grouped_gemm_forward_kernel\n        if autotune\n        else _grouped_gemm_forward_kernel\n    )\n\n    is_fake = _is_tracing(X, W)\n    if not is_fake:\n        compiled_kernel: triton.compiler.CompiledKernel = kernel[grid](**kernel_args)\n        if autotune:\n            log_kernel_info(compiled_kernel, kernel.best_config)\n        else:\n            log_kernel_info(compiled_kernel)\n\n    return y\n\n\n@allow_in_graph\ndef grouped_gemm_dX(\n    dY: torch.Tensor,\n    W: torch.Tensor,\n    gather_indices: torch.Tensor,\n    m_sizes: torch.Tensor,\n    topk: int,\n    BLOCK_SIZE_M: int = 32,\n    BLOCK_SIZE_N: int = 32,\n    BLOCK_SIZE_K: int = 32,\n    debug: bool = False,\n    permute_x: bool = False,\n    permute_y: bool = False,\n    use_tma_load_w: bool = False,\n    use_tma_load_dy: bool = False,\n    use_tma_store: bool = False,\n    num_warps: int = 4,\n    num_stages: int = 2,\n    flatten: bool = True,\n    fuse_mul_pre: bool = False,\n    fuse_mul_post: bool = False,\n    autotune: bool = False,\n) -> torch.Tensor:\n    \"\"\"\n    dX backward kernel\n    grad_output: (M, N)\n    gather_indices: (total_tokens,), indices of tokens assigned to each expert.  E.g., slicing gather_indices by cumsum of m_sizes gives the indices of tokens assigned to each expert.\n    m_sizes: tokens assigned to each expert which correspond to the size of M in the respective GEMMs in the grouped GEMM.\n    topk: number of experts chosen per token.\n    `permute_x`: whether X was permuted on load in the forward pass, typically only used for the first grouped GEMM in an MoE MLP to group tokens by expert.\n    - In the forward pass, if we permuted X on load, we need to permute store in the backward pass\n    - Shapes\n        - the forward pass input X shape is [NUM_TOKENS, K], reduce across K, output y is [NUM_TOKENS * TOPK, K]\n        - the backward pass input dy shape is [NUM_TOKENS * TOPK, N], reduce across N, output dX is [NUM_TOKENS * TOPK, K]\n    - Note that in the backward pass, the output size is still [NUM_TOKENS * TOPK, K] since we still need to accumulate gradients for each expert chosen by the token in a post-processing step.\n    `permute_y`: whether the output was permuted on store in the forward pass, typically only used for the second grouped GEMM in an MoE MLP to restore to the original token order.\n    - In the forward pass, if we permuted output on store (e.g., in the second grouped GEMM in fused MoE MLP), we need to permute on load to get from token order to expert grouped order\n    - We still store in contiguous order since we are writing out dX which will be the input to the backwards pass of the first grouped GEMM\n    `fuse_mul_{pre,post}`: always set to False since this should only be used for inference.\n    use_tma_load_dy: use TMA for loading dy. use_tma_load_dy is incompatible with permute_y.  TODO: add TMA gather / scatter support for Blackwell+ which will enable permute_y and use_tma_load_dy.\n    use_tma_load_w: use TMA for loading weights.  If TMA supported, this should always be enabled as it is faster than global memory load.\n    use_tma_store: use TMA for storing dX.  Incompatible with permute_x.  TODO: add TMA gather / scatter support for Blackwell+ which will enable permute_x and use_tma_store.\n    \"\"\"\n    assert (\n        not fuse_mul_pre\n    ), \"fuse_mul_pre should only be used for inference, not for training\"\n    assert (\n        not fuse_mul_post\n    ), \"fuse_mul_post should only be used for inference, not for training\"\n    assert dY.is_contiguous()\n    assert W.is_contiguous()\n    assert m_sizes.is_contiguous()\n    assert m_sizes.ndim == 1\n\n    # Preconditions\n    assert not (permute_x and permute_y), \"Cannot permute both X and Y\"\n    # Note that this is flipped from the forward pass\n    # If we permuted y in the forward, we need to permute on load in the backward\n    assert not (permute_y and use_tma_load_dy), \"Cannot use both TMA load and permute_y\"\n    assert not (permute_x and use_tma_store), \"Cannot use both TMA store and permute_x\"\n\n    use_tma = use_tma_load_dy or use_tma_load_w or use_tma_store\n    if not supports_tma() and use_tma:\n        warnings.warn(\"TMA not supported, tma_load will be set to False\")\n        use_tma_load_w = False\n        use_tma_load_dy = False\n        use_tma_store = False\n\n    if use_tma or autotune:\n        # Respect global persistent allocator if set\n        if _HAS_SET_ALLOCATOR and not getattr(triton, \"_unsloth_allocator_set\", False):\n\n            def alloc_fn(size: int, alignment: int, stream: int):\n                # print(f\"DEBUG::GROUPED_GEMM alloc_fn {size=} {alignment=} {stream=}\")\n                return torch.empty(size, device = \"cuda\", dtype = torch.int8)\n\n            triton.set_allocator(alloc_fn)\n\n    if W.ndim == 3:\n        num_experts = W.shape[0]\n        N = W.shape[1]\n    else:\n        num_experts = m_sizes.shape[0]\n        N = W.shape[0] // num_experts\n\n    dY = dY.view(-1, dY.shape[-1])\n    W = W.view(-1, W.shape[-1])\n\n    M_total, N_grad = dY.shape\n    N_total, K = W.shape\n    # N = N_total // num_experts\n    assert N_grad == N, f\"Grad_output N ({N_grad}) must match weight N ({N})\"\n\n    assert (\n        M_total % topk == 0\n    ), f\"M_total ({M_total}) must be divisible by topk ({topk})\"\n    num_tokens = M_total // topk\n\n    total_tokens = gather_indices.shape[0]\n    assert (\n        total_tokens == M_total\n    ), f\"Total tokens ({total_tokens}) must match M_total ({M_total})\"\n\n    # Note that the output shape is [NUM_TOKENS * TOPK, K] even when `permute_x` is True since we need to accumulate gradients across all experts chosen by the token.\n    # This will be done in a post-processing step reduction step.\n    output_shape = (total_tokens, K)\n    dX = torch.zeros(output_shape, device = dY.device, dtype = dY.dtype)\n\n    NUM_SMS = torch.cuda.get_device_properties(\n        \"cuda\"\n    ).multi_processor_count  # if not debug else 1\n\n    def grid(META):\n        return (NUM_SMS,)\n\n    if not autotune:\n        # BLOCK_SIZE_N = min(N_grad, BLOCK_SIZE_N)\n        # BLOCK_SIZE_K = min(K, BLOCK_SIZE_K)\n        pass\n\n    if debug:\n        print(\n            f\"DEBUG::GROUPED_GEMM {num_tokens = } {topk = } {output_shape = } {num_experts = } {N = } {K = } {BLOCK_SIZE_M = } {BLOCK_SIZE_N = } {BLOCK_SIZE_K = } {NUM_SMS = }\"\n        )\n        print(f\"DEBUG::GROUPED_GEMM {m_sizes.tolist()}\")\n\n    kernel_args = {\n        # Inputs\n        \"dY_ptr\": dY,\n        \"w_ptr\": W,\n        \"gather_indices_ptr\": gather_indices,\n        \"m_sizes_ptr\": m_sizes,\n        # Output\n        \"dX_ptr\": dX,\n        # Problem sizes\n        \"NUM_EXPERTS\": num_experts,\n        \"NUM_TOKENS\": num_tokens,\n        \"TOPK\": topk,\n        \"N\": N,\n        \"K\": K,\n        \"NUM_SMS\": NUM_SMS,\n        # Gather / Scatter\n        \"PERMUTE_X\": permute_x,\n        \"PERMUTE_Y\": permute_y,\n        \"FLATTEN\": flatten,\n    }\n    if not autotune:\n        kernel_args.update(\n            {\n                \"BLOCK_SIZE_M\": BLOCK_SIZE_M,\n                \"BLOCK_SIZE_N\": BLOCK_SIZE_N,\n                \"BLOCK_SIZE_K\": BLOCK_SIZE_K,\n                \"num_warps\": num_warps,\n                \"num_stages\": num_stages,\n                \"USE_TMA_LOAD_dY\": use_tma_load_dy,\n                \"USE_TMA_LOAD_W\": use_tma_load_w,\n                \"USE_TMA_STORE\": use_tma_store,\n            }\n        )\n    kernel = _autotuned_grouped_gemm_dX_kernel if autotune else _grouped_gemm_dX_kernel\n\n    is_fake = _is_tracing(dY, W)\n    if not is_fake:\n        compiled_kernel: triton.compiler.CompiledKernel = kernel[grid](**kernel_args)\n\n        if autotune:\n            log_kernel_info(compiled_kernel, kernel.best_config)\n        else:\n            log_kernel_info(compiled_kernel)\n    return dX\n\n\n@allow_in_graph\ndef grouped_gemm_dW(\n    X: torch.Tensor,\n    dY: torch.Tensor,\n    m_sizes: torch.Tensor,\n    gather_indices: torch.Tensor,\n    topk: int,\n    BLOCK_SIZE_M: int = 32,\n    BLOCK_SIZE_N: int = 32,\n    BLOCK_SIZE_K: int = 32,\n    permute_x: bool = False,\n    permute_y: bool = False,\n    use_tma_load_dy: bool = False,\n    use_tma_load_x: bool = False,\n    use_tma_store: bool = False,\n    fuse_mul_pre: bool = False,\n    fuse_mul_post: bool = False,\n    num_warps: int = 4,\n    num_stages: int = 2,\n    flatten: bool = True,\n    autotune: bool = False,\n    debug: bool = False,\n) -> torch.Tensor:\n    \"\"\"\n    X: (M, K) hidden states where M is the num_tokens if `permute_x` is True, otherwise `total_tokens` where `total_tokens = num_tokens * topk`.\n    dY: (M, N)\n    topk: number of experts to choose per token.\n    m_sizes: tokens assigned to each expert which correspond to the size of M in the respective GEMMs in the grouped GEMM.\n    gather_indices: (total_tokens,) indices of tokens assigned to each expert.  E.g., slicing gather_indices by cumsum of m_sizes gives the indices of tokens assigned to each expert.\n    permute_x: whether X was permuted on load in the forward pass, typically only used for the first grouped GEMM in an MoE MLP to group tokens by expert.\n    - for the first grouped GEMM, we permuted on load -> X was [num_tokens, K] and stored y in expert grouped order [num_tokens * topk, K]\n    - in the backwards pass, we need to permute on load of X while loading dy in contiguous (expert grouped) order\n    - since we are writing out dW, there is no need to permute on store\n    permute_y: whether the output was permuted on store in the forward pass, typically only used for the second grouped GEMM in an MoE MLP to restore to the original token order.\n    - for the second grouped GEMM, we permuted on store -> y was permuted from expert grouped order to token order while X was loaded in expert grouped order since it was the output of the first grouped GEMM\n    - in the backwards pass, we need to permute on load of dy to get from token order to expert grouped order to match the order of X\n    - since we are writing out dW, there is no need to permute on store\n    use_tma_load_dy: use TMA for loading dy. use_tma_load_dy is incompatible with permute_y.  TODO: add TMA gather / scatter support for Blackwell+ which will enable permute_y and use_tma_load_dy.\n    use_tma_load_x: use TMA for loading x. use_tma_load_x is incompatible with permute_x.  TODO: add TMA gather / scatter support for Blackwell+ which will enable permute_x and use_tma_load_x.\n    use_tma_store: use TMA for storing dW.  If TMA supported, this should always be enabled as it is faster than global memory store.\n    \"\"\"\n    assert not fuse_mul_pre, \"fuse_mul_pre not supported\"\n    assert not fuse_mul_post, \"fuse_mul_post not supported\"\n    NUM_SMS = (\n        torch.cuda.get_device_properties(\"cuda\").multi_processor_count\n        if not debug\n        else 1\n    )\n    X = X.view(-1, X.shape[-1]).contiguous()\n    dY = dY.contiguous()\n    m_sizes = m_sizes.contiguous()\n\n    # Preconditions\n    assert not (permute_x and permute_y), \"Cannot permute both X and Y\"\n    assert not (permute_y and use_tma_load_dy), \"Cannot use both TMA load and permute_y\"\n    assert not (permute_x and use_tma_load_x), \"Cannot use both TMA load and permute_x\"\n\n    use_tma = use_tma_load_dy or use_tma_load_x or use_tma_store\n    if not supports_tma() and use_tma:\n        warnings.warn(\"TMA not supported, tma_load will be set to False\")\n        use_tma_load_x = False\n        use_tma_load_dy = False\n        use_tma_store = False\n\n    if use_tma or autotune:\n        # Respect global persistent allocator if set\n        if _HAS_SET_ALLOCATOR and not getattr(triton, \"_unsloth_allocator_set\", False):\n\n            def alloc_fn(size: int, alignment: int, stream: int):\n                return torch.empty(size, device = \"cuda\", dtype = torch.int8)\n\n            triton.set_allocator(alloc_fn)\n\n    if permute_x or permute_y:\n        assert gather_indices is not None\n        assert gather_indices.is_contiguous()\n        assert gather_indices.device.type == \"cuda\"\n        assert gather_indices.ndim == 1\n        total_tokens = gather_indices.shape[0]\n        num_tokens = total_tokens // topk\n        if permute_x:\n            assert X.shape[0] == num_tokens\n        else:\n            assert X.shape[0] == total_tokens\n    else:\n        total_tokens = X.shape[0]\n        num_tokens = total_tokens // topk\n\n    num_experts = m_sizes.shape[0]\n    # Get dimensions\n    _, K = X.shape\n    M_grad, N = dY.shape\n\n    assert M_grad == total_tokens, f\"dY M ({M_grad}) != total_tokens ({total_tokens})\"\n\n    dW = torch.zeros((num_experts, N, K), device = X.device, dtype = X.dtype)\n\n    if not autotune:\n        # BLOCK_SIZE_N = min(N, BLOCK_SIZE_N)\n        # BLOCK_SIZE_K = min(K, BLOCK_SIZE_K)\n        pass\n\n    def grid(META):\n        return (NUM_SMS,)\n\n    if debug:\n        print(\n            f\"DEBUG::GROUPED_GEMM_DW_TMA {num_experts = } {N = } {K = } {BLOCK_SIZE_M = } {BLOCK_SIZE_N = } {BLOCK_SIZE_K = } {NUM_SMS = }\"\n        )\n\n        print(f\"DEBUG::GROUPED_GEMM_DW_TMA {m_sizes.tolist() = }\")\n        print(f\"DEBUG::GROUPED_GEMM_DW_TMA {gather_indices.tolist() = }\")\n        m_start = 0\n        for i in range(num_experts):\n            expert_token_idx = gather_indices[m_start : m_start + m_sizes[i]]\n            t_start = 0\n            while t_start < m_sizes[i]:\n                token_idx = expert_token_idx[t_start : t_start + BLOCK_SIZE_M]\n                if permute_x:\n                    token_idx = token_idx // topk\n                print(\n                    f\"DEBUG::GROUPED_GEMM_DW_TMA Token expert {i} indices: {token_idx.tolist()}\"\n                )\n                t_start += BLOCK_SIZE_M\n\n            m_start += m_sizes[i]\n\n    kernel_args = {\n        # Inputs\n        \"x_ptr\": X,\n        \"dY_ptr\": dY,\n        \"m_sizes_ptr\": m_sizes,\n        \"gather_indices_ptr\": gather_indices,\n        # Output\n        \"dW_ptr\": dW,\n        # Problem sizes\n        \"NUM_TOKENS\": num_tokens,\n        \"TOPK\": topk,\n        \"NUM_EXPERTS\": num_experts,\n        \"N\": N,\n        \"K\": K,\n        \"NUM_SMS\": NUM_SMS,\n        # Gather / Scatter\n        \"PERMUTE_X\": permute_x,\n        \"PERMUTE_Y\": permute_y,\n        # Loop pipelining\n        \"FLATTEN\": flatten,\n    }\n\n    if not autotune:\n        kernel_args.update(\n            {\n                \"BLOCK_SIZE_M\": BLOCK_SIZE_M,\n                \"BLOCK_SIZE_N\": BLOCK_SIZE_N,\n                \"BLOCK_SIZE_K\": BLOCK_SIZE_K,\n                \"USE_TMA_LOAD_dY\": use_tma_load_dy,\n                \"USE_TMA_LOAD_X\": use_tma_load_x,\n                \"USE_TMA_STORE\": use_tma_store,\n                \"num_warps\": num_warps,\n                \"num_stages\": num_stages,\n            }\n        )\n\n    kernel = _autotuned_grouped_gemm_dW_kernel if autotune else _grouped_gemm_dW_kernel\n\n    is_fake = _is_tracing(X, dY)\n    if not is_fake:\n        compiled_kernel: triton.compiler.CompiledKernel = kernel[grid](**kernel_args)\n\n        if autotune:\n            log_kernel_info(compiled_kernel, kernel.best_config)\n        else:\n            log_kernel_info(compiled_kernel)\n\n    return dW\n\n\nclass GroupedGemm(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        X,\n        W,\n        m_sizes,\n        topk,\n        gather_indices,\n        permute_x,\n        permute_y,\n        topk_weights,\n        fuse_mul_post,\n        kernel_config_fwd,\n        kernel_config_bwd_dX,\n        kernel_config_bwd_dW,\n        autotune,\n        dX_only,\n        dW_only,\n    ):\n        ctx.topk = topk\n        ctx.permute_x = permute_x\n        ctx.permute_y = permute_y\n        ctx.fuse_mul_post = fuse_mul_post\n        ctx.kernel_config_fwd = kernel_config_fwd\n        ctx.kernel_config_bwd_dX = kernel_config_bwd_dX\n        ctx.kernel_config_bwd_dW = kernel_config_bwd_dW\n        ctx.autotune = autotune\n        ctx.dX_only = dX_only\n        ctx.dW_only = dW_only\n\n        # NOTE: we don't save topk_weights for backward since we do not support training with fused_mul\n        ctx.save_for_backward(X, W, m_sizes, gather_indices)\n\n        fwd_config = {}\n        if kernel_config_fwd is not None:\n            fwd_config[\"BLOCK_SIZE_M\"] = kernel_config_fwd.BLOCK_SIZE_M\n            fwd_config[\"BLOCK_SIZE_N\"] = kernel_config_fwd.BLOCK_SIZE_N\n            fwd_config[\"BLOCK_SIZE_K\"] = kernel_config_fwd.BLOCK_SIZE_K\n            fwd_config[\"num_warps\"] = kernel_config_fwd.num_warps\n            fwd_config[\"num_stages\"] = kernel_config_fwd.num_stages\n            fwd_config[\"use_tma_load_x\"] = kernel_config_fwd.use_tma_load_x\n            fwd_config[\"use_tma_load_w\"] = kernel_config_fwd.use_tma_load_w\n            fwd_config[\"use_tma_store\"] = kernel_config_fwd.use_tma_store\n\n        return grouped_gemm_forward(\n            X = X,\n            W = W,\n            topk = topk,\n            m_sizes = m_sizes,\n            gather_indices = gather_indices,\n            topk_weights = topk_weights,\n            permute_x = permute_x,\n            permute_y = permute_y,\n            fuse_mul_post = fuse_mul_post,\n            # Autotune -- this will override the manual kernel config if true\n            autotune = autotune,\n            # Manual kernel config\n            **fwd_config,\n        )\n\n    @staticmethod\n    def backward(ctx, dY):\n        dY = dY.contiguous()\n        X, W, m_sizes, gather_indices = ctx.saved_tensors\n        topk = ctx.topk\n        permute_x = ctx.permute_x\n        permute_y = ctx.permute_y\n        fuse_mul_post = ctx.fuse_mul_post\n        kernel_config_bwd_dX = ctx.kernel_config_bwd_dX\n        kernel_config_bwd_dW = ctx.kernel_config_bwd_dW\n        autotune = ctx.autotune\n        dX_only = ctx.dX_only\n        dW_only = ctx.dW_only\n\n        if not autotune:\n            if not dW_only:\n                assert (\n                    kernel_config_bwd_dX is not None\n                ), \"kernel_config_bwd_dX must be provided if autotune is False\"\n            if not dX_only:\n                assert (\n                    kernel_config_bwd_dW is not None\n                ), \"kernel_config_bwd_dW must be provided if autotune is False\"\n\n        assert (\n            not fuse_mul_post\n        ), \"fused_mul should only be used for inference, not for training\"\n\n        if not dX_only:\n            bwd_dW_config = {}\n\n            if kernel_config_bwd_dW is not None:\n                bwd_dW_config[\"use_tma_load_dy\"] = kernel_config_bwd_dW.use_tma_load_dy\n                bwd_dW_config[\"use_tma_load_x\"] = kernel_config_bwd_dW.use_tma_load_x\n                bwd_dW_config[\"use_tma_store\"] = kernel_config_bwd_dW.use_tma_store\n                bwd_dW_config[\"BLOCK_SIZE_M\"] = kernel_config_bwd_dW.BLOCK_SIZE_M\n                bwd_dW_config[\"BLOCK_SIZE_N\"] = kernel_config_bwd_dW.BLOCK_SIZE_N\n                bwd_dW_config[\"BLOCK_SIZE_K\"] = kernel_config_bwd_dW.BLOCK_SIZE_K\n                bwd_dW_config[\"num_warps\"] = kernel_config_bwd_dW.num_warps\n                bwd_dW_config[\"num_stages\"] = kernel_config_bwd_dW.num_stages\n\n            dW = grouped_gemm_dW(\n                X = X,\n                dY = dY,\n                m_sizes = m_sizes,\n                gather_indices = gather_indices,\n                topk = topk,\n                permute_x = permute_x,\n                permute_y = permute_y,\n                # Autotune -- this will override the manual kernel config if true\n                autotune = autotune,\n                # Manual kernel config\n                **bwd_dW_config,\n            )\n        else:\n            dW = None\n\n        if not dW_only:\n            bwd_dX_config = {}\n            if kernel_config_bwd_dX is not None:\n                bwd_dX_config[\"use_tma_load_dy\"] = kernel_config_bwd_dX.use_tma_load_dy\n                bwd_dX_config[\"use_tma_load_w\"] = kernel_config_bwd_dX.use_tma_load_w\n                bwd_dX_config[\"use_tma_store\"] = kernel_config_bwd_dX.use_tma_store\n                bwd_dX_config[\"BLOCK_SIZE_M\"] = kernel_config_bwd_dX.BLOCK_SIZE_M\n                bwd_dX_config[\"BLOCK_SIZE_N\"] = kernel_config_bwd_dX.BLOCK_SIZE_N\n                bwd_dX_config[\"BLOCK_SIZE_K\"] = kernel_config_bwd_dX.BLOCK_SIZE_K\n                bwd_dX_config[\"num_warps\"] = kernel_config_bwd_dX.num_warps\n                bwd_dX_config[\"num_stages\"] = kernel_config_bwd_dX.num_stages\n\n            dX = grouped_gemm_dX(\n                dY = dY,\n                W = W,\n                m_sizes = m_sizes,\n                gather_indices = gather_indices,\n                topk = topk,\n                permute_x = permute_x,\n                permute_y = permute_y,\n                # Autotune -- this will override the manual kernel config if true\n                autotune = autotune,\n                # Manual kernel config\n                **bwd_dX_config,\n            )\n\n            if topk > 1 and permute_x:\n                dX = dX.view(X.shape[0], topk, -1).sum(dim = 1)\n        else:\n            dX = None\n\n        return (\n            dX,\n            dW,\n            None,  # m_sizes\n            None,  # gather_indices\n            None,  # topk\n            None,  # permute_x\n            None,  # permute_y\n            None,  # topk_weights\n            None,  # fuse_mul_post\n            None,  # kernel_config_fwd\n            None,  # kernel_config_bwd_dX\n            None,  # kernel_config_bwd_dW\n            None,  # autotune\n            None,  # dX_only\n            None,  # dW_only\n        )\n\n\ndef check_valid_config_fwd(\n    permute_x,\n    permute_y,\n    use_tma_load_x,\n    use_tma_load_w,\n    use_tma_store,\n    fuse_mul_post,\n    is_first_gemm,\n):\n    \"\"\"\n    Check if the configuration is valid for the forward pass.\n    \"\"\"\n    is_second_gemm = not is_first_gemm\n\n    assert not (permute_x and permute_y), \"Cannot permute both X and Y\"\n    assert not (\n        is_second_gemm and permute_x\n    ), \"Cannot permute X for the second grouped GEMM\"\n    assert not (\n        is_first_gemm and permute_y\n    ), \"Cannot permute Y for the first grouped GEMM\"\n    assert not (\n        fuse_mul_post and is_first_gemm\n    ), \"Cannot fuse mul for the first grouped GEMM\"\n    assert not (\n        use_tma_load_x and permute_x\n    ), \"Cannot use TMA load and permute X unless on sm100+ (Blackwell+)\"\n    assert not (\n        use_tma_store and permute_y and is_second_gemm\n    ), \"Cannot use TMA store and permute Y for the second grouped GEMM unless on sm100+ (Blackwell+)\"\n\n\ndef check_valid_config_bwd_dW(\n    permute_x,\n    permute_y,\n    use_tma_load_dY,\n    use_tma_load_x,\n    use_tma_store,\n    fuse_mul_post,\n    is_first_gemm,\n):\n    \"\"\"\n    Check if the configuration is valid for the backward pass of dW.\n    \"\"\"\n    is_second_gemm = not is_first_gemm\n    if fuse_mul_post:\n        assert False, \"Cannot fuse_mul is not supported for backward pass\"\n    if is_second_gemm and permute_y and use_tma_load_dY:\n        assert False, \"Cannot use TMA load and permute Y for the second grouped GEMM\"\n    if is_first_gemm and permute_x and use_tma_load_x:\n        assert False, \"Cannot use TMA load and permute X for the first grouped GEMM\"\n\n\ndef check_valid_config_bwd_dX(\n    permute_x,\n    permute_y,\n    use_tma_load_dY,\n    use_tma_load_w,\n    use_tma_store,\n    fuse_mul_post,\n    is_first_gemm,\n):\n    \"\"\"\n    Check if the configuration is valid for the backward pass of dW.\n    \"\"\"\n    is_second_gemm = not is_first_gemm\n    if fuse_mul_post:\n        assert False, \"Cannot fuse_mul is not supported for backward pass\"\n    if is_second_gemm and permute_y and use_tma_load_dY:\n        assert False, \"Cannot use TMA load and permute Y for the second grouped GEMM\"\n    if use_tma_store and permute_x and is_first_gemm:\n        assert False, \"Cannot use TMA store and permute X for the first grouped GEMM\"\n\n\ndef grouped_gemm(\n    X: torch.Tensor,\n    W: torch.Tensor,\n    m_sizes: torch.Tensor,\n    topk: int,\n    gather_indices: torch.Tensor = None,\n    permute_x: bool = False,\n    permute_y: bool = False,\n    topk_weights = None,\n    fuse_mul_post = False,\n    kernel_config_fwd: KernelConfigForward = None,\n    kernel_config_bwd_dX: KernelConfigBackward_dX = None,\n    kernel_config_bwd_dW: KernelConfigBackward_dW = None,\n    autotune: bool = False,\n    is_first_gemm: bool = True,\n    # Only for debugging\n    dX_only: bool = False,\n    dW_only: bool = False,\n):\n    \"\"\"\n    Grouped GEMM for MoE MLPs.\n\n    The implementation offers a number of fusions specific to MoE:\n    - `permute_x`: fuse the permutation of hidden states from token order (original order) to grouped expert order, typically only needed for the first grouped GEMM in an MoE MLP.\n        - When `permute_x` is True, `X` is expected to be of shape (num_tokens, K).\n        - When `permute_x` is False, `X` is expected to be of shape (total_tokens, K) where `total_tokens = num_tokens * topk` AND already permuted to grouped expert order, i.e., hidden states are sorted such that tokens assigned to each expert are contiguous.\n    - `permute_y`: fused the permutation of the output from expert grouped order back to original token order, typically only needed for the second grouped GEMM in an MoE MLP.\n    - `fuse_mul`: fuse the multiplication of the routed output with topk_weights, used only when `permute_y` is True. NOTE: this should only be used when using this kernel for inference, not for training.\n\n    X: (M, K) hidden states where M is the num_tokens if `permute_x` is True, otherwise `total_tokens` where `total_tokens = num_tokens * topk`.\n    W: (E, N, K) expert weights, where E is number of experts, N in the intermediate (output) dim, and K is the reduction dim\n    m_sizes: tokens assigned to each expert which correspond to the size of M in the respective GEMMs in the grouped GEMM.\n    gather_indices: (total_tokens,) indices of tokens assigned to each expert.  E.g., slicing gather_indices by cumsum of m_sizes gives the indices of tokens assigned to each expert. Needed when either `permute_x` or `permute_y` is True.\n    topk_weights: (total_tokens,) weights to multiply routed output by in expert MLP calculation, used only when `fuse_mul` is True (see note on `fuse_mul`).\n    kernel_config_fwd: KernelConfigForward for forward pass.\n    kernel_config_bwd_dX: KernelConfigBackward_dX for backward pass of dX.\n    kernel_config_bwd_dW: KernelConfigBackward_dW for backward pass of dW.\n    autotune: whether to autotune the kernel, if yes, kernel_config_fwd, kernel_config_bwd_dX, and kernel_config_bwd_dW will be ignored.\n    is_first_gemm: whether this is the first grouped GEMM in an MoE MLP.  This is needed to check whether kernel configs are valid.  `permute_x` should only be used for first gemm; `permute_y` should only be used for second gemm.\n    This will impact whether TMA can be used for loading and storing.\n\n    \"\"\"\n    if not autotune:\n        assert (\n            kernel_config_fwd is not None\n        ), \"kernel_config_fwd must be provided if autotune is False\"\n\n        check_valid_config_fwd(\n            permute_x,\n            permute_y,\n            use_tma_load_x = kernel_config_fwd.use_tma_load_x,\n            use_tma_load_w = kernel_config_fwd.use_tma_load_w,\n            use_tma_store = kernel_config_fwd.use_tma_store,\n            fuse_mul_post = fuse_mul_post,\n            is_first_gemm = is_first_gemm,\n        )\n        if kernel_config_bwd_dW is not None and not dX_only:\n            check_valid_config_bwd_dW(\n                permute_x,\n                permute_y,\n                use_tma_load_dY = kernel_config_bwd_dW.use_tma_load_dy,\n                use_tma_load_x = kernel_config_bwd_dW.use_tma_load_x,\n                use_tma_store = kernel_config_bwd_dW.use_tma_store,\n                fuse_mul_post = fuse_mul_post,\n                is_first_gemm = is_first_gemm,\n            )\n        if kernel_config_bwd_dX is not None and not dW_only:\n            check_valid_config_bwd_dX(\n                permute_x,\n                permute_y,\n                use_tma_load_dY = kernel_config_bwd_dX.use_tma_load_dy,\n                use_tma_load_w = kernel_config_bwd_dX.use_tma_load_w,\n                use_tma_store = kernel_config_bwd_dX.use_tma_store,\n                fuse_mul_post = fuse_mul_post,\n                is_first_gemm = is_first_gemm,\n            )\n\n    if permute_x or permute_y:\n        assert (\n            gather_indices is not None\n        ), \"gather_indices is required when either permute_x or permute_y is True\"\n\n    if fuse_mul_post:\n        assert (\n            topk_weights is not None\n        ), \"topk_weights is required when fuse_mul_post is True\"\n\n    X = X.view(-1, X.shape[-1])\n    m_sizes = m_sizes.view(-1)\n    gather_indices = gather_indices.view(-1)\n\n    return GroupedGemm.apply(\n        X,\n        W,\n        m_sizes,\n        topk,\n        gather_indices,\n        permute_x,\n        permute_y,\n        topk_weights,\n        fuse_mul_post,\n        kernel_config_fwd,\n        kernel_config_bwd_dX,\n        kernel_config_bwd_dW,\n        autotune,\n        dX_only,\n        dW_only,\n    )\n"
  },
  {
    "path": "unsloth/kernels/moe/grouped_gemm/kernels/__init__.py",
    "content": ""
  },
  {
    "path": "unsloth/kernels/moe/grouped_gemm/kernels/autotuning.py",
    "content": "# Unsloth\n# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.\n#\n# This program is free software: you can redistribute it and/or modify\n# it under the terms of the GNU Affero General Public License as published\n# by the Free Software Foundation, either version 3 of the License, or\n# (at your option) any later version.\n#\n# This program is distributed in the hope that it will be useful,\n# but WITHOUT ANY WARRANTY; without even the implied warranty of\n# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the\n# GNU Affero General Public License for more details.\n#\n# You should have received a copy of the GNU Affero General Public License\n# along with this program.  If not, see <https://www.gnu.org/licenses/>.\n\n\"\"\"\nAutotuning utils\n\"\"\"\n\nimport logging\nfrom itertools import product\nfrom typing import List\n\nimport torch\nimport triton\n\nlogger = logging.getLogger(__name__)\n\nDEFAULT_M_BLOCK_SIZES = [64, 128]\nDEFAULT_N_BLOCK_SIZES = [64, 128, 256]\nDEFAULT_K_BLOCK_SIZES = [64, 128, 256]\nDEFAULT_NUM_CTAS = 1\nDEFAULT_NUM_WARPS = [4, 8]\nDEFAULT_NUM_STAGES = [3, 4, 5]\nBOOLS = [True, False]\n\n\ndef val_to_list(val):\n    if val is None:\n        return None\n    elif isinstance(val, list):\n        return val\n    else:\n        return [val]\n\n\ndef convert_args_to_list(args):\n    return [val_to_list(arg) for arg in args]\n\n\ndef _triton_supports_tma():\n    \"\"\"Check if current Triton version supports TMA API.\"\"\"\n    import triton.language as tl\n\n    # Check for both old experimental and new stable API names\n    return hasattr(tl, \"make_tensor_descriptor\") or hasattr(\n        tl, \"_experimental_make_tensor_descriptor\"\n    )\n\n\n# Precompute at module import\n# NOTE: TMA is disabled for now due to compatibility issues with permute_x/permute_y settings\n# in the MoE grouped GEMM forward/backward passes. Re-enable once these are resolved.\n_TRITON_HAS_TMA = False  # _triton_supports_tma()\n\n\ndef get_forward_configs(\n    BLOCK_M = DEFAULT_M_BLOCK_SIZES,\n    BLOCK_N = DEFAULT_N_BLOCK_SIZES,\n    BLOCK_K = DEFAULT_K_BLOCK_SIZES,\n    TMA_LOAD_X = None,  # Auto-detect if not specified\n    TMA_LOAD_W = None,  # Auto-detect if not specified\n    TMA_STORE = False,  # NOTE: TMA_STORE is disabled for now\n    num_warps = DEFAULT_NUM_WARPS,\n    num_stages = DEFAULT_NUM_STAGES,\n    num_ctas = DEFAULT_NUM_CTAS,\n):\n    # Auto-detect TMA support\n    if TMA_LOAD_X is None:\n        TMA_LOAD_X = _TRITON_HAS_TMA\n    if TMA_LOAD_W is None:\n        TMA_LOAD_W = _TRITON_HAS_TMA\n\n    (\n        BLOCK_M,\n        BLOCK_N,\n        BLOCK_K,\n        TMA_LOAD_X,\n        TMA_LOAD_W,\n        TMA_STORE,\n        num_warps,\n        num_stages,\n        num_ctas,\n    ) = convert_args_to_list(\n        [\n            BLOCK_M,\n            BLOCK_N,\n            BLOCK_K,\n            TMA_LOAD_X,\n            TMA_LOAD_W,\n            TMA_STORE,\n            num_warps,\n            num_stages,\n            num_ctas,\n        ]\n    )\n    kernel_configs = []\n    for (\n        block_m,\n        block_n,\n        block_k,\n        w,\n        s,\n        tma_load_x,\n        tma_load_w,\n        tma_store,\n        num_ctas,\n    ) in product(\n        BLOCK_M,\n        BLOCK_N,\n        BLOCK_K,\n        num_warps,\n        num_stages,\n        TMA_LOAD_X,\n        TMA_LOAD_W,\n        TMA_STORE,\n        num_ctas,\n    ):\n        kernel_configs.append(\n            triton.Config(\n                dict(\n                    BLOCK_SIZE_M = block_m,\n                    BLOCK_SIZE_N = block_n,\n                    BLOCK_SIZE_K = block_k,\n                    USE_TMA_LOAD_X = tma_load_x,\n                    USE_TMA_LOAD_W = tma_load_w,\n                    USE_TMA_STORE = tma_store,\n                ),\n                num_warps = w,\n                num_stages = s,\n                num_ctas = num_ctas,\n            )\n        )\n\n    return kernel_configs\n\n\ndef get_dX_kernel_configs(\n    BLOCK_M = DEFAULT_M_BLOCK_SIZES,\n    BLOCK_N = DEFAULT_N_BLOCK_SIZES,\n    BLOCK_K = DEFAULT_K_BLOCK_SIZES,\n    TMA_LOAD_dY = None,  # Auto-detect if not specified\n    TMA_LOAD_W = None,  # Auto-detect if not specified\n    TMA_STORE = False,  # NOTE: TMA_STORE is disabled for now\n    num_warps = DEFAULT_NUM_WARPS,\n    num_stages = DEFAULT_NUM_STAGES,\n    num_ctas = DEFAULT_NUM_CTAS,\n):\n    # Auto-detect TMA support\n    if TMA_LOAD_dY is None:\n        TMA_LOAD_dY = _TRITON_HAS_TMA\n    if TMA_LOAD_W is None:\n        TMA_LOAD_W = _TRITON_HAS_TMA\n    (\n        BLOCK_M,\n        BLOCK_N,\n        BLOCK_K,\n        TMA_LOAD_dY,\n        TMA_LOAD_W,\n        TMA_STORE,\n        num_warps,\n        num_stages,\n        num_ctas,\n    ) = convert_args_to_list(\n        [\n            BLOCK_M,\n            BLOCK_N,\n            BLOCK_K,\n            TMA_LOAD_dY,\n            TMA_LOAD_W,\n            TMA_STORE,\n            num_warps,\n            num_stages,\n            num_ctas,\n        ]\n    )\n    kernel_configs = []\n    for (\n        block_m,\n        block_n,\n        block_k,\n        w,\n        s,\n        tma_load_dy,\n        tma_load_w,\n        tma_store,\n        num_ctas,\n    ) in product(\n        BLOCK_M,\n        BLOCK_N,\n        BLOCK_K,\n        num_warps,\n        num_stages,\n        TMA_LOAD_dY,\n        TMA_LOAD_W,\n        TMA_STORE,\n        num_ctas,\n    ):\n        kernel_configs.append(\n            triton.Config(\n                dict(\n                    BLOCK_SIZE_M = block_m,\n                    BLOCK_SIZE_N = block_n,\n                    BLOCK_SIZE_K = block_k,\n                    USE_TMA_LOAD_dY = tma_load_dy,\n                    USE_TMA_LOAD_W = tma_load_w,\n                    USE_TMA_STORE = tma_store,\n                ),\n                num_warps = w,\n                num_stages = s,\n                num_ctas = num_ctas,\n            )\n        )\n\n    return kernel_configs\n\n\ndef get_dW_kernel_configs(\n    BLOCK_M = DEFAULT_M_BLOCK_SIZES,\n    BLOCK_N = DEFAULT_N_BLOCK_SIZES,\n    BLOCK_K = DEFAULT_K_BLOCK_SIZES,\n    num_warps = DEFAULT_NUM_WARPS,\n    num_stages = DEFAULT_NUM_STAGES,\n    num_ctas = DEFAULT_NUM_CTAS,\n    TMA_LOAD_dY = None,  # Auto-detect if not specified\n    TMA_LOAD_X = None,  # Auto-detect if not specified\n    TMA_STORE = False,\n):\n    # Auto-detect TMA support\n    if TMA_LOAD_dY is None:\n        TMA_LOAD_dY = _TRITON_HAS_TMA\n    if TMA_LOAD_X is None:\n        TMA_LOAD_X = _TRITON_HAS_TMA\n    (\n        BLOCK_M,\n        BLOCK_N,\n        BLOCK_K,\n        num_warps,\n        num_stages,\n        num_ctas,\n        TMA_LOAD_dY,\n        TMA_LOAD_X,\n        TMA_STORE,\n    ) = convert_args_to_list(\n        [\n            BLOCK_M,\n            BLOCK_N,\n            BLOCK_K,\n            num_warps,\n            num_stages,\n            num_ctas,\n            TMA_LOAD_dY,\n            TMA_LOAD_X,\n            TMA_STORE,\n        ]\n    )\n    kernel_configs = []\n    for (\n        block_m,\n        block_n,\n        block_k,\n        w,\n        s,\n        tma_load_dy,\n        tma_load_x,\n        tma_store,\n        num_ctas,\n    ) in product(\n        BLOCK_M,\n        BLOCK_N,\n        BLOCK_K,\n        num_warps,\n        num_stages,\n        TMA_LOAD_dY,\n        TMA_LOAD_X,\n        TMA_STORE,\n        num_ctas,\n    ):\n        kernel_configs.append(\n            triton.Config(\n                dict(\n                    BLOCK_SIZE_M = block_m,\n                    BLOCK_SIZE_N = block_n,\n                    BLOCK_SIZE_K = block_k,\n                    USE_TMA_LOAD_dY = tma_load_dy,\n                    USE_TMA_LOAD_X = tma_load_x,\n                    USE_TMA_STORE = tma_store,\n                ),\n                num_warps = w,\n                num_stages = s,\n                num_ctas = num_ctas,\n            )\n        )\n\n    return kernel_configs\n\n\ndef estimate_smem_reqs(\n    num_stages: int,\n    BLOCK_SIZE_M: int,\n    BLOCK_SIZE_N: int,\n    BLOCK_SIZE_K: int,\n    dtype: torch.dtype,\n):\n    num_bytes = dtype.itemsize\n    return (\n        num_stages * BLOCK_SIZE_K * (BLOCK_SIZE_M + BLOCK_SIZE_N)\n        + BLOCK_SIZE_M * BLOCK_SIZE_N\n    ) * num_bytes\n\n\ndef exceeds_smem_capacity(\n    num_stages: int,\n    BLOCK_SIZE_M: int,\n    BLOCK_SIZE_N: int,\n    BLOCK_SIZE_K: int,\n    dtype: torch.dtype,\n    smem_size: int,\n    slack: float = 50000,\n):\n    smem_reqs = estimate_smem_reqs(\n        num_stages, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, dtype\n    )\n    return smem_reqs > smem_size + slack\n\n\ndef common_prune_criteria(config: triton.Config, kwargs: dict, dtype):\n    from ..interface import supports_tma\n    from .tuning import get_device_properties\n\n    smem_size = get_device_properties().SIZE_SMEM\n\n    num_stages = config.num_stages\n    BLOCK_SIZE_M = config.kwargs[\"BLOCK_SIZE_M\"]\n    BLOCK_SIZE_N = config.kwargs[\"BLOCK_SIZE_N\"]\n    BLOCK_SIZE_K = config.kwargs[\"BLOCK_SIZE_K\"]\n\n    num_tokens = kwargs[\"NUM_TOKENS\"]\n    num_experts = kwargs[\"NUM_EXPERTS\"]\n    permute_x = kwargs[\"PERMUTE_X\"]\n    permute_y = kwargs[\"PERMUTE_Y\"]\n    tokens_per_expert = num_tokens // num_experts\n\n    # use_tma = [k for k in config.kwargs.keys() if k.startswith(\"USE_TMA_\")]\n    MIN_BLOCK_SIZE_M = DEFAULT_M_BLOCK_SIZES[0]\n    if exceeds_smem_capacity(\n        num_stages, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, dtype, smem_size\n    ):\n        return True\n    if BLOCK_SIZE_M > tokens_per_expert * 2 and tokens_per_expert > MIN_BLOCK_SIZE_M:\n        return True\n    if permute_x and permute_y:\n        return True\n    # if not supports_tma() and any(use_tma):\n    #     return True\n    return False\n\n\ndef maybe_disable_tma(config: triton.Config):\n    from ..interface import supports_tma\n\n    tma_keys = [k for k in config.kwargs.keys() if k.startswith(\"USE_TMA_\")]\n    if not supports_tma():\n        logger.info(\"Disabling TMA\")\n        for k in tma_keys:\n            config.kwargs[k] = False\n\n\ndef prune_kernel_configs_fwd(configs: list[triton.Config], args, **kwargs):\n    x = kwargs[\"x_ptr\"]\n    dtype = x.dtype\n\n    logger.debug(f\"Pruning configs: {len(configs)}\")\n\n    pruned_configs = []\n    for config in configs:\n        # disable TMA if gpu does not support it\n        maybe_disable_tma(config)\n\n        if common_prune_criteria(config, kwargs, dtype):\n            continue\n        if config.kwargs[\"USE_TMA_LOAD_X\"] and kwargs[\"PERMUTE_X\"]:\n            # Dynamically disable TMA_LOAD_X for permuted X\n            config.kwargs[\"USE_TMA_LOAD_X\"] = False\n        if config.kwargs[\"USE_TMA_STORE\"] and kwargs[\"PERMUTE_Y\"]:\n            continue\n\n        pruned_configs.append(config)\n\n    logger.debug(f\"Pruned configs: {len(pruned_configs)}\")\n    return pruned_configs\n\n\ndef prune_dX_configs(configs: List[triton.Config], args, **kwargs):\n    dtype = kwargs[\"w_ptr\"].dtype\n\n    logger.debug(f\"Pruning configs: {len(configs)}\")\n    pruned_configs = []\n\n    for config in configs:\n        if common_prune_criteria(config, kwargs, dtype):\n            continue\n        if config.kwargs[\"USE_TMA_LOAD_dY\"] and kwargs[\"PERMUTE_Y\"]:\n            # dynamically disable TMA_LOAD_dY for permuted Y\n            config.kwargs[\"USE_TMA_LOAD_dY\"] = False\n        if config.kwargs[\"USE_TMA_STORE\"] and kwargs[\"PERMUTE_X\"]:\n            continue\n        pruned_configs.append(config)\n\n    logger.debug(f\"Pruned configs: {len(pruned_configs)}\")\n    return pruned_configs\n\n\ndef prune_kernel_configs_backward_dW(configs: list[triton.Config], args, **kwargs):\n    dtype = kwargs[\"x_ptr\"].dtype\n\n    pruned_configs = []\n    logger.debug(f\"Pruning configs: {len(configs)}\")\n\n    for config in configs:\n        if common_prune_criteria(config, kwargs, dtype):\n            continue\n        if config.kwargs[\"USE_TMA_LOAD_dY\"] and kwargs[\"PERMUTE_Y\"]:\n            config.kwargs[\"USE_TMA_LOAD_dY\"] = False\n        if config.kwargs[\"USE_TMA_LOAD_X\"] and kwargs[\"PERMUTE_X\"]:\n            config.kwargs[\"USE_TMA_LOAD_X\"] = False\n        pruned_configs.append(config)\n\n    logger.debug(f\"Pruned configs: {len(pruned_configs)}\")\n    return pruned_configs\n"
  },
  {
    "path": "unsloth/kernels/moe/grouped_gemm/kernels/backward.py",
    "content": "# SPDX-License-Identifier: GNU Affero General Public License v3.0\n# Copyright 2023-present the Unsloth team. All rights reserved.\n\nimport torch\nimport triton\nimport triton.language as tl\n\nfrom .autotuning import (\n    get_dW_kernel_configs,\n    get_dX_kernel_configs,\n    prune_dX_configs,\n    prune_kernel_configs_backward_dW,\n)\n\n\"\"\"\ndX backward kernel\n\n- Shapes\n    - the forward pass input X shape is [NUM_TOKENS, K] if permute_x else [NUM_TOKENS * TOPK, K]; output y is [NUM_TOKENS * TOPK, N]\n    - the backward pass input dy shape is [NUM_TOKENS * TOPK, N], reduce across N, output dX is [NUM_TOKENS * TOPK, K]\n- Note that in the backward pass, the output size is still [NUM_TOKENS * TOPK, K] since we still need to accumulate gradients for each expert chosen by the token in a post-processing step.\n\n`permute_x` notes:\n- In the forward pass, if we permute X on load, we need to permute on store in the backward pass to restore to original token order\n- the output dX with have shape [NUM_TOKENS * TOPK, K] and we need to perform an additional reduction across topk to accumulate gradients\n- This is done as a post-processing step in autograd.Function.\n- If not `permute_x`, this postprocessing step should take place outside autograd.Function such that the gradient shape matches the input X shape.\n\n`permute_y` notes:\n- In the forward pass, if we permuted output on store (e.g., in the second grouped GEMM in fused MoE MLP), we need to permute on load to get from token order to expert grouped order\n- We still store in contiguous order since we are writing out dX which will be the input to the backwards pass of the first grouped GEMM\n\n`fused_mul` notes:\n- In the forward pass, if we used the multiplication of topk weights (e.g., in the second grouped GEMM in fused MoE MLP), we need to make a few additional changes:\n    1) We load topk_weights in natural (token) order.  Since we only enable `fuse_mul` when permuting on store (`permute_y`), we multiply grad_output by topk_weights before backpropagating\n    2) We need to calculate the gradient of the topk_weights.  This gets messy since we need do an additional elementwise multiplication in the GEMM main loop and then write out in unpermuted order.  For now, we do not fuse this step but calculate as a simple\n\nInvalid combinations:\n- permute_y and use_tma_load: permuting y on store in forward -> load in permuted order in backward, therefore can't use TMA load (unless Blackwell which supports gather / scatter TMA)\n- permute_x and use_tma_store: permuting x on load in forward -> store in permuted order in backward, therefore can't use TMA store (unless Blackwell which supports gather / scatter TMA)\n\nTODO:\n- We define indices for all conditions and expect that unused indices will be DCE'd during compilation.  Check that this is the case otherwise will result in unnecessary register usage.\n\"\"\"\n\n\n@triton.jit\ndef _grouped_gemm_dX_kernel(\n    dY_ptr,  # [M_total, N]\n    w_ptr,  # [E, N, K]\n    dX_ptr,  # [M_total, K]\n    gather_indices_ptr,\n    m_sizes_ptr,\n    # problem sizes\n    NUM_EXPERTS: tl.constexpr,\n    NUM_TOKENS,\n    TOPK: tl.constexpr,\n    N: tl.constexpr,\n    K: tl.constexpr,\n    NUM_SMS,\n    # Tuning parameters\n    BLOCK_SIZE_M: tl.constexpr,\n    BLOCK_SIZE_N: tl.constexpr,\n    BLOCK_SIZE_K: tl.constexpr,\n    PERMUTE_X: tl.constexpr = False,\n    PERMUTE_Y: tl.constexpr = False,\n    USE_TMA_LOAD_W: tl.constexpr = False,\n    USE_TMA_LOAD_dY: tl.constexpr = False,\n    USE_TMA_STORE: tl.constexpr = False,\n    FLATTEN: tl.constexpr = True,\n) -> None:\n    TOTAL_TOKENS = NUM_TOKENS * TOPK\n    output_dtype = dX_ptr.dtype.element_ty\n\n    tidx = tl.program_id(0)\n    # This removes the need for predication along N in the GEMM main loop\n    tl.static_assert(N % BLOCK_SIZE_N == 0, \"N must be divisible by BLOCK_SIZE_N\")\n    tl.static_assert(K % BLOCK_SIZE_K == 0, \"K must be divisible by BLOCK_SIZE_K\")\n\n    # Create TMA descriptors for loading sorted tokens\n    # When using TMA load, we don't permute_x, so shape should be [TOTAL_TOKENS, K]\n    # Also, we are defining a single global descriptor with single block shape\n    # Need to check that this does not result in errors when crossing expert boundaries\n    if USE_TMA_LOAD_dY:\n        dY_desc = tl.make_tensor_descriptor(\n            dY_ptr,\n            shape = [TOTAL_TOKENS, N],\n            strides = [N, 1],\n            block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_N],\n        )\n\n    if USE_TMA_LOAD_W:\n        expert_stride = N * K\n        w_desc = tl.make_tensor_descriptor(\n            w_ptr,\n            shape = [NUM_EXPERTS, N, K],\n            strides = [expert_stride, K, 1],\n            block_shape = [1, BLOCK_SIZE_N, BLOCK_SIZE_K],\n        )\n\n    m_end = 0\n    processed_tiles = 0\n    m_block_range = tl.arange(0, BLOCK_SIZE_M)\n    n_block_range = tl.arange(0, BLOCK_SIZE_N)\n    k_block_range = tl.arange(0, BLOCK_SIZE_K)\n\n    for expert_idx in range(NUM_EXPERTS, flatten = FLATTEN):\n        m_start = m_end\n        m_size = tl.load(m_sizes_ptr + expert_idx).to(tl.int32)\n        m_end = m_start + m_size\n\n        if m_size > 0:\n            # Advance n offset to the weights for that respective expert\n            n_start = expert_idx * N\n            # N_start_offset = g.to(tl.int64) * N\n            # tiles for this group's GEMM\n            num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)\n            num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)\n            num_tiles_per_expert = num_m_tiles * num_k_tiles\n\n            if USE_TMA_STORE:\n                # Need to define descript within loop to predicate store along M\n                tl.static_assert(\n                    K % BLOCK_SIZE_K == 0, \"K must be divisible by BLOCK_SIZE_K\"\n                )\n                dX_desc = tl.make_tensor_descriptor(\n                    dX_ptr,\n                    shape = [m_end, K],\n                    strides = [K, 1],\n                    block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_K],\n                )\n\n            # Lower bound and upper bound are defined relative to the total tiles processed so far\n            # This ensures that we are only processing tiles for the current expert group AND\n            # we never exceed the total number of tiles for all expert groups\n            while tidx >= processed_tiles and tidx < (\n                processed_tiles + num_tiles_per_expert\n            ):\n                group_index = tidx - processed_tiles\n\n                # Output tile for this thread block for this expert group\n                tile_m_idx = group_index % num_m_tiles\n                tile_k_idx = group_index // num_m_tiles\n\n                if PERMUTE_X or PERMUTE_Y:\n                    # These will be used for loading and storing in permuted order\n                    gather_offsets = tile_m_idx * BLOCK_SIZE_M + m_block_range\n                    # indices_to_gather = m_start + gather_offsets\n                    indices_to_gather = m_start + tl.max_contiguous(\n                        tl.multiple_of(gather_offsets % m_size, BLOCK_SIZE_M),\n                        BLOCK_SIZE_M,\n                    )\n                    expert_token_idx = tl.load(\n                        gather_indices_ptr + indices_to_gather,\n                        mask = indices_to_gather < TOTAL_TOKENS,\n                    )\n                    expert_token_offsets = expert_token_idx[:, None]\n\n                    # Masks for permuted load and store\n                    row_mask = gather_offsets < m_size\n                    row_mask = row_mask[:, None]\n\n                    # We only take into account the following two cases: (PERMUTE_X and NOT PERMUTE_Y) and (NOT PERMUTE_X and PERMUTE_Y)\n                    # Hence, we can make the following simplifying assumptions when loading and storing\n                    # Note the different strides between the two cases: the offsets for loading and storing are flipped and the strides must also be adjusted\n\n                    if PERMUTE_X:\n                        # Case where we permuted on load in the forward pass (typically first grouped GEMM in MoE MLP)\n                        load_a_idx = (\n                            indices_to_gather[:, None] * N\n                        )  # Load in contiguous (expert grouped) order\n                        store_idx = (\n                            expert_token_offsets * K\n                        )  # Permute on store from expert -> token order\n                    else:\n                        # Case where we permuted on store in the forward pass (typically second grouped GEMM in MoE MLP)\n                        load_a_idx = (\n                            expert_token_offsets * N\n                        )  # Permute on load from token -> expert order\n                        store_idx = (\n                            indices_to_gather[:, None] * K\n                        )  # Store in contiguous order\n                else:\n                    # # Position in full matrix - needed for TMA\n                    # m_offset = (M_start + (tile_m_idx * BLOCK_SIZE_M)).to(tl.int32)\n                    # k_offset = (tile_k_idx * BLOCK_SIZE_K).to(tl.int32)\n                    # Offsets *relative* to the *current* expert -- m_start will then advance to this expert's start token\n                    offs_am = tile_m_idx * BLOCK_SIZE_M + m_block_range\n\n                    # [M, N] @ [N, K] -> [M, K] => Stride for A is N, stride for B is K\n                    # We need two additional offsets:\n                    # 1. For A, m_start to advance to this expert's start token\n                    # 2. For B, n_start to advance to this expert's weights since we are passing in an [E, N, K] weight matrix\n                    row_offsets_a = m_start + offs_am[:, None]\n                    load_a_idx = row_offsets_a * N\n                    store_idx = row_offsets_a * K\n                    row_mask = offs_am[:, None] < m_size\n\n                if not USE_TMA_LOAD_dY:\n                    dY_ptrs = dY_ptr + load_a_idx + n_block_range[None, :]\n\n                offs_bk = tile_k_idx * BLOCK_SIZE_K + k_block_range\n                if not USE_TMA_LOAD_W:\n                    row_offsets_b = n_start + n_block_range\n                    # offs_bn = n_start + n_block_range\n                    # row_offsets_b = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)\n                    w_ptrs = w_ptr + row_offsets_b[:, None] * K + offs_bk[None, :]\n\n                # TODO: check whether predication along K is needed since we checked that K is divisible by BLOCK_SIZE_K in the forward kernel\n                # col_mask = offs_bk[None, :] < K\n                store_mask = row_mask  # & col_mask\n\n                accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype = tl.float32)\n\n                # GEMM main loop\n                for n_offset in range(0, N, BLOCK_SIZE_N):\n                    # dY block [M, N]\n                    if not USE_TMA_LOAD_dY:\n                        dY = tl.load(dY_ptrs, mask = row_mask)\n                    else:\n                        dY = dY_desc.load(\n                            [m_start + tile_m_idx * BLOCK_SIZE_M, n_offset]\n                        )\n\n                    if not USE_TMA_LOAD_W:\n                        w = tl.load(w_ptrs)  # , mask=col_mask)\n                    else:\n                        w = w_desc.load(\n                            [expert_idx, n_offset, tile_k_idx * BLOCK_SIZE_K]\n                        )\n                        w = tl.reshape(w, (BLOCK_SIZE_N, BLOCK_SIZE_K))\n                    # TODO: check if predication along K is needed since we checked that K is divisible by BLOCK_SIZE_K in the forward kernel\n\n                    # [M, N] @ [N, K] -> [M, K]\n                    dY = dY.to(w.dtype)\n                    accumulator += tl.dot(dY, w)  # NOTE: no transpose of b\n\n                    # Advance A along contiguous dimension\n                    if not USE_TMA_LOAD_dY:\n                        dY_ptrs += BLOCK_SIZE_N\n                    # Note we are no longer advancing B along contiguous dimension since weights are arranged as [N, K]\n                    # Instead, we need to stride by K to advance to the [N_BLOCK_SIZE, K_BLOCK_SIZE] tile\n                    if not USE_TMA_LOAD_W:\n                        w_ptrs += BLOCK_SIZE_N * K\n\n                dX = accumulator.to(output_dtype)\n\n                # Writing out a BLOCK_M x BLOCK_K tile, so we need to stride by K\n                if USE_TMA_STORE:\n                    offset_m = tile_m_idx * BLOCK_SIZE_M  # .to(tl.int32)\n                    offset_k = tile_k_idx * BLOCK_SIZE_K  # .to(tl.int32)\n                    dX_desc.store([m_start + offset_m, offset_k], dX)\n                else:\n                    tl.store(\n                        dX_ptr + store_idx + offs_bk[None, :],\n                        dX,\n                        mask = store_mask,\n                    )\n\n                # Move to the next tile within this expert group\n                tidx += NUM_SMS\n\n            # Update the total tiles count for the next expert group\n            processed_tiles += num_tiles_per_expert\n\n\n_autotuned_grouped_gemm_dX_kernel = triton.autotune(\n    configs = get_dX_kernel_configs(),\n    prune_configs_by = {\"early_config_prune\": prune_dX_configs},\n    # NOTE: NUM_TOKENS removed from key to avoid recompilation for every sequence length\n    key = [\"NUM_EXPERTS\", \"N\", \"K\", \"PERMUTE_X\", \"PERMUTE_Y\"],\n)(_grouped_gemm_dX_kernel)\n\n\"\"\"\nnotes on permute_x:\n- for the first grouped GEMM, we permuted on load -> X was [num_tokens, K] and stored y in expert grouped order [num_tokens * topk, K]\n- in the backwards pass, we need to permute on load of X while loading dy in contiguous (expert grouped) order\n- since we are writing out dW, there is no need to permute on store\n\nnotes on permute_y:\n- for the second grouped GEMM, we permuted on store -> y was permuted from expert grouped order to token order, x was loaded in expert grouped order since it was the output of the first grouped GEMM\n- in the backwards pass, we need to permute on load of dy to get from token order to expert grouped order to match the order of X\n- since we are writing out dW, there is no need to permute on store\n\nnotes on TMA loading:\n- if we're TMA loading both X and dY, then we need to mask along the M dimension\nto account for expert boundaries\n- we can either\n    - define TMA descriptors within the outer for loop to predicate loads\n    or\n    - mask along M after loading\n\"\"\"\n\n\n@triton.jit\ndef _grouped_gemm_dW_kernel(\n    x_ptr,\n    dY_ptr,\n    dW_ptr,\n    m_sizes_ptr,\n    gather_indices_ptr,\n    # problem sizes\n    NUM_TOKENS,\n    TOPK: tl.constexpr,\n    NUM_EXPERTS: tl.constexpr,\n    N: tl.constexpr,\n    K: tl.constexpr,\n    NUM_SMS,\n    BLOCK_SIZE_N: tl.constexpr,\n    BLOCK_SIZE_K: tl.constexpr,\n    BLOCK_SIZE_M: tl.constexpr,\n    PERMUTE_X: tl.constexpr = False,\n    PERMUTE_Y: tl.constexpr = False,\n    USE_TMA_LOAD_dY: tl.constexpr = False,\n    USE_TMA_LOAD_X: tl.constexpr = False,\n    USE_TMA_STORE: tl.constexpr = False,\n    FLATTEN: tl.constexpr = True,\n    acc_dtype: tl.constexpr = tl.float32,\n) -> None:\n    TOTAL_TOKENS = NUM_TOKENS * TOPK\n    TMA_LOAD_BOTH: tl.constexpr = USE_TMA_LOAD_X and USE_TMA_LOAD_dY\n\n    tidx = tl.program_id(0)\n    output_dtype = dW_ptr.dtype.element_ty\n\n    if USE_TMA_LOAD_dY and not TMA_LOAD_BOTH:\n        dY_desc = tl.make_tensor_descriptor(\n            dY_ptr,\n            shape = [TOTAL_TOKENS, N],\n            strides = [N, 1],\n            block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_N],\n        )\n\n    if USE_TMA_LOAD_X and not TMA_LOAD_BOTH:\n        x_desc = tl.make_tensor_descriptor(\n            x_ptr,\n            shape = [TOTAL_TOKENS, K],\n            strides = [K, 1],\n            block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_K],\n        )\n    # Output tiles per expert, since each expert weight matrix is [N, K]\n    num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)\n    num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)\n    output_tiles_per_expert = num_n_tiles * num_k_tiles\n\n    block_range_m = tl.arange(0, BLOCK_SIZE_M)\n    block_range_n = tl.arange(0, BLOCK_SIZE_N)\n    block_range_k = tl.arange(0, BLOCK_SIZE_K)\n\n    # NOTE: Important that N % BLOCK_SIZE_N == 0 and K % BLOCK_SIZE_K == 0 when using TMA store\n    if USE_TMA_STORE:\n        tl.static_assert(N % BLOCK_SIZE_N == 0, \"N must be divisible by BLOCK_SIZE_N\")\n        tl.static_assert(K % BLOCK_SIZE_K == 0, \"K must be divisible by BLOCK_SIZE_K\")\n        dW_desc = tl.make_tensor_descriptor(\n            dW_ptr,\n            shape = [NUM_EXPERTS, N, K],\n            strides = [N * K, K, 1],\n            block_shape = [1, BLOCK_SIZE_N, BLOCK_SIZE_K],\n        )\n\n    for tile_idx in range(\n        tidx, output_tiles_per_expert, NUM_SMS\n    ):  # , flatten=FLATTEN):\n        # Output tile index\n        tile_n_idx = tile_idx % num_n_tiles\n        tile_k_idx = tile_idx // num_n_tiles\n\n        # Output tile offsets\n        n_offset = tile_n_idx * BLOCK_SIZE_N\n        k_offset = tile_k_idx * BLOCK_SIZE_K\n\n        # For storing\n        # TODO: Check whether the k mask is needed since we statically check that K is divisible by BLOCK_SIZE_K in the forward kernel\n        # ditto for n_mask\n        n_mask = block_range_n + n_offset < N\n        k_mask = block_range_k + k_offset < K\n        nk_mask = n_mask[:, None] & k_mask[None, :]\n\n        m_end = 0\n        for expert_idx in range(NUM_EXPERTS):\n            # We need to instantiate a fresh accumulator for each expert\n            accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype = acc_dtype)\n\n            m_start = m_end\n            # Need to figure out why this cast is needed, otherwise compiler complains about mismatching types\n            m_size = tl.load(m_sizes_ptr + expert_idx).to(tl.int32)\n            m_end = m_start + m_size\n\n            # NOTE: when storing the result, we need to offset by n_start since we are storing the result for this expert to the global [E, N, K] weight matrix\n            n_start = expert_idx * N\n            store_row_offs = n_start + n_offset + block_range_n\n\n            if m_size > 0:\n                if TMA_LOAD_BOTH:\n                    dY_desc = tl.make_tensor_descriptor(\n                        dY_ptr,\n                        shape = [m_end, N],\n                        strides = [N, 1],\n                        block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_N],\n                    )\n\n                    x_desc = tl.make_tensor_descriptor(\n                        x_ptr,\n                        shape = [m_end, K],\n                        strides = [K, 1],\n                        block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_K],\n                    )\n\n                for tile_m_idx in range(0, m_size, BLOCK_SIZE_M):\n                    m_block_size = tl.minimum(BLOCK_SIZE_M, m_size - tile_m_idx)\n\n                    if m_block_size > 0:\n                        # Global offset for this chunk\n                        m_global_offset = m_start + tile_m_idx\n                        m_offsets = m_global_offset + block_range_m\n\n                        if PERMUTE_X or PERMUTE_Y:\n                            # These will be used for loading and storing in permuted order\n                            gather_offsets = (\n                                tile_m_idx + block_range_m\n                            )  # NOTE: tile_m_idx is already strided by BLOCK_SIZE_M\n\n                            indices_to_gather = m_start + tl.max_contiguous(\n                                tl.multiple_of(gather_offsets % m_size, BLOCK_SIZE_M),\n                                BLOCK_SIZE_M,\n                            )\n                            # indices_to_gather = m_start + gather_offsets\n                            expert_token_idx = tl.load(\n                                gather_indices_ptr + indices_to_gather,\n                                mask = indices_to_gather < TOTAL_TOKENS,\n                            )\n                            expert_token_offsets = expert_token_idx[:, None]\n\n                            # Masks for permuted load and store\n                            row_load_mask = gather_offsets < m_size\n\n                            # We only take into account the following two cases: (PERMUTE_X and NOT PERMUTE_Y) and (NOT PERMUTE_X and PERMUTE_Y)\n                            # Hence, we can make the following simplifying assumptions when loading and storing\n                            # Note the different strides between the two cases: the offsets for loading and storing are flipped and the strides must also be adjusted\n                            if PERMUTE_X:\n                                x_row_load_idx = (\n                                    (expert_token_offsets // TOPK) * K\n                                )  # Permute on load from token -> expert order, divide by TOPK to index from original number of tokens\n                                dY_row_load_idx = m_offsets[:, None] * N\n                            else:\n                                x_row_load_idx = (\n                                    indices_to_gather[:, None] * K\n                                )  # Load in contiguous order (no permutation on load)\n                                dY_row_load_idx = expert_token_offsets * N\n\n                        else:\n                            x_row_load_idx = m_offsets[:, None] * K\n                            dY_row_load_idx = m_offsets[:, None] * N\n                            row_load_mask = block_range_m < m_block_size\n\n                        mk_mask = row_load_mask[:, None] & k_mask[None, :]\n                        mn_mask = row_load_mask[:, None] & n_mask[None, :]\n\n                        if USE_TMA_LOAD_X:\n                            x = x_desc.load([m_global_offset, k_offset])\n                        else:\n                            x = tl.load(\n                                x_ptr\n                                + x_row_load_idx\n                                + (k_offset + block_range_k)[None, :],\n                                mask = mk_mask,\n                            )\n\n                        if USE_TMA_LOAD_dY:\n                            dY = dY_desc.load([m_global_offset, n_offset])\n                        else:\n                            dY = tl.load(\n                                dY_ptr\n                                + dY_row_load_idx\n                                + (n_offset + block_range_n)[None, :],\n                                mask = mn_mask,\n                            )\n\n                        accumulator += tl.dot(\n                            dY.T.to(x.dtype),  # [BLOCK_N, BLOCK_M]\n                            x,  # [BLOCK_M, BLOCK_K]\n                        )\n\n                y = accumulator.to(output_dtype)\n                if USE_TMA_STORE:\n                    # Need to expand dims to match [E, N, K] shape\n                    y = tl.expand_dims(y, 0)\n                    dW_desc.store([expert_idx, n_offset, k_offset], y)\n                else:\n                    tl.store(\n                        dW_ptr\n                        # + (n_offset + offs_n)[:, None] * K\n                        + store_row_offs[:, None] * K\n                        + (k_offset + block_range_k)[None, :],\n                        y,\n                        mask = nk_mask,\n                    )\n\n\n_autotuned_grouped_gemm_dW_kernel = triton.autotune(\n    configs = get_dW_kernel_configs(),\n    prune_configs_by = {\"early_config_prune\": prune_kernel_configs_backward_dW},\n    # NOTE: NUM_TOKENS removed from key to avoid recompilation for every sequence length\n    key = [\"NUM_EXPERTS\", \"N\", \"K\", \"PERMUTE_X\", \"PERMUTE_Y\"],\n)(_grouped_gemm_dW_kernel)\n"
  },
  {
    "path": "unsloth/kernels/moe/grouped_gemm/kernels/forward.py",
    "content": "# SPDX-License-Identifier: GNU Affero General Public License v3.0\n# Copyright 2023-present the Unsloth team. All rights reserved.\n\nimport torch\nimport triton\nimport triton.language as tl\n\nfrom .autotuning import (\n    get_forward_configs,\n    prune_kernel_configs_fwd,\n)\n\n\n#\n# PERMUTE_X -> permute tokens so that they are ordered by expert\n# PERMUTE_Y -> permute output so that they are ordered by token\n# These are effectively the same thing: the former loads in permuted order, the latter stores in permuted order => we only need to define the permutation indices once\n# In the former, we use these row indices when loading X\n# For the latter, we use these row indices when storing Y\n# FUSE_MUL -> multiply routed outputs by their respective weights\n# topk_weights are in token order\n# Only account for the case when X is in expert order and we are permuting Y when fusing mul -- this precondition is checked in the interface\n@triton.jit\ndef _grouped_gemm_forward_kernel(\n    x_ptr,\n    w_ptr,\n    y_ptr,\n    # Variable depending on routed probs\n    m_sizes_ptr,\n    gather_indices_ptr,\n    topk_weights_ptr,\n    # Constant problem shapes\n    NUM_EXPERTS: tl.constexpr,\n    NUM_TOKENS,\n    TOPK: tl.constexpr,\n    N: tl.constexpr,\n    K: tl.constexpr,\n    NUM_SMS,\n    # Tuning params\n    BLOCK_SIZE_M: tl.constexpr,\n    BLOCK_SIZE_N: tl.constexpr,\n    BLOCK_SIZE_K: tl.constexpr,\n    PERMUTE_X: tl.constexpr = False,\n    PERMUTE_Y: tl.constexpr = False,\n    FUSE_MUL_PRE: tl.constexpr = False,\n    FUSE_MUL_POST: tl.constexpr = False,\n    USE_FAST_ACCUM: tl.constexpr = False,\n    USE_TMA_LOAD_W: tl.constexpr = False,\n    USE_TMA_LOAD_X: tl.constexpr = False,\n    USE_TMA_STORE: tl.constexpr = False,\n    acc_dtype: tl.constexpr = tl.float32,\n    FLATTEN: tl.constexpr = True,\n) -> None:\n    tl.static_assert(K % BLOCK_SIZE_K == 0)\n\n    TOTAL_TOKENS = NUM_TOKENS * TOPK\n    SHOULD_PERMUTE: tl.constexpr = PERMUTE_X or PERMUTE_Y\n    SHOULD_FUSE_MUL: tl.constexpr = FUSE_MUL_PRE or FUSE_MUL_POST\n    SHOULD_PERMUTE_OR_FUSE: tl.constexpr = SHOULD_PERMUTE or SHOULD_FUSE_MUL\n    # tl.static_print(\"SHOULD_PERMUTE\", PERMUTE_X, PERMUTE_Y, FUSE_MUL_PRE, FUSE_MUL_POST, SHOULD_PERMUTE, SHOULD_FUSE, SHOULD_PERMUTE_OR_FUSE)\n    tidx = tl.program_id(0)\n    output_dtype: tl.dtype = y_ptr.dtype.element_ty\n\n    # Create TMA descriptors for loading sorted tokens\n    # When using TMA load, we don't permute_x, so shape should be [TOTAL_TOKENS, K]\n    # Also, we are defining a single global descriptor with single block shape\n    # Need to check that this does not result in errors when crossing expert boundaries\n    if USE_TMA_LOAD_X:\n        x_desc = tl.make_tensor_descriptor(\n            x_ptr,\n            shape = [TOTAL_TOKENS, K],\n            strides = [K, 1],\n            block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_K],\n        )\n\n    if USE_TMA_LOAD_W:\n        expert_stride = N * K\n        w_desc = tl.make_tensor_descriptor(\n            w_ptr,\n            shape = [NUM_EXPERTS, N, K],\n            strides = [expert_stride, K, 1],\n            block_shape = [1, BLOCK_SIZE_N, BLOCK_SIZE_K],\n        )\n\n    m_end = 0\n    processed_tiles = 0\n    m_block_range = tl.arange(0, BLOCK_SIZE_M)\n\n    for expert_idx in tl.range(NUM_EXPERTS, flatten = FLATTEN):\n        m_start = m_end\n        m_size = tl.load(m_sizes_ptr + expert_idx).to(tl.int32)\n        m_end = m_start + m_size\n\n        if m_size > 0:\n            n_start = expert_idx * N\n\n            num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)\n            num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)\n            num_tiles_per_expert = num_m_tiles * num_n_tiles\n\n            # Need to create tma_store within loop since we need to predicate stores based on m_size\n            if USE_TMA_STORE:\n                y_desc = tl.make_tensor_descriptor(\n                    y_ptr,  # + m_start * N,\n                    shape = [m_end, N],\n                    strides = [N, 1],\n                    block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_N],\n                )\n\n            # Process tiles for this expert\n            while (\n                tidx >= processed_tiles\n                and tidx < processed_tiles + num_tiles_per_expert\n            ):\n                tile_idx = tidx - processed_tiles\n\n                # Check if L2 cache re-use for this order is optimal\n                tile_m_idx = tile_idx % num_m_tiles\n                tile_n_idx = tile_idx // num_m_tiles\n\n                if SHOULD_PERMUTE_OR_FUSE:\n                    # These will be used for loading and storing in permuted order\n                    gather_offsets = tile_m_idx * BLOCK_SIZE_M + m_block_range\n                    indices_to_gather = m_start + tl.max_contiguous(\n                        tl.multiple_of(gather_offsets % m_size, BLOCK_SIZE_M),\n                        BLOCK_SIZE_M,\n                    )\n                    expert_token_idx = tl.load(\n                        gather_indices_ptr + indices_to_gather,\n                        mask = indices_to_gather < TOTAL_TOKENS,\n                    )\n                    expert_token_offsets = expert_token_idx[:, None]\n\n                    # Masks for permuted load and store\n\n                    row_mask = gather_offsets < m_size\n                    row_mask = row_mask[:, None]\n\n                    # row_mask = indices_to_gather < m_end\n                    # row_mask = row_mask[:, None]\n\n                # We only take into account the following two cases: (PERMUTE_X and NOT PERMUTE_Y) and (NOT PERMUTE_X and PERMUTE_Y)\n                # Hence, we can make the following simplifying assumptions when loading and storing\n                # Note the different strides between the two cases: the offsets for loading and storing are flipped and the strides must also be adjusted\n                if PERMUTE_X:\n                    load_idx = (\n                        (expert_token_offsets // TOPK) * K\n                    )  # Permute on load from token -> expert order, divide by TOPK to index from original number of tokens\n                    store_idx = (\n                        indices_to_gather[:, None] * N\n                    )  # Store in contiguous order\n                else:\n                    off_am = tile_m_idx * BLOCK_SIZE_M\n                    if not PERMUTE_Y:\n                        # These will already be computed if permuting y\n                        offs_am = off_am + m_block_range\n                        row_mask = offs_am[:, None] < m_size\n                        row_idx = m_start + offs_am[:, None]\n                        store_idx = row_idx * N\n                        if not USE_TMA_LOAD_X:\n                            load_idx = row_idx * K\n\n                if PERMUTE_Y:\n                    if not USE_TMA_LOAD_X:\n                        load_idx = (\n                            indices_to_gather[:, None] * K\n                        )  # Load in contiguous order (no permutation on load)\n                    # offs_am = off_am + m_block_range\n                    # row_mask = offs_am[:, None] < m_size\n                    store_idx = (\n                        expert_token_offsets * N\n                    )  # Permute on store from expert -> token order\n\n                # We always load topk weights in expert order\n                # In the pre-multiplication case, we multiply permuted hidden states by weights before the first gemm\n                # In the post-multiplication case, we multiply permuted hidden states by weights after the second gemm\n                # In either case, the hidden states are grouped by expert, so we always permute on load of topk weights\n                if SHOULD_FUSE_MUL:\n                    topk_load_idx = expert_token_offsets\n\n                accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype = acc_dtype)\n\n                offs_k = tl.arange(0, BLOCK_SIZE_K)\n\n                if not USE_TMA_LOAD_X:\n                    x_ptrs = x_ptr + load_idx + offs_k[None, :]\n\n                if not USE_TMA_LOAD_W:\n                    offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n                    offs_bn = tl.max_contiguous(\n                        tl.multiple_of(offs_bn % N, BLOCK_SIZE_N), BLOCK_SIZE_N\n                    )\n                    w_ptrs = w_ptr + (n_start + offs_bn[:, None]) * K + offs_k[None, :]\n\n                for k_offset in range(0, K, BLOCK_SIZE_K):\n                    if not USE_TMA_LOAD_X:\n                        x = tl.load(x_ptrs, mask = row_mask)\n                    else:\n                        x = x_desc.load([m_start + off_am, k_offset])\n\n                    if FUSE_MUL_PRE:\n                        # Check for correct broadcasting\n                        topk_weights = tl.load(\n                            topk_weights_ptr + topk_load_idx, mask = row_mask\n                        )\n                        x *= topk_weights.to(x.dtype)\n\n                    if not USE_TMA_LOAD_W:\n                        w = tl.load(w_ptrs, mask = offs_bn[:, None] < N)\n                    else:\n                        w = w_desc.load(\n                            [expert_idx, tile_n_idx * BLOCK_SIZE_N, k_offset]\n                        )\n                        w = tl.reshape(w, (BLOCK_SIZE_N, BLOCK_SIZE_K))\n\n                    x = x.to(w.dtype)\n                    accumulator += tl.dot(x, w.T)\n\n                    if not USE_TMA_LOAD_X:\n                        x_ptrs += BLOCK_SIZE_K\n\n                    if not USE_TMA_LOAD_W:\n                        w_ptrs += BLOCK_SIZE_K\n\n                y = accumulator.to(output_dtype)\n\n                # NOTE: order of fusing multiplication is important\n                # Fusing before accumulator dtype conversion results in numerical diffs\n                if FUSE_MUL_POST:\n                    # Check for correct broadcasting\n                    topk_weights = tl.load(\n                        topk_weights_ptr + topk_load_idx, mask = row_mask\n                    )\n                    y *= topk_weights.to(output_dtype)\n\n                offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n                store_mask = row_mask & (offs_bn[None, :] < N)\n\n                if USE_TMA_STORE:\n                    offset_m = tile_m_idx * BLOCK_SIZE_M  # .to(tl.int32)\n                    offset_n = tile_n_idx * BLOCK_SIZE_N  # .to(tl.int32)\n                    y_desc.store([m_start + offset_m, offset_n], y)\n                else:\n                    tl.store(\n                        y_ptr + store_idx + offs_bn[None, :],\n                        y,\n                        mask = store_mask,\n                    )\n                tidx += NUM_SMS\n\n            processed_tiles += num_tiles_per_expert\n\n\n_autotuned_grouped_gemm_forward_kernel = triton.autotune(\n    configs = get_forward_configs(),\n    prune_configs_by = {\"early_config_prune\": prune_kernel_configs_fwd},\n    # NOTE: NUM_TOKENS removed from key to avoid recompilation for every sequence length\n    # The kernel handles variable token counts via m_sizes and tile-based processing\n    key = [\n        \"NUM_EXPERTS\",\n        \"N\",\n        \"K\",\n        \"PERMUTE_X\",\n        \"PERMUTE_Y\",\n        \"FUSE_MUL_POST\",\n    ],\n)(_grouped_gemm_forward_kernel)\n"
  },
  {
    "path": "unsloth/kernels/moe/grouped_gemm/kernels/tuning.py",
    "content": "# SPDX-License-Identifier: GNU Affero General Public License v3.0\n# Copyright 2023-present the Unsloth team. All rights reserved.\n\n\"\"\"\nManual tuning utils\n\"\"\"\n\nfrom collections import OrderedDict\nfrom dataclasses import asdict, dataclass, fields\nfrom itertools import product\nfrom typing import Optional\n\nimport pandas as pd\nimport torch\nimport triton\nfrom triton.runtime.errors import OutOfResources\n\nfrom .autotuning import (\n    BOOLS,\n    DEFAULT_K_BLOCK_SIZES,\n    DEFAULT_M_BLOCK_SIZES,\n    DEFAULT_N_BLOCK_SIZES,\n    DEFAULT_NUM_STAGES,\n    DEFAULT_NUM_WARPS,\n)\n\n\n@dataclass\nclass DeviceProperties:\n    NUM_SM: int\n    NUM_REGS: int\n    SIZE_SMEM: int\n    WARP_SIZE: int\n\n\n_DEVICE_PROPERTIES: Optional[DeviceProperties] = None\n\n\ndef get_device_properties():\n    global _DEVICE_PROPERTIES\n    if _DEVICE_PROPERTIES is None:\n        properties = triton.runtime.driver.active.utils.get_device_properties(\n            torch.cuda.current_device()\n        )\n        NUM_SM = properties[\"multiprocessor_count\"]\n        NUM_REGS = properties[\"max_num_regs\"]\n        SIZE_SMEM = properties[\"max_shared_mem\"]\n        WARP_SIZE = properties[\"warpSize\"]\n        _DEVICE_PROPERTIES = DeviceProperties(NUM_SM, NUM_REGS, SIZE_SMEM, WARP_SIZE)\n    return _DEVICE_PROPERTIES\n\n\n@dataclass\nclass KernelConfig:\n    BLOCK_SIZE_M: int = 32\n    BLOCK_SIZE_N: int = 32\n    BLOCK_SIZE_K: int = 32\n    num_warps: int = 4\n    num_stages: int = 2\n    flatten: bool = True\n    permute_x: bool = False\n    permute_y: bool = False\n    fuse_mul_post: bool = False\n    use_tma_store: bool = False\n\n    def to_string(self, include_tuning_params: bool = False, include_tma: bool = False):\n        s = []\n        if self.permute_x:\n            s.append(\"permute_x\")\n        if self.permute_y:\n            s.append(\"permute_y\")\n        if include_tuning_params:\n            s.append(\n                f\"BLOCK_SIZE_M={self.BLOCK_SIZE_M},BLOCK_SIZE_N={self.BLOCK_SIZE_N},BLOCK_SIZE_K={self.BLOCK_SIZE_K},num_warps={self.num_warps},num_stages={self.num_stages},flatten={self.flatten}\"\n            )\n        if include_tma:\n            for f in fields(self):\n                if f.name.startswith(\"use_tma_\"):\n                    if getattr(self, f.name):\n                        s.append(f.name)\n        return \",\".join(s)\n\n\n@dataclass\nclass KernelConfigForward(KernelConfig):\n    use_tma_load_w: bool = False\n    use_tma_load_x: bool = False\n\n\n@dataclass\nclass KernelConfigBackward_dW(KernelConfig):\n    use_tma_load_dy: bool = False\n    use_tma_load_x: bool = False\n\n\n@dataclass\nclass KernelConfigBackward_dX(KernelConfig):\n    use_tma_load_dy: bool = False\n    use_tma_load_w: bool = False\n\n\n@dataclass\nclass KernelResult:\n    torch_time: float\n    triton_time: float\n    speedup: float\n    kernel_config: KernelConfig\n\n    def to_dict(self):\n        return OrderedDict(\n            **asdict(self.kernel_config),\n            torch_time = self.torch_time,\n            triton_time = self.triton_time,\n            speedup = self.speedup,\n        )\n\n    @staticmethod\n    def to_dataframe(\n        results: list[\"KernelResult\"], sort_by: str = \"speedup\", ascending: bool = False\n    ):\n        df = pd.DataFrame([result.to_dict() for result in results])\n        df = df.sort_values(by = sort_by, ascending = ascending)\n        return df\n\n    @staticmethod\n    def to_csv(\n        results: list[\"KernelResult\"],\n        sort_by: str = \"speedup\",\n        ascending: bool = False,\n        filename: str = \"results.csv\",\n    ):\n        df = KernelResult.to_dataframe(results, sort_by, ascending)\n        df.to_csv(filename, index = False)\n\n    @staticmethod\n    def print_table(\n        results: list[\"KernelResult\"],\n        sort_by: str = \"speedup\",\n        ascending: bool = False,\n        num_results: int = 10,\n    ):\n        df = KernelResult.to_dataframe(results, sort_by, ascending)\n        print(df.head(num_results).to_string(index = False))\n\n\ndef get_kernel_configs(\n    BLOCK_M = DEFAULT_M_BLOCK_SIZES,\n    BLOCK_N = DEFAULT_N_BLOCK_SIZES,\n    BLOCK_K = DEFAULT_K_BLOCK_SIZES,\n    num_warps = DEFAULT_NUM_WARPS,\n    num_stages = DEFAULT_NUM_STAGES,\n    use_tma_loads = BOOLS,\n    fuse_permute = BOOLS,\n):\n    kernel_configs_fwd = []\n    kernel_configs_backward_dW = []\n    kernel_configs_backward_dX = []\n    for block_m, block_n, block_k, w, s, use_tma_load, permute in product(\n        BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages, use_tma_loads, fuse_permute\n    ):\n        kernel_configs_fwd.append(\n            KernelConfigForward(\n                BLOCK_SIZE_M = block_m,\n                BLOCK_SIZE_N = block_n,\n                BLOCK_SIZE_K = block_k,\n                num_warps = w,\n                num_stages = s,\n                use_tma_load_x = use_tma_load,\n                use_tma_load_w = use_tma_load,\n                use_tma_store = False,\n                permute_x = permute,\n                permute_y = permute,\n            )\n        )\n        kernel_configs_backward_dW.append(\n            KernelConfigBackward_dW(\n                BLOCK_SIZE_M = block_m,\n                BLOCK_SIZE_N = block_n,\n                BLOCK_SIZE_K = block_k,\n                num_warps = w,\n                num_stages = s,\n                use_tma_load_dy = use_tma_load,\n                use_tma_load_x = use_tma_load,\n                use_tma_store = False,\n                permute_x = permute,\n                permute_y = permute,\n            )\n        )\n        kernel_configs_backward_dX.append(\n            KernelConfigBackward_dX(\n                BLOCK_SIZE_M = block_m,\n                BLOCK_SIZE_N = block_n,\n                BLOCK_SIZE_K = block_k,\n                num_warps = w,\n                num_stages = s,\n                use_tma_load_dy = use_tma_load,\n                use_tma_load_w = use_tma_load,\n                use_tma_store = False,\n                permute_x = permute,\n                permute_y = permute,\n            )\n        )\n\n    kernel_configs_fwd = prune_kernel_configs_fwd(kernel_configs_fwd)\n    kernel_configs_backward_dW = prune_kernel_configs_backward_dW(\n        kernel_configs_backward_dW\n    )\n    kernel_configs_backward_dX = prune_kernel_configs_backward_dX(\n        kernel_configs_backward_dX\n    )\n    return kernel_configs_fwd, kernel_configs_backward_dW, kernel_configs_backward_dX\n\n\ndef prune_kernel_configs_fwd(configs: list[KernelConfigForward]):\n    pruned_configs = []\n    for config in configs:\n        if config.use_tma_load_x and config.permute_x:\n            continue\n        if config.permute_x and config.permute_y:\n            continue\n        if config.use_tma_store and config.permute_y:\n            continue\n        pruned_configs.append(config)\n    return pruned_configs\n\n\ndef prune_kernel_configs_backward_dX(configs: list[KernelConfigBackward_dX]):\n    pruned_configs = []\n    for config in configs:\n        if config.use_tma_load_dy and config.permute_y:\n            continue\n        if config.permute_x and config.permute_y:\n            continue\n        if config.use_tma_store and config.permute_x:\n            continue\n        pruned_configs.append(config)\n    return pruned_configs\n\n\ndef prune_kernel_configs_backward_dW(configs: list[KernelConfigBackward_dW]):\n    pruned_configs = []\n    for config in configs:\n        if config.use_tma_load_dy and config.permute_y:\n            continue\n        if config.use_tma_load_x and config.permute_x:\n            continue\n        if config.permute_x and config.permute_y:\n            continue\n        pruned_configs.append(config)\n    return pruned_configs\n\n\nclass TritonTuningContext:\n    def __init__(self, kernel_config: KernelConfig):\n        self.kernel_config = kernel_config\n        self.success = True\n\n    def __enter__(self):\n        # Setup code can be added here if needed\n        return self\n\n    def __exit__(self, exc_type, exc_value, traceback):\n        if exc_type is OutOfResources:\n            name = exc_value.name\n            required = exc_value.required\n            limit = exc_value.limit\n            print(\n                f\"Kernel config {self.kernel_config} failed: {name}, required: {required}, limit: {limit}\"\n            )\n            self.success = False\n        elif exc_type is not None:\n            print(\n                f\"Error running Triton grouped GEMM for kernel config: {self.kernel_config}: {exc_value}\"\n            )\n            self.success = False\n        # Return False to propagate exceptions, True to suppress them\n        return True\n"
  },
  {
    "path": "unsloth/kernels/moe/grouped_gemm/reference/__init__.py",
    "content": ""
  },
  {
    "path": "unsloth/kernels/moe/grouped_gemm/reference/layers/llama4_moe.py",
    "content": "# SPDX-License-Identifier: GNU Affero General Public License v3.0\n# Copyright 2023-present the Unsloth team. All rights reserved.\n\nfrom dataclasses import dataclass\nfrom typing import Tuple\n\nimport torch\nimport torch.nn.functional as F\nfrom transformers.models.llama4 import Llama4TextConfig\nfrom transformers.models.llama4.modeling_llama4 import Llama4TextMoe\n\nfrom ...interface import grouped_gemm\nfrom ...kernels.tuning import (\n    KernelConfigBackward_dW,\n    KernelConfigBackward_dX,\n    KernelConfigForward,\n)\nfrom ..moe_ops import (\n    get_routing_indices,\n    permute,\n    torch_grouped_gemm,\n    unpermute,\n)\n\n\"\"\"\nReference implementation of Llama4 MoE block using triton grouped gemm.\n\n`Llama4GroupedGemmTextMoe` is the HF `Llama4TextMoe` block implemented with a torch-native grouped gemm.\n`Llama4TritonTextMoe` is the HF `Llama4TextMoe` implemented with triton grouped gemm.\n\"\"\"\n\n\n@dataclass\nclass Llama4MoeResult:\n    token_counts_by_expert: torch.Tensor\n    gather_indices: torch.Tensor\n    topk_weights: torch.Tensor\n    hidden_states_after_weight_merge: torch.Tensor\n    first_gemm: torch.Tensor\n    intermediate: torch.Tensor\n    second_gemm: torch.Tensor\n    hidden_states_unpermute: torch.Tensor\n    shared_expert_out: torch.Tensor\n    final_out: torch.Tensor\n    router_logits: torch.Tensor = None\n\n\nclass Llama4GroupedGemmTextMoe(Llama4TextMoe):\n    EXPERT_WEIGHT_NAMES = [\"experts.gate_up_proj\", \"experts.down_proj\"]\n\n    def __init__(\n        self,\n        config: Llama4TextConfig,\n        overlap_router_shared = False,\n        verbose = False,\n        debug = False,\n    ):\n        super().__init__(config)\n        self.overlap_router_shared = overlap_router_shared\n        self.verbose = verbose\n        self.debug = debug\n\n        # Permute in-place expert weights\n        E, K, N = self.num_experts, self.hidden_dim, self.experts.expert_dim\n        assert self.experts.gate_up_proj.shape == torch.Size(\n            [E, K, 2 * N]\n        ), f\"{self.experts.gate_up_proj.shape} != {[E, K, 2 * N]}\"\n        permuted_shape = [E, 2 * N, K]\n        permuted_stride = [2 * N * K, K, 1]\n        if verbose:\n            print(\n                f\"Changing gate_up_proj from {self.experts.gate_up_proj.size()}:{self.experts.gate_up_proj.stride()} to {permuted_shape}:{permuted_stride}\"\n            )\n        with torch.no_grad():\n            self.experts.gate_up_proj.as_strided_(permuted_shape, permuted_stride)\n\n        if verbose:\n            print(\n                f\"{self.experts.gate_up_proj.shape}:{self.experts.gate_up_proj.stride()}\"\n            )\n\n        assert self.experts.down_proj.shape == torch.Size(\n            [E, N, K]\n        ), f\"{self.experts.down_proj.shape} != {[E, N, K]}\"\n        permuted_shape = [E, K, N]\n        permuted_stride = [K * N, N, 1]\n        if verbose:\n            print(\n                f\"Changing down_proj from {self.experts.down_proj.size()}:{self.experts.down_proj.stride()} to {permuted_shape}:{permuted_stride}\"\n            )\n\n        with torch.no_grad():\n            self.experts.down_proj.as_strided_(permuted_shape, permuted_stride)\n\n        if verbose:\n            print(f\"{self.experts.down_proj.shape}:{self.experts.down_proj.stride()}\")\n\n        if overlap_router_shared:\n            self.shared_expert_stream = torch.cuda.Stream()\n            self.default_event = torch.cuda.Event()\n            self.shared_expert_end_event = torch.cuda.Event()\n\n    @torch.no_grad\n    def copy_weights(self, other: Llama4TextMoe):\n        for name, param_to_copy in other.named_parameters():\n            if self.verbose:\n                print(f\"Copying {name} with shape {param_to_copy.shape}\")\n            param = self.get_parameter(name)\n\n            if any(n in name for n in self.EXPERT_WEIGHT_NAMES):\n                param_to_copy = param_to_copy.permute(0, 2, 1)\n\n            assert (\n                param.shape == param_to_copy.shape\n            ), f\"{param.shape} != {param_to_copy.shape}\"\n            param.copy_(param_to_copy)\n\n        return self\n\n    def check_weights(self, other: Llama4TextMoe):\n        for name, other_param in other.named_parameters():\n            if any(n in name for n in self.EXPERT_WEIGHT_NAMES):\n                other_param = other_param.permute(0, 2, 1)\n            param = self.get_parameter(name)\n            assert param.equal(other_param), f\"Param {name} not equal!\"\n            assert param.is_contiguous(), f\"{name} not contiguous!\"\n\n    def act_and_mul(self, x: torch.Tensor) -> torch.Tensor:\n        assert x.shape[-1] == 2 * self.experts.expert_dim\n        gate_proj = x[..., : self.experts.expert_dim]\n        up_proj = x[..., self.experts.expert_dim :]\n        return self.experts.act_fn(gate_proj) * up_proj\n\n    def run_router(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        # router_logits: (batch * sequence_length, n_experts)\n        hidden_states = hidden_states.view(-1, self.hidden_dim)\n        router_logits = self.router(hidden_states)\n        routing_weights, selected_experts = torch.topk(\n            router_logits, self.top_k, dim = -1\n        )\n\n        routing_weights = F.sigmoid(routing_weights.float()).to(hidden_states.dtype)\n\n        return router_logits, routing_weights, selected_experts\n\n    def get_token_counts_and_gather_indices(\n        self, selected_experts: torch.Tensor\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        token_counts_by_expert, gather_indices = get_routing_indices(\n            selected_experts, self.num_experts\n        )\n        assert not token_counts_by_expert.requires_grad\n        assert not gather_indices.requires_grad\n        return token_counts_by_expert, gather_indices\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        \"\"\" \"\"\"\n        batch_size, sequence_length, hidden_dim = hidden_states.shape\n        num_tokens = batch_size * sequence_length\n        total_tokens = num_tokens * self.top_k\n        hidden_states = hidden_states.view(-1, hidden_dim)\n\n        if self.overlap_router_shared:\n            # Marker for all prior ops on default stream\n            self.default_event.record()\n\n        router_logits, routing_weights, selected_experts = self.run_router(\n            hidden_states\n        )\n        assert routing_weights.shape == (\n            num_tokens,\n            self.top_k,\n        ), f\"{routing_weights.shape} != {(num_tokens, self.top_k)}\"\n\n        if self.overlap_router_shared:\n            with torch.cuda.stream(self.shared_expert_stream):\n                # Ensure prior kernels on default stream complete\n                self.default_event.wait()\n\n                shared_expert_out = self.shared_expert(hidden_states)\n                # Ensure hidden states remains valid on this stream\n                hidden_states.record_stream(self.shared_expert_stream)\n\n                self.shared_expert_end_event.record()\n\n            # Ensure shared expert still valid on default stream\n            shared_expert_out.record_stream(torch.cuda.current_stream())\n            self.shared_expert_end_event.wait()\n        else:\n            shared_expert_out = self.shared_expert(hidden_states)\n\n        hidden_states = (\n            hidden_states.view(num_tokens, self.top_k, hidden_dim)\n            * routing_weights[..., None]\n        )\n\n        if self.top_k > 1:\n            hidden_states = hidden_states.sum(dim = 1)\n        hidden_states_after_weight_merge = hidden_states.view(-1, hidden_dim)\n\n        # 1. Compute tokens per expert and indices for gathering tokes from token order to expert order\n        # NOTE: these are auxiliary data structs which don't need to be recorded in autograd graph\n        token_counts_by_expert, gather_indices = (\n            self.get_token_counts_and_gather_indices(selected_experts)\n        )\n\n        # 2. Permute tokens from token order to expert order\n        hidden_states = permute(\n            hidden_states_after_weight_merge, gather_indices, self.top_k\n        )\n        assert hidden_states.shape == (total_tokens, hidden_dim)\n\n        # Start expert computation\n        first_gemm = torch_grouped_gemm(\n            X = hidden_states, W = self.experts.gate_up_proj, m_sizes = token_counts_by_expert\n        )\n        assert first_gemm.shape == (total_tokens, 2 * self.experts.expert_dim)\n\n        intermediate = self.act_and_mul(first_gemm)\n        assert intermediate.shape == (total_tokens, self.experts.expert_dim)\n\n        # See comment above\n        second_gemm = torch_grouped_gemm(\n            X = intermediate, W = self.experts.down_proj, m_sizes = token_counts_by_expert\n        )\n        assert second_gemm.shape == (total_tokens, hidden_dim)\n\n        # Post-processing\n        hidden_states_unpermute = unpermute(second_gemm, gather_indices)\n        assert hidden_states_unpermute.shape == (total_tokens, hidden_dim)\n        # grouped_gemm_out = hidden_states.view(batch_size, sequence_length, hidden_dim)\n\n        final_out = hidden_states_unpermute + shared_expert_out\n\n        result = (\n            Llama4MoeResult(\n                token_counts_by_expert = token_counts_by_expert,\n                gather_indices = gather_indices,\n                topk_weights = routing_weights,\n                hidden_states_after_weight_merge = hidden_states_after_weight_merge,\n                first_gemm = first_gemm,\n                intermediate = intermediate,\n                second_gemm = second_gemm,\n                hidden_states_unpermute = hidden_states_unpermute,\n                shared_expert_out = shared_expert_out,\n                final_out = final_out,\n                router_logits = router_logits,\n            )\n            if self.debug\n            else (final_out, routing_weights)\n        )\n\n        return result\n\n\nclass Llama4TritonTextMoe(Llama4GroupedGemmTextMoe):\n    def __init__(\n        self,\n        config: Llama4TextConfig,\n        overlap_router_shared = False,\n        permute_x: bool = False,\n        permute_y: bool = True,\n        autotune: bool = True,\n        kernel_config_fwd: KernelConfigForward = None,\n        kernel_config_bwd_dW: KernelConfigBackward_dW = None,\n        kernel_config_bwd_dX: KernelConfigBackward_dX = None,\n        dW_only: bool = False,\n        dX_only: bool = False,\n        verbose = False,\n    ):\n        super().__init__(config, overlap_router_shared = overlap_router_shared)\n        assert not permute_x, \"Llama4 triton grouped gemm does not support permute x due to pre-multiplication of router weights\"\n        self.permute_x = permute_x\n        self.permute_y = permute_y\n        self.autotune = autotune\n        if not autotune:\n            assert (\n                kernel_config_fwd is not None\n                and kernel_config_bwd_dW is not None\n                and kernel_config_bwd_dX is not None\n            ), \"Kernel configs must be provided if autotune is False\"\n        self.kernel_config_fwd = kernel_config_fwd\n        self.kernel_config_bwd_dW = kernel_config_bwd_dW\n        self.kernel_config_bwd_dX = kernel_config_bwd_dX\n        self.dW_only = dW_only\n        self.dX_only = dX_only\n\n    @torch.no_grad\n    def copy_weights(self, other: Llama4TextMoe):\n        for name, param_to_copy in other.named_parameters():\n            if self.verbose:\n                print(f\"Copying {name} with shape {param_to_copy.shape}\")\n            param = self.get_parameter(name)\n\n            if any(n in name for n in self.EXPERT_WEIGHT_NAMES):\n                param_to_copy = param_to_copy.permute(0, 2, 1)\n\n            assert (\n                param.shape == param_to_copy.shape\n            ), f\"{param.shape} != {param_to_copy.shape}\"\n            param.copy_(param_to_copy)\n\n        return self\n\n    def check_weights(self, other: Llama4TextMoe):\n        for name, other_param in other.named_parameters():\n            if any(n in name for n in self.EXPERT_WEIGHT_NAMES):\n                other_param = other_param.permute(0, 2, 1)\n            param = self.get_parameter(name)\n            assert param.equal(other_param), f\"Param {name} not equal!\"\n            assert param.is_contiguous(), f\"{name} not contiguous!\"\n\n    def act_and_mul(self, x: torch.Tensor) -> torch.Tensor:\n        assert x.shape[-1] == 2 * self.experts.expert_dim\n        gate_proj = x[..., : self.experts.expert_dim]\n        up_proj = x[..., self.experts.expert_dim :]\n        return self.experts.act_fn(gate_proj) * up_proj\n\n    def run_router(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        # router_logits: (batch * sequence_length, n_experts)\n        hidden_states = hidden_states.view(-1, self.hidden_dim)\n        router_logits = self.router(hidden_states)\n        routing_weights, selected_experts = torch.topk(\n            router_logits, self.top_k, dim = -1\n        )\n\n        routing_weights = F.sigmoid(routing_weights.float()).to(hidden_states.dtype)\n\n        return router_logits, routing_weights, selected_experts\n\n    def get_token_counts_and_gather_indices(\n        self, selected_experts: torch.Tensor\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        token_counts_by_expert, gather_indices = get_routing_indices(\n            selected_experts, self.num_experts\n        )\n        assert not token_counts_by_expert.requires_grad\n        assert not gather_indices.requires_grad\n        return token_counts_by_expert, gather_indices\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        \"\"\" \"\"\"\n        batch_size, sequence_length, hidden_dim = hidden_states.shape\n        num_tokens = batch_size * sequence_length\n        total_tokens = num_tokens * self.top_k\n        hidden_states = hidden_states.view(-1, hidden_dim)\n\n        if self.overlap_router_shared:\n            # Marker for all prior ops on default stream\n            self.default_event.record()\n\n        router_logits, routing_weights, selected_experts = self.run_router(\n            hidden_states\n        )\n        assert routing_weights.shape == (\n            num_tokens,\n            self.top_k,\n        ), f\"{routing_weights.shape} != {(num_tokens, self.top_k)}\"\n\n        if self.overlap_router_shared:\n            with torch.cuda.stream(self.shared_expert_stream):\n                # Ensure prior kernels on default stream complete\n                self.default_event.wait()\n\n                shared_expert_out = self.shared_expert(hidden_states)\n                # Ensure hidden states remains valid on this stream\n                hidden_states.record_stream(self.shared_expert_stream)\n\n                self.shared_expert_end_event.record()\n\n            # Ensure shared expert still valid on default stream\n            shared_expert_out.record_stream(torch.cuda.current_stream())\n            self.shared_expert_end_event.wait()\n        else:\n            shared_expert_out = self.shared_expert(hidden_states)\n\n        hidden_states = (\n            hidden_states.view(num_tokens, self.top_k, hidden_dim)\n            * routing_weights[..., None]\n        )\n\n        if self.top_k > 1:\n            hidden_states = hidden_states.sum(dim = 1)\n        hidden_states = hidden_states.view(-1, hidden_dim)\n\n        # 1. Compute tokens per expert and indices for gathering tokes from token order to expert order\n        # NOTE: these are auxiliary data structs which don't need to be recorded in autograd graph\n        token_counts_by_expert, gather_indices = (\n            self.get_token_counts_and_gather_indices(selected_experts)\n        )\n\n        # 2. Permute tokens from token order to expert order\n        hidden_states = permute(hidden_states, gather_indices, self.top_k)\n        assert hidden_states.shape == (total_tokens, hidden_dim)\n\n        # Start expert computation\n        hidden_states = grouped_gemm(\n            X = hidden_states,\n            W = self.experts.gate_up_proj,\n            m_sizes = token_counts_by_expert,\n            gather_indices = gather_indices,\n            topk = self.top_k,\n            permute_x = self.permute_x,\n            permute_y = False,  # output of first grouped gemm should never be permuted\n            autotune = self.autotune,\n            kernel_config_fwd = self.kernel_config_fwd,\n            kernel_config_bwd_dW = self.kernel_config_bwd_dW,\n            kernel_config_bwd_dX = self.kernel_config_bwd_dX,\n            is_first_gemm = True,\n            dW_only = self.dW_only,\n            dX_only = self.dX_only,\n        )\n        hidden_states = self.act_and_mul(hidden_states)\n        hidden_states = grouped_gemm(\n            X = hidden_states,\n            W = self.experts.down_proj,\n            m_sizes = token_counts_by_expert,\n            gather_indices = gather_indices,\n            topk = self.top_k,\n            permute_x = False,\n            permute_y = self.permute_y,\n            autotune = self.autotune,\n            kernel_config_fwd = self.kernel_config_fwd,\n            kernel_config_bwd_dW = self.kernel_config_bwd_dW,\n            kernel_config_bwd_dX = self.kernel_config_bwd_dX,\n            is_first_gemm = False,\n            dW_only = self.dW_only,\n            dX_only = self.dX_only,\n        )\n\n        # Post-processing\n        # 1. Unpermute from expert order to token order\n        if not self.permute_y:\n            hidden_states = unpermute(hidden_states, gather_indices)\n        hidden_states += shared_expert_out\n\n        return hidden_states, routing_weights\n"
  },
  {
    "path": "unsloth/kernels/moe/grouped_gemm/reference/layers/qwen3_moe.py",
    "content": "# SPDX-License-Identifier: GNU Affero General Public License v3.0\n# Copyright 2023-present the Unsloth team. All rights reserved.\n\nfrom dataclasses import dataclass\nfrom typing import Tuple\n\nimport torch\nimport torch.nn.functional as F\nfrom transformers.models.qwen3_moe.configuration_qwen3_moe import Qwen3MoeConfig\nfrom transformers.models.qwen3_moe.modeling_qwen3_moe import (\n    ACT2FN,\n    Qwen3MoeSparseMoeBlock,\n)\n\nfrom ...interface import grouped_gemm\nfrom ...kernels.tuning import (\n    KernelConfigBackward_dW,\n    KernelConfigBackward_dX,\n    KernelConfigForward,\n)\nfrom ..moe_ops import (\n    get_routing_indices,\n    permute,\n    torch_grouped_gemm,\n    unpermute,\n)\n\n\"\"\"\nReference implementation of HF Qwen3 MoE block using grouped gemm.\n\nThe Qwen3MoeGroupedGEMMBlock is a reference torch-native implementation.\nQwen3MoeFusedGroupedGEMMBlock is a version using the triton grouped gemm kernel.\n\nNOTE: This is NOT to be used for production as it contains many extra checks and saves all intermediate results for debugging.\n\"\"\"\n\n\n@dataclass\nclass GroupedGEMMResult:\n    token_counts_by_expert: torch.Tensor\n    gather_indices: torch.Tensor\n    topk_weights: torch.Tensor\n    first_gemm: torch.Tensor\n    intermediate: torch.Tensor\n    second_gemm: torch.Tensor\n    hidden_states_unpermute: torch.Tensor\n    hidden_states: torch.Tensor  # final output\n\n\nclass Qwen3MoeGroupedGEMMBlock(torch.nn.Module):\n    def __init__(\n        self,\n        config,\n        gate: torch.Tensor,\n        gate_up_proj: torch.Tensor,\n        down_proj: torch.Tensor,\n    ):\n        super().__init__()\n        self.num_experts = config.num_experts\n        self.top_k = config.num_experts_per_tok\n        self.norm_topk_prob = config.norm_topk_prob\n        self.hidden_size = config.hidden_size\n        self.moe_intermediate_size = config.moe_intermediate_size\n\n        assert gate.shape == (config.num_experts, config.hidden_size)\n        assert gate_up_proj.shape == (\n            config.num_experts,\n            2 * config.moe_intermediate_size,\n            config.hidden_size,\n        )\n        assert down_proj.shape == (\n            config.num_experts,\n            config.hidden_size,\n            config.moe_intermediate_size,\n        )\n\n        # gating\n        self.gate = torch.nn.Parameter(gate)\n\n        # experts\n        self.gate_up_proj = torch.nn.Parameter(gate_up_proj, requires_grad = True)\n        self.down_proj = torch.nn.Parameter(down_proj, requires_grad = True)\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    @staticmethod\n    def extract_hf_weights(moe_block: Qwen3MoeSparseMoeBlock):\n        config: Qwen3MoeConfig = moe_block.experts[0].config\n        num_experts = config.num_experts\n\n        gate = moe_block.gate.weight.data\n        gate_proj = torch.stack(\n            [moe_block.experts[i].gate_proj.weight.data for i in range(num_experts)],\n            dim = 0,\n        )\n        up_proj = torch.stack(\n            [moe_block.experts[i].up_proj.weight.data for i in range(num_experts)],\n            dim = 0,\n        )\n        down_proj = torch.stack(\n            [moe_block.experts[i].down_proj.weight.data for i in range(num_experts)],\n            dim = 0,\n        )\n        gate_up_proj = torch.cat([gate_proj, up_proj], dim = 1)\n        return gate, gate_up_proj, down_proj\n\n    @classmethod\n    def from_hf(cls, moe_block: Qwen3MoeSparseMoeBlock):\n        config: Qwen3MoeConfig = moe_block.experts[0].config\n        gate, gate_up_proj, down_proj = cls.extract_hf_weights(moe_block)\n        return cls(config, gate, gate_up_proj, down_proj)\n\n    def check_weights(self, moe_block: Qwen3MoeSparseMoeBlock):\n        for i in range(self.num_experts):\n            assert self.gate_up_proj[i].equal(\n                torch.cat(\n                    [\n                        moe_block.experts[i].gate_proj.weight.data,\n                        moe_block.experts[i].up_proj.weight.data,\n                    ],\n                    dim = 0,\n                )\n            )\n            assert self.down_proj[i].equal(moe_block.experts[i].down_proj.weight.data)\n\n    def act_and_mul(self, x: torch.Tensor) -> torch.Tensor:\n        assert x.shape[-1] == 2 * self.moe_intermediate_size\n        gate_proj = x[..., : self.moe_intermediate_size]\n        up_proj = x[..., self.moe_intermediate_size :]\n        return self.act_fn(gate_proj) * up_proj\n\n    def run_router(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        # router_logits: (batch * sequence_length, n_experts)\n        router_logits = torch.nn.functional.linear(hidden_states, self.gate)\n\n        routing_weights = F.softmax(router_logits, dim = 1, dtype = torch.float)\n        routing_weights, selected_experts = torch.topk(\n            routing_weights, self.top_k, dim = -1\n        )\n        if self.norm_topk_prob:  # only diff with mixtral sparse moe block!\n            routing_weights /= routing_weights.sum(dim = -1, keepdim = True)\n        # we cast back to the input dtype\n        routing_weights = routing_weights.to(hidden_states.dtype)\n\n        return router_logits, routing_weights, selected_experts\n\n    def get_token_counts_and_gather_indices(\n        self, selected_experts: torch.Tensor\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        token_counts_by_expert, gather_indices = get_routing_indices(\n            selected_experts, self.num_experts\n        )\n        assert not token_counts_by_expert.requires_grad\n        assert not gather_indices.requires_grad\n        return token_counts_by_expert, gather_indices\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        \"\"\" \"\"\"\n        batch_size, sequence_length, hidden_dim = hidden_states.shape\n        num_tokens = batch_size * sequence_length\n        total_tokens = num_tokens * self.top_k\n\n        hidden_states = hidden_states.view(-1, hidden_dim)\n\n        router_logits, routing_weights, selected_experts = self.run_router(\n            hidden_states\n        )\n\n        # 1. Compute tokens per expert and indices for gathering tokes from token order to expert order\n        # NOTE: these are auxiliary data structs which don't need to be recorded in autograd graph\n        token_counts_by_expert, gather_indices = (\n            self.get_token_counts_and_gather_indices(selected_experts)\n        )\n\n        # 2. Permute tokens from token order to expert order\n        hidden_states = permute(hidden_states, gather_indices, self.top_k)\n        assert hidden_states.shape == (total_tokens, hidden_dim)\n\n        # Start expert computation\n        first_gemm = torch_grouped_gemm(\n            X = hidden_states, W = self.gate_up_proj, m_sizes = token_counts_by_expert\n        )\n        assert first_gemm.shape == (total_tokens, 2 * self.moe_intermediate_size)\n        intermediate = self.act_and_mul(first_gemm)\n        assert intermediate.shape == (total_tokens, self.moe_intermediate_size)\n        second_gemm = torch_grouped_gemm(\n            X = intermediate, W = self.down_proj, m_sizes = token_counts_by_expert\n        )\n        assert second_gemm.shape == (total_tokens, hidden_dim)\n\n        # Post-processing\n        # 1. Unpermute from expert order to token order\n        hidden_states_unpermute = unpermute(second_gemm, gather_indices)\n        assert hidden_states_unpermute.shape == (total_tokens, hidden_dim)\n\n        # 2. Merge topk weights\n        hidden_states = (\n            hidden_states_unpermute.view(num_tokens, self.top_k, hidden_dim)\n            * routing_weights[..., None]\n        )\n        hidden_states = hidden_states.sum(dim = 1)\n        assert hidden_states.shape == (num_tokens, hidden_dim)\n\n        hidden_states = hidden_states.view(batch_size, sequence_length, hidden_dim)\n        return GroupedGEMMResult(\n            token_counts_by_expert = token_counts_by_expert,\n            gather_indices = gather_indices,\n            topk_weights = routing_weights,\n            first_gemm = first_gemm,\n            intermediate = intermediate,\n            second_gemm = second_gemm,\n            hidden_states_unpermute = hidden_states_unpermute,\n            hidden_states = hidden_states,\n        ), router_logits\n\n\nclass Qwen3MoeFusedGroupedGEMMBlock(Qwen3MoeGroupedGEMMBlock):\n    def __init__(\n        self,\n        config: Qwen3MoeConfig,\n        gate: torch.Tensor,\n        gate_up_proj: torch.Tensor,\n        down_proj: torch.Tensor,\n        permute_x: bool = True,\n        permute_y: bool = True,\n        autotune: bool = True,\n        kernel_config_fwd: KernelConfigForward = None,\n        kernel_config_bwd_dW: KernelConfigBackward_dW = None,\n        kernel_config_bwd_dX: KernelConfigBackward_dX = None,\n        dW_only: bool = False,\n        dX_only: bool = False,\n    ):\n        super().__init__(config, gate, gate_up_proj, down_proj)\n        self.permute_x = permute_x\n        self.permute_y = permute_y\n        self.autotune = autotune\n        if not autotune:\n            assert (\n                kernel_config_fwd is not None\n                and kernel_config_bwd_dW is not None\n                and kernel_config_bwd_dX is not None\n            ), \"Kernel configs must be provided if autotune is False\"\n        self.kernel_config_fwd = kernel_config_fwd\n        self.kernel_config_bwd_dW = kernel_config_bwd_dW\n        self.kernel_config_bwd_dX = kernel_config_bwd_dX\n        self.dW_only = dW_only\n        self.dX_only = dX_only\n\n    @classmethod\n    def from_hf(\n        cls,\n        moe_block: Qwen3MoeSparseMoeBlock,\n        permute_x: bool = True,\n        permute_y: bool = True,\n        autotune: bool = True,\n        kernel_config_fwd: KernelConfigForward = None,\n        kernel_config_bwd_dW: KernelConfigBackward_dW = None,\n        kernel_config_bwd_dX: KernelConfigBackward_dX = None,\n        dW_only: bool = False,\n        dX_only: bool = False,\n    ):\n        config: Qwen3MoeConfig = moe_block.experts[0].config\n        gate, gate_up_proj, down_proj = Qwen3MoeGroupedGEMMBlock.extract_hf_weights(\n            moe_block\n        )\n        return cls(\n            config,\n            gate,\n            gate_up_proj,\n            down_proj,\n            permute_x = permute_x,\n            permute_y = permute_y,\n            autotune = autotune,\n            kernel_config_fwd = kernel_config_fwd,\n            kernel_config_bwd_dW = kernel_config_bwd_dW,\n            kernel_config_bwd_dX = kernel_config_bwd_dX,\n            dW_only = dW_only,\n            dX_only = dX_only,\n        )\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        batch_size, sequence_length, hidden_dim = hidden_states.shape\n        num_tokens = batch_size * sequence_length\n        total_tokens = num_tokens * self.top_k\n\n        hidden_states = hidden_states.view(-1, hidden_dim)\n\n        router_logits, routing_weights, selected_experts = self.run_router(\n            hidden_states\n        )\n        # Pre-processing\n        # 1. Compute tokens per expert and indices for gathering tokes from token order to expert order\n        # NOTE: these are auxiliary data structs which don't need to be recorded in autograd graph\n        token_counts_by_expert, gather_indices = (\n            self.get_token_counts_and_gather_indices(selected_experts)\n        )\n\n        # 2. permute_x -> permutation will be fused in prologue of first grouped gemm\n        if not self.permute_x:\n            hidden_states = permute(hidden_states, gather_indices, self.top_k)\n        # Start expert computation\n        hidden_states = grouped_gemm(\n            X = hidden_states,\n            W = self.gate_up_proj,\n            m_sizes = token_counts_by_expert,\n            gather_indices = gather_indices,\n            topk = self.top_k,\n            permute_x = self.permute_x,\n            permute_y = False,  # output of first grouped gemm should never be permuted\n            autotune = self.autotune,\n            kernel_config_fwd = self.kernel_config_fwd,\n            kernel_config_bwd_dW = self.kernel_config_bwd_dW,\n            kernel_config_bwd_dX = self.kernel_config_bwd_dX,\n            is_first_gemm = True,\n            dW_only = self.dW_only,\n            dX_only = self.dX_only,\n        )\n        hidden_states = self.act_and_mul(hidden_states)\n        hidden_states = grouped_gemm(\n            X = hidden_states,\n            W = self.down_proj,\n            m_sizes = token_counts_by_expert,\n            gather_indices = gather_indices,\n            topk = self.top_k,\n            permute_x = False,\n            permute_y = self.permute_y,\n            autotune = self.autotune,\n            kernel_config_fwd = self.kernel_config_fwd,\n            kernel_config_bwd_dW = self.kernel_config_bwd_dW,\n            kernel_config_bwd_dX = self.kernel_config_bwd_dX,\n            is_first_gemm = False,\n            dW_only = self.dW_only,\n            dX_only = self.dX_only,\n        )\n\n        # Post-processing\n        # 1. Unpermute from expert order to token order\n        if not self.permute_y:\n            hidden_states = unpermute(hidden_states, gather_indices)\n\n        # 2. Merge topk weights\n        hidden_states = (\n            hidden_states.view(num_tokens, self.top_k, hidden_dim)\n            * routing_weights[..., None]\n        )\n        hidden_states = hidden_states.sum(dim = 1)\n\n        hidden_states = hidden_states.view(batch_size, sequence_length, hidden_dim)\n        return hidden_states, router_logits\n"
  },
  {
    "path": "unsloth/kernels/moe/grouped_gemm/reference/moe_block.py",
    "content": "# SPDX-License-Identifier: GNU Affero General Public License v3.0\n# Copyright 2023-present the Unsloth team. All rights reserved.\n\nimport torch\nfrom transformers.models.qwen3_moe.configuration_qwen3_moe import Qwen3MoeConfig\nfrom transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock\n\nfrom ..interface import grouped_gemm\nfrom ..kernels.tuning import (\n    KernelConfigBackward_dW,\n    KernelConfigBackward_dX,\n    KernelConfigForward,\n)\nfrom .moe_ops import (\n    Qwen3MoeGroupedGEMMBlock,\n    permute,\n    unpermute,\n)\n\n\"\"\"\nReference implementation of MoE block using grouped gemm.\n\nThis is the same as the Qwen3MoeGroupedGEMMBlock but with triton grouped gemm in place of torch-native grouped gemm implementation.\n\nNOTE: This is NOT to be used for production as it contains many extra checks and saves all intermediate results for debugging.\n\"\"\"\n\n\nclass Qwen3MoeFusedGroupedGEMMBlock(Qwen3MoeGroupedGEMMBlock):\n    def __init__(\n        self,\n        config: Qwen3MoeConfig,\n        gate: torch.Tensor,\n        gate_up_proj: torch.Tensor,\n        down_proj: torch.Tensor,\n        permute_x: bool = True,\n        permute_y: bool = True,\n        autotune: bool = True,\n        kernel_config_fwd: KernelConfigForward = None,\n        kernel_config_bwd_dW: KernelConfigBackward_dW = None,\n        kernel_config_bwd_dX: KernelConfigBackward_dX = None,\n        dW_only: bool = False,\n        dX_only: bool = False,\n    ):\n        super().__init__(config, gate, gate_up_proj, down_proj)\n        self.permute_x = permute_x\n        self.permute_y = permute_y\n        self.autotune = autotune\n        if not autotune:\n            assert (\n                kernel_config_fwd is not None\n                and kernel_config_bwd_dW is not None\n                and kernel_config_bwd_dX is not None\n            ), \"Kernel configs must be provided if autotune is False\"\n        self.kernel_config_fwd = kernel_config_fwd\n        self.kernel_config_bwd_dW = kernel_config_bwd_dW\n        self.kernel_config_bwd_dX = kernel_config_bwd_dX\n        self.dW_only = dW_only\n        self.dX_only = dX_only\n\n    @classmethod\n    def from_hf(\n        cls,\n        moe_block: Qwen3MoeSparseMoeBlock,\n        permute_x: bool = True,\n        permute_y: bool = True,\n        autotune: bool = True,\n        kernel_config_fwd: KernelConfigForward = None,\n        kernel_config_bwd_dW: KernelConfigBackward_dW = None,\n        kernel_config_bwd_dX: KernelConfigBackward_dX = None,\n        dW_only: bool = False,\n        dX_only: bool = False,\n    ):\n        config: Qwen3MoeConfig = moe_block.experts[0].config\n        gate, gate_up_proj, down_proj = Qwen3MoeGroupedGEMMBlock.extract_hf_weights(\n            moe_block\n        )\n        return cls(\n            config,\n            gate,\n            gate_up_proj,\n            down_proj,\n            permute_x = permute_x,\n            permute_y = permute_y,\n            autotune = autotune,\n            kernel_config_fwd = kernel_config_fwd,\n            kernel_config_bwd_dW = kernel_config_bwd_dW,\n            kernel_config_bwd_dX = kernel_config_bwd_dX,\n            dW_only = dW_only,\n            dX_only = dX_only,\n        )\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        batch_size, sequence_length, hidden_dim = hidden_states.shape\n        num_tokens = batch_size * sequence_length\n        total_tokens = num_tokens * self.top_k\n\n        hidden_states = hidden_states.view(-1, hidden_dim)\n\n        router_logits, routing_weights, selected_experts = self.run_router(\n            hidden_states\n        )\n        # Pre-processing\n        # 1. Compute tokens per expert and indices for gathering tokes from token order to expert order\n        # NOTE: these are auxiliary data structs which don't need to be recorded in autograd graph\n        token_counts_by_expert, gather_indices = (\n            self.get_token_counts_and_gather_indices(selected_experts)\n        )\n\n        # 2. permute_x -> permutation will be fused in prologue of first grouped gemm\n        if not self.permute_x:\n            hidden_states = permute(hidden_states, gather_indices, self.top_k)\n        # Start expert computation\n        hidden_states = grouped_gemm(\n            X = hidden_states,\n            W = self.gate_up_proj,\n            m_sizes = token_counts_by_expert,\n            gather_indices = gather_indices,\n            topk = self.top_k,\n            permute_x = self.permute_x,\n            permute_y = False,  # output of first grouped gemm should never be permuted\n            autotune = self.autotune,\n            kernel_config_fwd = self.kernel_config_fwd,\n            kernel_config_bwd_dW = self.kernel_config_bwd_dW,\n            kernel_config_bwd_dX = self.kernel_config_bwd_dX,\n            is_first_gemm = True,\n            dW_only = self.dW_only,\n            dX_only = self.dX_only,\n        )\n        hidden_states = self.act_and_mul(hidden_states)\n        hidden_states = grouped_gemm(\n            X = hidden_states,\n            W = self.down_proj,\n            m_sizes = token_counts_by_expert,\n            gather_indices = gather_indices,\n            topk = self.top_k,\n            permute_x = False,\n            permute_y = self.permute_y,\n            autotune = self.autotune,\n            kernel_config_fwd = self.kernel_config_fwd,\n            kernel_config_bwd_dW = self.kernel_config_bwd_dW,\n            kernel_config_bwd_dX = self.kernel_config_bwd_dX,\n            is_first_gemm = False,\n            dW_only = self.dW_only,\n            dX_only = self.dX_only,\n        )\n\n        # Post-processing\n        # 1. Unpermute from expert order to token order\n        if not self.permute_y:\n            hidden_states = unpermute(hidden_states, gather_indices)\n\n        # 2. Merge topk weights\n        hidden_states = (\n            hidden_states.view(num_tokens, self.top_k, hidden_dim)\n            * routing_weights[..., None]\n        )\n        hidden_states = hidden_states.sum(dim = 1)\n\n        hidden_states = hidden_states.view(batch_size, sequence_length, hidden_dim)\n        return hidden_states, router_logits\n"
  },
  {
    "path": "unsloth/kernels/moe/grouped_gemm/reference/moe_ops.py",
    "content": "# SPDX-License-Identifier: GNU Affero General Public License v3.0\n# Copyright 2023-present the Unsloth team. All rights reserved.\n\nimport torch\nimport torch.nn.functional as F\n\n\ndef permute(X: torch.Tensor, gather_indices: torch.Tensor, topk: int):\n    \"\"\"\n    Scatters X to a new tensor with shape [total_tokens, hidden_dim] where total_tokens is num_tokens * topk,\n    permuting the tokens according to sorted_token_idx.\n\n    Helper for grouped gemm where hidden states need be ordered by expert.\n    X: [num_tokens, hidden_dim]\n    sorted_token_idx: [num_tokens * topk]\n    topk: int\n\n    Returns:\n        [total_tokens, hidden_dim]\n    \"\"\"\n    assert gather_indices.ndim == 1\n    X = X.view(-1, X.shape[-1])\n    # Shortcut for topk == 1\n    if topk == 1:\n        return X[gather_indices]\n\n    return X[gather_indices // topk]\n\n\ndef unpermute(X: torch.Tensor, gather_indices: torch.Tensor):\n    X = X.view(-1, X.shape[-1]) if X.ndim > 2 else X\n    unpermuted = torch.empty_like(X)\n    unpermuted.index_copy_(0, gather_indices, X)\n    return unpermuted.view_as(X)\n\n\ndef calculate_topk(\n    gating_output: torch.Tensor,\n    top_k: int,\n    use_sigmoid: bool,\n    renormalize: bool,\n    pre_act: bool = True,\n    post_act: bool = False,\n):\n    \"\"\"\n    If post_act is True, then activation function is run AFTER topk\n    If post_act is False, then activation function is run BEFORE topk\n\n    This is to align with triton_bench implementation (post_act) whereas most models use pre_act (e.g. llama4, deepseek)\n    \"\"\"\n    assert pre_act ^ post_act, \"only one of pre_act or post_act can be True\"\n\n    def _activation(gating_output: torch.Tensor):\n        if use_sigmoid:\n            scores = torch.sigmoid(gating_output.to(torch.float32)).to(\n                gating_output.dtype\n            )\n        else:\n            scores = F.softmax(gating_output.to(torch.float32), dim = 1).to(\n                gating_output.dtype\n            )\n\n        return scores\n\n    if pre_act:\n        scores = _activation(gating_output)\n    else:\n        scores = gating_output\n\n    topk_weights, topk_ids = torch.topk(scores, k = top_k, dim = 1)\n\n    if post_act:\n        topk_weights = _activation(topk_weights)\n\n    if renormalize:\n        topk_weights /= torch.sum(topk_weights, dim = -1, keepdim = True).to(\n            gating_output.dtype\n        )\n\n    return topk_weights, topk_ids\n\n\n@torch.no_grad()\ndef get_routing_indices(\n    selected_experts, num_experts, return_scatter_indices: bool = False\n):\n    \"\"\"\n    Returns:\n        token_counts_by_expert: [num_experts]\n        gather_indices: [num_tokens]\n        scatter_indices [Optional] (torch.Tensor):\n            Indices for unpermuting gathered inputs back to token order, shape ``(bs * seqlen * top_k,)``.\n    \"\"\"\n    # group tokens together by expert indices from 0 to num_experts and pass that to experts forward\n    token_counts_by_expert = torch.histc(\n        selected_experts.view(-1),\n        bins = num_experts,\n        min = 0,\n        max = num_experts,\n    )\n    # token_indices_experts_sorted shape (bs*slen*top_k,)\n    gather_indices = torch.argsort(selected_experts.view(-1), stable = True)\n    if return_scatter_indices:\n        scatter_indices = gather_indices.argsort()\n        return token_counts_by_expert, gather_indices, scatter_indices\n    else:\n        return token_counts_by_expert, gather_indices\n\n\ndef torch_grouped_gemm(X, W, m_sizes, transpose = True):\n    \"\"\"\n    X: [M, K] if forward, else [M, N]\n    W: [E, N, K]\n    m_sizes: [E]\n\n    Returns:\n        Y: [M, N] if forward, else [M, K]\n    \"\"\"\n    X = X.view(-1, X.shape[-1])\n    M, K = X.shape\n\n    assert m_sizes.ndim == 1\n    E = m_sizes.shape[0]\n\n    assert W.ndim == 3\n    assert W.shape[0] == E\n\n    N = W.shape[1]\n\n    result = torch.zeros((M, N), dtype = X.dtype, device = X.device)\n\n    m_start = 0\n    for g in range(E):\n        m_size = m_sizes[g]\n        if m_size > 0:\n            m_end = m_start + m_size\n\n            # Extract group input\n            # m_size x K\n            X_g = X[m_start:m_end]\n            # N x K\n            W_g = W[g]\n\n            # Y_g = X_g @ W_g.T -> [m_size, N]\n            W_g = W_g.T if transpose else W_g\n            Y_g = X_g @ W_g\n\n            result[m_start:m_end] = Y_g\n\n            m_start = m_end\n    return result\n"
  },
  {
    "path": "unsloth/kernels/moe/requirements.txt",
    "content": "torch\ngit+https://github.com/huggingface/transformers.git@main\npytest\npandas\nruff"
  },
  {
    "path": "unsloth/kernels/moe/tests/__init__.py",
    "content": ""
  },
  {
    "path": "unsloth/kernels/moe/tests/common.py",
    "content": "# SPDX-License-Identifier: GNU Affero General Public License v3.0\n# Copyright 2023-present the Unsloth team. All rights reserved.\n\nimport itertools\nfrom contextlib import contextmanager\nfrom dataclasses import dataclass, field\n\nimport torch\n\nfrom grouped_gemm.kernels.tuning import (\n    KernelConfig,\n    KernelConfigBackward_dW,\n    KernelConfigBackward_dX,\n    KernelConfigForward,\n    prune_kernel_configs_backward_dW,\n    prune_kernel_configs_backward_dX,\n    prune_kernel_configs_fwd,\n)\n\n\ndef print_delimiter(char = \"-\", length = 80):\n    print(char * length)\n\n\n@contextmanager\ndef delimiter_context():\n    print_delimiter()\n    yield\n    print_delimiter()\n\n\ndef make_inputs(M, N, K, E, topk, dtype, requires_grad = False):\n    X1 = (\n        torch.randn((M, K), device = \"cuda\", dtype = dtype, requires_grad = requires_grad)\n        / 10\n    )\n    X2 = (\n        torch.randn(\n            (M * topk, N), device = \"cuda\", dtype = dtype, requires_grad = requires_grad\n        )\n        / 10\n    )\n    W1 = (\n        torch.randn(\n            (E, 2 * N, K), device = \"cuda\", dtype = dtype, requires_grad = requires_grad\n        )\n        / 10\n    )\n    W2 = (\n        torch.randn((E, K, N), device = \"cuda\", dtype = dtype, requires_grad = requires_grad)\n        / 10\n    )\n    score = torch.randn((M, E), device = \"cuda\", dtype = dtype, requires_grad = requires_grad)\n    if requires_grad:\n        X1.retain_grad()\n        X2.retain_grad()\n        W1.retain_grad()\n        W2.retain_grad()\n        score.retain_grad()\n    return X1, X2, W1, W2, score\n\n\n@dataclass(kw_only = True)\nclass DataConfig:\n    seq_len: int\n    dtype: torch.dtype\n    device: str = \"cuda\"\n    bs: int = 1\n\n\n@dataclass(kw_only = True)\nclass ModelConfig:\n    hidden_size: int\n    intermediate_size: int\n    num_experts: int\n    topk: int\n    use_sigmoid: bool\n    renormalize: bool\n    pre_mul: bool = False\n    post_mul: bool = field(init = False)\n\n    def __post_init__(self):\n        self.post_mul = not self.pre_mul\n\n\n@dataclass(kw_only = True)\nclass GroupedGEMMTestConfig:\n    name: str = \"test\"\n    data_config: DataConfig\n    model_config: ModelConfig\n\n\nTOLERANCE = {\n    torch.bfloat16: (1e-3, 1e-3),\n    torch.float16: (1e-4, 1e-4),\n    torch.float32: (1e-5, 1e-5),\n}\n\n\n# from https://github.com/triton-lang/triton/blob/main/bench/triton_bench/testing.py\ndef assert_equal(ref, tri):\n    if isinstance(ref, torch.Tensor):\n        assert torch.all(ref == tri), f\"tensors not equal {ref} != {tri}\"\n    else:\n        assert ref == tri, f\"ref not equal to tri {ref} != {tri}\"\n\n\ndef assert_close(ref, tri, maxtol = None, rmstol = None, description = \"--\", verbose = True):\n    if tri.dtype.itemsize == 1:\n        ref_as_type = ref.to(tri.dtype)\n        if ref.dtype == tri.dtype:\n            assert torch.all(ref_as_type == tri)\n            return\n        ref = ref_as_type\n\n    if maxtol is None:\n        maxtol = 2e-2\n    if rmstol is None:\n        rmstol = 4e-3\n    \"\"\"\n    Compare reference values against obtained values.\n    \"\"\"\n\n    # cast to float32:\n    ref = ref.to(torch.float32).detach()\n    tri = tri.to(torch.float32).detach()\n    assert (\n        ref.shape == tri.shape\n    ), f\"Tensors must have same size {ref.shape = } {tri.shape = }\"\n\n    # deal with infinite elements:\n    inf_mask_ref = torch.isinf(ref)\n    inf_mask_tri = torch.isinf(tri)\n    assert torch.equal(\n        inf_mask_ref, inf_mask_tri\n    ), \"Tensor must have same infinite elements\"\n    refn = torch.where(inf_mask_ref, 0, ref)\n    trin = torch.where(inf_mask_tri, 0, tri)\n\n    # normalise so that RMS calculation doesn't overflow:\n    eps = 1.0e-30\n    multiplier = 1.0 / (torch.max(torch.abs(refn)) + eps)\n    refn *= multiplier\n    trin *= multiplier\n\n    ref_rms = torch.sqrt(torch.square(refn).mean()) + eps\n\n    rel_err = torch.abs(refn - trin) / torch.maximum(ref_rms, torch.abs(refn))\n    max_err = torch.max(rel_err).item()\n    rms_err = torch.sqrt(torch.square(rel_err).mean()).item()\n\n    if verbose:\n        print(\n            \"%s maximum relative error = %s (threshold = %s)\"\n            % (description, max_err, maxtol)\n        )\n        print(\n            \"%s RMS relative error = %s (threshold = %s)\"\n            % (description, rms_err, rmstol)\n        )\n\n    if max_err > maxtol:\n        bad_idxs = torch.nonzero(rel_err > maxtol)\n        num_nonzero = bad_idxs.size(0)\n        bad_idxs = bad_idxs[:1000]\n        print(\n            \"%d / %d mismatched elements (shape = %s) at coords %s\"\n            % (num_nonzero, rel_err.numel(), tuple(rel_err.shape), bad_idxs.tolist())\n        )\n\n        bad_idxs = bad_idxs.unbind(-1)\n        print(\"ref values: \", ref[*bad_idxs].cpu())\n        print(\"tri values: \", tri[*bad_idxs].cpu())\n\n    assert max_err <= maxtol\n    assert rms_err <= rmstol\n\n\ndef assert_indx_equal(ref, tri):\n    assert_equal(ref, tri[: len(ref)])\n    assert torch.all(tri[len(ref) :] == -1)\n\n\ndef get_kernel_test_configs(\n    BLOCK_SIZE_M = 32,\n    BLOCK_SIZE_N = 32,\n    BLOCK_SIZE_K = 32,\n    num_warps = 4,\n    num_stages = 2,\n) -> list[KernelConfig]:\n    configs_fwd = []\n    configs_bwd_dX = []\n    configs_bwd_dW = []\n\n    for permute_x in [False, True]:\n        for permute_y in [False, True]:\n            for use_tma_load_w in [True, False]:\n                for use_tma_load_x in [True, False]:\n                    for use_tma_store in [True, False]:\n                        configs_fwd.append(\n                            KernelConfigForward(\n                                BLOCK_SIZE_M = BLOCK_SIZE_M,\n                                BLOCK_SIZE_N = BLOCK_SIZE_N,\n                                BLOCK_SIZE_K = BLOCK_SIZE_K,\n                                num_warps = num_warps,\n                                num_stages = num_stages,\n                                use_tma_load_w = use_tma_load_w,\n                                use_tma_load_x = use_tma_load_x,\n                                use_tma_store = use_tma_store,\n                                permute_x = permute_x,\n                                permute_y = permute_y,\n                            )\n                        )\n                        configs_bwd_dX.append(\n                            KernelConfigBackward_dX(\n                                BLOCK_SIZE_M = BLOCK_SIZE_M,\n                                BLOCK_SIZE_N = BLOCK_SIZE_N,\n                                BLOCK_SIZE_K = BLOCK_SIZE_K,\n                                num_warps = num_warps,\n                                num_stages = num_stages,\n                                use_tma_load_dy = use_tma_load_x,\n                                use_tma_load_w = use_tma_load_w,\n                                permute_x = permute_x,\n                                permute_y = permute_y,\n                                use_tma_store = use_tma_store,\n                            )\n                        )\n                        configs_bwd_dW.append(\n                            KernelConfigBackward_dW(\n                                BLOCK_SIZE_M = BLOCK_SIZE_M,\n                                BLOCK_SIZE_N = BLOCK_SIZE_N,\n                                BLOCK_SIZE_K = BLOCK_SIZE_K,\n                                num_warps = num_warps,\n                                num_stages = num_stages,\n                                use_tma_load_dy = use_tma_load_w,\n                                use_tma_load_x = use_tma_load_x,\n                                permute_x = permute_x,\n                                permute_y = permute_y,\n                                use_tma_store = use_tma_store,\n                            )\n                        )\n    configs_fwd = prune_kernel_configs_fwd(configs_fwd)\n    configs_bwd_dX = prune_kernel_configs_backward_dX(configs_bwd_dX)\n    configs_bwd_dW = prune_kernel_configs_backward_dW(configs_bwd_dW)\n    return configs_fwd, configs_bwd_dX, configs_bwd_dW\n\n\ndef remove_feature_flags(\n    kernel_configs: list[KernelConfig],\n    permute_x: bool = True,\n    permute_y: bool = True,\n    tma_loads: bool = True,\n    tma_store: bool = True,\n):\n    pruned_configs = []\n    for config in kernel_configs:\n        # Remove permute flags first:\n        if permute_x and config.permute_x:\n            continue\n        if permute_y and config.permute_y:\n            continue\n        if tma_loads:\n            if isinstance(config, KernelConfigForward):\n                if config.use_tma_load_w or config.use_tma_load_x:\n                    continue\n            if isinstance(config, KernelConfigBackward_dX):\n                if config.use_tma_load_dy or config.use_tma_load_w:\n                    continue\n            if isinstance(config, KernelConfigBackward_dW):\n                if config.use_tma_load_dy or config.use_tma_load_x:\n                    continue\n        if tma_store:\n            if config.use_tma_store:\n                continue\n        pruned_configs.append(config)\n    return pruned_configs\n\n\n# Test Configs\n\nTOPK = [1, 4]\nNUM_EXPERTS = [4, 16]\n\nTEST_MODEL_SIZES = [\n    (32, 32),  # Debug\n    (128, 128),  # Small\n    (512, 512),  # Medium\n]\n\nSMALL_MODEL_CONFIGS = [\n    ModelConfig(\n        topk = topk,\n        num_experts = num_experts,\n        hidden_size = model_size[0],\n        intermediate_size = model_size[1],\n        use_sigmoid = False,\n        renormalize = False,\n    )\n    for topk, num_experts, model_size in itertools.product(\n        TOPK, NUM_EXPERTS, TEST_MODEL_SIZES\n    )\n]\nLLAMA_MODEL_CONFIG = ModelConfig(\n    topk = 1,\n    num_experts = 16,\n    hidden_size = 5120,\n    intermediate_size = 8192,\n    use_sigmoid = True,\n    renormalize = False,\n)\nQWEN_MODEL_CONFIG = ModelConfig(\n    topk = 8,\n    num_experts = 128,\n    hidden_size = 2048,\n    intermediate_size = 768,\n    use_sigmoid = False,\n    renormalize = False,\n)\n\nSEQLENS = [128, 1024]\nDTYPE = [torch.bfloat16]\n\nDATA_CONFIGS = [\n    DataConfig(seq_len = seq_len, dtype = dtype)\n    for seq_len, dtype in itertools.product(SEQLENS, DTYPE)\n]\nKERNEL_CONFIGS_FWD, KERNEL_CONFIGS_BWD_dX, KERNEL_CONFIGS_BWD_dW = (\n    get_kernel_test_configs()\n)\n\nif __name__ == \"__main__\":\n    print(\n        KERNEL_CONFIGS_BWD_dX[0].to_string(\n            include_tuning_params = False, include_tma = False\n        )\n    )\n"
  },
  {
    "path": "unsloth/kernels/moe/tests/moe_utils.py",
    "content": "# SPDX-License-Identifier: GNU Affero General Public License v3.0\n# Copyright 2023-present the Unsloth team. All rights reserved.\n\nfrom dataclasses import dataclass, fields\n\nimport torch\nimport torch.nn as nn\nfrom huggingface_hub import HfApi\nfrom huggingface_hub.utils import _safetensors\nfrom transformers.models.qwen3_moe.configuration_qwen3_moe import Qwen3MoeConfig\nfrom transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock\n\nfrom grouped_gemm.interface import grouped_gemm\nfrom grouped_gemm.kernels.tuning import (\n    KernelConfigBackward_dW,\n    KernelConfigBackward_dX,\n    KernelConfigForward,\n)\nfrom grouped_gemm.reference.layers.qwen3_moe import (\n    GroupedGEMMResult,\n    Qwen3MoeGroupedGEMMBlock,\n)\nfrom grouped_gemm.reference.moe_ops import permute, unpermute\n\n\ndef rebind_experts_to_shared_buffer(\n    moe_block: Qwen3MoeSparseMoeBlock, config: Qwen3MoeConfig\n):\n    num_experts = config.num_experts\n    hidden_size = config.hidden_size\n    interm_size = config.moe_intermediate_size\n    device = moe_block.experts[0].down_proj.weight.device\n    dtype = moe_block.experts[0].down_proj.weight.dtype\n\n    buffer_up = torch.empty(\n        num_experts, interm_size, hidden_size, device = device, dtype = dtype\n    )\n    buffer_gate = torch.empty(\n        num_experts, interm_size, hidden_size, device = device, dtype = dtype\n    )\n    buffer_down = torch.empty(\n        num_experts, hidden_size, interm_size, device = device, dtype = dtype\n    )\n\n    # Step 2: Copy existing expert weights into buffers\n    for i, expert in enumerate(moe_block.experts):\n        buffer_up[i].copy_(expert.up_proj.weight.data)\n        buffer_gate[i].copy_(expert.gate_proj.weight.data)\n        buffer_down[i].copy_(expert.down_proj.weight.data)\n\n    # Step 3: Rebind expert weights to views in shared buffer\n    for i, expert in enumerate(moe_block.experts):\n        expert.up_proj.weight = torch.nn.Parameter(buffer_up[i])\n        expert.gate_proj.weight = torch.nn.Parameter(buffer_gate[i])\n        expert.down_proj.weight = torch.nn.Parameter(buffer_down[i])\n\n    return buffer_up, buffer_gate, buffer_down\n\n\ndef get_expert_metadata(model_id: str):\n    api = HfApi()\n    metadata: _safetensors.SafetensorsRepoMetadata = api.get_safetensors_metadata(\n        model_id\n    )\n    return metadata.files_metadata\n\n\ndef clone_experts(\n    moe_block: Qwen3MoeSparseMoeBlock, config: Qwen3MoeConfig, copy: bool = True\n):\n    down_projs = torch.empty(\n        config.num_experts, config.hidden_size, config.moe_intermediate_size\n    )\n    up_projs = torch.empty(\n        config.num_experts, config.moe_intermediate_size, config.hidden_size\n    )\n    gate_projs = torch.empty(\n        config.num_experts, config.moe_intermediate_size, config.hidden_size\n    )\n    for expert_idx, expert in enumerate(moe_block.experts):\n        down_projs[expert_idx].copy_(expert.down_proj.weight.data)\n        up_projs[expert_idx].copy_(expert.up_proj.weight.data)\n        gate_projs[expert_idx].copy_(expert.gate_proj.weight.data)\n    return gate_projs, up_projs, down_projs\n\n\n@dataclass\nclass ForwardResult:\n    output: torch.Tensor\n    router_logits: torch.Tensor\n    X: torch.Tensor\n    # When using grouped gemm MoE implementation to additional debugging / checking of intermediate results\n    grouped_gemm_result: GroupedGEMMResult = None\n\n\n@dataclass\nclass BackwardResult:\n    X_grad: torch.Tensor\n    gate_grad: torch.Tensor\n    gate_proj_grad: torch.Tensor\n    up_proj_grad: torch.Tensor\n    down_proj_grad: torch.Tensor\n\n\ndef check_down_proj_grad(\n    moe_block: Qwen3MoeSparseMoeBlock,\n    grouped_gemm_block: Qwen3MoeGroupedGEMMBlock,\n    atol: float,\n    rtol: float,\n):\n    for i, expert in enumerate(moe_block.experts):\n        ref_grad = expert.down_proj.weight.grad\n        assert ref_grad is not None\n        test_grad = grouped_gemm_block.down_proj.grad[i]\n        assert test_grad is not None\n        diff = (ref_grad - test_grad).abs().max()\n        if not torch.allclose(ref_grad, test_grad, atol = atol, rtol = rtol):\n            print(f\"expert {i} down_proj_grad_diff: {diff.detach().cpu().item():.6f}\")\n\n\ndef check_gate_up_proj_grad(\n    moe_block: Qwen3MoeSparseMoeBlock,\n    grouped_gemm_block: Qwen3MoeGroupedGEMMBlock,\n    atol: float,\n    rtol: float,\n):\n    moe_intermediate_size = grouped_gemm_block.moe_intermediate_size\n    for i, expert in enumerate(moe_block.experts):\n        ref_gate_proj_grad = expert.gate_proj.weight.grad\n        ref_up_proj_grad = expert.up_proj.weight.grad\n        assert ref_gate_proj_grad is not None\n        assert ref_up_proj_grad is not None\n\n        # Extract gradients\n        test_gate_proj_grad = grouped_gemm_block.gate_up_proj.grad[\n            i, :moe_intermediate_size\n        ]\n        test_up_proj_grad = grouped_gemm_block.gate_up_proj.grad[\n            i, moe_intermediate_size:\n        ]\n        assert test_gate_proj_grad is not None\n        assert test_up_proj_grad is not None\n\n        # Sanity check shapes\n        assert (\n            ref_gate_proj_grad.shape == test_gate_proj_grad.shape\n        ), f\"{ref_gate_proj_grad.shape} != {test_gate_proj_grad.shape}\"\n        assert (\n            ref_up_proj_grad.shape == test_up_proj_grad.shape\n        ), f\"{ref_up_proj_grad.shape} != {test_up_proj_grad.shape}\"\n\n        # Check gradients\n        diff = (ref_gate_proj_grad - test_gate_proj_grad).abs().max()\n        if not torch.allclose(\n            ref_gate_proj_grad, test_gate_proj_grad, atol = atol, rtol = rtol\n        ):\n            print(f\"expert {i} gate_proj_grad_diff: {diff.detach().cpu().item():.6f}\")\n        diff = (ref_up_proj_grad - test_up_proj_grad).abs().max()\n        if not torch.allclose(\n            ref_up_proj_grad, test_up_proj_grad, atol = atol, rtol = rtol\n        ):\n            print(f\"expert {i} up_proj_grad_diff: {diff.detach().cpu().item():.6f}\")\n\n\ndef check_gate_grad(\n    moe_block: Qwen3MoeSparseMoeBlock,\n    grouped_gemm_block: Qwen3MoeGroupedGEMMBlock,\n    atol: float,\n    rtol: float,\n):\n    ref_grad = moe_block.gate.weight.grad\n    assert ref_grad is not None\n    test_grad = grouped_gemm_block.gate.grad\n    assert test_grad is not None\n    diff = (ref_grad - test_grad).abs().max()\n    if not torch.allclose(ref_grad, test_grad, atol = atol, rtol = rtol):\n        print(f\"gate_grad_diff: {diff.detach().cpu().item():.6f}\")\n\n\ndef check_wgrad(\n    moe_block: Qwen3MoeSparseMoeBlock,\n    grouped_gemm_block: Qwen3MoeGroupedGEMMBlock,\n    atol: float,\n    rtol: float,\n):\n    check_down_proj_grad(moe_block, grouped_gemm_block, atol, rtol)\n    check_gate_up_proj_grad(moe_block, grouped_gemm_block, atol, rtol)\n    check_gate_grad(moe_block, grouped_gemm_block, atol, rtol)\n\n\ndef check_tensor_allclose(\n    X_ref: torch.Tensor,\n    X_test: torch.Tensor,\n    atol: float,\n    rtol: float,\n    name: str,\n    verbose: bool = False,\n):\n    diff = (X_ref - X_test).abs().max()\n    if verbose:\n        print(f\"{name} diff: {diff.detach().cpu().item():.6f}\")\n    assert torch.allclose(\n        X_ref, X_test, atol = atol, rtol = rtol\n    ), f\"{name} diff: {diff.detach().cpu().item():.6f}\"\n\n\ndef check_expert_grads(\n    ref_result: BackwardResult,\n    test_result: BackwardResult,\n    atol: float,\n    rtol: float,\n    verbose: bool = False,\n):\n    fields_to_check = [f.name for f in fields(BackwardResult) if \"proj\" in f.name]\n    assert len(fields_to_check) == 3\n\n    for field in fields_to_check:\n        ref_grads = getattr(ref_result, field)\n        test_grads = getattr(test_result, field)\n        assert (\n            ref_grads.shape == test_grads.shape\n        ), f\"{field}: {ref_grads.shape} != {test_grads.shape}\"\n\n        # Test each expert\n        for i in range(ref_grads.shape[0]):\n            ref_grad = ref_grads[i]\n            test_grad = test_grads[i]\n            diff = (ref_grad - test_grad).abs().max()\n            assert torch.allclose(\n                ref_grad, test_grad, atol = atol, rtol = rtol\n            ), f\"{field}[{i}] diff: {diff.detach().cpu().item():.6f}\"\n\n        # Test all experts\n        diff = (ref_grads - test_grads).abs().max()\n        if verbose:\n            print(f\"{field} diff: {diff.detach().cpu().item():.6f}\")\n        assert torch.allclose(\n            ref_grads, test_grads, atol = atol, rtol = rtol\n        ), f\"{field} diff: {diff.detach().cpu().item():.6f}\"\n\n\ndef check_grads(\n    ref_result: BackwardResult,\n    test_result: BackwardResult,\n    atol: float,\n    rtol: float,\n    verbose: bool = False,\n):\n    check_tensor_allclose(\n        ref_result.X_grad, test_result.X_grad, atol, rtol, \"X.grad\", verbose\n    )\n    check_tensor_allclose(\n        ref_result.gate_grad, test_result.gate_grad, atol, rtol, \"gate.grad\", verbose\n    )\n    check_expert_grads(ref_result, test_result, atol, rtol, verbose)\n\n\ndef check_fwd(\n    ref_result: ForwardResult,\n    test_result: ForwardResult,\n    atol: float,\n    rtol: float,\n    verbose: bool = False,\n):\n    # First check hidden states (output)\n    ref_output = ref_result.output\n    test_output = test_result.output\n    diff = (ref_output - test_output).abs().max()\n    if verbose:\n        print(f\"output diff: {diff.detach().cpu().item():.6f}\")\n    assert torch.allclose(\n        ref_output, test_output, atol = atol, rtol = rtol\n    ), f\"output diff: {diff.detach().cpu().item():.6f}\"\n\n    # Check router logits\n    ref_router_logits = ref_result.router_logits\n    test_router_logits = test_result.router_logits\n    diff = (ref_router_logits - test_router_logits).abs().max()\n    if verbose:\n        print(f\"router_logits diff: {diff.detach().cpu().item():.6f}\")\n    assert torch.allclose(\n        ref_router_logits, test_router_logits, atol = atol, rtol = rtol\n    ), f\"router_logits diff: {diff.detach().cpu().item():.6f}\"\n\n\ndef check_grouped_gemm_results(\n    grouped_result: GroupedGEMMResult,\n    fused_result: GroupedGEMMResult,\n    permute_y: bool,\n    atol: float,\n    rtol: float,\n    verbose: bool = False,\n):\n    for field in fields(GroupedGEMMResult):\n        ref_value = getattr(grouped_result, field.name)\n        test_value = getattr(fused_result, field.name)\n        diff = (ref_value - test_value).abs().max()\n\n        # second_gemm in torch grouped gemm is not yet unpermuted so comparing the fused unpermuted second_gemm will result in error\n        # instead the hidden_states_unpermute should match since hidden_states_unpermute for the fused result is the same as second_gemm\n        if field.name == \"second_gemm\" and permute_y:\n            continue\n\n        if verbose:\n            print(f\"{field.name} diff: {diff.detach().cpu().item():.6f}\")\n\n        assert torch.allclose(\n            ref_value, test_value, atol = atol, rtol = rtol\n        ), f\"{field.name} diff: {diff.detach().cpu().item():.6f}\"\n\n\ndef run_forward(model: nn.Module, X: torch.Tensor, is_grouped_gemm: bool = False):\n    X = X.detach().clone().requires_grad_(True)\n    output, router_logits = model(X)\n    if is_grouped_gemm:\n        result = ForwardResult(\n            output = output.hidden_states,\n            router_logits = router_logits,\n            X = X,\n            grouped_gemm_result = output,\n        )\n    else:\n        result = ForwardResult(output = output, router_logits = router_logits, X = X)\n    return result\n\n\ndef run_backward(\n    model: nn.Module, grad_output: torch.Tensor, output: torch.Tensor, X: torch.Tensor\n):\n    output.backward(grad_output)\n    assert X.grad is not None\n    for name, param in model.named_parameters():\n        assert param.grad is not None, f\"{name} grad is None\"\n    if isinstance(model, Qwen3MoeSparseMoeBlock):\n        gate_grad = model.gate.weight.grad\n        gate_proj_grad = torch.stack(\n            [expert.gate_proj.weight.grad for expert in model.experts]\n        )\n        up_proj_grad = torch.stack(\n            [expert.up_proj.weight.grad for expert in model.experts]\n        )\n        down_proj_grad = torch.stack(\n            [expert.down_proj.weight.grad for expert in model.experts]\n        )\n    elif isinstance(model, Qwen3MoeGroupedGEMMBlock):\n        gate_grad = model.gate.grad\n        gate_proj_grad, up_proj_grad = model.gate_up_proj.grad.chunk(2, dim = 1)\n        down_proj_grad = model.down_proj.grad\n    else:\n        raise ValueError(f\"Unsupported model type: {type(model)}\")\n    return BackwardResult(\n        X_grad = X.grad,\n        gate_grad = gate_grad,\n        gate_proj_grad = gate_proj_grad,\n        up_proj_grad = up_proj_grad,\n        down_proj_grad = down_proj_grad,\n    )\n\n\nclass Qwen3MoeFusedGroupedGEMMBlock(Qwen3MoeGroupedGEMMBlock):\n    \"\"\"\n    Reference implementation of MoE block using grouped gemm.\n\n    This is the same as the Qwen3MoeGroupedGEMMBlock but with triton grouped gemm in place of torch-native grouped gemm implementation.\n\n    NOTE: This is NOT to be used for production as it contains many extra checks and saves all intermediate results for debugging.\n    See grouped_gemm/reference/moe_block.py for a cleaner implementation.\n    \"\"\"\n\n    def __init__(\n        self,\n        config: Qwen3MoeConfig,\n        gate: torch.Tensor,\n        gate_up_proj: torch.Tensor,\n        down_proj: torch.Tensor,\n        permute_x: bool = False,\n        permute_y: bool = False,\n        autotune: bool = True,\n        kernel_config_fwd: KernelConfigForward = None,\n        kernel_config_bwd_dW: KernelConfigBackward_dW = None,\n        kernel_config_bwd_dX: KernelConfigBackward_dX = None,\n    ):\n        super().__init__(config, gate, gate_up_proj, down_proj)\n        self.permute_x = permute_x\n        self.permute_y = permute_y\n        self.autotune = autotune\n        if not autotune:\n            assert (\n                kernel_config_fwd is not None\n                and kernel_config_bwd_dW is not None\n                and kernel_config_bwd_dX is not None\n            ), \"Kernel configs must be provided if autotune is False\"\n        self.kernel_config_fwd = kernel_config_fwd\n        self.kernel_config_bwd_dW = kernel_config_bwd_dW\n        self.kernel_config_bwd_dX = kernel_config_bwd_dX\n\n    @classmethod\n    def from_hf(\n        cls,\n        moe_block: Qwen3MoeSparseMoeBlock,\n        permute_x: bool = False,\n        permute_y: bool = False,\n        autotune: bool = True,\n        kernel_config_fwd: KernelConfigForward = None,\n        kernel_config_bwd_dW: KernelConfigBackward_dW = None,\n        kernel_config_bwd_dX: KernelConfigBackward_dX = None,\n    ):\n        config: Qwen3MoeConfig = moe_block.experts[0].config\n        gate, gate_up_proj, down_proj = Qwen3MoeGroupedGEMMBlock.extract_hf_weights(\n            moe_block\n        )\n        return cls(\n            config,\n            gate,\n            gate_up_proj,\n            down_proj,\n            permute_x = permute_x,\n            permute_y = permute_y,\n            autotune = autotune,\n            kernel_config_fwd = kernel_config_fwd,\n            kernel_config_bwd_dW = kernel_config_bwd_dW,\n            kernel_config_bwd_dX = kernel_config_bwd_dX,\n        )\n\n    def forward(self, hidden_states: torch.Tensor, debug: bool = False) -> torch.Tensor:\n        batch_size, sequence_length, hidden_dim = hidden_states.shape\n        num_tokens = batch_size * sequence_length\n        total_tokens = num_tokens * self.top_k\n\n        hidden_states = hidden_states.view(-1, hidden_dim)\n\n        router_logits, routing_weights, selected_experts = self.run_router(\n            hidden_states\n        )\n        # Pre-processing\n        # 1. Compute tokens per expert and indices for gathering tokes from token order to expert order\n        # NOTE: these are auxiliary data structs which don't need to be recorded in autograd graph\n        token_counts_by_expert, gather_indices = (\n            self.get_token_counts_and_gather_indices(selected_experts)\n        )\n\n        # 2. permute_x -> permutation will be fused in prologue of first grouped gemm\n        if not self.permute_x:\n            hidden_states = permute(hidden_states, gather_indices, self.top_k)\n            assert hidden_states.shape == (total_tokens, hidden_dim)\n\n        # Start expert computation\n        first_gemm = grouped_gemm(\n            X = hidden_states,\n            W = self.gate_up_proj,\n            m_sizes = token_counts_by_expert,\n            gather_indices = gather_indices,\n            topk = self.top_k,\n            permute_x = self.permute_x,\n            permute_y = False,  # output of first grouped gemm should never be permuted\n            autotune = self.autotune,\n            kernel_config_fwd = self.kernel_config_fwd,\n            kernel_config_bwd_dW = self.kernel_config_bwd_dW,\n            kernel_config_bwd_dX = self.kernel_config_bwd_dX,\n            is_first_gemm = True,\n        )\n        assert first_gemm.shape == (total_tokens, 2 * self.moe_intermediate_size)\n        intermediate = self.act_and_mul(first_gemm)\n        assert intermediate.shape == (total_tokens, self.moe_intermediate_size)\n        second_gemm = grouped_gemm(\n            X = intermediate,\n            W = self.down_proj,\n            m_sizes = token_counts_by_expert,\n            gather_indices = gather_indices,\n            topk = self.top_k,\n            permute_x = False,\n            permute_y = self.permute_y,\n            autotune = self.autotune,\n            kernel_config_fwd = self.kernel_config_fwd,\n            kernel_config_bwd_dW = self.kernel_config_bwd_dW,\n            kernel_config_bwd_dX = self.kernel_config_bwd_dX,\n            is_first_gemm = False,\n        )\n        assert second_gemm.shape == (total_tokens, hidden_dim)\n\n        # Post-processing\n        # 1. Unpermute from expert order to token order\n        if not self.permute_y:\n            hidden_states_unpermute = unpermute(second_gemm, gather_indices)\n            assert hidden_states_unpermute.shape == (total_tokens, hidden_dim)\n        else:\n            hidden_states_unpermute = second_gemm\n\n        # 2. Merge topk weights\n        hidden_states = (\n            hidden_states_unpermute.view(num_tokens, self.top_k, hidden_dim)\n            * routing_weights[..., None]\n        )\n        hidden_states = hidden_states.sum(dim = 1)\n        assert hidden_states.shape == (num_tokens, hidden_dim)\n\n        hidden_states = hidden_states.view(batch_size, sequence_length, hidden_dim)\n        return GroupedGEMMResult(\n            token_counts_by_expert = token_counts_by_expert,\n            gather_indices = gather_indices,\n            topk_weights = routing_weights,\n            first_gemm = first_gemm,\n            intermediate = intermediate,\n            second_gemm = second_gemm,\n            hidden_states_unpermute = hidden_states_unpermute,\n            hidden_states = hidden_states,\n        ), router_logits\n"
  },
  {
    "path": "unsloth/kernels/moe/tests/run_qwen3_moe_tests.sh",
    "content": "#!/bin/bash\n\nset -euo pipefail\n\nSEQLENS=(1024)  \nDTYPES=(bfloat16)\nPERMUTE_X=(false true)\nPERMUTE_Y=(false true)\nAUTOTUNE=(false true)\n\nfor SEQLEN in \"${SEQLENS[@]}\"; do\n    for DTYPE in \"${DTYPES[@]}\"; do\n        for PX in \"${PERMUTE_X[@]}\"; do\n            for PY in \"${PERMUTE_Y[@]}\"; do\n                for AT in \"${AUTOTUNE[@]}\"; do\n\n                    ARGS=()\n                    [[ \"$PX\" == \"true\" ]] && ARGS+=(\"--permute_x\")\n                    [[ \"$PY\" == \"true\" ]] && ARGS+=(\"--permute_y\")\n                    [[ \"$AT\" == \"true\" ]] && ARGS+=(\"--autotune\")\n\n                    ARGS+=(--seqlen \"$SEQLEN\" --dtype \"$DTYPE\")\n\n                    echo \"Running with args: ${ARGS[*]}\"\n                    if ! python -m tests.test_qwen3_moe \"${ARGS[@]}\"; then\n                        echo \"❌ Test failed with args: --permute_x=$PX --permute_y=$PY --autotune=$AT --seqlen=$SEQLEN --dtype=$DTYPE\" >&2\n                    else\n                        echo \"✅ Test passed with args: --permute_x=$PX --permute_y=$PY --autotune=$AT --seqlen=$SEQLEN --dtype=$DTYPE\"\n                    fi\n\n                done\n            done\n        done\n    done\ndone\n"
  },
  {
    "path": "unsloth/kernels/moe/tests/test_grouped_gemm.py",
    "content": "# SPDX-License-Identifier: GNU Affero General Public License v3.0\n# Copyright 2023-present the Unsloth team. All rights reserved.\n\nfrom dataclasses import asdict\n\nimport pytest\nimport torch\n\nfrom grouped_gemm.interface import (\n    grouped_gemm,\n    grouped_gemm_dW,\n    grouped_gemm_dX,\n    grouped_gemm_forward,\n)\nfrom grouped_gemm.kernels.tuning import (\n    KernelConfig,\n    KernelConfigBackward_dW,\n    KernelConfigBackward_dX,\n    KernelConfigForward,\n)\nfrom grouped_gemm.reference.moe_ops import (\n    calculate_topk,\n    get_routing_indices,\n    permute,\n    torch_grouped_gemm,\n    unpermute,\n)\n\nfrom .common import (\n    DATA_CONFIGS,\n    KERNEL_CONFIGS_FWD,\n    LLAMA_MODEL_CONFIG,\n    QWEN_MODEL_CONFIG,\n    SMALL_MODEL_CONFIGS,\n    TOLERANCE,\n    DataConfig,\n    KERNEL_CONFIGS_BWD_dW,\n    KERNEL_CONFIGS_BWD_dX,\n    ModelConfig,\n    make_inputs,\n)\n\nSEED = 0\n\n\n# Only certain combinations of permute_x, permute_y, use_W1 are valid.\n# use_W1 => first grouped GEMM in a fused MoE MLP\n# use_W2 => second grouped GEMM in a fused MoE MLP\n# permute_x => permute the input to the grouped GEMM, only done for the first grouped GEMM\n# permute_y => permute the output of the grouped GEMM, only done for the second grouped GEMM\n# fuse_mul_post => fuse the multiplication of topk weights in the epilogue of the second grouped GEMM; only used for inference, not currently tested\ndef check_valid_config(\n    permute_x, permute_y, use_W1, fuse_mul_post = False, is_backward = False, verbose = False\n):\n    use_W2 = not use_W1\n\n    if permute_x and permute_y:\n        if verbose:\n            print(f\"Skipping test: {permute_x = } {permute_y = }\")\n        return False\n    if use_W2 and permute_x:\n        if verbose:\n            print(f\"Skipping test: {permute_x = } {use_W2 = }\")\n        return False\n    if use_W1 and permute_y:\n        if verbose:\n            print(f\"Skipping test: {permute_y = } {use_W1 = }\")\n        return False\n    if fuse_mul_post and use_W1:\n        if verbose:\n            print(f\"Skipping test: {fuse_mul_post = } {use_W1 = }\")\n        return False\n    if is_backward and fuse_mul_post:\n        if verbose:\n            print(f\"Skipping test: {fuse_mul_post = } {is_backward = }\")\n        return False\n\n    return True\n\n\n\"\"\"\ngrouped_gemm_forward\n\npermute_x: typically in a fused MoE MLP, we can fuse the permutation of hidden states (X) from token order to expert grouped order needed for grouped GEMM by directly loading X in permuted order rather than launching a separate permutation kernel.\npermute_y: We can also fuse the unpermutation of tokens after the second grouped GEMM to restore to original token order.  This is fused into the second grouped GEMM by directly storing the output in unpermuted order.\nfuse_mul: We can also fuse the multiplication of topk weights in the epilogue of the second grouped GEMM.  Note that this is only supported for inference and not training, although this may change in the future.\nuse_W1 test the shapes for the first grouped GEMM in a fused MoE MLP\nuse_W2 = `not use_W1` tests the shapes for the second grouped GEMM in a fused MoE MLP\n\nGiven the above, only certain combinations are valid:\n- use_W1 is always False when permute_y is True since we only permute the second grouped GEMM\n- use_W2 is always False when permute_x is True since we only permute the first grouped GEMM\n- only one of permute_x and permute_y can be True\n- fuse_mul is only True if permute_y is also True\n\nSee `check_valid_config` for more details.\n\"\"\"\n\n\ndef _test_grouped_gemm_forward(\n    data_config: DataConfig,\n    model_config: ModelConfig,\n    permute_x: bool,\n    permute_y: bool,\n    use_W1: bool,  # W1 -> first grouped GEMM in a fused MoE MLP, not W1 -> second grouped GEMM in a fused MoE MLP\n    fuse_mul_post: bool = False,\n    flatten: bool = True,\n    # Manually tuned parameters\n    use_tma_load_w: bool = False,\n    use_tma_load_x: bool = False,\n    use_tma_store: bool = False,\n    BLOCK_SIZE_M: int = None,\n    BLOCK_SIZE_N: int = None,\n    BLOCK_SIZE_K: int = None,\n    num_warps: int = None,\n    num_stages: int = None,\n    # Autotuning parameters\n    autotune: bool = False,\n    num_autotune_configs: int = None,\n    # Flag to manually enable TMA store\n    allow_tma_store: bool = False,\n    use_autograd: bool = False,\n):\n    if not check_valid_config(\n        permute_x, permute_y, use_W1 = use_W1, fuse_mul_post = fuse_mul_post\n    ):\n        pytest.skip(\n            f\"Skipping test due to invalid config: {permute_x = } {permute_y = } {use_W1 = } {fuse_mul_post = }\"\n        )\n\n    if use_tma_store and not allow_tma_store:\n        pytest.skip(\"TMA store needs to be debugged due to non-deterministic behavior\")\n\n    X1, X2, W1, W2, gating_output = make_inputs(\n        M = data_config.bs * data_config.seq_len,\n        N = model_config.intermediate_size,\n        K = model_config.hidden_size,\n        E = model_config.num_experts,\n        topk = model_config.topk,\n        dtype = data_config.dtype,\n    )\n    topk = model_config.topk\n    use_sigmoid = model_config.use_sigmoid\n    renormalize = model_config.renormalize\n\n    X = X1 if use_W1 else X2\n    num_tokens = data_config.bs * data_config.seq_len\n    E, K, N = W2.shape  # E = num_experts, K = hidden_size, N = intermediate_size\n    assert W1.shape == (E, 2 * N, K)\n    W = W1 if use_W1 else W2\n\n    if use_W1:\n        assert X.shape == (\n            num_tokens,\n            K,\n        ), f\"X.shape: {X.shape}, num_tokens: {num_tokens}, K: {K}\"\n    else:\n        assert X.shape == (\n            num_tokens * topk,\n            N,\n        ), f\"X.shape: {X.shape}, num_tokens: {num_tokens}, topk: {topk}, N: {N}\"\n\n    total_tokens = num_tokens * topk\n    output_shape = (total_tokens, 2 * N) if use_W1 else (total_tokens, K)\n\n    topk_weights, topk_ids = calculate_topk(\n        gating_output, topk, use_sigmoid = use_sigmoid, renormalize = renormalize\n    )\n    topk_weights = topk_weights.view(-1)  # num_tokens * topk\n    topk_ids = topk_ids.view(-1)  # num_tokens * topk\n\n    expert_token_counts, gather_indices = get_routing_indices(topk_ids, num_experts = E)\n    assert len(gather_indices) == total_tokens\n    assert len(expert_token_counts) == E\n\n    atol, rtol = TOLERANCE[X.dtype]\n\n    Xperm = permute(X, gather_indices, topk)\n\n    Xref = Xperm\n\n    assert (\n        Xperm.shape == (total_tokens, K) if use_W1 else (total_tokens, N)\n    ), f\"Xperm.shape: {Xperm.shape}, total_tokens: {total_tokens}, K: {K}\"\n\n    ref_output = torch_grouped_gemm(X = Xref, W = W, m_sizes = expert_token_counts)\n\n    if permute_x:\n        X_test = X\n    else:\n        X_test = Xperm\n\n    # No need to run all configs for tests, otherwise takes too long\n    if autotune:\n        from grouped_gemm.kernels.forward import _autotuned_grouped_gemm_forward_kernel\n\n        if num_autotune_configs is not None:\n            _autotuned_grouped_gemm_forward_kernel.configs = (\n                _autotuned_grouped_gemm_forward_kernel.configs[:num_autotune_configs]\n            )\n\n    # Use autograd.Function interface\n    if use_autograd:\n        from grouped_gemm.interface import grouped_gemm\n\n        kernel_config_fwd = KernelConfigForward(\n            BLOCK_SIZE_M = BLOCK_SIZE_M,\n            BLOCK_SIZE_N = BLOCK_SIZE_N,\n            BLOCK_SIZE_K = BLOCK_SIZE_K,\n            num_warps = num_warps,\n            num_stages = num_stages,\n            permute_x = permute_x,\n            permute_y = permute_y,\n            fuse_mul_post = fuse_mul_post,\n            use_tma_load_w = use_tma_load_w,\n            use_tma_load_x = use_tma_load_x,\n            use_tma_store = use_tma_store,\n        )\n\n        test_output = grouped_gemm(\n            X = X_test,\n            W = W,\n            topk = topk,\n            m_sizes = expert_token_counts,\n            gather_indices = gather_indices,\n            topk_weights = topk_weights if fuse_mul_post else None,\n            permute_x = permute_x,\n            permute_y = permute_y,\n            fuse_mul_post = fuse_mul_post,\n            kernel_config_fwd = kernel_config_fwd,\n            autotune = autotune,\n            is_first_gemm = use_W1,\n        )\n    # Use manual interface\n    else:\n        test_output = grouped_gemm_forward(\n            X = X_test,\n            W = W,\n            topk = topk,\n            m_sizes = expert_token_counts,\n            gather_indices = gather_indices,\n            topk_weights = topk_weights if fuse_mul_post else None,\n            permute_x = permute_x,\n            permute_y = permute_y,\n            fuse_mul_post = fuse_mul_post,\n            use_tma_load_w = use_tma_load_w,\n            use_tma_load_x = use_tma_load_x,\n            use_tma_store = use_tma_store,\n            autotune = autotune,\n            BLOCK_SIZE_M = BLOCK_SIZE_M,\n            BLOCK_SIZE_N = BLOCK_SIZE_N,\n            BLOCK_SIZE_K = BLOCK_SIZE_K,\n            num_warps = num_warps,\n            num_stages = num_stages,\n            flatten = flatten,\n        )\n    assert ref_output.shape == output_shape\n    assert test_output.shape == output_shape\n\n    if permute_y:\n        ref_output = unpermute(ref_output, gather_indices)\n    if fuse_mul_post:\n        # if we don't permute_y, then test output is permuted with topk weights applied\n        # the ref output needs to be unpermuted before multiplying by topk weights since topk weights are in token order\n        if not permute_y:\n            ref_output = unpermute(ref_output, gather_indices)\n            test_output = unpermute(test_output, gather_indices)\n        ref_output = ref_output * topk_weights[:, None]\n\n    assert torch.allclose(\n        ref_output, test_output, atol = atol, rtol = rtol\n    ), f\"Grouped gemm forward failed: {(ref_output - test_output).abs().max().item():.6f}\"\n\n\n# NOTE: Fuse multiplication of topk weights is only supported for inference and not training, although this may change in the future; not currently tested.\n@pytest.mark.parametrize(\n    \"kernel_config\",\n    KERNEL_CONFIGS_FWD,\n    ids = lambda x: x.to_string(include_tuning_params = True, include_tma = True),\n)\n@pytest.mark.parametrize(\n    \"model_config\",\n    SMALL_MODEL_CONFIGS + [QWEN_MODEL_CONFIG, LLAMA_MODEL_CONFIG],\n    ids = lambda x: f\"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}\",\n)\n@pytest.mark.parametrize(\n    \"data_config\", DATA_CONFIGS, ids = lambda x: f\"seq_len={x.seq_len} dtype={x.dtype}\"\n)\n@pytest.mark.parametrize(\"use_W1\", [True, False], ids = lambda x: f\"use_W1={x}\")\ndef test_grouped_gemm_forward_manual(\n    data_config: DataConfig,\n    model_config: ModelConfig,\n    kernel_config: KernelConfigForward,\n    use_W1: bool,\n):\n    _test_grouped_gemm_forward(\n        data_config = data_config,\n        model_config = model_config,\n        use_W1 = use_W1,\n        **asdict(kernel_config),\n    )\n\n\n@pytest.mark.parametrize(\n    \"kernel_config\",\n    KERNEL_CONFIGS_FWD,\n    ids = lambda x: x.to_string(include_tuning_params = True, include_tma = True),\n)\n@pytest.mark.parametrize(\n    \"model_config\",\n    SMALL_MODEL_CONFIGS + [QWEN_MODEL_CONFIG, LLAMA_MODEL_CONFIG],\n    ids = lambda x: f\"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}\",\n)\n@pytest.mark.parametrize(\n    \"data_config\", DATA_CONFIGS, ids = lambda x: f\"seq_len={x.seq_len} dtype={x.dtype}\"\n)\n@pytest.mark.parametrize(\"use_W1\", [True, False], ids = lambda x: f\"use_W1={x}\")\ndef test_grouped_gemm_forward_manual_autograd(\n    data_config: DataConfig,\n    model_config: ModelConfig,\n    kernel_config: KernelConfigForward,\n    use_W1: bool,\n):\n    _test_grouped_gemm_forward(\n        data_config = data_config,\n        model_config = model_config,\n        use_W1 = use_W1,\n        use_autograd = True,\n        **asdict(kernel_config),\n    )\n\n\n@pytest.mark.parametrize(\n    \"num_autotune_configs\", [10], ids = lambda x: f\"num_autotune_configs={x}\"\n)\n@pytest.mark.parametrize(\n    \"permute_x\", [True, False], ids = lambda x: \"permute_x\" if x else \"\"\n)\n@pytest.mark.parametrize(\n    \"permute_y\", [True, False], ids = lambda x: \"permute_y\" if x else \"\"\n)\n@pytest.mark.parametrize(\n    \"model_config\",\n    [LLAMA_MODEL_CONFIG, QWEN_MODEL_CONFIG],\n    ids = lambda x: f\"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}\",\n)\n@pytest.mark.parametrize(\n    \"data_config\", DATA_CONFIGS, ids = lambda x: f\"seq_len={x.seq_len} dtype={x.dtype}\"\n)\n@pytest.mark.parametrize(\"use_W1\", [True, False], ids = lambda x: f\"use_W1={x}\")\ndef test_grouped_gemm_forward_autotune(\n    data_config: DataConfig,\n    model_config: ModelConfig,\n    permute_x: bool,\n    permute_y: bool,\n    use_W1: bool,\n    num_autotune_configs: int,\n):\n    _test_grouped_gemm_forward(\n        data_config = data_config,\n        model_config = model_config,\n        permute_x = permute_x,\n        permute_y = permute_y,\n        use_W1 = use_W1,\n        num_autotune_configs = num_autotune_configs,\n        autotune = True,\n        use_autograd = False,\n    )\n\n\n@pytest.mark.parametrize(\n    \"num_autotune_configs\", [10], ids = lambda x: f\"num_autotune_configs={x}\"\n)\n@pytest.mark.parametrize(\n    \"permute_x\", [True, False], ids = lambda x: \"permute_x\" if x else \"\"\n)\n@pytest.mark.parametrize(\n    \"permute_y\", [True, False], ids = lambda x: \"permute_y\" if x else \"\"\n)\n@pytest.mark.parametrize(\n    \"model_config\",\n    [LLAMA_MODEL_CONFIG, QWEN_MODEL_CONFIG],\n    ids = lambda x: f\"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}\",\n)\n@pytest.mark.parametrize(\n    \"data_config\", DATA_CONFIGS, ids = lambda x: f\"seq_len={x.seq_len} dtype={x.dtype}\"\n)\n@pytest.mark.parametrize(\"use_W1\", [True, False], ids = lambda x: f\"use_W1={x}\")\ndef test_grouped_gemm_forward_autotune_autograd(\n    data_config: DataConfig,\n    model_config: ModelConfig,\n    permute_x: bool,\n    permute_y: bool,\n    use_W1: bool,\n    num_autotune_configs: int,\n):\n    _test_grouped_gemm_forward(\n        data_config = data_config,\n        model_config = model_config,\n        permute_x = permute_x,\n        permute_y = permute_y,\n        use_W1 = use_W1,\n        num_autotune_configs = num_autotune_configs,\n        autotune = True,\n        use_autograd = True,\n    )\n\n\n\"\"\"\ngrouped_gemm_backward_dX\n\nuse_W1 test the shapes for the first grouped GEMM in a fused MoE MLP\nuse_W2 = `not use_W1` tests the shapes for the second grouped GEMM in a fused MoE MLP\n\nOnly certain combinations of permute_x, permute_y, and fuse_mul are supported.\n\nTypically in a fused MoE MLP, we can fuse the permutation of hidden states (X) from token order to expert grouped order needed for grouped GEMM by directly loading X in permuted order rather than launching a separate permutation kernel.\nWe can also fuse the unpermutation of tokens after the second grouped GEMM to restore to original token order.  This is fused into the second grouped GEMM by directly storing the output in unpermuted order.\n\nHence the following conditions:\n- If use_W1 there are two cases:\n    - permute_x is False and topk > 1:\n    - dX_test is still in permuted order and has shape (total_tokens, K)\n    - it needs to be unpermuted and summed across topk before comparing to ref_grad\n- permute_x is True:\n    - dX_test is already unpermuted and summed across topk with shape (num_tokens, K)\n    - no further processing is needed\n- permute_x is False and topk == 1:\n    - dX_test needs to be permuted, no need to sum since topk == 1\n\n- If use_W2:\n    - permute_x is always False\n    - if permute_y:\n        - grad_output needs to be unpermuted before passing to grouped_gemm_dX\n        - dX_test is permuted and has shape (total_tokens, N)\n        - it needs to be unpermuted before comparing to ref_grad or can be compared directly to Xperm.grad\n    - if not permute_y:\n        - dX_test is not permuted and has shape (total_tokens, N)\n        - no further processing is needed\n\"\"\"\n\n\ndef _test_grouped_gemm_backward_dX(\n    data_config: DataConfig,\n    model_config: ModelConfig,\n    permute_x: bool = False,\n    permute_y: bool = False,\n    use_tma_load_dy: bool = False,\n    use_tma_load_w: bool = False,\n    use_tma_store: bool = False,\n    use_W1: bool = True,\n    autotune: bool = False,\n    num_autotune_configs: int = None,\n    BLOCK_SIZE_M: int = None,\n    BLOCK_SIZE_N: int = None,\n    BLOCK_SIZE_K: int = None,\n    num_warps: int = None,\n    num_stages: int = None,\n    flatten: bool = True,\n    allow_tma_store: bool = False,\n    use_autograd: bool = False,\n    fuse_mul_post: bool = False,\n):\n    if not check_valid_config(permute_x, permute_y, use_W1 = use_W1, is_backward = True):\n        pytest.skip(\n            f\"Skipping test due to invalid config: {permute_x = } {permute_y = } {use_W1 = }\"\n        )\n\n    if use_tma_store and not allow_tma_store:\n        pytest.skip(\"TMA store needs to be debugged due to non-deterministic behavior\")\n\n    if (\n        autotune\n        and model_config.intermediate_size <= 128\n        and model_config.hidden_size <= 128\n    ):\n        pytest.skip(\"Skipping autotuning for small model configs\")\n\n    # Prevent OOM for large intermediate sizes\n    if model_config.intermediate_size > 2048:\n        model_config.intermediate_size = 1024\n    if model_config.hidden_size > 2048:\n        model_config.hidden_size = 1024\n\n    use_W2 = not use_W1\n    X1, X2, W1, W2, gating_output = make_inputs(\n        M = data_config.bs * data_config.seq_len,\n        N = model_config.intermediate_size,\n        K = model_config.hidden_size,\n        E = model_config.num_experts,\n        topk = model_config.topk,\n        dtype = data_config.dtype,\n        requires_grad = True,\n    )\n    topk = model_config.topk\n    num_experts = model_config.num_experts\n    use_sigmoid = model_config.use_sigmoid\n    renormalize = model_config.renormalize\n\n    X = X1 if use_W1 else X2\n    num_tokens = data_config.bs * data_config.seq_len\n    total_tokens = num_tokens * topk\n\n    E, K, N = W2.shape  # E = num_experts, K = hidden_size, N = intermediate_size\n    assert W1.shape == (E, 2 * N, K)\n    W = W1 if use_W1 else W2\n\n    if use_W1:\n        assert X.shape == (\n            num_tokens,\n            K,\n        ), f\"X.shape: {X.shape}, num_tokens: {num_tokens}, K: {K}\"\n    else:\n        assert X.shape == (\n            total_tokens,\n            N,\n        ), f\"X.shape: {X.shape}, total_tokens: {total_tokens}, N: {N}\"\n\n    W_test = W.detach().clone().requires_grad_(True)\n\n    topk_weights, topk_ids = calculate_topk(\n        gating_output, topk, use_sigmoid = use_sigmoid, renormalize = renormalize\n    )\n    topk_weights = topk_weights.view(-1)  # num_tokens * topk\n    topk_ids = topk_ids.view(-1)  # num_tokens * topk\n\n    expert_token_counts, gather_indices = get_routing_indices(topk_ids, num_experts = E)\n    assert len(gather_indices) == total_tokens\n    assert len(expert_token_counts) == num_experts\n\n    atol, rtol = TOLERANCE[X.dtype]\n    Xperm = permute(X, gather_indices, topk)\n\n    # Need to retain grad otherwise grad is not propagated\n    X.retain_grad()\n    W.retain_grad()\n    Xperm.retain_grad()\n\n    assert Xperm.shape == (total_tokens, K) if use_W1 else (total_tokens, N)\n\n    output_shape = (total_tokens, 2 * N) if use_W1 else (total_tokens, K)\n    ref_output = torch_grouped_gemm(X = Xperm, W = W, m_sizes = expert_token_counts)\n    assert (\n        ref_output.shape == output_shape\n    ), f\"ref_output.shape: {ref_output.shape}, output_shape: {output_shape}\"\n\n    if permute_y:\n        ref_output = unpermute(ref_output, gather_indices)\n\n    grad_output = torch.randn_like(ref_output)\n    ref_output.backward(grad_output)\n\n    assert X.grad is not None\n    assert W.grad is not None\n\n    ref_grad = Xperm.grad\n\n    if autotune:\n        # No need to run all configs for autotuning\n        from grouped_gemm.kernels.backward import _autotuned_grouped_gemm_dX_kernel\n\n        if num_autotune_configs is not None:\n            _autotuned_grouped_gemm_dX_kernel.configs = (\n                _autotuned_grouped_gemm_dX_kernel.configs[:num_autotune_configs]\n            )\n\n    if use_autograd:\n        from grouped_gemm.interface import grouped_gemm\n\n        if not autotune:\n            kernel_config_fwd = KernelConfigForward()\n            kernel_config_bwd_dX = KernelConfigBackward_dX(\n                use_tma_load_dy = use_tma_load_dy,\n                use_tma_load_w = use_tma_load_w,\n                use_tma_store = use_tma_store,\n                BLOCK_SIZE_M = BLOCK_SIZE_M,\n                BLOCK_SIZE_N = BLOCK_SIZE_N,\n                BLOCK_SIZE_K = BLOCK_SIZE_K,\n                num_warps = num_warps,\n                num_stages = num_stages,\n            )\n            kernel_config_bwd_dW = KernelConfigBackward_dW()\n        else:\n            from grouped_gemm.kernels.backward import (\n                _autotuned_grouped_gemm_dW_kernel,\n                _autotuned_grouped_gemm_dX_kernel,\n            )\n            from grouped_gemm.kernels.forward import (\n                _autotuned_grouped_gemm_forward_kernel,\n            )\n\n            if num_autotune_configs is not None:\n                _autotuned_grouped_gemm_dX_kernel.configs = (\n                    _autotuned_grouped_gemm_dX_kernel.configs[:num_autotune_configs]\n                )\n                _autotuned_grouped_gemm_forward_kernel.configs = (\n                    _autotuned_grouped_gemm_forward_kernel.configs[\n                        :num_autotune_configs\n                    ]\n                )\n\n            kernel_config_fwd = None\n            kernel_config_bwd_dX = None\n        X_ = (\n            X.detach().clone().requires_grad_(True)\n            if permute_x\n            else Xperm.detach().clone().requires_grad_(True)\n        )\n        test_output = grouped_gemm(\n            X = X_,\n            W = W_test,\n            m_sizes = expert_token_counts,\n            gather_indices = gather_indices,\n            topk = topk,\n            permute_x = permute_x,\n            permute_y = permute_y,\n            autotune = autotune,\n            kernel_config_fwd = kernel_config_fwd,\n            kernel_config_bwd_dX = kernel_config_bwd_dX,\n            is_first_gemm = use_W1,\n            dX_only = True,\n        )\n        assert (\n            test_output.shape == ref_output.shape\n        ), f\"test_output.shape: {test_output.shape}, ref_output.shape: {ref_output.shape}\"\n        assert torch.allclose(\n            test_output, ref_output, atol = atol, rtol = rtol\n        ), f\"Grouped gemm backward_dX forward outputs mismatch: {(test_output - ref_output).abs().max().item():.6f}\"\n        test_output.backward(grad_output)\n        assert X_.grad is not None\n\n        # NOTE:need to handle grad differenlty in this case due to errors arising to do how torch autograd handles unpermute and sum reduction\n        # the grad of Xperm unpermuted and reduced across topk should match X_.grad\n        # However, both will have a numerical difference with that of ref_grad\n        # This is due to the fact that torch autograd handles unpermute and sum reduction differently see: https://discuss.pytorch.org/t/permute-unpermute-gradient/219557    else:\n        if permute_x and use_W1:\n            X_grad_unperm = unpermute(Xperm.grad, gather_indices)\n            manual_grad_check = X_grad_unperm.view(num_tokens, topk, K).sum(dim = 1)\n            assert (\n                manual_grad_check.shape == X_.grad.shape\n            ), f\"manual_grad_check.shape: {manual_grad_check.shape}, X_.grad.shape: {X_.grad.shape}\"\n            assert torch.allclose(\n                manual_grad_check, X_.grad, atol = atol, rtol = rtol\n            ), f\"Grouped gemm backward_dX forward outputs mismatch: {(manual_grad_check - X_.grad).abs().max().item():.6f}\"\n            manual_diff = (X_.grad - manual_grad_check).abs().max().item()\n            autograd_diff = (X_.grad - X.grad).abs().max().item()\n            print(f\"manual_diff: {manual_diff:.6f}, autograd_diff: {autograd_diff:.6f}\")\n        else:\n            assert torch.allclose(\n                X_.grad, ref_grad, atol = atol, rtol = rtol\n            ), f\"Grouped gemm backward_dX forward outputs mismatch: {(X_.grad - ref_grad).abs().max().item():.6f}\"\n        return\n    else:\n        dX_test = grouped_gemm_dX(\n            dY = grad_output,\n            W = W_test,\n            gather_indices = gather_indices,\n            m_sizes = expert_token_counts,\n            topk = topk,\n            permute_x = permute_x,\n            permute_y = permute_y,\n            use_tma_load_w = use_tma_load_w,\n            use_tma_load_dy = use_tma_load_dy,\n            use_tma_store = use_tma_store,\n            autotune = autotune,\n            BLOCK_SIZE_M = BLOCK_SIZE_M,\n            BLOCK_SIZE_N = BLOCK_SIZE_N,\n            BLOCK_SIZE_K = BLOCK_SIZE_K,\n            num_warps = num_warps,\n            num_stages = num_stages,\n            flatten = flatten,\n            # debug=True,\n        )\n\n    # if permute_x and use_W1 (first grouped GEMM) then the kernel should have unpermuted the dX\n    # therefore we need to unpermute the ref_grad to compare to the output of the kernel\n    if permute_x and use_W1:\n        ref_grad = unpermute(ref_grad, gather_indices)\n\n    assert (\n        ref_grad.shape == dX_test.shape\n    ), f\"Grouped gemm manual backward_dX outputs mismatch: ref_grad: {ref_grad.shape}, dX_test: {dX_test.shape}\"\n    diff = (ref_grad - dX_test).abs().max().item()\n\n    assert torch.allclose(\n        ref_grad, dX_test, atol = atol, rtol = rtol\n    ), f\"Grouped gemm manual backward_dX outputs mismatch: {diff:.6f}\"\n\n    if permute_x and use_W1:\n        # Show that reduction results in diffs\n        # First calculate X.grad manually by backpropping through unpermuted ref_grad\n        dX_ref_check = ref_grad.view(num_tokens, topk, K).sum(dim = 1)\n        # Do the same for the actual output of the kernel\n        dX_test_check = dX_test.view(num_tokens, topk, K).sum(dim = 1)\n        # Show diffs for each combination\n        diff_ref_check = (X.grad - dX_ref_check).abs().max().item()\n        diff_test_check = (X.grad - dX_test_check).abs().max().item()\n        diff_check_test = (dX_ref_check - dX_test_check).abs().max().item()\n        print(\n            f\"diff_ref_check: {diff_ref_check:.6f}, diff_test_check: {diff_test_check:.6f}, diff_check_test: {diff_check_test:.6f}\"\n        )\n\n\n# NOTE: We reduce the size of the Llama4 model configs to prevent OOM\n# Important to note that for the full model size (5120, 8192), the tests do result in diffs on the order of 1e-2.\n@pytest.mark.parametrize(\n    \"kernel_config\",\n    KERNEL_CONFIGS_BWD_dX,\n    ids = lambda x: x.to_string(include_tuning_params = True, include_tma = True),\n)\n@pytest.mark.parametrize(\n    \"model_config\",\n    SMALL_MODEL_CONFIGS[:1] + [LLAMA_MODEL_CONFIG, QWEN_MODEL_CONFIG],\n    ids = lambda x: f\"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}\",\n)\n@pytest.mark.parametrize(\n    \"data_config\", DATA_CONFIGS, ids = lambda x: f\"seq_len={x.seq_len} dtype={x.dtype}\"\n)\n@pytest.mark.parametrize(\"use_W1\", [True, False], ids = lambda x: f\"use_W1={x}\")\ndef test_grouped_gemm_backward_dX_manual(\n    data_config: DataConfig,\n    model_config: ModelConfig,\n    kernel_config: KernelConfigBackward_dX,\n    use_W1: bool,\n):\n    _test_grouped_gemm_backward_dX(\n        data_config = data_config,\n        model_config = model_config,\n        use_W1 = use_W1,\n        use_autograd = False,\n        **asdict(kernel_config),\n    )\n\n\n@pytest.mark.parametrize(\n    \"kernel_config\",\n    KERNEL_CONFIGS_BWD_dX,\n    ids = lambda x: x.to_string(include_tuning_params = True, include_tma = True),\n)\n@pytest.mark.parametrize(\n    \"model_config\",\n    SMALL_MODEL_CONFIGS[:1] + [LLAMA_MODEL_CONFIG, QWEN_MODEL_CONFIG],\n    ids = lambda x: f\"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}\",\n)\n@pytest.mark.parametrize(\n    \"data_config\", DATA_CONFIGS, ids = lambda x: f\"seq_len={x.seq_len} dtype={x.dtype}\"\n)\n@pytest.mark.parametrize(\"use_W1\", [True, False], ids = lambda x: f\"use_W1={x}\")\ndef test_grouped_gemm_backward_dX_manual_autograd(\n    data_config: DataConfig,\n    model_config: ModelConfig,\n    kernel_config: KernelConfigBackward_dX,\n    use_W1: bool,\n):\n    _test_grouped_gemm_backward_dX(\n        data_config = data_config,\n        model_config = model_config,\n        use_W1 = use_W1,\n        use_autograd = True,\n        **asdict(kernel_config),\n    )\n\n\n@pytest.mark.parametrize(\n    \"num_autotune_configs\", [20], ids = lambda x: f\"num_autotune_configs={x}\"\n)\n@pytest.mark.parametrize(\n    \"permute_x\", [True, False], ids = lambda x: \"permute_x\" if x else \"\"\n)\n@pytest.mark.parametrize(\n    \"permute_y\", [True, False], ids = lambda x: \"permute_y\" if x else \"\"\n)\n@pytest.mark.parametrize(\n    \"model_config\",\n    [LLAMA_MODEL_CONFIG, QWEN_MODEL_CONFIG],\n    ids = lambda x: f\"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}\",\n)\n@pytest.mark.parametrize(\n    \"data_config\", DATA_CONFIGS, ids = lambda x: f\"seq_len={x.seq_len} dtype={x.dtype}\"\n)\n@pytest.mark.parametrize(\"use_W1\", [True, False], ids = lambda x: f\"use_W1={x}\")\ndef test_grouped_gemm_backward_dX_autotune(\n    data_config: DataConfig,\n    model_config: ModelConfig,\n    permute_x: bool,\n    permute_y: bool,\n    use_W1: bool,\n    num_autotune_configs: int,\n):\n    # TMA loads / stores will be autotuned\n    _test_grouped_gemm_backward_dX(\n        data_config = data_config,\n        model_config = model_config,\n        permute_x = permute_x,\n        permute_y = permute_y,\n        use_W1 = use_W1,\n        autotune = True,\n        use_autograd = False,\n        num_autotune_configs = num_autotune_configs,\n    )\n\n\n@pytest.mark.parametrize(\n    \"num_autotune_configs\", [20], ids = lambda x: f\"num_autotune_configs={x}\"\n)\n@pytest.mark.parametrize(\n    \"permute_x\", [True, False], ids = lambda x: \"permute_x\" if x else \"\"\n)\n@pytest.mark.parametrize(\n    \"permute_y\", [True, False], ids = lambda x: \"permute_y\" if x else \"\"\n)\n@pytest.mark.parametrize(\n    \"model_config\",\n    [LLAMA_MODEL_CONFIG, QWEN_MODEL_CONFIG],\n    ids = lambda x: f\"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}\",\n)\n@pytest.mark.parametrize(\n    \"data_config\", DATA_CONFIGS, ids = lambda x: f\"seq_len={x.seq_len} dtype={x.dtype}\"\n)\n@pytest.mark.parametrize(\"use_W1\", [True, False], ids = lambda x: f\"use_W1={x}\")\ndef test_grouped_gemm_backward_dX_autotune_autograd(\n    data_config: DataConfig,\n    model_config: ModelConfig,\n    permute_x: bool,\n    permute_y: bool,\n    use_W1: bool,\n    num_autotune_configs: int,\n):\n    # TMA loads / stores will be autotuned\n    _test_grouped_gemm_backward_dX(\n        data_config = data_config,\n        model_config = model_config,\n        permute_x = permute_x,\n        permute_y = permute_y,\n        use_W1 = use_W1,\n        autotune = True,\n        use_autograd = True,\n        num_autotune_configs = num_autotune_configs,\n    )\n\n\ndef _test_grouped_gemm_backward_dW(\n    data_config: DataConfig,\n    model_config: ModelConfig,\n    permute_x: bool,\n    permute_y: bool,\n    use_W1: bool,\n    use_tma_load_dy: bool = False,\n    use_tma_load_x: bool = False,\n    use_tma_store: bool = False,\n    BLOCK_SIZE_M: int = None,\n    BLOCK_SIZE_N: int = None,\n    BLOCK_SIZE_K: int = None,\n    num_warps: int = None,\n    num_stages: int = None,\n    flatten: bool = True,\n    autotune: bool = False,\n    num_autotune_configs: int = None,\n    allow_tma_store: bool = False,\n    debug: bool = False,\n    fuse_mul_post: bool = False,  # Unused for backward_dW\n    use_autograd: bool = False,\n):\n    if not check_valid_config(\n        permute_x,\n        permute_y,\n        fuse_mul_post = fuse_mul_post,\n        use_W1 = use_W1,\n        is_backward = True,\n    ):\n        pytest.skip(\n            f\"Skipping test due to invalid config: {permute_x = } {permute_y = } {use_W1 = }\"\n        )\n\n    if use_tma_store and not allow_tma_store:\n        pytest.skip(\"TMA store needs to be debugged due to non-deterministic behavior\")\n\n    X1, X2, W1, W2, gating_output = make_inputs(\n        M = data_config.bs * data_config.seq_len,\n        N = model_config.intermediate_size,\n        K = model_config.hidden_size,\n        E = model_config.num_experts,\n        topk = model_config.topk,\n        dtype = data_config.dtype,\n        requires_grad = True,\n    )\n    topk = model_config.topk\n    num_experts = model_config.num_experts\n    use_sigmoid = model_config.use_sigmoid\n    renormalize = model_config.renormalize\n\n    X = X1 if use_W1 else X2\n    num_tokens = data_config.bs * data_config.seq_len\n    E, K, N = W2.shape  # E = num_experts, K = hidden_size, N = intermediate_size\n    assert W1.shape == (E, 2 * N, K)\n    W = W1 if use_W1 else W2\n\n    if use_W1:\n        assert X.shape == (\n            num_tokens,\n            K,\n        ), f\"X.shape: {X.shape}, num_tokens: {num_tokens}, K: {K}\"\n    else:\n        assert X.shape == (\n            num_tokens * topk,\n            N,\n        ), f\"X.shape: {X.shape}, num_tokens: {num_tokens}, topk: {topk}, N: {N}\"\n\n    total_tokens = num_tokens * topk\n    output_shape = (total_tokens, 2 * N) if use_W1 else (total_tokens, K)\n\n    X_test = X.detach().clone().requires_grad_(True)\n    W_test = W.detach().clone().requires_grad_(True)\n\n    topk_weights, topk_ids = calculate_topk(\n        gating_output, topk, use_sigmoid = use_sigmoid, renormalize = renormalize\n    )\n    topk_weights = topk_weights.view(-1)  # num_tokens * topk\n    topk_ids = topk_ids.view(-1)  # num_tokens * topk\n\n    expert_token_counts, gather_indices = get_routing_indices(topk_ids, num_experts = E)\n    assert len(gather_indices) == total_tokens\n    assert len(expert_token_counts) == num_experts\n\n    atol, rtol = TOLERANCE[X.dtype]\n    Xperm = permute(X, gather_indices, topk)\n    Xperm_test = Xperm.detach().clone().requires_grad_(True)\n\n    # Need to retain grad otherwise grad is not propagated\n    X.retain_grad()\n    W.retain_grad()\n    Xperm.retain_grad()\n    assert Xperm.shape == (total_tokens, K) if use_W1 else (total_tokens, N)\n\n    output_shape = (total_tokens, 2 * N) if use_W1 else (total_tokens, K)\n\n    ref_output = torch_grouped_gemm(X = Xperm, W = W, m_sizes = expert_token_counts)\n    assert ref_output.shape == output_shape\n\n    # if permute_y then the assumption is that the output of grouped_gemm was unpermuted on store\n    # Therefore we have to unpermute before backpropping to ensure proper alignment\n    if permute_y:\n        ref_output = unpermute(ref_output, gather_indices)\n\n    grad_output = torch.randn_like(ref_output)\n    ref_output.backward(grad_output)\n    assert X.grad is not None\n    assert W.grad is not None\n\n    # Test backward kernel directly\n    X_ = X_test if permute_x else Xperm_test\n\n    if debug:\n        torch.set_printoptions(precision = 4)\n        for i in range(num_experts):\n            print(f\"Expert {i} weight grad:\\n{W.grad[i, :5, :5]}\")\n\n    if autotune:\n        from grouped_gemm.kernels.backward import _autotuned_grouped_gemm_dW_kernel\n\n        if num_autotune_configs is not None:\n            _autotuned_grouped_gemm_dW_kernel.configs = (\n                _autotuned_grouped_gemm_dW_kernel.configs[:num_autotune_configs]\n            )\n\n    if use_autograd:\n        from grouped_gemm.interface import grouped_gemm\n\n        if not autotune:\n            kernel_config_fwd = KernelConfigForward(\n                # Only care about backward_dW config\n                use_tma_load_w = False,\n                use_tma_load_x = False,\n                use_tma_store = False,\n                BLOCK_SIZE_M = BLOCK_SIZE_M,\n                BLOCK_SIZE_N = BLOCK_SIZE_N,\n                BLOCK_SIZE_K = BLOCK_SIZE_K,\n                num_warps = num_warps,\n                num_stages = num_stages,\n            )\n            kernel_config_bwd_dW = KernelConfigBackward_dW(\n                use_tma_load_dy = use_tma_load_dy,\n                use_tma_load_x = use_tma_load_x,\n                use_tma_store = use_tma_store,\n                BLOCK_SIZE_M = BLOCK_SIZE_M,\n                BLOCK_SIZE_N = BLOCK_SIZE_N,\n                BLOCK_SIZE_K = BLOCK_SIZE_K,\n                num_warps = num_warps,\n                num_stages = num_stages,\n            )\n        else:\n            from grouped_gemm.kernels.backward import _autotuned_grouped_gemm_dW_kernel\n            from grouped_gemm.kernels.forward import (\n                _autotuned_grouped_gemm_forward_kernel,\n            )\n\n            if num_autotune_configs is not None:\n                _autotuned_grouped_gemm_forward_kernel.configs = (\n                    _autotuned_grouped_gemm_forward_kernel.configs[\n                        :num_autotune_configs\n                    ]\n                )\n                _autotuned_grouped_gemm_dW_kernel.configs = (\n                    _autotuned_grouped_gemm_dW_kernel.configs[:num_autotune_configs]\n                )\n            kernel_config_fwd = None\n            kernel_config_bwd_dW = None\n\n        test_output = grouped_gemm(\n            X = X_,\n            W = W_test,\n            m_sizes = expert_token_counts,\n            gather_indices = gather_indices,\n            topk = topk,\n            permute_x = permute_x,\n            permute_y = permute_y,\n            kernel_config_fwd = kernel_config_fwd,\n            kernel_config_bwd_dW = kernel_config_bwd_dW,\n            autotune = autotune,\n            is_first_gemm = use_W1,\n            dW_only = True,\n        )\n        assert (\n            test_output.shape == ref_output.shape\n        ), f\"Grouped gemm autograd backward_dW outputs mismatch: {test_output.shape} != {ref_output.shape}\"\n        assert torch.allclose(\n            test_output, ref_output, atol = atol, rtol = rtol\n        ), f\"Grouped gemm autograd backward_dW forward outputs mismatch: {test_output.shape} != {ref_output.shape}\"\n        test_output.backward(grad_output)\n        assert W_test.grad is not None\n        dW_test = W_test.grad\n    else:\n        dW_test = grouped_gemm_dW(\n            dY = grad_output,\n            X = X_,\n            m_sizes = expert_token_counts,\n            gather_indices = gather_indices,\n            topk = topk,\n            permute_x = permute_x,\n            permute_y = permute_y,\n            use_tma_load_dy = use_tma_load_dy,\n            use_tma_load_x = use_tma_load_x,\n            use_tma_store = use_tma_store,\n            BLOCK_SIZE_M = BLOCK_SIZE_M,\n            BLOCK_SIZE_N = BLOCK_SIZE_N,\n            BLOCK_SIZE_K = BLOCK_SIZE_K,\n            num_warps = num_warps,\n            num_stages = num_stages,\n            flatten = flatten,\n            autotune = autotune,\n            debug = debug,\n        )\n    assert (\n        W.grad.shape == dW_test.shape\n    ), f\"Grouped gemm manual backward_dW outputs mismatch: W.grad: {W.grad.shape}, dW_test: {dW_test.shape}\"\n\n    if debug:\n        with torch.no_grad():\n            if not torch.allclose(W.grad, dW_test, atol = atol, rtol = rtol):\n                print(f\"Ref Wgrad sum: {W.grad.sum().item():.4f}\")\n            print(f\"Test Wgrad sum: {dW_test.sum().item():.4f}\")\n\n            for i in range(num_experts):\n                print(f\"Expert {i} weight grad:\\n{W.grad[i, :5, :5]}\")\n                print(f\"Expert {i} dW_test:\\n{dW_test[i, :5, :5]}\")\n                expert_diff = (W.grad[i, :, :] - dW_test[i, :, :]).abs().max().item()\n                print(f\"Expert {i} diff: {expert_diff:.6f}\")\n\n            diff = (W.grad - dW_test).abs().max().item()\n            assert (\n                False\n            ), f\"Grouped gemm manual backward_dW outputs mismatch: {diff:.6f}\"\n    else:\n        diff = (W.grad - dW_test).abs().max().item()\n        assert torch.allclose(\n            W.grad, dW_test, atol = atol, rtol = rtol\n        ), f\"Grouped gemm manual backward_dW outputs mismatch: {diff:.6f}\"\n\n\n@pytest.mark.parametrize(\n    \"kernel_config\",\n    KERNEL_CONFIGS_BWD_dW,\n    ids = lambda x: x.to_string(include_tuning_params = False, include_tma = True),\n)\n@pytest.mark.parametrize(\n    \"model_config\",\n    SMALL_MODEL_CONFIGS + [LLAMA_MODEL_CONFIG, QWEN_MODEL_CONFIG],\n    ids = lambda x: f\"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}\",\n)\n@pytest.mark.parametrize(\n    \"data_config\", DATA_CONFIGS, ids = lambda x: f\"seq_len={x.seq_len} dtype={x.dtype}\"\n)\n@pytest.mark.parametrize(\"use_W1\", [True, False], ids = lambda x: f\"use_W1={x}\")\ndef test_grouped_gemm_backward_dW_manual(\n    data_config: DataConfig,\n    model_config: ModelConfig,\n    kernel_config: KernelConfig,\n    use_W1: bool,\n    debug: bool = False,\n):\n    _test_grouped_gemm_backward_dW(\n        data_config = data_config,\n        model_config = model_config,\n        use_W1 = use_W1,\n        use_autograd = False,\n        **asdict(kernel_config),\n    )\n\n\n@pytest.mark.parametrize(\n    \"kernel_config\",\n    KERNEL_CONFIGS_BWD_dW,\n    ids = lambda x: x.to_string(include_tuning_params = False, include_tma = True),\n)\n@pytest.mark.parametrize(\n    \"model_config\",\n    SMALL_MODEL_CONFIGS + [LLAMA_MODEL_CONFIG, QWEN_MODEL_CONFIG],\n    ids = lambda x: f\"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}\",\n)\n@pytest.mark.parametrize(\n    \"data_config\", DATA_CONFIGS, ids = lambda x: f\"seq_len={x.seq_len} dtype={x.dtype}\"\n)\n@pytest.mark.parametrize(\"use_W1\", [True, False], ids = lambda x: f\"use_W1={x}\")\ndef test_grouped_gemm_backward_dW_manual_autograd(\n    data_config: DataConfig,\n    model_config: ModelConfig,\n    kernel_config: KernelConfig,\n    use_W1: bool,\n    debug: bool = False,\n):\n    _test_grouped_gemm_backward_dW(\n        data_config = data_config,\n        model_config = model_config,\n        use_W1 = use_W1,\n        use_autograd = True,\n        **asdict(kernel_config),\n    )\n\n\n@pytest.mark.parametrize(\n    \"num_autotune_configs\", [20], ids = lambda x: f\"num_autotune_configs={x}\"\n)\n@pytest.mark.parametrize(\n    \"permute_x\", [True, False], ids = lambda x: \"permute_x\" if x else \"\"\n)\n@pytest.mark.parametrize(\n    \"permute_y\", [True, False], ids = lambda x: \"permute_y\" if x else \"\"\n)\n@pytest.mark.parametrize(\n    \"model_config\",\n    [LLAMA_MODEL_CONFIG, QWEN_MODEL_CONFIG],\n    ids = lambda x: f\"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}\",\n)\n@pytest.mark.parametrize(\n    \"data_config\", DATA_CONFIGS, ids = lambda x: f\"seq_len={x.seq_len} dtype={x.dtype}\"\n)\n@pytest.mark.parametrize(\"use_W1\", [True, False], ids = lambda x: f\"use_W1={x}\")\ndef test_grouped_gemm_backward_dW_autotune(\n    data_config: DataConfig,\n    model_config: ModelConfig,\n    permute_x: bool,\n    permute_y: bool,\n    use_W1: bool,\n    num_autotune_configs: int,\n):\n    _test_grouped_gemm_backward_dW(\n        data_config = data_config,\n        model_config = model_config,\n        use_W1 = use_W1,\n        permute_x = permute_x,\n        permute_y = permute_y,\n        autotune = True,\n        use_autograd = False,\n        num_autotune_configs = num_autotune_configs,\n    )\n\n\n@pytest.mark.parametrize(\n    \"num_autotune_configs\", [20], ids = lambda x: f\"num_autotune_configs={x}\"\n)\n@pytest.mark.parametrize(\n    \"permute_x\", [True, False], ids = lambda x: \"permute_x\" if x else \"\"\n)\n@pytest.mark.parametrize(\n    \"permute_y\", [True, False], ids = lambda x: \"permute_y\" if x else \"\"\n)\n@pytest.mark.parametrize(\n    \"model_config\",\n    [LLAMA_MODEL_CONFIG, QWEN_MODEL_CONFIG],\n    ids = lambda x: f\"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}\",\n)\n@pytest.mark.parametrize(\n    \"data_config\", DATA_CONFIGS, ids = lambda x: f\"seq_len={x.seq_len} dtype={x.dtype}\"\n)\n@pytest.mark.parametrize(\"use_W1\", [True, False], ids = lambda x: f\"use_W1={x}\")\ndef test_grouped_gemm_backward_dW_autotune_autograd(\n    data_config: DataConfig,\n    model_config: ModelConfig,\n    permute_x: bool,\n    permute_y: bool,\n    use_W1: bool,\n    num_autotune_configs: int,\n):\n    _test_grouped_gemm_backward_dW(\n        data_config = data_config,\n        model_config = model_config,\n        use_W1 = use_W1,\n        permute_x = permute_x,\n        permute_y = permute_y,\n        autotune = True,\n        use_autograd = True,\n        num_autotune_configs = num_autotune_configs,\n    )\n"
  },
  {
    "path": "unsloth/kernels/moe/tests/test_llama4_moe.py",
    "content": "# SPDX-License-Identifier: GNU Affero General Public License v3.0\n# Copyright 2023-present the Unsloth team. All rights reserved.\n\nimport argparse\nimport sys\nfrom contextlib import contextmanager\nfrom functools import partial\n\nimport pytest\nimport torch\nfrom transformers import AutoConfig\nfrom transformers.models.llama4 import Llama4Config, Llama4TextConfig\nfrom transformers.models.llama4.modeling_llama4 import Llama4TextMoe\n\nfrom grouped_gemm.kernels.tuning import (\n    KernelConfigBackward_dW,\n    KernelConfigBackward_dX,\n    KernelConfigForward,\n)\nfrom grouped_gemm.reference.layers.llama4_moe import (\n    Llama4GroupedGemmTextMoe,\n    Llama4TritonTextMoe,\n)\n\nTOLERANCES = {\n    torch.bfloat16: (1e-2, 1e-2),\n    torch.float16: (1e-3, 1e-3),\n    torch.float: (1e-5, 1e-5),\n}\n\nLLAMA4_SCOUT_ID = \"meta-llama/Llama-4-Scout-17B-16E\"\nSEED = 42\nSEQ_LENS = [1024]\nDTYPES = [torch.bfloat16]\n# Reduce the number of autotuning configs to prevent excessive runtime\nNUM_AUTOTUNE_CONFIGS = 50\n\n\n@contextmanager\ndef annotated_context(prelude, epilogue = \"Passed!\", char = \"-\", num_chars = 80):\n    print(char * num_chars)\n    print(prelude)\n    yield\n    print(epilogue)\n    print(char * num_chars)\n\n\ndef get_text_config(model_id):\n    config: Llama4Config = AutoConfig.from_pretrained(model_id)\n    return config.text_config\n\n\ndef prep_triton_kernel_traits(autotune):\n    if not autotune:\n        kernel_config_fwd = KernelConfigForward()\n        kernel_config_bwd_dW = KernelConfigBackward_dW()\n        kernel_config_bwd_dX = KernelConfigBackward_dX()\n    else:\n        from grouped_gemm.kernels.backward import (\n            _autotuned_grouped_gemm_dW_kernel,\n            _autotuned_grouped_gemm_dX_kernel,\n        )\n        from grouped_gemm.kernels.forward import _autotuned_grouped_gemm_forward_kernel\n\n        # Hack to reduce number of autotuning configs\n        _autotuned_grouped_gemm_forward_kernel.configs = (\n            _autotuned_grouped_gemm_forward_kernel.configs[:NUM_AUTOTUNE_CONFIGS]\n        )\n        _autotuned_grouped_gemm_dW_kernel.configs = (\n            _autotuned_grouped_gemm_dW_kernel.configs[:NUM_AUTOTUNE_CONFIGS]\n        )\n        _autotuned_grouped_gemm_dX_kernel.configs = (\n            _autotuned_grouped_gemm_dX_kernel.configs[:NUM_AUTOTUNE_CONFIGS]\n        )\n\n        kernel_config_fwd = None\n        kernel_config_bwd_dW = None\n        kernel_config_bwd_dX = None\n\n    return kernel_config_fwd, kernel_config_bwd_dW, kernel_config_bwd_dX\n\n\ndef sparse_to_dense(t: torch.Tensor):\n    t = t.sum(dim = 0).view(-1)\n    return t\n\n\n@torch.no_grad()\ndef _check_diff(\n    t1: torch.Tensor,\n    t2: torch.Tensor,\n    atol,\n    rtol,\n    precision = \".6f\",\n    verbose = False,\n    msg = \"\",\n):\n    t2 = t2.view_as(t1)\n    diff = t1.sub(t2).abs().max().item()\n    if verbose:\n        if msg == \"\":\n            msg = \"diff\"\n        print(f\"{msg}: {diff:{precision}}\")\n    assert torch.allclose(t1, t2, atol = atol, rtol = rtol)\n\n\ndef run_backwards(y: torch.Tensor, grad_output: torch.Tensor, module: torch.nn.Module):\n    y.backward(grad_output)\n    for name, param in module.named_parameters():\n        assert param.grad is not None, f\"{name} missing grad!\"\n\n\ndef _check_grads(\n    m1: torch.nn.Module,\n    m2: torch.nn.Module,\n    atol,\n    rtol,\n    precision = \".6f\",\n    verbose = False,\n    msg = \"\",\n):\n    for name, param in m1.named_parameters():\n        _check_diff(\n            param.grad,\n            m2.get_parameter(name).grad,\n            atol = atol,\n            rtol = rtol,\n            precision = precision,\n            verbose = verbose,\n            msg = f\"{msg}:{name}.grad\",\n        )\n\n\n@pytest.fixture\ndef model_config():\n    return AutoConfig.from_pretrained(LLAMA4_SCOUT_ID).text_config\n\n\n@pytest.mark.parametrize(\n    \"overlap_router_shared\",\n    [False, True],\n    ids = lambda x: \"overlap_router_shared\" if x else \"no_overlap\",\n)\n@pytest.mark.parametrize(\n    \"permute_y\", [False, True], ids = lambda x: \"permute_y\" if x else \"no_permute_y\"\n)\n@pytest.mark.parametrize(\n    \"permute_x\", [False], ids = lambda x: \"permute_x\" if x else \"no_permute_x\"\n)  # Llama4 does not support permute_x\n@pytest.mark.parametrize(\n    \"autotune\", [True], ids = lambda x: \"autotune\" if x else \"manual\"\n)\n@pytest.mark.parametrize(\"seqlen\", SEQ_LENS, ids = lambda x: f\"seqlen={x}\")\n@pytest.mark.parametrize(\"dtype\", DTYPES, ids = str)\ndef test_llama4_ref(\n    dtype: torch.dtype,\n    seqlen,\n    autotune: bool,\n    permute_x: bool,\n    permute_y: bool,\n    overlap_router_shared: bool,\n    model_config: Llama4TextConfig,  # test fixture\n    bs: int = 1,\n    device = \"cuda\",\n    precision = \".6f\",\n    verbose = False,\n):\n    torch.manual_seed(\n        SEED\n    )  # Should not be needed when running using pytest -- autouse fixture in conftest.py\n    device = \"cuda\"\n    hidden_dim = model_config.hidden_size\n    atol, rtol = TOLERANCES[dtype]\n    check_diff = partial(\n        _check_diff, atol = atol, rtol = rtol, precision = precision, verbose = verbose\n    )\n    check_grads = partial(\n        _check_grads, atol = atol, rtol = rtol, precision = precision, verbose = verbose\n    )\n\n    # Reference op -- HF\n    llama4_ref = Llama4TextMoe(model_config).to(dtype = dtype, device = device)\n\n    # Torch grouped gemm impl\n    llama4_gg_ref = Llama4GroupedGemmTextMoe(\n        model_config, overlap_router_shared = overlap_router_shared\n    ).to(dtype = dtype, device = device)\n    llama4_gg_ref.copy_weights(llama4_ref)\n    llama4_gg_ref.check_weights(llama4_ref)\n\n    x_ref = torch.randn(\n        bs, seqlen, hidden_dim, dtype = dtype, device = device, requires_grad = True\n    )\n    x_torch_gg = x_ref.detach().clone().requires_grad_()\n    x_triton = x_ref.detach().clone().requires_grad_()\n\n    y_ref, routing_ref = llama4_ref(x_ref)\n    y_torch_gg, routing_torch_gg = llama4_gg_ref(x_torch_gg)\n    assert y_ref.shape == y_torch_gg.shape, f\"{y_ref.shape} != {y_torch_gg.shape}\"\n    with annotated_context(\"Testing torch grouped gemm Llama4TextMoe\"):\n        check_diff(y_ref, y_torch_gg, msg = \"y_torch_gg\")\n        check_diff(\n            sparse_to_dense(routing_ref), routing_torch_gg, msg = \"routing_torch_gg\"\n        )\n\n    kernel_config_fwd, kernel_config_bwd_dW, kernel_config_bwd_dX = (\n        prep_triton_kernel_traits(autotune)\n    )\n\n    llama4_triton = Llama4TritonTextMoe(\n        model_config,\n        overlap_router_shared = overlap_router_shared,\n        permute_x = permute_x,\n        permute_y = permute_y,\n        autotune = autotune,\n        kernel_config_fwd = kernel_config_fwd,\n        kernel_config_bwd_dW = kernel_config_bwd_dW,\n        kernel_config_bwd_dX = kernel_config_bwd_dX,\n    ).to(device = device, dtype = dtype)\n    llama4_triton.copy_weights(llama4_ref)\n    llama4_triton.check_weights(llama4_ref)\n\n    y_triton, routing_triton = llama4_triton(x_triton)\n    with annotated_context(\"Testing triton grouped gemm Llama4TextMoe forward\"):\n        check_diff(y_ref, y_triton, msg = \"y_triton\")\n        check_diff(sparse_to_dense(routing_ref), routing_triton, msg = \"routing_triton\")\n\n    ref_grad = torch.randn_like(y_ref)\n    run_backwards(y_ref, ref_grad, llama4_ref)\n    run_backwards(y_torch_gg, ref_grad, llama4_gg_ref)\n    with annotated_context(\"Testing torch group gemm Llama4TextMoe backward\"):\n        check_grads(llama4_ref, llama4_gg_ref, msg = \"torch_gg\")\n\n    run_backwards(y_triton, ref_grad, llama4_triton)\n    with annotated_context(\"Testing triton group gemm Llama4TextMoe backward\"):\n        check_grads(llama4_ref, llama4_triton, msg = \"triton\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--seqlen\", type = int, default = 1024)\n    parser.add_argument(\n        \"--dtype\", type = str, choices = [\"bfloat16\", \"float16\"], default = \"bfloat16\"\n    )\n    args = parser.parse_args()\n    args.dtype = getattr(torch, args.dtype)\n    args_dict = vars(args)\n\n    model_id = LLAMA4_SCOUT_ID\n\n    text_config: Llama4TextConfig = get_text_config(model_id)\n    for overlap in [False, True]:\n        test_llama4_ref(\n            seqlen = args.seqlen,\n            model_config = text_config,\n            dtype = args.dtype,\n            autotune = True,\n            permute_x = False,\n            permute_y = True,\n            overlap_router_shared = overlap,\n            verbose = True,\n        )\n"
  },
  {
    "path": "unsloth/kernels/moe/tests/test_qwen3_moe.py",
    "content": "# SPDX-License-Identifier: GNU Affero General Public License v3.0\n# Copyright 2023-present the Unsloth team. All rights reserved.\n\nimport argparse\nfrom contextlib import contextmanager\n\nimport pytest\nimport torch\nfrom transformers import AutoConfig\nfrom transformers.models.qwen3_moe import Qwen3MoeConfig\nfrom transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock\n\nfrom grouped_gemm.kernels.tuning import (\n    KernelConfigBackward_dW,\n    KernelConfigBackward_dX,\n    KernelConfigForward,\n)\nfrom grouped_gemm.reference.layers.qwen3_moe import Qwen3MoeGroupedGEMMBlock\n\nfrom .moe_utils import (\n    Qwen3MoeFusedGroupedGEMMBlock,\n    check_fwd,\n    check_grads,\n    check_grouped_gemm_results,\n    run_backward,\n    run_forward,\n)\n\n\"\"\"\nQwen3 MoE tests\n\nNOTE: Test this as a module and NOT with pytest as running with pytest results in random numerical errors: python -m tests.test_qwen3_moe --permute_x --permute_y --autotune NOT pytest -sv tests/test_qwen3_moe.py\nMore specifically, all tests pass when run individually, but some will fail randomly (even with the same seed) when the entire test is run as a parametrized test suite using pytest, likely due to how pytest interacts with triton / autotuning.\n\nSee tests/run_qwen3_moe_tests.sh for a script that runs all the tests\n\nThe tests run the following:\nHuggingface's Qwen3 MoE block (Qwen3MoeSparseMoeBlock)\nTorch-native grouped gemm version of MoE block (Qwen3MoeGroupedGEMMBlock), which is the HF block with the expert computation replaced with a torch-native grouped gemm\nTriton kernel grouped gemm version of MoE block (Qwen3MoeFusedGroupedGEMMBlock), which is the HF block with the expert computation replaced with the fused triton grouped gemm kernel\n\nThe tests check the following:\n- HF MoE block vs torch grouped gemm MoE block (sanity check)\n- torch grouped gemm MoE block vs fused grouped gemm MoE block -- this allows us to test each of the intermediate results for easier debugging\n- HF MoE block vs fused grouped gemm MoE block -- this is the actual test\n\nBoth forward and backward passes are tests:\n- forward: output of the moe block\n- backwards:\n    - X: gradient of the input to the moe block\n    - gate.weight: gradient of the gate weights (router weights)\n    - gate_proj: gradient of concatenated gate projections\n    - up_proj: gradient of the concatenated up projections\n    - down_proj: gradient of the concatenated down projections\n\nAdditionally, for the torch grouped gemm and triton grouped gemm versions, the intermediate outputs of the forward pass are checked:\n- first_gemm: output of the first grouped gemm (X @ fused_gate_proj)\n- intermediate: output of silu_mul(first_gemm)\n- second_gemm: output of the second grouped gemm (intermediate @ down_proj)\n- hidden_states_unpermute: output of the second_gemm after unpermuting back to token order (from expert grouped order); in the case where the permutation is fused in the triton kernel, this is the same as second_gemm\n- hidden_states: output with the topk_weights applied\n\"\"\"\n\nTOLERANCES = {\n    torch.bfloat16: (1e-2, 1e-2),\n    torch.float16: (1e-3, 1e-3),\n    torch.float: (1e-5, 1e-5),\n}\n\n\n@pytest.fixture(scope = \"module\")\ndef model_id():\n    return \"Qwen/Qwen3-30B-A3B\"\n\n\n@pytest.fixture(scope = \"module\")\ndef config(model_id: str):\n    return AutoConfig.from_pretrained(model_id)\n\n\n@contextmanager\ndef annotated_context(prelude, epilogue = \"Passed!\", char = \"-\", num_chars = 80):\n    print(char * num_chars)\n    print(prelude)\n    yield\n    print(epilogue)\n    print(char * num_chars)\n\n\nSEED = 42\nSEQ_LENS = [1024]\nDTYPES = [torch.bfloat16]\n\n# Reduce the number of autotuning configs to prevent excessive runtime\nNUM_AUTOTUNE_CONFIGS = 50\n\n\n@pytest.mark.parametrize(\n    \"permute_y\", [True], ids = lambda x: \"permute_y\" if x else \"no_permute_y\"\n)\n@pytest.mark.parametrize(\n    \"permute_x\", [True], ids = lambda x: \"permute_x\" if x else \"no_permute_x\"\n)\n@pytest.mark.parametrize(\n    \"autotune\", [True], ids = lambda x: \"autotune\" if x else \"manual\"\n)\n@pytest.mark.parametrize(\"seqlen\", SEQ_LENS, ids = lambda x: f\"seqlen={x}\")\n@pytest.mark.parametrize(\"dtype\", DTYPES, ids = str)\ndef test_qwen3_moe(\n    config: Qwen3MoeConfig,\n    seqlen: int,\n    dtype: torch.dtype,\n    permute_x: bool,\n    permute_y: bool,\n    autotune: bool,\n):\n    torch.manual_seed(\n        SEED\n    )  # Should not be needed when running using pytest -- autouse fixture in conftest.py\n    device = \"cuda\"\n    hidden_size = config.hidden_size\n    bs = 1\n    atol, rtol = TOLERANCES[dtype]\n    # Reference op -- HF\n    moe_block = Qwen3MoeSparseMoeBlock(config).to(device, dtype)\n\n    # Torch-native grouped gemm version of MoE Block -- for sanity checking\n    grouped_gemm_block = Qwen3MoeGroupedGEMMBlock.from_hf(moe_block).to(device, dtype)\n    grouped_gemm_block.check_weights(moe_block)\n\n    if not autotune:\n        kernel_config_fwd = KernelConfigForward()\n        kernel_config_bwd_dW = KernelConfigBackward_dW()\n        kernel_config_bwd_dX = KernelConfigBackward_dX()\n    else:\n        from grouped_gemm.kernels.backward import (\n            _autotuned_grouped_gemm_dW_kernel,\n            _autotuned_grouped_gemm_dX_kernel,\n        )\n        from grouped_gemm.kernels.forward import _autotuned_grouped_gemm_forward_kernel\n\n        # Hack to reduce number of autotuning configs\n        _autotuned_grouped_gemm_forward_kernel.configs = (\n            _autotuned_grouped_gemm_forward_kernel.configs[:NUM_AUTOTUNE_CONFIGS]\n        )\n        _autotuned_grouped_gemm_dW_kernel.configs = (\n            _autotuned_grouped_gemm_dW_kernel.configs[:NUM_AUTOTUNE_CONFIGS]\n        )\n        _autotuned_grouped_gemm_dX_kernel.configs = (\n            _autotuned_grouped_gemm_dX_kernel.configs[:NUM_AUTOTUNE_CONFIGS]\n        )\n\n        kernel_config_fwd = None\n        kernel_config_bwd_dW = None\n        kernel_config_bwd_dX = None\n\n    # Triton kernel grouped gemm version of MoE Block -- this is what we're testing\n    fused_gemm_block = Qwen3MoeFusedGroupedGEMMBlock.from_hf(\n        moe_block,\n        permute_x = permute_x,\n        permute_y = permute_y,\n        autotune = autotune,\n        kernel_config_fwd = kernel_config_fwd,\n        kernel_config_bwd_dW = kernel_config_bwd_dW,\n        kernel_config_bwd_dX = kernel_config_bwd_dX,\n    ).to(device, dtype)\n    fused_gemm_block.check_weights(moe_block)\n\n    X = torch.randn(\n        bs, seqlen, hidden_size, dtype = dtype, device = device, requires_grad = True\n    )\n\n    # Forward\n    ref_result = run_forward(moe_block, X, is_grouped_gemm = False)\n    grouped_result = run_forward(grouped_gemm_block, X, is_grouped_gemm = True)\n    fused_result = run_forward(fused_gemm_block, X, is_grouped_gemm = True)\n\n    with annotated_context(\n        \"Testing forward pass\",\n        epilogue = \"Passed forward tests!\",\n        char = \"=\",\n        num_chars = 100,\n    ):\n        # Sanity checks\n\n        with annotated_context(\n            \"Checking HF vs torch grouped gemm MoE forward outputs...\"\n        ):\n            check_fwd(ref_result, grouped_result, atol, rtol, verbose = False)\n\n        with annotated_context(\n            \"Checking torch grouped gemm MoE vs fused grouped gemm MoE forward outputs...\"\n        ):\n            # We implement a custom check for grouped gemm results to test each of the intermediate results for easier debugging\n            check_grouped_gemm_results(\n                grouped_result.grouped_gemm_result,\n                fused_result.grouped_gemm_result,\n                permute_y = permute_y,\n                atol = atol,\n                rtol = rtol,\n                verbose = False,\n            )\n        # Actual test\n        with annotated_context(\n            \"Checking HF vs fused grouped gemm MoE forward outputs...\"\n        ):\n            check_fwd(ref_result, fused_result, atol, rtol, verbose = True)\n\n    # Backward\n    grad_output = torch.randn_like(ref_result.output)\n    ref_backward_result = run_backward(\n        moe_block, grad_output, output = ref_result.output, X = ref_result.X\n    )\n    grouped_backward_result = run_backward(\n        grouped_gemm_block,\n        grad_output,\n        output = grouped_result.output,\n        X = grouped_result.X,\n    )\n    fused_backward_result = run_backward(\n        fused_gemm_block, grad_output, output = fused_result.output, X = fused_result.X\n    )\n\n    with annotated_context(\n        \"Testing backward pass\",\n        epilogue = \"Passed backward tests!\",\n        char = \"=\",\n        num_chars = 100,\n    ):\n        # Sanity checks\n        with annotated_context(\"Checking HF vs torch grouped gemm MoE grads...\"):\n            check_grads(\n                ref_backward_result, grouped_backward_result, atol, rtol, verbose = False\n            )\n        with annotated_context(\n            \"Checking torch grouped gemm MoE vs fused grouped gemm MoE grads...\"\n        ):\n            check_grads(\n                grouped_backward_result,\n                fused_backward_result,\n                atol,\n                rtol,\n                verbose = False,\n            )\n\n        # Actual test\n        with annotated_context(\"Checking HF vs fused grouped gemm MoE grads...\"):\n            check_grads(\n                ref_backward_result, fused_backward_result, atol, rtol, verbose = True\n            )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--seqlen\", type = int, default = 1024)\n    parser.add_argument(\n        \"--dtype\", type = str, choices = [\"bfloat16\", \"float16\"], default = \"bfloat16\"\n    )\n    parser.add_argument(\"--permute_x\", action = \"store_true\")\n    parser.add_argument(\"--permute_y\", action = \"store_true\")\n    parser.add_argument(\"--autotune\", action = \"store_true\")\n    args = parser.parse_args()\n    args.dtype = getattr(torch, args.dtype)\n    args_dict = vars(args)\n\n    model_id = \"Qwen/Qwen3-30B-A3B\"\n    config = AutoConfig.from_pretrained(model_id)\n    atol, rtol = TOLERANCES[args.dtype]\n\n    print(\n        f\"Testing {model_id} with seqlen={args.seqlen}, dtype={args.dtype}, permute_x={args.permute_x}, permute_y={args.permute_y}, autotune={args.autotune}, atol={atol}, rtol={rtol}\"\n    )\n    test_qwen3_moe(config, **args_dict)\n"
  },
  {
    "path": "unsloth/kernels/rms_layernorm.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport triton\nimport triton.language as tl\nimport torch\nfrom .utils import calculate_settings, torch_gpu_device\n\n\n@triton.jit\ndef _rms_layernorm_forward(\n    Y,\n    Y_row_stride: tl.constexpr,\n    X,\n    X_row_stride: tl.constexpr,\n    W,\n    W_row_stride: tl.constexpr,\n    r,\n    r_row_stride: tl.constexpr,\n    n_cols: tl.constexpr,\n    eps: tl.constexpr,\n    BLOCK_SIZE: tl.constexpr,\n):\n    \"\"\"\n    Fast RMS Layernorm kernel\n    Inspiration from a Triton tutorial:\n    https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html\n    \"\"\"\n    row_idx = tl.program_id(0)\n    col_offsets = tl.arange(0, BLOCK_SIZE)\n    mask = col_offsets < n_cols\n\n    Y += row_idx * Y_row_stride\n    X += row_idx * X_row_stride\n    r += row_idx * r_row_stride\n\n    X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)\n    W_row = tl.load(W + col_offsets, mask = mask, other = 0)  # .to(tl.float32)\n\n    row_var = tl.sum(X_row * X_row, axis = 0) / n_cols\n    # Explicit float32 scalar to ensure correct type promotion on HIP/ROCm\n    eps_f32 = tl.full((), eps, tl.float32)\n    inv_var = tl.math.rsqrt(row_var + eps_f32)\n    tl.store(r, inv_var)\n    normed = X_row * inv_var\n    normed = normed.to(W_row.dtype)  # Exact copy from HF\n    output = normed * W_row\n    tl.store(Y + col_offsets, output, mask = mask)\n\n\ndef _rms_layernorm_backward(\n    dY,\n    dY_row_stride: tl.constexpr,\n    dX,\n    dX_row_stride: tl.constexpr,\n    X,\n    X_row_stride: tl.constexpr,\n    W,\n    W_row_stride: tl.constexpr,\n    r,\n    r_row_stride: tl.constexpr,\n    # dW, dW_row_stride,\n    n_cols: tl.constexpr,\n    eps: tl.constexpr,\n    GEMMA: tl.constexpr,\n    BLOCK_SIZE: tl.constexpr,\n):\n    \"\"\"\n    Fast RMS Layernorm kernel for the backward pass\n    Inspiration from a Triton tutorial:\n    https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html\n    \"\"\"\n    row_idx = tl.program_id(0)\n    col_offsets = tl.arange(0, BLOCK_SIZE)\n    mask = col_offsets < n_cols\n\n    dY += row_idx * dY_row_stride\n    X += row_idx * X_row_stride\n    r += row_idx * r_row_stride\n\n    if GEMMA:\n        dX += row_idx * dY_row_stride\n    else:\n        dX = dY\n\n    dY_row = tl.load(dY + col_offsets, mask = mask, other = 0).to(tl.float32)\n    X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)\n    W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)\n\n    # Get saved row variance\n    inv_var = tl.load(r).to(tl.float32)\n    normed = X_row * inv_var\n\n    if GEMMA:\n        dY_W = dY_row * (W_row + 1.0)\n    else:\n        dY_W = dY_row * W_row\n\n    rowsum_dY_normed = tl.sum(dY_W * normed, axis = 0)\n    output = inv_var / n_cols * (n_cols * dY_W - normed * rowsum_dY_normed)\n    tl.store(dX + col_offsets, output, mask = mask)\n\n\n_rms_layernorm_backward = triton.jit(_rms_layernorm_backward)\n_rms_layernorm_backward = triton.heuristics(\n    {\n        \"GEMMA\": lambda args: bool(args[\"GEMMA\"]),\n    }\n)(_rms_layernorm_backward)\n\n\n@triton.jit\ndef _gemma_rms_layernorm_forward(\n    Y,\n    Y_row_stride: tl.constexpr,\n    X,\n    X_row_stride: tl.constexpr,\n    W,\n    W_row_stride: tl.constexpr,\n    r,\n    r_row_stride: tl.constexpr,\n    n_cols: tl.constexpr,\n    eps: tl.constexpr,\n    BLOCK_SIZE: tl.constexpr,\n):\n    # Copies https://github.com/google-deepmind/gemma/blob/main/gemma/layers.py#L31\n    # and https://github.com/keras-team/keras-nlp/blob/v0.8.2/keras_nlp/models/gemma/rms_normalization.py#L33\n    # exactly. Essentially all in float32!\n    row_idx = tl.program_id(0)\n    col_offsets = tl.arange(0, BLOCK_SIZE)\n    mask = col_offsets < n_cols\n\n    Y += row_idx * Y_row_stride\n    X += row_idx * X_row_stride\n    r += row_idx * r_row_stride\n\n    X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)\n    W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)\n\n    row_var = tl.sum(X_row * X_row, axis = 0) / n_cols\n    # Explicit float32 scalar to ensure correct type promotion on HIP/ROCm\n    eps_f32 = tl.full((), eps, tl.float32)\n    inv_var = tl.math.rsqrt(row_var + eps_f32)\n    tl.store(r, inv_var)\n    normed = X_row * inv_var\n    output = normed * (W_row + 1.0)\n\n    tl.store(Y + col_offsets, output, mask = mask)\n\n\nclass Fast_RMS_Layernorm(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, X: torch.Tensor, W: torch.Tensor, eps: float, gemma: bool = False):\n        shape = X.shape\n        dim: int = shape[-1]\n        X = X.reshape(-1, dim)\n        n_rows: int\n        n_cols: int\n        n_rows, n_cols = X.shape\n        BLOCK_SIZE: int\n        num_warps: int\n        BLOCK_SIZE, num_warps = calculate_settings(n_cols)\n        device = X.device\n\n        Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = device)\n        r = torch.empty(n_rows, dtype = torch.float32, device = device)\n\n        fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward\n        with torch_gpu_device(device):\n            fx[(n_rows,)](\n                Y,\n                Y.stride(0),\n                X,\n                X.stride(0),\n                W,\n                W.stride(0),\n                r,\n                r.stride(0),\n                n_cols,\n                eps,\n                BLOCK_SIZE = BLOCK_SIZE,\n                num_warps = num_warps,\n            )\n        ctx.eps = eps\n        ctx.BLOCK_SIZE = BLOCK_SIZE\n        ctx.num_warps = num_warps\n        ctx.GEMMA = gemma\n        ctx.save_for_backward(X, W, r)\n        return Y.view(*shape)\n\n    @staticmethod\n    def backward(ctx, dY: torch.Tensor):\n        shape = dY.shape\n        dim: int = shape[-1]\n        dY = dY.reshape(-1, dim)\n        X, W, r = ctx.saved_tensors\n        n_rows: int\n        n_cols: int\n        n_rows, n_cols = dY.shape\n        # dW = X\n        dX = torch.empty_like(dY) if ctx.GEMMA else dY\n\n        with torch_gpu_device(dY.device):\n            _rms_layernorm_backward[(n_rows,)](\n                dY,\n                dY.stride(0),\n                dX,\n                dX.stride(0),\n                X,\n                X.stride(0),\n                W,\n                W.stride(0),\n                r,\n                r.stride(0),\n                # dW, dW.stride(0),\n                n_cols,\n                ctx.eps,\n                GEMMA = ctx.GEMMA,\n                BLOCK_SIZE = ctx.BLOCK_SIZE,\n                num_warps = ctx.num_warps,\n            )\n        dX = dX.view(*shape)\n        return dX, None, None, None\n\n\n# [TODO] Unsure why RMS Layernorm is not torch.compiling properly\n@torch.compiler.disable\ndef fast_rms_layernorm(layernorm, X: torch.Tensor, gemma: bool = False):\n    W: torch.Tensor = layernorm.weight\n    eps: float = (\n        layernorm.variance_epsilon\n        if hasattr(layernorm, \"variance_epsilon\")\n        else layernorm.eps\n    )\n    out = Fast_RMS_Layernorm.apply(X, W, eps, gemma)\n    return out\n\n\nfrom transformers.models.llama.modeling_llama import LlamaRMSNorm\n\n\nclass Unsloth_LlamaRMSNorm(LlamaRMSNorm):\n    def forward(self, X):\n        return fast_rms_layernorm(self, X, gemma = False)\n\n\ntry:\n    from transformers.models.mllama.modeling_mllama import MllamaTextRMSNorm\n\n    class Unsloth_MllamaTextRMSNorm(MllamaTextRMSNorm):\n        def forward(self, X):\n            return fast_rms_layernorm(self, X, gemma = False)\n\n\nexcept:\n    pass\n\n\ndef patch_rms_layernorm():\n    import transformers.models.llama.modeling_llama\n\n    transformers.models.llama.modeling_llama.LlamaRMSNorm = Unsloth_LlamaRMSNorm\n    try:\n        import transformers.models.mllama.modeling_mllama\n\n        transformers.models.mllama.modeling_mllama.MllamaTextRMSNorm = (\n            Unsloth_MllamaTextRMSNorm\n        )\n    except:\n        pass\n    return\n\n\ndef unpatch_rms_layernorm():\n    import transformers.models.llama.modeling_llama\n\n    transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm\n    try:\n        import transformers.models.mllama.modeling_mllama\n\n        transformers.models.mllama.modeling_mllama.MllamaTextRMSNorm = MllamaTextRMSNorm\n    except:\n        pass\n    return\n\n\ndef test_rms_layernorm(\n    dim = 1024,\n    eps = 1e-5,\n    dtype = torch.float16,\n    bsz = 21,\n    random_state = 3407,\n    seqlen = 3341,\n):\n    from transformers.models.llama.modeling_llama import LlamaRMSNorm\n\n    layernorm = LlamaRMSNorm((dim,), eps = eps).to(\"cuda\")\n    torch.cuda.manual_seed(random_state)\n    torch.manual_seed(random_state)\n    torch.nn.init.uniform_(layernorm.weight)\n    X = torch.randn((bsz, seqlen, dim), dtype = dtype, device = \"cuda\")\n    XX = X.clone()\n    X.requires_grad_(True)\n    XX.requires_grad_(True)\n    Y = layernorm(X)\n    YY = torch.randn((bsz, seqlen, dim), dtype = dtype, device = \"cuda\", requires_grad = True)\n    Y.backward(YY)\n    correct_grad = X.grad.clone()\n    # from unsloth.kernels import fast_rms_layernorm\n    Y = fast_rms_layernorm(layernorm, XX)\n    Y.backward(YY)\n    assert torch.amax(correct_grad - XX.grad).item() <= 0.05\n\n\ndef testing_suite_layernorm():\n    for dim in [512, 1024, 2048]:\n        for dtype in [torch.float16, torch.bfloat16]:\n            with torch.autocast(device_type = \"cuda\", dtype = dtype):\n                for seqlen in [3341, 2048, 349]:\n                    for random_state in [3407, 42]:\n                        test_rms_layernorm(\n                            dim = dim,\n                            eps = 1e-5,\n                            dtype = dtype,\n                            bsz = 21,\n                            random_state = random_state,\n                            seqlen = seqlen,\n                        )\n"
  },
  {
    "path": "unsloth/kernels/rope_embedding.py",
    "content": "# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.\n#\n# This program is free software: you can redistribute it and/or modify\n# it under the terms of the GNU Lesser General Public License as published by\n# the Free Software Foundation, either version 3 of the License, or\n# (at your option) any later version.\n#\n# This program is distributed in the hope that it will be useful,\n# but WITHOUT ANY WARRANTY; without even the implied warranty of\n# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the\n# GNU General Public License for more details.\n#\n# You should have received a copy of the GNU Lesser General Public License\n# along with this program.  If not, see <https://www.gnu.org/licenses/>.\n\nimport triton\nimport triton.language as tl\nimport torch\nfrom ..device_type import DEVICE_COUNT\nfrom .utils import calculate_settings, torch_gpu_device, torch_device_stream\n\n\ndef _rope_embedding_QK(\n    Q,\n    Q_batch_stride,\n    Q_head_stride,\n    Q_seq_stride,\n    K,\n    K_batch_stride,\n    K_head_stride,\n    K_seq_stride,\n    cos,\n    cos_row_stride,\n    sin,\n    sin_row_stride,\n    rope_embedding_indices,\n    seqlen,\n    head_dim: tl.constexpr,\n    n_heads_K: tl.constexpr,\n    BACKWARD_PASS: tl.constexpr,\n    HAS_ROPE_INDICES: tl.constexpr,\n    BLOCK_SIZE: tl.constexpr,\n):\n    row_position = tl.program_id(0)\n    head_position = tl.program_id(1)\n    col_offsets = tl.arange(0, BLOCK_SIZE)\n    half_head_dim = head_dim // 2\n    mask = col_offsets < half_head_dim\n\n    if HAS_ROPE_INDICES:\n        rot_position = tl.load(\n            rope_embedding_indices + row_position,\n            eviction_policy = \"evict_first\",\n        ).to(tl.int32)\n    else:\n        rot_position = row_position % seqlen\n\n    cos_ptr = cos + rot_position * cos_row_stride\n    sin_ptr = sin + rot_position * sin_row_stride\n    sin1 = tl.load(\n        sin_ptr + col_offsets,\n        mask = mask,\n        other = 0,\n    )\n    cos1 = tl.load(\n        cos_ptr + col_offsets,\n        mask = mask,\n        other = 0,\n    )\n    if BACKWARD_PASS:\n        sin1 = -sin1\n\n    batch_id = row_position // seqlen\n    seq_index = row_position - batch_id * seqlen\n\n    q_ptr = (\n        Q\n        + batch_id * Q_batch_stride\n        + head_position * Q_head_stride\n        + seq_index * Q_seq_stride\n    )\n    q0 = tl.load(q_ptr + col_offsets, mask = mask, other = 0)\n    q1 = tl.load(q_ptr + half_head_dim + col_offsets, mask = mask, other = 0)\n    tl.store(q_ptr + col_offsets, q0 * cos1 - q1 * sin1, mask = mask)\n    tl.store(q_ptr + half_head_dim + col_offsets, q1 * cos1 + q0 * sin1, mask = mask)\n\n    if head_position < n_heads_K:\n        k_ptr = (\n            K\n            + batch_id * K_batch_stride\n            + head_position * K_head_stride\n            + seq_index * K_seq_stride\n        )\n        k0 = tl.load(k_ptr + col_offsets, mask = mask, other = 0)\n        k1 = tl.load(k_ptr + half_head_dim + col_offsets, mask = mask, other = 0)\n        tl.store(k_ptr + col_offsets, k0 * cos1 - k1 * sin1, mask = mask)\n        tl.store(k_ptr + half_head_dim + col_offsets, k1 * cos1 + k0 * sin1, mask = mask)\n\n\n_rope_embedding_QK = triton.jit(_rope_embedding_QK)\n_rope_embedding_QK = triton.heuristics(\n    {\n        \"BACKWARD_PASS\": lambda args: bool(args[\"BACKWARD_PASS\"]),\n        \"HAS_ROPE_INDICES\": lambda args: bool(args[\"HAS_ROPE_INDICES\"]),\n    }\n)(_rope_embedding_QK)\n\n\nROPE_GROUP_SIZE: int = 4\n\n\ndef _rope_embedding(\n    Q,\n    Q_row_stride: tl.constexpr,\n    cos,\n    cos_row_stride: tl.constexpr,\n    sin,\n    sin_row_stride: tl.constexpr,\n    seqlen,\n    head_dim: tl.constexpr,\n    n_heads: tl.constexpr,\n    BACKWARD_PASS: tl.constexpr,\n    BLOCK_SIZE: tl.constexpr,\n):\n    \"\"\"\n    Calculates the RoPE Embedding quickly\n    RoPE is Q * cos + rotate_half(Q) * sin\n    See our blog post for more info\n    \"\"\"\n    ROPE_GROUP_SIZE = 4\n    row_position = tl.program_id(0)\n    group_head_position = tl.program_id(1)\n    col_offsets = tl.arange(0, BLOCK_SIZE)\n    half_head_dim = head_dim // 2\n    mask = col_offsets < half_head_dim\n\n    sin1 = tl.load(\n        sin\n        + (row_position % seqlen) * sin_row_stride\n        + half_head_dim * 0\n        + col_offsets,\n        mask = mask,\n        other = 0,\n    )\n    cos1 = tl.load(\n        cos\n        + (row_position % seqlen) * cos_row_stride\n        + half_head_dim * 0\n        + col_offsets,\n        mask = mask,\n        other = 0,\n    )\n\n    if BACKWARD_PASS:\n        # See our blog post for more info.\n        sin1 = -sin1\n\n    # [TODO] Autotune ROPE_GROUP_SIZE to be 1, 2, 4, 8\n    head_start = group_head_position * ROPE_GROUP_SIZE\n    head_end = min((head_start + ROPE_GROUP_SIZE), n_heads)\n\n    # 10% Faster kernel from [HuyNguyen-hust](https://github.com/unslothai/unsloth/pull/238)\n    for k in range(head_start, head_end):\n        offs_q1 = row_position * Q_row_stride + k * head_dim + col_offsets\n        offs_q2 = (\n            row_position * Q_row_stride + k * head_dim + col_offsets + half_head_dim\n        )\n\n        # For Gemma - sometimes RoPE must be done in float32 and not bfloat16\n        Q1 = tl.load(Q + offs_q1, mask = mask, other = 0).to(sin1.dtype)\n        Q2 = tl.load(Q + offs_q2, mask = mask, other = 0).to(sin1.dtype)\n\n        tl.store(Q + offs_q1, Q1 * cos1 - Q2 * sin1, mask = mask)\n        tl.store(Q + offs_q2, Q2 * cos1 + Q1 * sin1, mask = mask)\n\n\n_rope_embedding = triton.jit(_rope_embedding)\n_rope_embedding = triton.heuristics(\n    {\n        \"BACKWARD_PASS\": lambda args: bool(args[\"BACKWARD_PASS\"]),\n    }\n)(_rope_embedding)\n\n\nclass Fast_RoPE_Embedding(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, Q, cos, sin):\n        cos, sin = cos.squeeze(), sin.squeeze()\n        batch: int\n        seq_len: int\n        n_heads: int\n        head_dim: int\n        batch, seq_len, n_heads, head_dim = Q.shape\n        Q = Q.reshape(batch * seq_len, n_heads * head_dim)\n        n_rows: int\n        n_cols: int\n        n_rows, n_cols = Q.shape\n        assert seq_len <= cos.shape[0]\n\n        # [TODO] Changing blocksize to head_dim//2 seems to have\n        # some concurrency / un-deterministic issues.\n        BLOCK_SIZE, num_warps = calculate_settings(head_dim // 2)  # (head_dim//2)\n\n        # group_size = 4 # 4 or 8, too large group_size can hurt performance.\n        div: int\n        mod: int\n        div, mod = divmod(n_heads, ROPE_GROUP_SIZE)\n        n_groups: int = div + (mod != 0)\n\n        with torch_gpu_device(Q.device):\n            _rope_embedding[\n                (\n                    n_rows,\n                    n_groups,\n                )\n            ](\n                Q,\n                Q.stride(0),\n                cos,\n                cos.stride(0),\n                sin,\n                sin.stride(0),\n                seq_len,\n                head_dim,\n                n_heads,\n                BACKWARD_PASS = False,\n                BLOCK_SIZE = BLOCK_SIZE,\n                num_warps = num_warps,\n            )\n        ctx.BLOCK_SIZE = BLOCK_SIZE\n        ctx.num_warps = num_warps\n        ctx.n_groups = n_groups\n        ctx.cos = cos\n        ctx.sin = sin\n        return Q.reshape(batch, seq_len, n_heads, head_dim)\n\n    @staticmethod\n    def backward(ctx, dY):\n        batch: int\n        seq_len: int\n        n_heads: int\n        head_dim: int\n        batch, seq_len, n_heads, head_dim = dY.shape\n        dY = dY.reshape(batch * seq_len, n_heads * head_dim)\n        n_rows: int\n        n_cols: int\n        n_rows, n_cols = dY.shape\n\n        cos = ctx.cos\n        sin = ctx.sin\n\n        with torch_gpu_device(dY.device):\n            _rope_embedding[\n                (\n                    n_rows,\n                    ctx.n_groups,\n                )\n            ](\n                dY,\n                dY.stride(0),\n                cos,\n                cos.stride(0),\n                sin,\n                sin.stride(0),\n                seq_len,\n                head_dim,\n                n_heads,\n                BACKWARD_PASS = True,\n                BLOCK_SIZE = ctx.BLOCK_SIZE,\n                num_warps = ctx.num_warps,\n            )\n        dY = dY.reshape(batch, seq_len, n_heads, head_dim)\n        return (\n            dY,\n            None,\n            None,\n        )\n\n\n# [TODO] Unsure why RoPE Embedding is not torch.compiling properly\n@torch.compiler.disable\ndef fast_rope_embedding(\n    Q,\n    K,\n    cos,\n    sin,\n    rope_embedding_indices = None,\n):\n    if rope_embedding_indices is not None:\n        Q_out, K_out = Fast_RoPE_Embedding_QK.apply(\n            Q, K, cos, sin, rope_embedding_indices\n        )\n    else:\n        Q_out = Fast_RoPE_Embedding.apply(\n            Q.transpose(1, 2).contiguous(), cos, sin\n        ).transpose(1, 2)\n        K_out = Fast_RoPE_Embedding.apply(\n            K.transpose(1, 2).contiguous(), cos, sin\n        ).transpose(1, 2)\n    if DEVICE_COUNT > 1:\n        torch_device_stream(Q.device).synchronize()\n    return Q_out, K_out\n\n\nclass Fast_RoPE_Embedding_QK(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, Q, K, cos, sin, rope_indices):\n        has_indices = rope_indices is not None\n        cos, sin = cos.squeeze(), sin.squeeze()\n\n        batch, n_heads_Q, seq_len, head_dim = Q.shape\n        _, n_heads_K, _, _ = K.shape\n\n        # Inplace rotary embedding is generally fine\n        Q_out = Q.clone() if not Q.is_contiguous() else Q\n        K_out = K.clone() if not K.is_contiguous() else K\n\n        if has_indices:\n            # TRL's rotary indices are always in int32, so casting is just for safety\n            rope_ptr = rope_indices.reshape(-1).to(dtype = torch.int32, device = Q.device)\n        else:\n            rope_ptr = cos.new_empty(1, dtype = torch.int32)\n\n        BLOCK_SIZE, num_warps = calculate_settings(head_dim)\n\n        Q_batch_stride, Q_head_stride, Q_seq_stride = (\n            Q_out.stride(0),\n            Q_out.stride(1),\n            Q_out.stride(2),\n        )\n        K_batch_stride, K_head_stride, K_seq_stride = (\n            K_out.stride(0),\n            K_out.stride(1),\n            K_out.stride(2),\n        )\n\n        with torch_gpu_device(Q.device):\n            _rope_embedding_QK[(batch * seq_len, n_heads_Q)](\n                Q_out,\n                Q_batch_stride,\n                Q_head_stride,\n                Q_seq_stride,\n                K_out,\n                K_batch_stride,\n                K_head_stride,\n                K_seq_stride,\n                cos,\n                cos.stride(0),\n                sin,\n                sin.stride(0),\n                rope_ptr,\n                seq_len,\n                head_dim = head_dim,\n                n_heads_K = n_heads_K,\n                BACKWARD_PASS = False,\n                HAS_ROPE_INDICES = has_indices,\n                BLOCK_SIZE = BLOCK_SIZE,\n                num_warps = num_warps,\n            )\n\n        ctx.block_size = BLOCK_SIZE\n        ctx.num_warps = num_warps\n        ctx.has_indices = has_indices\n        ctx.cos = cos\n        ctx.sin = sin\n        ctx.rope_indices = rope_ptr if has_indices else None\n        ctx.seq_len = seq_len\n        ctx.n_heads_Q = n_heads_Q\n        ctx.n_heads_K = n_heads_K\n\n        return (\n            Q_out,\n            K_out,\n        )\n\n    @staticmethod\n    def backward(ctx, dQ, dK):\n        batch, _, _, head_dim = dQ.shape\n\n        rope_ptr = (\n            ctx.rope_indices\n            if ctx.has_indices\n            else ctx.cos.new_empty(1, dtype = torch.int32)\n        )\n\n        # Inplace rotary embedding is generally fine\n        dQ_out = dQ.clone() if not dQ.is_contiguous() else dQ\n        dK_out = dK.clone() if not dK.is_contiguous() else dK\n\n        Q_batch_stride, Q_head_stride, Q_seq_stride = (\n            dQ_out.stride(0),\n            dQ_out.stride(1),\n            dQ_out.stride(2),\n        )\n        K_batch_stride, K_head_stride, K_seq_stride = (\n            dK_out.stride(0),\n            dK_out.stride(1),\n            dK_out.stride(2),\n        )\n\n        with torch_gpu_device(dQ.device):\n            _rope_embedding_QK[(batch * ctx.seq_len, ctx.n_heads_Q)](\n                dQ_out,\n                Q_batch_stride,\n                Q_head_stride,\n                Q_seq_stride,\n                dK_out,\n                K_batch_stride,\n                K_head_stride,\n                K_seq_stride,\n                ctx.cos,\n                ctx.cos.stride(0),\n                ctx.sin,\n                ctx.sin.stride(0),\n                rope_ptr,\n                ctx.seq_len,\n                head_dim = head_dim,\n                n_heads_K = ctx.n_heads_K,\n                BACKWARD_PASS = True,\n                HAS_ROPE_INDICES = ctx.has_indices,\n                BLOCK_SIZE = ctx.block_size,\n                num_warps = ctx.num_warps,\n            )\n\n        return (dQ_out, dK_out, None, None, None)\n\n\nclass Slow_RoPE_Embedding(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, Q, cos, sin, position_ids):\n        if position_ids is not None:\n            # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.\n            cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]\n            sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]\n            cos = cos[position_ids].unsqueeze(2)  # [bs, seq_len, 1, dim]\n            sin = sin[position_ids].unsqueeze(2)  # [bs, seq_len, 1, dim]\n\n        # Q * cos + rotate_half(Q) * sin\n        half = Q.shape[-1] // 2\n        RH_Q = torch.cat((-Q[..., half:], Q[..., :half]), dim = -1)\n        Q *= cos\n        Q.addcmul_(RH_Q, sin)\n        # RH_Q *= sin\n        # Q += RH_Q\n        ctx.save_for_backward(cos, sin)\n        return Q\n\n    @staticmethod\n    def backward(ctx, dY):\n        cos, sin = ctx.saved_tensors\n        # Q * cos + rotate_half.T(Q) * sin\n        half = dY.shape[-1] // 2\n        RH_dY = torch.cat((dY[..., half:], -dY[..., :half]), dim = -1)\n        dY *= cos\n        dY.addcmul_(RH_dY, sin)\n        # RH_dY *= sin\n        # dY += RH_dY\n        return dY, None, None, None\n\n\ndef inplace_rope_embedding(Q, K, cos, sin, position_ids):\n    Q = Slow_RoPE_Embedding.apply(Q, cos, sin, position_ids)\n    K = Slow_RoPE_Embedding.apply(K, cos, sin, position_ids)\n    torch_device_stream(Q.device).synchronize()\n    return Q, K\n"
  },
  {
    "path": "unsloth/kernels/swiglu.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport triton\nimport triton.language as tl\nimport torch\nfrom .utils import calculate_settings, torch_gpu_device\n\n# signed int32 max is 2**31-1 so num_elements cannot exceed 2**31\nNUM_INT32_ELEMENTS = 2**31\nSAFE_INT32_BUFFER_MULTIPLIER = 4\nBLOCK_SIZE = 1024\nINT32_SAFETY_BUFFER = NUM_INT32_ELEMENTS - BLOCK_SIZE * SAFE_INT32_BUFFER_MULTIPLIER\n\n\n@triton.jit\ndef _fg_kernel(\n    e,\n    g,\n    h,\n    n_elements,\n    BLOCK_SIZE: tl.constexpr,\n    LONG_INDEXING: tl.constexpr,\n):\n    block_idx = tl.program_id(0)\n    if LONG_INDEXING:\n        offsets = block_idx.to(tl.int64) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE).to(\n            tl.int64\n        )\n        n_elements = tl.cast(n_elements, tl.int64)\n    else:\n        offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n    mask = offsets < n_elements\n\n    e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)\n    g_row = tl.load(g + offsets, mask = mask, other = 0)  # .to(tl.float32)\n\n    # f = e * sigmoid(e)\n    f_row = e_row * tl.sigmoid(e_row)  # e_row / (1 + tl.exp(-e_row))\n    f_row = f_row.to(g_row.dtype)  # Exact copy from HF\n    # h = f * g\n    h_row = f_row * g_row\n\n    # Store h\n    tl.store(h + offsets, h_row, mask = mask)\n\n\ndef swiglu_fg_kernel(e, g):\n    batch, seq_len, hd = e.shape\n    n_elements = e.numel()\n    h = torch.empty((batch, seq_len, hd), dtype = e.dtype, device = e.device)\n    grid = lambda meta: (triton.cdiv(n_elements, meta[\"BLOCK_SIZE\"]),)\n    with torch_gpu_device(e.device):\n        _fg_kernel[grid](\n            e,\n            g,\n            h,\n            n_elements,\n            BLOCK_SIZE = BLOCK_SIZE,\n            LONG_INDEXING = 0 if n_elements <= INT32_SAFETY_BUFFER else 1,\n        )\n    return h\n\n\n@triton.jit\ndef _DWf_DW_dfg_kernel(\n    DW,\n    e,\n    g,\n    n_elements,\n    BLOCK_SIZE: tl.constexpr,\n    LONG_INDEXING: tl.constexpr,\n):\n    \"\"\"\n    e = e.float()\n    se = 1.0 / (1.0 + torch.exp(-e))\n    f = (se * e).to(dtype)\n    h = f * g\n    df = DW * f\n    dg = DW * g\n    de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype)\n    \"\"\"\n    block_idx = tl.program_id(0)\n    if LONG_INDEXING:\n        offsets = block_idx.to(tl.int64) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE).to(\n            tl.int64\n        )\n        n_elements = tl.cast(n_elements, tl.int64)\n    else:\n        offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n    mask = offsets < n_elements\n\n    DW_row = tl.load(DW + offsets, mask = mask, other = 0)  # .to(tl.float32)\n    e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)\n    g_row = tl.load(g + offsets, mask = mask, other = 0)  # .to(tl.float32)\n\n    # e = e.float()\n    # se = 1.0 / (1.0 + torch.exp(-e))\n    se_row = tl.sigmoid(e_row)  # 1.0 / (1.0 + tl.exp(-e_row))\n    # f = (se * e).to(dtype)\n    f_row = se_row * e_row\n    f_row = f_row.to(DW_row.dtype)\n    # h = f * g\n    h_row = f_row * g_row\n    # df = DW * f\n    df_row = DW_row * f_row\n    # dg = DW * g\n    dg_row = DW_row * g_row\n    # de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype)\n    de_row = dg_row.to(tl.float32) * se_row * (1.0 + e_row * (1.0 - se_row))\n    de_row = de_row.to(DW_row.dtype)\n\n    # Store derivatives in buffers\n    tl.store(DW + offsets, h_row, mask = mask)  # h  = f * g\n    tl.store(e + offsets, df_row, mask = mask)  # df = DW * f\n    tl.store(g + offsets, de_row, mask = mask)  # de\n\n\ndef swiglu_DWf_DW_dfg_kernel(DW, e, g):\n    batch_seq_len, hd = e.shape  # Flattened to 2D, so 1st dim is bsz * seq_len\n    n_elements = e.numel()\n    grid = lambda meta: (triton.cdiv(n_elements, meta[\"BLOCK_SIZE\"]),)\n    with torch_gpu_device(e.device):\n        _DWf_DW_dfg_kernel[grid](\n            DW,\n            e,\n            g,\n            n_elements,\n            BLOCK_SIZE = BLOCK_SIZE,\n            LONG_INDEXING = 0 if n_elements <= INT32_SAFETY_BUFFER else 1,\n        )\n    return DW, e, g\n"
  },
  {
    "path": "unsloth/kernels/utils.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport importlib\nimport triton\nimport ctypes\n\nMAX_FUSED_SIZE: int = 65536\nnext_power_of_2 = triton.next_power_of_2\nimport functools\nfrom typing import Optional\n\nfrom ..device_type import (\n    is_hip,\n    get_device_type,\n    DEVICE_TYPE,\n    DEVICE_TYPE_TORCH,\n    DEVICE_COUNT,\n    ALLOW_PREQUANTIZED_MODELS,\n)\nfrom .fp8 import weight_dequant, fp8_linear\nimport functools\n\n# torch.cuda.amp.custom_fwd is deprecated >= 2.4\nimport torch\n\ntorch_Tensor = torch.Tensor\nfrom unsloth_zoo.utils import Version\n\nif DEVICE_TYPE == \"xpu\" and Version(torch.__version__) < Version(\"2.6.0\"):\n    raise RuntimeError(\n        \"Intel xpu currently supports unsloth with torch.version >= 2.6.0\"\n    )\n\nif Version(torch.__version__) < Version(\"2.4.0\"):\n    torch_amp_custom_fwd = torch.cuda.amp.custom_fwd\n    torch_amp_custom_bwd = torch.cuda.amp.custom_bwd\nelse:\n    torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = \"cuda\")\n    torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = \"cuda\")\n\nif DEVICE_TYPE == \"xpu\":\n    torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = \"xpu\")\n    torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = \"xpu\")\n\n\n# tl.math.tanh now is libdevice.tanh\nimport triton\nimport triton.language as tl\n\nif Version(triton.__version__) >= Version(\"3.0.0\"):\n    if DEVICE_TYPE == \"xpu\":\n        triton_tanh = tl.extra.intel.libdevice.tanh\n    else:\n        from triton.language.extra import libdevice\n\n        triton_tanh = libdevice.tanh\n    triton_cast = tl.cast\nelse:\n    triton_tanh = tl.math.tanh\n\n    # No casting in old Triton versions\n    @triton.jit\n    def triton_cast(x, dtype):\n        return x.to(dtype)\n\n\n@functools.lru_cache(1)\ndef is_cdna():\n    return is_hip() and triton.runtime.driver.active.get_current_target().arch in (\n        \"gfx940\",\n        \"gfx941\",\n        \"gfx942\",\n        \"gfx950\",  # CDNA4 (MI350/MI355X)\n    )\n\n\n@functools.lru_cache(1)\ndef is_rdna():\n    \"\"\"Detect ROCm-supported RDNA consumer/workstation GPUs (RDNA3, RDNA4).\"\"\"\n    return is_hip() and triton.runtime.driver.active.get_current_target().arch in (\n        \"gfx1100\",\n        \"gfx1101\",\n        \"gfx1200\",\n        \"gfx1201\",\n    )\n\n\ndef calculate_settings(\n    n: int,\n) -> (\n    int,\n    int,\n):\n    BLOCK_SIZE: int = next_power_of_2(n)\n    if BLOCK_SIZE > MAX_FUSED_SIZE:\n        raise RuntimeError(\n            f\"Cannot launch Triton kernel since n = {n} exceeds \"\n            f\"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.\"\n        )\n    num_warps: int = 4\n    if BLOCK_SIZE >= 32768:\n        num_warps = 32\n    elif BLOCK_SIZE >= 8192:\n        num_warps = 16\n    elif BLOCK_SIZE >= 2048:\n        num_warps = 8\n    return BLOCK_SIZE, num_warps\n\n\nHAS_CUDA_STREAM = False\nimport bitsandbytes as bnb\n\n# https://github.com/bitsandbytes-foundation/bitsandbytes/pull/1330/files\nHAS_CUDA_STREAM = Version(bnb.__version__) > Version(\"0.43.3\")\nget_ptr = bnb.functional.get_ptr\n\nif DEVICE_TYPE == \"xpu\":\n    HAS_XPU_STREAM = True\n\nif DEVICE_COUNT > 1:\n    if DEVICE_TYPE in (\"cuda\", \"hip\"):\n        torch_gpu_device = torch.cuda.device\n    elif DEVICE_TYPE == \"xpu\":\n        torch_gpu_device = torch.xpu.device\nelse:\n    from contextlib import nullcontext\n\n    def torch_gpu_device(device):\n        return nullcontext()\n\n\n# INTEL GPU Specific Logic\nif DEVICE_TYPE == \"xpu\":\n    _gpu_getCurrentRawStream = torch._C._xpu_getCurrentRawStream\n# NVIDIA GPU Default Logic\nelse:\n    _gpu_getCurrentRawStream = torch._C._cuda_getCurrentRawStream\n\nc_void_p = ctypes.c_void_p\n\n\ndef _get_tensor_stream(tensor: torch_Tensor) -> c_void_p:\n    return c_void_p(_gpu_getCurrentRawStream(tensor.device.index))\n\n\n# Get array of CUDA streams and other buffers\nglobal CUDA_STREAMS\nglobal XPU_STREAMS\nglobal WEIGHT_BUFFERS\nglobal ABSMAX_BUFFERS\n\n# INTEL GPU Specific Logic\nif DEVICE_TYPE == \"xpu\":\n    _XPU_STREAMS = {\n        (index := torch.xpu.device(i).idx): ctypes.c_void_p(\n            torch._C._xpu_getCurrentRawStream(index)\n        )\n        for i in range(DEVICE_COUNT)\n    }\n    XPU_STREAMS = [None] * (max(_XPU_STREAMS.keys()) + 1)\n    WEIGHT_BUFFERS = [None] * (max(_XPU_STREAMS.keys()) + 1)\n    ABSMAX_BUFFERS = [None] * (max(_XPU_STREAMS.keys()) + 1)\n    for k, v in _XPU_STREAMS.items():\n        XPU_STREAMS[k] = v\n    XPU_STREAMS = tuple(XPU_STREAMS)\n    del _XPU_STREAMS\nelse:\n    # NVIDIA GPU Default Logic\n    _CUDA_STREAMS = {\n        (index := torch.cuda.device(i).idx): ctypes.c_void_p(\n            torch._C._cuda_getCurrentRawStream(index)\n        )\n        for i in range(DEVICE_COUNT)\n    }\n    CUDA_STREAMS = [None] * (max(_CUDA_STREAMS.keys()) + 1)\n    WEIGHT_BUFFERS = [None] * (max(_CUDA_STREAMS.keys()) + 1)\n    ABSMAX_BUFFERS = [None] * (max(_CUDA_STREAMS.keys()) + 1)\n    for k, v in _CUDA_STREAMS.items():\n        CUDA_STREAMS[k] = v\n    CUDA_STREAMS = tuple(CUDA_STREAMS)\n    del _CUDA_STREAMS\n\n# Bitsandbytes operations\nctypes_c_int = ctypes.c_int\nctypes_c_int32 = ctypes.c_int32\ncdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32\ncdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4\ncdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4\n\nif DEVICE_TYPE == \"xpu\":\n    # https://github.com/bitsandbytes-foundation/bitsandbytes/blob/c3b8de268fdb55a88f92feada23fc811a1e6877a/bitsandbytes/backends/xpu/ops.py#L115\n    # for xpu, inference gemv using above link\n    cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemv_4bit_inference_fp16\n    cgemm_4bit_inference_naive_bf16 = bnb.functional.lib.cgemv_4bit_inference_bf16\nelse:\n    cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemm_4bit_inference_naive_fp16\n    cgemm_4bit_inference_naive_bf16 = bnb.functional.lib.cgemm_4bit_inference_naive_bf16\n\n\ntorch_device_stream = (\n    torch.xpu.current_stream if DEVICE_TYPE == \"xpu\" else torch.cuda.current_stream\n)\n\ntorch_mm = torch.mm\ntorch_mv = torch.mv\ntorch_matmul = torch.matmul\ntorch_addmm = torch.addmm\ntorch_empty = torch.empty\ntorch_float32 = torch.float32\ntorch_float16 = torch.float16\ntorch_bfloat16 = torch.bfloat16\n\n\n# Check whether torchao can be imported to get Float8Tensor\nif importlib.util.find_spec(\"torchao\") is not None:\n    try:\n        from torchao.quantization import Float8Tensor\n    except:\n        import torchao\n\n        if Version(torchao.__version__) >= Version(\"0.15.0\"):\n            print(\n                f\"Unsloth: `from torchao.quantization import Float8Tensor` failed on version={torchao.__version__}\"\n            )\n        Float8Tensor = type(None)\nelse:\n    Float8Tensor = type(None)\n\n\ndef QUANT_STATE(W):\n    return getattr(W, \"quant_state\", None)\n\n\ndef get_lora_parameters(proj):\n    \"\"\"\n    Return a 5-tuple of (weight, weight quant_state, lora A, lora B, and lora scale).\n    If QAT is enabled, additionally fake quantize the base layer and lora weights.\n    \"\"\"\n    # For DPO or disabled adapters\n    base_layer = getattr(\n        proj, \"base_layer\", proj\n    )  # (proj.base_layer if hasattr(proj, \"base_layer\") else proj)\n    W = base_layer.weight\n\n    # Optionally apply fake quantization to base layer weights for QAT\n    if hasattr(base_layer, \"weight_fake_quantizer\"):\n        weight_fake_quantizer = getattr(base_layer, \"weight_fake_quantizer\", None)\n        if weight_fake_quantizer is not None:\n            W = weight_fake_quantizer(W)\n\n    # Get quant state for 4bit or FP8\n    W_quant = getattr(W, \"quant_state\", None)\n    if W_quant is None:\n        W_quant = getattr(base_layer, \"weight_scale_inv\", None)\n        if W_quant is None:\n            W_quant = getattr(base_layer, \"weight_scale\", None)\n\n    if getattr(base_layer, \"quant_method\", None) == \"fp8\":\n        # we need to somehow store and pass this information :)\n        W.block_size = getattr(base_layer, \"block_size\", [128, 128])\n        W_quant.block_size = W.block_size\n\n    # if not hasattr(proj, \"disable_adapters\") or proj.disable_adapters or proj.merged:\n    if getattr(proj, \"disable_adapters\", True) or proj.merged:\n        return W, W_quant, None, None, None\n\n    adapter = getattr(proj, \"active_adapters\", None)\n    if adapter is None:\n        adapter = getattr(proj, \"active_adapter\", (\"default\"))\n    adapter = adapter[0]\n\n    # Optionally apply fake quantization to lora weights for QAT\n    lora_A_linear = proj.lora_A[adapter]\n    lora_B_linear = proj.lora_B[adapter]\n    A = lora_A_linear.weight\n    B = lora_B_linear.weight\n    if hasattr(lora_A_linear, \"weight_fake_quantizer\"):\n        lora_A_fake_quantizer = getattr(lora_A_linear, \"weight_fake_quantizer\", None)\n        if lora_A_fake_quantizer is not None:\n            A = lora_A_fake_quantizer(A)\n    if hasattr(lora_B_linear, \"weight_fake_quantizer\"):\n        lora_B_fake_quantizer = getattr(lora_B_linear, \"weight_fake_quantizer\", None)\n        if lora_B_fake_quantizer is not None:\n            B = lora_B_fake_quantizer(B)\n\n    return (\n        W,\n        W_quant,\n        A,\n        B,\n        proj.scaling[adapter],\n    )\n\n\ndef get_lora_parameters_bias(proj):\n    # For DPO or disabled adapters\n    base_layer = getattr(\n        proj, \"base_layer\", proj\n    )  # (proj.base_layer if hasattr(proj, \"base_layer\") else proj)\n    W = base_layer.weight\n\n    # Get quant state for 4bit or FP8\n    W_quant = getattr(W, \"quant_state\", None)\n    if W_quant is None:\n        W_quant = getattr(base_layer, \"weight_scale_inv\", None)\n        if W_quant is None:\n            W_quant = getattr(base_layer, \"weight_scale\", None)\n\n    # if not hasattr(proj, \"disable_adapters\") or proj.disable_adapters or proj.merged:\n    if getattr(proj, \"disable_adapters\", True) or proj.merged:\n        return W, W_quant, None, None, None, base_layer.bias\n\n    if getattr(base_layer, \"quant_method\", None) == \"fp8\":\n        # we need to somehow store and pass this information :)\n        W.block_size = getattr(base_layer, \"block_size\", [128, 128])\n        W_quant.block_size = W.block_size\n\n    adapter = getattr(proj, \"active_adapters\", None)\n    if adapter is None:\n        adapter = getattr(proj, \"active_adapter\", (\"default\"))\n    adapter = adapter[0]\n\n    return (\n        W,\n        W_quant,\n        proj.lora_A[adapter].weight,\n        proj.lora_B[adapter].weight,\n        proj.scaling[adapter],\n        base_layer.bias,\n    )\n\n\ndef _maybe_fake_quantize_activations(\n    X: torch.Tensor, proj: torch.nn.Module\n) -> torch.Tensor:\n    \"\"\"\n    If QAT is enabled, fake quantize the input activations.\n    Otherwise, just return the input activations as is.\n    Weights are fake quantized separately in `get_lora_parameters`.\n    \"\"\"\n    base_layer = getattr(proj, \"base_layer\", proj)\n    activation_fake_quantizer = getattr(base_layer, \"activation_fake_quantizer\", None)\n    if activation_fake_quantizer is not None:\n        X = activation_fake_quantizer(X)\n    return X\n\n\n# INTEL GPU Specific Logic\nif DEVICE_TYPE == \"xpu\" and HAS_XPU_STREAM:\n\n    @torch.inference_mode\n    def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):\n        # TODO: After adding XPU BNB support, check this function\n        if isinstance(W, Float8Tensor):\n            return W.dequantize()\n        if quant_state is None:\n            return W\n        if W.dtype == torch.float8_e4m3fn:\n            return weight_dequant(W, quant_state)\n        if type(quant_state) is not list:\n            # New quant_state as a class\n            # https://github.com/TimDettmers/bitsandbytes/pull/763/files\n            absmax = quant_state.absmax\n            shape = quant_state.shape\n            dtype = quant_state.dtype\n            blocksize = quant_state.blocksize\n            offset = quant_state.offset\n            state2 = quant_state.state2\n            absmax2 = state2.absmax\n            code2 = state2.code\n            blocksize2 = state2.blocksize\n        else:\n            # Old quant_state as a list of lists\n            absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state\n            offset, state2 = compressed_stats\n            absmax2, code2, blocksize2, _, _, _, _ = state2\n        global XPU_STREAMS\n        device = W.device\n        device_index = device.index\n        XPU_STREAM = XPU_STREAMS[device_index]\n\n        n_elements_absmax = absmax.numel()\n        # Create weight matrix\n        if use_global_buffer:\n            # Use same buffers for faster inference\n            size = shape[0] * shape[1]\n            global WEIGHT_BUFFERS\n            global ABSMAX_BUFFERS\n            WEIGHT_BUFFER = WEIGHT_BUFFERS[device_index]\n            ABSMAX_BUFFER = ABSMAX_BUFFERS[device_index]\n            if WEIGHT_BUFFER is None or WEIGHT_BUFFER.dtype != dtype:\n                WEIGHT_BUFFERS[device_index] = WEIGHT_BUFFER = torch_empty(\n                    size, dtype = dtype, device = device, requires_grad = False\n                )\n                ABSMAX_BUFFERS[device_index] = ABSMAX_BUFFER = torch_empty(\n                    n_elements_absmax,\n                    dtype = torch.float32,\n                    device = device,\n                    requires_grad = False,\n                )\n\n            if size > WEIGHT_BUFFER.numel():\n                WEIGHT_BUFFER.resize_(size)\n            if n_elements_absmax > ABSMAX_BUFFER.numel():\n                ABSMAX_BUFFER.resize_(n_elements_absmax)\n\n            out = WEIGHT_BUFFER[:size].view(shape)\n            out_absmax = ABSMAX_BUFFER[:n_elements_absmax]\n        else:\n            if out is None:\n                out = torch_empty(\n                    shape, dtype = dtype, device = device, requires_grad = False\n                )\n            else:\n                assert out.shape == shape\n                assert out.dtype == dtype\n            out_absmax = torch_empty(\n                n_elements_absmax,\n                dtype = torch_float32,\n                device = device,\n                requires_grad = False,\n            )\n\n        # NF4 dequantization of statistics\n        ptr_out_absmax = get_ptr(out_absmax)\n        with torch_gpu_device(device):\n            cdequantize_blockwise_fp32(\n                get_ptr(code2),\n                get_ptr(absmax),\n                get_ptr(absmax2),\n                ptr_out_absmax,\n                ctypes_c_int(blocksize2),\n                ctypes_c_int(n_elements_absmax),\n                XPU_STREAM,\n            )\n            out_absmax += offset\n\n            # Dequantize W\n            fx = (\n                cdequantize_blockwise_fp16_nf4\n                if dtype == torch_float16\n                else cdequantize_blockwise_bf16_nf4\n            )\n            fx(\n                get_ptr(None),\n                get_ptr(W),\n                ptr_out_absmax,\n                get_ptr(out),\n                ctypes_c_int(blocksize),\n                ctypes_c_int(out.numel()),\n                XPU_STREAM,\n            )\n        # Careful returning transposed data\n        is_transposed = True if W.shape[0] == 1 else False\n        return out.t() if is_transposed else out\n\n# NVIDIA GPU Default Logic\nelif DEVICE_TYPE in (\"cuda\", \"hip\") and HAS_CUDA_STREAM:\n\n    @torch.inference_mode\n    def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):\n        if isinstance(W, Float8Tensor):\n            return W.dequantize()\n        if quant_state is None:\n            return W\n        if W.dtype == torch.float8_e4m3fn:\n            return weight_dequant(W, quant_state)\n        if type(quant_state) is not list:\n            # New quant_state as a class\n            # https://github.com/TimDettmers/bitsandbytes/pull/763/files\n            absmax = quant_state.absmax\n            shape = quant_state.shape\n            dtype = quant_state.dtype\n            blocksize = quant_state.blocksize\n            offset = quant_state.offset\n            state2 = quant_state.state2\n            absmax2 = state2.absmax\n            code2 = state2.code\n            blocksize2 = state2.blocksize\n        else:\n            # Old quant_state as a list of lists\n            absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state\n            offset, state2 = compressed_stats\n            absmax2, code2, blocksize2, _, _, _, _ = state2\n        pass\n        global CUDA_STREAMS\n        device = W.device\n        device_index = device.index\n        CUDA_STREAM = CUDA_STREAMS[device_index]\n\n        n_elements_absmax = absmax.numel()\n\n        # Create weight matrix\n        if use_global_buffer:\n            # Use same buffers for faster inference\n            size = shape[0] * shape[1]\n            global WEIGHT_BUFFERS\n            global ABSMAX_BUFFERS\n            WEIGHT_BUFFER = WEIGHT_BUFFERS[device_index]\n            ABSMAX_BUFFER = ABSMAX_BUFFERS[device_index]\n            if WEIGHT_BUFFER is None or WEIGHT_BUFFER.dtype != dtype:\n                WEIGHT_BUFFERS[device_index] = WEIGHT_BUFFER = torch_empty(\n                    size, dtype = dtype, device = device, requires_grad = False\n                )\n                ABSMAX_BUFFERS[device_index] = ABSMAX_BUFFER = torch_empty(\n                    n_elements_absmax,\n                    dtype = torch_float32,\n                    device = device,\n                    requires_grad = False,\n                )\n\n            if size > WEIGHT_BUFFER.numel():\n                WEIGHT_BUFFER.resize_(size)\n            if n_elements_absmax > ABSMAX_BUFFER.numel():\n                ABSMAX_BUFFER.resize_(n_elements_absmax)\n\n            out = WEIGHT_BUFFER[:size].view(shape)\n            out_absmax = ABSMAX_BUFFER[:n_elements_absmax]\n        else:\n            if out is None:\n                out = torch_empty(\n                    shape, dtype = dtype, device = device, requires_grad = False\n                )\n            else:\n                assert out.shape == shape\n                assert out.dtype == dtype\n            out_absmax = torch_empty(\n                n_elements_absmax,\n                dtype = torch_float32,\n                device = device,\n                requires_grad = False,\n            )\n        pass\n\n        # NF4 dequantization of statistics\n        ptr_out_absmax = get_ptr(out_absmax)\n        with torch_gpu_device(device):\n            cdequantize_blockwise_fp32(\n                get_ptr(code2),\n                get_ptr(absmax),\n                get_ptr(absmax2),\n                ptr_out_absmax,\n                ctypes_c_int(blocksize2),\n                ctypes_c_int(n_elements_absmax),\n                CUDA_STREAM,\n            )\n            out_absmax += offset\n\n            # Dequantize W\n            fx = (\n                cdequantize_blockwise_fp16_nf4\n                if dtype == torch_float16\n                else cdequantize_blockwise_bf16_nf4\n            )\n            fx(\n                get_ptr(None),\n                get_ptr(W),\n                ptr_out_absmax,\n                get_ptr(out),\n                ctypes_c_int(blocksize),\n                ctypes_c_int(out.numel()),\n                CUDA_STREAM,\n            )\n        pass\n        # Careful returning transposed data\n        is_transposed = True if W.shape[0] == 1 else False\n        return out.t() if is_transposed else out\n\n    pass\nelse:\n\n    @torch.inference_mode\n    def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):\n        if isinstance(W, Float8Tensor):\n            return W.dequantize()\n        if quant_state is None:\n            return W\n        if W.dtype == torch.float8_e4m3fn:\n            return weight_dequant(W, quant_state)\n        if type(quant_state) is not list:\n            # New quant_state as a class\n            # https://github.com/TimDettmers/bitsandbytes/pull/763/files\n            absmax = quant_state.absmax\n            shape = quant_state.shape\n            dtype = quant_state.dtype\n            blocksize = quant_state.blocksize\n            offset = quant_state.offset\n            state2 = quant_state.state2\n            absmax2 = state2.absmax\n            code2 = state2.code\n            blocksize2 = state2.blocksize\n        else:\n            # Old quant_state as a list of lists\n            absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state\n            offset, state2 = compressed_stats\n            absmax2, code2, blocksize2, _, _, _, _ = state2\n        pass\n\n        n_elements_absmax = absmax.numel()\n        device = W.device\n\n        # Create weight matrix\n        if out is None:\n            out = torch_empty(shape, dtype = dtype, device = device, requires_grad = False)\n        else:\n            assert out.shape == shape\n            assert out.dtype == dtype\n        out_absmax = torch_empty(\n            n_elements_absmax, dtype = torch_float32, device = device, requires_grad = False\n        )\n\n        # Do dequantization\n        ptr_out_absmax = get_ptr(out_absmax)\n        cdequantize_blockwise_fp32(\n            get_ptr(code2),\n            get_ptr(absmax),\n            get_ptr(absmax2),\n            ptr_out_absmax,\n            ctypes_c_int(blocksize2),\n            ctypes_c_int(n_elements_absmax),\n        )\n        out_absmax += offset\n\n        fx = (\n            cdequantize_blockwise_fp16_nf4\n            if dtype == torch_float16\n            else cdequantize_blockwise_bf16_nf4\n        )\n        fx(\n            get_ptr(None),\n            get_ptr(W),\n            ptr_out_absmax,\n            get_ptr(out),\n            ctypes_c_int(blocksize),\n            ctypes_c_int(out.numel()),\n        )\n\n        # Careful returning transposed data\n        is_transposed = True if W.shape[0] == 1 else False\n        return out.t() if is_transposed else out\n\n    pass\n\n\n# INTEL GPU Specific Logic\nif DEVICE_TYPE == \"xpu\" and HAS_XPU_STREAM:\n\n    def fast_gemv(X, W, quant_state, out = None):\n        if quant_state is None:\n            return torch_matmul(X, W, out = out)\n        # For fast X @ W where seq_len == 1\n        # From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469\n        _, q_len, hd = X.shape\n        # assert(q_len == 1)\n\n        if type(quant_state) is not list:\n            # https://github.com/TimDettmers/bitsandbytes/pull/763/files\n            absmax = quant_state.absmax\n            shape = quant_state.shape\n            dtype = quant_state.dtype\n            blocksize = quant_state.blocksize\n            stats = quant_state.code\n            offset = quant_state.offset\n            state2 = quant_state.state2\n            absmax2 = state2.absmax\n            code2 = state2.code\n            blocksize2 = state2.blocksize\n        else:\n            absmax, shape, dtype, blocksize, compressed_stats, quant_type, stats = (\n                quant_state\n            )\n            offset, state2 = compressed_stats\n            absmax2, code2, blocksize2, _, _, _, _ = state2\n        global XPU_STREAMS\n        device = W.device\n        device_index = device.index\n        XPU_STREAM = XPU_STREAMS[device_index]\n\n        # assert(dtype == X.dtype)\n        bout = shape[0]\n\n        if out is None:\n            out = torch_empty(\n                (\n                    1,\n                    1,\n                    bout,\n                ),\n                dtype = dtype,\n                device = device,\n            )\n        # else:\n        #     assert(out.shape == (1, 1, bout,))\n        # pass\n\n        if DEVICE_TYPE == \"xpu\":\n            m = 1\n            n = shape[0]\n        else:\n            n = 1\n            m = shape[0]\n        k = shape[1]\n        lda = shape[0]\n        ldc = shape[0]\n        ldb = (hd + 1) // 2\n        m = ctypes_c_int32(m)\n        n = ctypes_c_int32(n)\n        k = ctypes_c_int32(k)\n        lda = ctypes_c_int32(lda)\n        ldb = ctypes_c_int32(ldb)\n        ldc = ctypes_c_int32(ldc)\n\n        df = torch_empty(absmax.shape, dtype = torch_float32, device = device)\n        with torch_gpu_device(device):\n            cdequantize_blockwise_fp32(\n                get_ptr(code2),\n                get_ptr(absmax),\n                get_ptr(absmax2),\n                get_ptr(df),\n                ctypes_c_int(blocksize2),\n                ctypes_c_int(df.numel()),\n                XPU_STREAM,\n            )\n            df += offset\n            absmax = df\n\n            fx = (\n                cgemm_4bit_inference_naive_fp16\n                if dtype == torch_float16\n                else cgemm_4bit_inference_naive_bf16\n            )\n\n            blocksize = ctypes_c_int32(blocksize)\n            fx(\n                m,\n                n,\n                k,\n                get_ptr(X),\n                get_ptr(W),\n                get_ptr(absmax),\n                get_ptr(stats),\n                get_ptr(out),\n                lda,\n                ldb,\n                ldc,\n                blocksize,\n                XPU_STREAM,\n            )\n\n        return out\n\nelif DEVICE_TYPE in (\"cuda\", \"hip\") and HAS_CUDA_STREAM:\n\n    def fast_gemv(X, W, quant_state, out = None):\n        if quant_state is None:\n            return torch_matmul(X, W, out = out)\n        # For fast X @ W where seq_len == 1\n        # From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469\n        _, q_len, hd = X.shape\n        # assert(q_len == 1)\n\n        if type(quant_state) is not list:\n            # https://github.com/TimDettmers/bitsandbytes/pull/763/files\n            absmax = quant_state.absmax\n            shape = quant_state.shape\n            dtype = quant_state.dtype\n            blocksize = quant_state.blocksize\n            stats = quant_state.code\n            offset = quant_state.offset\n            state2 = quant_state.state2\n            absmax2 = state2.absmax\n            code2 = state2.code\n            blocksize2 = state2.blocksize\n        else:\n            absmax, shape, dtype, blocksize, compressed_stats, quant_type, stats = (\n                quant_state\n            )\n            offset, state2 = compressed_stats\n            absmax2, code2, blocksize2, _, _, _, _ = state2\n        pass\n        global CUDA_STREAMS\n        device = W.device\n        device_index = device.index\n        CUDA_STREAM = CUDA_STREAMS[device_index]\n\n        # assert(dtype == X.dtype)\n        bout = shape[0]\n\n        if out is None:\n            out = torch_empty(\n                (\n                    1,\n                    1,\n                    bout,\n                ),\n                dtype = dtype,\n                device = device,\n            )\n        # else:\n        #     assert(out.shape == (1, 1, bout,))\n        # pass\n\n        n = 1\n        m = shape[0]\n        k = shape[1]\n        lda = shape[0]\n        ldc = shape[0]\n        ldb = (hd + 1) // 2\n        m = ctypes_c_int32(m)\n        n = ctypes_c_int32(n)\n        k = ctypes_c_int32(k)\n        lda = ctypes_c_int32(lda)\n        ldb = ctypes_c_int32(ldb)\n        ldc = ctypes_c_int32(ldc)\n\n        df = torch_empty(absmax.shape, dtype = torch_float32, device = device)\n        with torch_gpu_device(device):\n            cdequantize_blockwise_fp32(\n                get_ptr(code2),\n                get_ptr(absmax),\n                get_ptr(absmax2),\n                get_ptr(df),\n                ctypes_c_int(blocksize2),\n                ctypes_c_int(df.numel()),\n                CUDA_STREAM,\n            )\n            df += offset\n            absmax = df\n\n            fx = (\n                cgemm_4bit_inference_naive_fp16\n                if dtype == torch_float16\n                else cgemm_4bit_inference_naive_bf16\n            )\n\n            blocksize = ctypes_c_int32(blocksize)\n            fx(\n                m,\n                n,\n                k,\n                get_ptr(X),\n                get_ptr(W),\n                get_ptr(absmax),\n                get_ptr(stats),\n                get_ptr(out),\n                lda,\n                ldb,\n                ldc,\n                blocksize,\n                CUDA_STREAM,\n            )\n        pass\n\n        return out\n\n    pass\nelse:\n\n    def fast_gemv(X, W, quant_state, out = None):\n        if quant_state is None:\n            return torch_matmul(X, W, out = out)\n        # For fast X @ W where seq_len == 1\n        # From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469\n        _, q_len, hd = X.shape\n        # assert(q_len == 1)\n\n        if type(quant_state) is not list:\n            # https://github.com/TimDettmers/bitsandbytes/pull/763/files\n            absmax = quant_state.absmax\n            shape = quant_state.shape\n            dtype = quant_state.dtype\n            blocksize = quant_state.blocksize\n            stats = quant_state.code\n            offset = quant_state.offset\n            state2 = quant_state.state2\n            absmax2 = state2.absmax\n            code2 = state2.code\n            blocksize2 = state2.blocksize\n        else:\n            absmax, shape, dtype, blocksize, compressed_stats, quant_type, stats = (\n                quant_state\n            )\n            offset, state2 = compressed_stats\n            absmax2, code2, blocksize2, _, _, _, _ = state2\n        pass\n        # assert(dtype == X.dtype)\n        bout = shape[0]\n        device = W.device\n\n        if out is None:\n            out = torch_empty(\n                (\n                    1,\n                    1,\n                    bout,\n                ),\n                dtype = dtype,\n                device = device,\n            )\n        # else:\n        #     assert(out.shape == (1, 1, bout,))\n        # pass\n\n        n = 1\n        m = shape[0]\n        k = shape[1]\n        lda = shape[0]\n        ldc = shape[0]\n        ldb = (hd + 1) // 2\n        m = ctypes_c_int32(m)\n        n = ctypes_c_int32(n)\n        k = ctypes_c_int32(k)\n        lda = ctypes_c_int32(lda)\n        ldb = ctypes_c_int32(ldb)\n        ldc = ctypes_c_int32(ldc)\n\n        df = torch_empty(absmax.shape, dtype = torch_float32, device = device)\n        cdequantize_blockwise_fp32(\n            get_ptr(code2),\n            get_ptr(absmax),\n            get_ptr(absmax2),\n            get_ptr(df),\n            ctypes_c_int(blocksize2),\n            ctypes_c_int(df.numel()),\n        )\n        df += offset\n        absmax = df\n\n        fx = (\n            cgemm_4bit_inference_naive_fp16\n            if dtype == torch_float16\n            else cgemm_4bit_inference_naive_bf16\n        )\n\n        blocksize = ctypes_c_int32(blocksize)\n        fx(\n            m,\n            n,\n            k,\n            get_ptr(X),\n            get_ptr(W),\n            get_ptr(absmax),\n            get_ptr(stats),\n            get_ptr(out),\n            lda,\n            ldb,\n            ldc,\n            blocksize,\n        )\n\n        return out\n\n    pass\n\n\ndef fast_linear_forward(proj, X, temp_lora = None, out = None):\n    W, W_quant, lora_A, lora_B, lora_S, bias = get_lora_parameters_bias(proj)\n    bsz, q_len, in_dim = X.shape\n    if q_len != 1:\n        return matmul_lora(X, W, W_quant, lora_A, lora_B, lora_S)\n\n    if W_quant is None:\n        out = torch_matmul(X, W.t(), out = out)\n    elif W.dtype == torch.float8_e4m3fn:\n        out = fp8_linear(X, W, W_quant, bias)\n    elif bsz == 1 and q_len == 1:\n        out = fast_gemv(X, W, W_quant, out = out)\n    else:\n        W = fast_dequantize(W.t(), W_quant, use_global_buffer = True)\n        out = torch_matmul(X, W, out = out)\n\n    # Add in LoRA weights\n    if lora_A is not None:\n        out_dim = out.shape[2]\n        dtype = X.dtype\n\n        if not hasattr(lora_A, \"_fast_lora\"):\n            lora_A._fast_lora = lora_A.to(dtype)\n            lora_B._fast_lora = lora_B.to(dtype)\n\n        if bsz == 1:\n            out = out.view(out_dim)\n            temp_lora = torch_mv(lora_A._fast_lora, X.ravel(), out = temp_lora)\n            out.addmv_(lora_B._fast_lora, temp_lora, alpha = lora_S)\n        else:\n            out = out.view(bsz, out_dim)\n            temp_lora = torch_mm(\n                X.view(bsz, in_dim), lora_A._fast_lora.t(), out = temp_lora\n            )\n            out.addmm_(temp_lora, lora_B._fast_lora.t(), alpha = lora_S)\n        out = out.view(bsz, 1, out_dim)\n\n    if bias is not None:\n        out += bias\n\n    return out\n\n\ndef matmul_lora(X, W, W_quant, A, B, s, out = None):\n    dtype = X.dtype\n\n    if X.dim() == 3:\n        batch, seq_len, d = X.shape\n        X = X.view(-1, X.shape[-1])\n        reshape = True\n    else:\n        reshape = False\n\n    if isinstance(W, Float8Tensor):\n        assert W.ndim == 2\n        if W.block_size[0] == W.shape[0] and W.block_size[1] == 1:\n            # In the backward pass, rowwise scaled becomes colwise scaled after we\n            # transpose the weight tensor. Use this case to detect backward.\n            # TODO: would be simpler if we simply don't call `matmul_lora` in backward\n            W = W.dequantize()\n        else:\n            W = W.contiguous()\n        out = torch_matmul(X, W.t(), out = out)\n    elif W.dtype == torch.float8_e4m3fn:\n        out = fp8_linear(X, W, W_quant)\n    else:\n        W = fast_dequantize(W, W_quant, use_global_buffer = True)\n        out = torch_matmul(X, W.t(), out = out)\n    if W_quant is not None:\n        del W\n\n    if A is not None:\n        # LoRA is enabled\n        A, B = A.t(), B.t()\n        XA = torch_matmul(X, A.to(dtype))\n        out.addmm_(XA, B.to(dtype), alpha = s)\n        # out += (X @ A.to(dtype)) @ (s * B.to(dtype))\n\n    return out.view(batch, seq_len, -1) if reshape else out\n"
  },
  {
    "path": "unsloth/models/__init__.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .llama import FastLlamaModel\nfrom .loader import FastLanguageModel, FastVisionModel, FastTextModel, FastModel\nfrom .mistral import FastMistralModel\nfrom .qwen2 import FastQwen2Model\nfrom .qwen3 import FastQwen3Model\nfrom .qwen3_moe import FastQwen3MoeModel\nfrom .granite import FastGraniteModel\nfrom .sentence_transformer import FastSentenceTransformer\n\ntry:\n    from .falcon_h1 import FastFalconH1Model\nexcept:\n    # transformers_version < 4.53.0 does not have falcon_h1 so silently skip it for now\n    pass\nfrom .dpo import PatchDPOTrainer, PatchKTOTrainer\nfrom ._utils import is_bfloat16_supported, is_vLLM_available, __version__\nfrom .rl import PatchFastRL, vLLMSamplingParams\n"
  },
  {
    "path": "unsloth/models/_utils.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n__version__ = \"2026.3.8\"\n\n__all__ = [\n    \"SUPPORTS_BFLOAT16\",\n    \"is_bfloat16_supported\",\n    \"is_vLLM_available\",\n    \"prepare_model_for_kbit_training\",\n    \"xformers\",\n    \"xformers_attention\",\n    \"xformers_version\",\n    \"__version__\",\n    \"importlib_version\",\n    \"HAS_FLASH_ATTENTION\",\n    \"HAS_FLASH_ATTENTION_SOFTCAPPING\",\n    \"USE_MODELSCOPE\",\n    \"platform_system\",\n    \"resolve_hip_gpu_stats_name\",\n    \"patch_tokenizer\",\n    \"get_statistics\",\n    \"Unsloth_Offloaded_Gradient_Checkpointer\",\n    \"offload_to_disk\",\n    \"offload_input_embeddings\",\n    \"offload_output_embeddings\",\n    \"unsloth_offloaded_gradient_checkpoint\",\n    \"torch_compile_options\",\n    \"patch_linear_scaling\",\n    \"patch_llama_rope_scaling\",\n    \"create_boolean_mask\",\n    \"torch_amp_custom_fwd\",\n    \"torch_amp_custom_bwd\",\n    # \"accelerate_old_send_to_device\",\n    # \"accelerate_new_send_to_device\",\n    \"patch_gradient_accumulation_fix\",\n    \"patch_compiling_bitsandbytes\",\n    \"patch_regional_compilation\",\n    \"patch_layernorm\",\n    \"patch_torch_compile\",\n    \"patch_model_and_tokenizer\",\n    \"patch_unsloth_gradient_checkpointing\",\n    \"unpatch_unsloth_gradient_checkpointing\",\n    \"patch_gradient_checkpointing\",\n    \"unpatch_gradient_checkpointing\",\n    \"HAS_CUT_CROSS_ENTROPY\",\n    \"EMPTY_LOGITS\",\n    \"fused_linear_cross_entropy\",\n    \"unsloth_fused_ce_loss\",\n    \"patch_unsloth_smart_gradient_checkpointing\",\n    \"unpatch_unsloth_smart_gradient_checkpointing\",\n    \"apply_unsloth_gradient_checkpointing\",\n    \"patch_compiled_autograd\",\n    \"process_vision_info\",\n    \"unsloth_compile_transformers\",\n    \"prefer_flex_attn_if_supported\",\n    \"patch_fast_lora\",\n    \"validate_loftq_config\",\n    \"RaiseUninitialized\",\n    \"fast_inference_setup\",\n    \"patch_peft_fast_inference\",\n    \"error_out_no_vllm\",\n    \"dequantize_module_weight\",\n    \"patch_hf_quantizer\",\n    \"verify_fp8_support_if_applicable\",\n    \"_get_inference_mode_context_manager\",\n    \"hf_login\",\n    \"is_moe_model\",\n    \"get_moe_target_parameters\",\n    \"make_fast_generate_wrapper\",\n]\n\nimport torch\nfrom typing import Union, Optional, List, Any, Callable, Tuple, Iterator\nfrom platform import system as platform_system\n\nplatform_system = platform_system()\nimport numpy as np\nimport contextlib\nimport re\nfrom dataclasses import dataclass, field\nimport functools\nimport textwrap\nimport logging\nimport warnings, subprocess, inspect, psutil, os, math\nfrom unsloth_zoo.utils import Version, get_quant_type\nfrom importlib.metadata import version as importlib_version\nfrom ..device_type import (\n    is_hip,\n    get_device_type,\n    DEVICE_TYPE,\n    DEVICE_TYPE_TORCH,\n    DEVICE_COUNT,\n    ALLOW_PREQUANTIZED_MODELS,\n)\nfrom ..import_fixes import UNSLOTH_ENABLE_LOGGING\nfrom unsloth_zoo.log import logger\nfrom unsloth_zoo.tokenizer_utils import (\n    patch_tokenizer as _patch_tokenizer,\n)\nfrom unsloth_zoo.rl_environments import (\n    check_python_modules,\n    create_locked_down_function,\n    execute_with_time_limit,\n    Benchmarker,\n)\nfrom unsloth_zoo.patching_utils import (\n    patch_compiling_bitsandbytes,\n    patch_layernorm,\n    patch_torch_compile,\n    patch_model_and_tokenizer,\n    patch_compiled_autograd,\n)\nfrom unsloth_zoo.gradient_checkpointing import (\n    Unsloth_Offloaded_Gradient_Checkpointer,\n    unsloth_offloaded_gradient_checkpoint,\n    patch_unsloth_gradient_checkpointing,\n    unpatch_unsloth_gradient_checkpointing,\n    Unsloth_Gradient_Checkpointer,\n    unsloth_gradient_checkpoint,\n    patch_gradient_checkpointing,\n    unpatch_gradient_checkpointing,\n    patch_unsloth_smart_gradient_checkpointing,\n    unpatch_unsloth_smart_gradient_checkpointing,\n)\nfrom unsloth_zoo.loss_utils import (\n    HAS_CUT_CROSS_ENTROPY,\n    fused_linear_cross_entropy,\n    _unsloth_get_batch_samples,\n    unsloth_fused_ce_loss,\n)\nfrom unsloth_zoo.vision_utils import (\n    process_vision_info,\n)\nfrom unsloth_zoo.compiler import (\n    get_transformers_model_type,\n    unsloth_compile_transformers as _unsloth_compile_transformers,\n)\nfrom unsloth_zoo.training_utils import (\n    prepare_model_for_training,\n)\n\n\ndef resolve_hip_gpu_stats_name(gpu_stats):\n    name = str(getattr(gpu_stats, \"name\", \"\") or \"\").strip()\n    name = re.sub(r\"\\s*\\([^)]*\\)\\s*$\", \"\", name).strip()\n    normalized_name = name.lower().strip(\". \")\n    if normalized_name and normalized_name not in (\"amd radeon graphics\",):\n        return name + \". \"\n\n    try:\n        torch_name = str(torch.cuda.get_device_name(0) or \"\").strip()\n        torch_name = re.sub(r\"\\s*\\([^)]*\\)\\s*$\", \"\", torch_name).strip()\n    except Exception:\n        torch_name = \"\"\n    normalized_torch_name = torch_name.lower().strip(\". \")\n    if normalized_torch_name and normalized_torch_name not in (\"amd radeon graphics\",):\n        return torch_name + \". \"\n\n    arch_name = \"\"\n    for key in (\"gcnArchName\", \"gcn_arch_name\", \"arch_name\", \"gfx_arch_name\"):\n        value = getattr(gpu_stats, key, None)\n        if value is not None and str(value).strip():\n            arch_name = str(value).strip()\n            break\n\n    if arch_name:\n        arch_name = arch_name.strip()\n        match = re.search(r\"(gfx[0-9a-z]+)\", arch_name, flags = re.I)\n        if match:\n            return f\"AMD {match.group(1).lower()} GPU. \"\n    return \"AMD GPU. \"\n\n\nfrom unsloth_zoo.temporary_patches import (\n    TEMPORARY_PATCHES,\n)\n\n\ndef apply_unsloth_gradient_checkpointing(\n    use_gradient_checkpointing, max_seq_length, dtype\n):\n    \"\"\"\n    Apply gradient checkpointing with smart heuristics.\n\n    For seq < 512, the overhead of gradient offloading in gc=\"unsloth\" mode\n    is not worth it. Benchmarks show standard gc is faster for small sequences.\n\n    Args:\n        use_gradient_checkpointing: \"unsloth\", True, False, or None\n        max_seq_length: The maximum sequence length\n        dtype: The model dtype for patching\n\n    Returns:\n        The effective use_gradient_checkpointing value (may change from \"unsloth\" to True)\n    \"\"\"\n    if use_gradient_checkpointing == \"unsloth\":\n        # Gradient offloading overhead is not worth it for small sequences.\n        # Benchmarks show crossover point is around seq_len 384-512.\n        # For seq < 512, standard gradient checkpointing is faster.\n        if max_seq_length < 512:\n            unpatch_unsloth_smart_gradient_checkpointing()\n            return True\n        else:\n            patch_unsloth_smart_gradient_checkpointing(dtype = dtype)\n            return \"unsloth\"\n    elif use_gradient_checkpointing in (True, False):\n        # User explicitly set True or False - unpatch any previous \"unsloth\" patching\n        unpatch_unsloth_smart_gradient_checkpointing()\n        return use_gradient_checkpointing\n    return use_gradient_checkpointing\n\n\ndef prefer_flex_attn_if_supported(model_class, config):\n    if os.environ.get(\"UNSLOTH_ENABLE_FLEX_ATTENTION\", \"1\") == \"0\":\n        return None\n    try:\n        from transformers.utils.import_utils import is_torch_flex_attn_available\n\n        if not is_torch_flex_attn_available():\n            return None\n        if model_class is None or not getattr(\n            model_class, \"_supports_flex_attn\", False\n        ):\n            return None\n        # GPT-OSS, Mllama and Gemma3N use eager/sdpa attention during\n        # inference since flex attention returns incorrect results or errors out.\n        # GPT-OSS: left padding issues cause incorrect outputs.\n        # Mllama: _update_causal_mask uses make_flex_block_causal_mask which\n        # creates BlockMask with Q_LEN=KV_LEN=total_seq_len, but during\n        # decode q_len=1, causing ValueError. Needs transformers update.\n        # Gemma3N: timm vision wrappers (eg Gemma3nVisionConfig) do not\n        # support flex_attention.\n        # NemotronH: hybrid Mamba-2 + Transformer model that does not\n        # support flex_attention (raises NotImplementedError from transformers).\n        model_type = getattr(config, \"model_type\", \"\") if config else \"\"\n        if model_type in (\"gpt_oss\", \"mllama\", \"nemotron_h\") or str(\n            model_type\n        ).startswith(\"gemma3n\"):\n            return None\n        if config is not None:\n            setattr(config, \"_attn_implementation\", \"flex_attention\")\n            if hasattr(config, \"attn_implementation\"):\n                setattr(config, \"attn_implementation\", \"flex_attention\")\n        return \"flex_attention\"\n    except Exception:\n        return None\n\n\ndef _run_temporary_patches(phase):\n    import inspect\n\n    for temporary_patch in TEMPORARY_PATCHES:\n        try:\n            sig = inspect.signature(temporary_patch)\n            if \"phase\" in sig.parameters:\n                temporary_patch(phase = phase)\n            else:\n                temporary_patch()\n        except (ValueError, TypeError):\n            temporary_patch()\n\n\n_run_temporary_patches(\"init\")\n\n# =============================================\n# Disable some warnings which can get annoying\nwarnings.filterwarnings(action = \"ignore\", category = UserWarning, module = \"torch\")\nwarnings.filterwarnings(action = \"ignore\", category = FutureWarning, module = \"torch\")\nwarnings.filterwarnings(action = \"ignore\", category = UserWarning, module = \"huggingface_hub\")\nwarnings.filterwarnings(\n    action = \"ignore\", category = FutureWarning, module = \"huggingface_hub\"\n)\nwarnings.filterwarnings(action = \"ignore\", category = UserWarning, module = \"trl\")\nwarnings.filterwarnings(action = \"ignore\", category = FutureWarning, module = \"trl\")\nwarnings.filterwarnings(action = \"ignore\", category = FutureWarning, module = \"xformers\")\nwarnings.filterwarnings(action = \"ignore\", category = RuntimeWarning, module = \"subprocess\")\nwarnings.filterwarnings(action = \"ignore\", category = UserWarning, module = \"transformers\")\nwarnings.filterwarnings(action = \"ignore\", category = FutureWarning, module = \"accelerate\")\nwarnings.filterwarnings(\n    action = \"ignore\", category = RuntimeWarning, module = \"multiprocessing\"\n)\nwarnings.filterwarnings(action = \"ignore\", category = RuntimeWarning, module = \"multiprocess\")\nwarnings.filterwarnings(action = \"ignore\", category = UserWarning, module = \"triton\")\nwarnings.filterwarnings(action = \"ignore\", category = UserWarning, module = \"bitsandbytes\")\n\n# Stop \"Special tokens have been added in the vocabulary, ...\"\nlogging.getLogger(\"transformers.tokenization_utils_base\").setLevel(logging.CRITICAL + 1)\n\nTORCHAO_MSG = \"Error: torchao not found, please install with `pip install torchao`\"\n\n\n# Ignore logging messages\nclass HideLoggingMessage(logging.Filter):\n    __slots__ = (\"text\",)\n\n    def __init__(self, text):\n        self.text = text\n\n    def filter(self, x):\n        return not (self.text in x.getMessage())\n\n\n# Replace warning messages (analogous to HideLoggingMessage but for warnings.warn)\nclass ReplaceWarningMessage:\n    \"\"\"\n    Intercepts warnings.warn calls and replaces matching messages with Unsloth branded ones.\n    Uses a list of registered (match_text, replacement, category) rules checked in order.\n    \"\"\"\n\n    _rules = []\n    _original_showwarning = None\n    _installed = False\n\n    @classmethod\n    def add_rule(cls, match_text, replacement, category = None):\n        cls._rules.append((match_text, replacement, category))\n        if not cls._installed:\n            cls._install()\n\n    @classmethod\n    def _install(cls):\n        cls._original_showwarning = warnings.showwarning\n        cls._installed = True\n\n        def _patched_showwarning(\n            message, category, filename, lineno, file = None, line = None\n        ):\n            msg_str = str(message)\n            for match_text, replacement, match_category in cls._rules:\n                if match_text in msg_str and (\n                    match_category is None or category is match_category\n                ):\n                    print(replacement)\n                    return\n            cls._original_showwarning(message, category, filename, lineno, file, line)\n\n        warnings.showwarning = _patched_showwarning\n\n\n# Stop vLLM messages\nif not UNSLOTH_ENABLE_LOGGING:\n    try:\n        from vllm.worker.worker import logger as vllm_worker_logger\n\n        vllm_worker_logger.addFilter(HideLoggingMessage(\"Sleep mode freed\"))\n        del vllm_worker_logger\n    except:\n        pass\n    try:\n        from vllm.v1.worker.gpu_worker import logger as vllm_gpu_worker_logger\n\n        vllm_gpu_worker_logger.addFilter(HideLoggingMessage(\"Sleep mode freed\"))\n        del vllm_gpu_worker_logger\n    except:\n        pass\n    try:\n        from vllm.executor.executor_base import logger as vllm_executor_logger\n\n        vllm_executor_logger.addFilter(HideLoggingMessage(\"to fall asleep\"))\n        vllm_executor_logger.addFilter(HideLoggingMessage(\"to wake up\"))\n        vllm_executor_logger.addFilter(HideLoggingMessage(\"Executor is not sleeping\"))\n        del vllm_executor_logger\n    except:\n        pass\n    try:\n        from vllm.v1.executor.abstract import logger as vllm_v1_executor_logger\n\n        vllm_v1_executor_logger.addFilter(HideLoggingMessage(\"to fall asleep\"))\n        vllm_v1_executor_logger.addFilter(HideLoggingMessage(\"to wake up\"))\n        vllm_v1_executor_logger.addFilter(\n            HideLoggingMessage(\"Executor is not sleeping\")\n        )\n        del vllm_v1_executor_logger\n    except:\n        pass\n    try:\n        from vllm.core.block.prefix_caching_block import (\n            logger as vllm_prefix_caching_logger,\n        )\n\n        vllm_prefix_caching_logger.addFilter(HideLoggingMessage(\"reset prefix cache\"))\n        del vllm_prefix_caching_logger\n    except:\n        pass\n    try:\n        from vllm.v1.core.block_pool import logger as vllm_block_pool_logger\n\n        vllm_block_pool_logger.addFilter(HideLoggingMessage(\"reset prefix cache\"))\n        del vllm_block_pool_logger\n    except:\n        pass\n    try:\n        from vllm.lora.models import logger as vllm_lora_model_logger\n\n        vllm_lora_model_logger.addFilter(\n            HideLoggingMessage(\n                \"Regarding multimodal models, vLLM currently only supports adding\"\n            )\n        )\n        del vllm_lora_model_logger\n    except:\n        pass\n    try:\n        from vllm.attention.utils.fa_utils import (\n            logger as vllm_attention_utils_fa_utils_logger,\n        )\n\n        vllm_attention_utils_fa_utils_logger.addFilter(\n            HideLoggingMessage(\"Cannot use FA version\")\n        )\n        del vllm_attention_utils_fa_utils_logger\n    except:\n        pass\n\n# The speedups for torchdynamo mostly come with GPU Ampere or higher and which is not detected here.\nfrom transformers.training_args import logger as transformers_training_args_logger\n\ntransformers_training_args_logger.addFilter(HideLoggingMessage(\"The speedups\"))\n# torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED.\ntransformers_training_args_logger.addFilter(HideLoggingMessage(\"torch.distributed\"))\n# average_tokens_across_devices is set to True but it is invalid when world size is1\ntransformers_training_args_logger.addFilter(\n    HideLoggingMessage(\"average_tokens_across_devices\")\n)\ndel transformers_training_args_logger\n\n# No label_names provided for model class\nfrom transformers.trainer import logger as transformers_trainer_logger\n\ntransformers_trainer_logger.addFilter(HideLoggingMessage(\"No label_names\"))\n\n# The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config.\ntransformers_trainer_logger.addFilter(HideLoggingMessage(\"The tokenizer has new\"))\ndel transformers_trainer_logger\n\n# Using the default loss: `ForCausalLMLoss`.\ntry:\n    from transformers.modeling_utils import logger as transformers_modeling_utils_logger\n\n    transformers_modeling_utils_logger.addFilter(HideLoggingMessage(\"ForCausalLMLoss\"))\n    del transformers_modeling_utils_logger\nexcept:\n    pass\n\n# The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.\ntry:\n    from accelerate.utils.modeling import logger as accelerate_utils_modeling_logger\n\n    accelerate_utils_modeling_logger.addFilter(\n        HideLoggingMessage(\"The model weights are not tied\")\n    )\n    del accelerate_utils_modeling_logger\nexcept:\n    pass\n\n# Setting `pad_token_id` to `eos_token_id`\ntry:\n    from transformers.generation.utils import (\n        logger as transformers_generation_utils_logger,\n    )\n\n    transformers_generation_utils_logger.addFilter(\n        HideLoggingMessage(\"Setting `pad_token_id` to `eos_token_id`\")\n    )\n    # \"You have set `compile_config`\n    transformers_generation_utils_logger.addFilter(HideLoggingMessage(\"compile_config\"))\n    del transformers_generation_utils_logger\nexcept:\n    pass\n\n# The following generation flags are not valid and may be ignored:\ntry:\n    from transformers.generation.configuration_utils import (\n        logger as configuration_logger,\n    )\n\n    configuration_logger.addFilter(HideLoggingMessage(\"following generation flags\"))\n    del configuration_logger\nexcept:\n    pass\n\n# Gemma3 It is strongly recommended to train Gemma3 models with the `eager`\ntry:\n    from transformers.models.gemma3.modeling_gemma3 import logger as gemma3_logger\n\n    gemma3_logger.addFilter(HideLoggingMessage(\"strongly recommended\"))\n    del gemma3_logger\nexcept:\n    pass\n\n# Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed.\ntry:\n    from huggingface_hub.file_download import logger as hub_logger\n\n    hub_logger.addFilter(HideLoggingMessage(\"hf_xet\"))\n    del hub_logger\nexcept:\n    pass\n\n# MXFP4 quantization requires triton >= 3.4.0\ntry:\n    from transformers.quantizers.quantizer_mxfp4 import logger as mxfp4_logger\n\n    mxfp4_logger.addFilter(HideLoggingMessage(\"requires triton\"))\n    del mxfp4_logger\nexcept:\n    pass\n\n# You passed `quantization_config` or equivalent parameters\ntry:\n    warnings.filterwarnings(\n        action = \"ignore\",\n        message = r\".*quantization_config.*\",\n        category = UserWarning,\n        append = True,\n    )\nexcept:\n    pass\n\n# UserWarning: Logical operators 'and' and 'or' are deprecated for non-scalar tensors; please use '&' or '|' instead\n# Will be fixed in torch 2.8.1 https://github.com/pytorch/pytorch/issues/158463\ntry:\n    warnings.filterwarnings(\n        action = \"ignore\",\n        message = r\".*Logical operators 'and' and 'or'.*\",\n        category = UserWarning,\n        append = True,\n    )\nexcept:\n    pass\n\n# Using a slow image processor as `use_fast`\ntry:\n    from transformers.processing_utils import logger as processing_utils_logger\n\n    processing_utils_logger.addFilter(HideLoggingMessage(\"`use_fast`\"))\n    del processing_utils_logger\nexcept:\n    pass\n\n# Using a slow image processor as `use_fast`\ntry:\n    from transformers.models.auto.image_processing_auto import (\n        logger as processing_utils_logger,\n    )\n\n    processing_utils_logger.addFilter(HideLoggingMessage(\"`use_fast`\"))\n    del processing_utils_logger\nexcept:\n    pass\n\n# `use_cache=True` is incompatible with gradient checkpointing\ntry:\n    from transformers.trainer import logger as trainer_logger\n\n    trainer_logger.addFilter(HideLoggingMessage(\"`use_cache=True`\"))\n    del trainer_logger\nexcept:\n    pass\n\n# `use_cache=True` is incompatible with gradient checkpointing\ntry:\n    from transformers.utils.generic import logger as trainer_logger\n\n    trainer_logger.addFilter(HideLoggingMessage(\"`use_cache=True`\"))\n    del trainer_logger\nexcept:\n    pass\n\n# We detected that you are using `from_pretrained` with a meta device context manager or `torch.set_default_device('meta')\ntry:\n    from transformers.modeling_utils import logger as modeling_utils_logger\n\n    modeling_utils_logger.addFilter(HideLoggingMessage(\"anti-pattern\"))\n    del modeling_utils_logger\nexcept:\n    pass\n\n# Errors out on\n# Some weights of Gemma3nForConditionalGeneration were not initialized from the model checkpoint\nfrom transformers.modeling_utils import logger as transformers_logger\n\n\nclass _RaiseUninitialized(logging.Handler):\n    def __init__(self):\n        super().__init__()\n\n    def emit(self, record):\n        record_lower = str(record).lower()\n        if (\n            (\"some weights of\" in record_lower)\n            and (\"score.weight\" not in record_lower)\n            and (\"classifier.weight\" not in record_lower)\n            and (\"cls.predictions\" not in record_lower)\n            and (\"predictions.decoder\" not in record_lower)\n            and (os.environ.get(\"UNSLOTH_WARN_UNINITIALIZED\", \"1\") == \"1\")\n        ):\n            raise Exception(\n                f\"Unsloth: Critical error since some weights are not initialized.\\n\"\n                f\"Please try updating Unsloth, transformers and timm via:\\n\"\n                f\"`pip install --upgrade --force-reinstall --no-cache-dir --no-deps unsloth unsloth_zoo transformers timm`\\n\"\n                f\"{str(record)}\"\n            )\n\n\nclass RaiseUninitialized:\n    def __init__(self):\n        self.error_handler = _RaiseUninitialized()\n        transformers_logger.addHandler(self.error_handler)\n\n    def remove(self):\n        transformers_logger.removeHandler(self.error_handler)\n\n\ntry:\n    from transformers.trainer import logger as transformers_trainer_logger\n\n    transformers_trainer_logger.addFilter(\n        HideLoggingMessage(\"The model is already on multiple devices.\")\n    )\nexcept:\n    pass\n\n# Hide HF Hub unauthenticated request warnings\ntry:\n    from huggingface_hub.utils._http import logger as hf_http_logger\n\n    hf_http_logger.addFilter(\n        HideLoggingMessage(\"You are sending unauthenticated requests\")\n    )\n    del hf_http_logger\nexcept:\n    pass\n\n# Replace PEFT target_parameters warning with Unsloth branded message for MoE models\nReplaceWarningMessage.add_rule(\n    match_text = \"target_parameters\",\n    replacement = (\n        \"Unsloth: PEFT set target_parameters but found no matching parameters.\\n\"\n        \"This is expected for MoE models - Unsloth handles MoE expert LoRA targeting separately.\"\n    ),\n    category = RuntimeWarning,\n)\n\n# Patch get_model_param_count to record correct 4bit / 8bit\nfrom transformers.trainer_pt_utils import is_deepspeed_zero3_enabled\n\n\ndef extract_quant_model_param_count(model):\n    \"\"\"\n    Calculate quant model param count based on difference in param class. Returns int for param count.\n    \"\"\"\n    count: int = 0\n    for name, p in model.named_parameters():\n        if p.__class__.__name__ == \"Params4bit\":\n            count += 2 * p.numel()\n        else:\n            count += p.numel()\n    return count\n\n\ndef get_model_param_count(model, trainable_only = False):\n    \"\"\"\n    Calculate model's total param count. If trainable_only is True then count only those requiring grads\n    \"\"\"\n    if is_deepspeed_zero3_enabled():\n\n        def numel(p):\n            return p.ds_numel if hasattr(p, \"ds_numel\") else p.numel()\n    else:\n\n        def numel(p):\n            return p.numel()\n\n    s = sum(\n        numel(p) for p in model.parameters() if not trainable_only or p.requires_grad\n    )\n    if (\n        (not trainable_only)\n        and hasattr(model, \"config\")\n        and hasattr(model.config, \"quantization_config\")\n    ):\n        approx = extract_quant_model_param_count(model)\n        if approx is not None:\n            s = approx\n    return s\n\n\nimport transformers.trainer_pt_utils\n\ntransformers.trainer_pt_utils.get_model_param_count = get_model_param_count\nimport transformers.trainer\n\ntransformers.trainer.get_model_param_count = get_model_param_count\n# =============================================\n\n# =============================================\n# Edits all Config files to enable RoPE Scaling for all models\n\n\n# Transformers had to update for Mistral Nemo 12b since Attention is (5120, 4096) now.\ndef patch_mistral_nemo_config(config):\n    if \"head_dim (\" not in config:\n        add_head_dim = (\n            \"If it is not specified, will default to `8`.\\n\"\n            \"        head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):\\n\"\n            \"            The attention head dimension.\"\n        )\n        config = config.replace(\n            \"If it is not specified, will default to `8`.\", add_head_dim\n        )\n\n        add_head_dim = \"num_key_value_heads=8,\\n        head_dim=None,\"\n        config = config.replace(\"num_key_value_heads=8,\", add_head_dim)\n\n        add_head_dim = \"self.sliding_window = sliding_window\\n        self.head_dim = head_dim or hidden_size // num_attention_heads\\n\"\n        config = config.replace(\"self.sliding_window = sliding_window\", add_head_dim)\n    return config\n\n\ntry:\n    # Some Config files use layer_type_validation\n    # for eg Gemma-2, so we must import it to stop errors.\n    from transformers.configuration_utils import layer_type_validation\nexcept:\n    pass\n\ntry:\n    # Transformers 5.0+ uses RotaryEmbeddingConfigMixin as a base class for configs\n    from transformers.modeling_rope_utils import RotaryEmbeddingConfigMixin\nexcept:\n    pass\nfrom transformers import __version__ as transformers_version\n\ntry:\n    from transformers import PreTrainedConfig\nexcept:\n    from transformers import PretrainedConfig\n\nmodel_architectures = [\n    \"llama\",\n    \"mistral\",\n    \"gemma\",\n    \"gemma2\",\n    \"qwen2\",\n    \"granite\",\n    \"qwen3\",\n    \"qwen3_moe\",\n    \"falcon_h1\",\n]\n\nfor model_name in model_architectures:\n    config_filepath = f\"transformers.models.{model_name}.configuration_{model_name}\"\n    model_filepath = f\"transformers.models.{model_name}.modeling_{model_name}\"\n    config_filename = f\"{model_name.title().replace('_','')}Config\"  # qwen3 arch folder is qwen3_moe but config is Qwen3Config. Need to remove underscore(_) for now\n    try:\n        exec(f\"from {config_filepath} import {config_filename}\", globals())\n    except:\n        continue\n\n    try:\n        config = inspect.getsource(eval(config_filename))\n    except:\n        continue\n    if \"RopeParameters\" in config:\n        try:\n            exec(f\"from {config_filepath} import RopeParameters\", globals())\n        except:\n            continue\n\n    if \"rope_scaling\" in config:\n        continue\n    config = re.sub(\n        r\"(\\*\\*kwargs)[\\s]{0,}\\,[\\s]{0,}\\)[\\s]{0,}\\:\",\n        r\"rope_scaling=None,\"\n        r\"\\n        **kwargs):\\n\"\n        r\"\\n        self.rope_scaling = rope_scaling\\n\",\n        config,\n    )\n\n    # Just for Mistral Nemo\n    if model_name == \"mistral\":\n        if Version(transformers_version) <= Version(\"4.42.4\"):\n            config = patch_mistral_nemo_config(config)\n\n    exec(config, globals())\n    exec(f\"import {config_filepath}\", globals())\n    exec(f\"{config_filepath}.{config_filename} = {config_filename}\", globals())\n# =============================================\n\n# =============================================\n# torch.cuda.amp.custom_fwd is deprecated >= 2.4\ntorch_version = torch.__version__\nif DEVICE_TYPE in (\"cuda\", \"hip\"):\n    if Version(torch_version) < Version(\"2.4.0\"):\n        torch_amp_custom_fwd = torch.cuda.amp.custom_fwd\n        torch_amp_custom_bwd = torch.cuda.amp.custom_bwd\n    else:\n        torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = \"cuda\")\n        torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = \"cuda\")\nelif DEVICE_TYPE == \"xpu\":\n    if Version(torch_version) < Version(\"2.6.0\"):\n        raise RuntimeError(\"torch.xpu currently only supports torch.version >= 2.6.0\")\n    else:\n        torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = \"xpu\")\n        torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = \"xpu\")\n# =============================================\n\n# =============================================\n# Fix KeyError: 'Cache only has 0 layers, attempted to access layer with index 0'\n# import transformers.cache_utils\n# if hasattr(transformers.cache_utils, \"DynamicCache\") and \\\n#     transformers.cache_utils.DynamicCache.__getitem__.__name__ != \"__cache_utils_getitem__\":\n\n#     source = inspect.getsource(transformers.cache_utils.DynamicCache.__getitem__)\n#     start = source.find(\"def\")\n#     spaces = start*\" \"\n#     source = source.split(\"\\n\")\n#     source = \"\\n\".join(x[start:] for x in source)\n#     where = source.find(\"raise KeyError\")\n#     source = source[:where] + \\\n#         f\"if len(self) == 0:\\n{spaces}{spaces}\"\\\n#         \"    raise RuntimeError('Unsloth: You must call `FastLanguageModel.for_inference(model)` before doing inference for Unsloth models.')\\n\" + \\\n#         f\"{spaces}{spaces}else:\\n{spaces}{spaces}{spaces}\" + source[where:]\n#     source = source.replace(\"__getitem__\", \"__cache_utils_getitem__\", 1)\n#     exec(source)\n#     transformers.cache_utils.DynamicCache.__getitem__ = __cache_utils_getitem__\n# pass\n# =============================================\n\n# =============================================\n# Weird Databricks errors\nfrom transformers.utils import is_openai_available\n\nif is_openai_available():\n    try:\n        from openai import OpenAI\n    except:\n        print(\"Unsloth: OpenAI failed to import - ignoring for now.\")\n        import transformers.utils\n\n        def _is_openai_available():\n            return False\n\n        transformers.utils.is_openai_available = _is_openai_available\n\n# =============================================\n# Get Flash Attention v2 if Ampere (RTX 30xx, A100)\nimport bitsandbytes as bnb\n\nfrom transformers import AutoTokenizer\nfrom transformers.utils.import_utils import _is_package_available\n\nSUPPORTS_BFLOAT16 = False\nHAS_FLASH_ATTENTION = False\nHAS_FLASH_ATTENTION_SOFTCAPPING = False\n\nif DEVICE_TYPE == \"cuda\":\n    major_version, minor_version = torch.cuda.get_device_capability()\n    torch.cuda.get_device_capability = functools.cache(torch.cuda.get_device_capability)\n\n    if major_version >= 8:\n        SUPPORTS_BFLOAT16 = True\n        if _is_package_available(\"flash_attn\"):\n            # Check for CUDA linking errors \"undefined symbol: _ZNK3c106SymIntltEl\"\n            try:\n                try:\n                    # See https://github.com/unslothai/unsloth/issues/1437\n                    from flash_attn.flash_attn_interface import flash_attn_gpu\n                except:\n                    from flash_attn.flash_attn_interface import flash_attn_cuda\n                HAS_FLASH_ATTENTION = True\n\n                # Also check for softcapping\n                from flash_attn import __version__ as flash_attn_version\n\n                HAS_FLASH_ATTENTION_SOFTCAPPING = Version(\n                    flash_attn_version\n                ) >= Version(\"2.6.3\")\n                if not HAS_FLASH_ATTENTION_SOFTCAPPING:\n                    print(\n                        \"Unsloth: If you want to finetune Gemma 2, upgrade flash-attn to version 2.6.3 or higher!\\n\"\n                        \"Newer versions support faster and less memory usage kernels for Gemma 2's attention softcapping!\\n\"\n                        \"To update flash-attn, do the below:\\n\"\n                        '\\npip install --no-deps --no-build-isolation --upgrade \"flash-attn>=2.6.3\"'\n                    )\n            except:\n                print(\n                    \"Unsloth: Your Flash Attention 2 installation seems to be broken. \"\n                    \"Using Xformers instead. No performance changes will be seen.\"\n                )\n\n                # Stop Flash Attention from importing!\n                import transformers.utils.import_utils\n\n                transformers.utils.import_utils.is_flash_attn_2_available = (\n                    lambda *args, **kwargs: False\n                )\n                import transformers.utils\n\n                transformers.utils.is_flash_attn_2_available = (\n                    lambda *args, **kwargs: False\n                )\n\n                HAS_FLASH_ATTENTION = False\n        else:\n            HAS_FLASH_ATTENTION = False\n    else:\n        # Tri Dao's benchmark shows xformers is faster for now.\n        HAS_FLASH_ATTENTION = False\nelif DEVICE_TYPE == \"hip\":\n    SUPPORTS_BFLOAT16 = True\n    if _is_package_available(\"flash_attn\"):\n        # Check for CUDA linking errors \"undefined symbol: _ZNK3c106SymIntltEl\"\n        try:\n            try:\n                # See https://github.com/unslothai/unsloth/issues/1437\n                from flash_attn.flash_attn_interface import flash_attn_gpu\n            except:\n                from flash_attn.flash_attn_interface import flash_attn_cuda\n            HAS_FLASH_ATTENTION = True\n\n            # Also check for softcapping\n            from flash_attn import __version__ as flash_attn_version\n\n            HAS_FLASH_ATTENTION_SOFTCAPPING = Version(flash_attn_version) >= Version(\n                \"2.6.3\"\n            )\n            if not HAS_FLASH_ATTENTION_SOFTCAPPING:\n                print(\n                    \"Unsloth: If you want to finetune Gemma 2, upgrade flash-attn to version 2.6.3 or higher!\\n\"\n                    \"Newer versions support faster and less memory usage kernels for Gemma 2's attention softcapping!\\n\"\n                    \"To update flash-attn, do the below:\\n\"\n                    '\\npip install --no-deps --no-build-isolation --upgrade \"flash-attn>=2.6.3\"'\n                )\n        except:\n            print(\n                \"Unsloth: Your Flash Attention 2 installation seems to be broken. \"\n                \"Using Xformers instead. No performance changes will be seen.\"\n            )\n\n            # Stop Flash Attention from importing!\n            import transformers.utils.import_utils\n\n            transformers.utils.import_utils.is_flash_attn_2_available = (\n                lambda *args, **kwargs: False\n            )\n            import transformers.utils\n\n            transformers.utils.is_flash_attn_2_available = lambda *args, **kwargs: False\n\n            HAS_FLASH_ATTENTION = False\nelif DEVICE_TYPE == \"xpu\":\n    SUPPORTS_BFLOAT16 = True\n\n# =============================================\n# Get Xformers\n# Silence xformers CUDA mismatch warnings before import\ntry:\n    _xformers_logger = logging.getLogger(\"xformers\")\n    _xformers_logger.setLevel(logging.ERROR)\n    del _xformers_logger\nexcept:\n    pass\ntry:\n    from xformers import __version__ as xformers_version\n\n    # Xformers <= 0.0.32.post2 has a broken FA3 dispatch on Blackwell/RTX 50x GPUs.\n    # The FA3 check used `capability >= (9, 0)` which matches SM 10.0/11.0/12.0,\n    # causing sm_90a kernels to be attempted on non-Hopper GPUs (CUDA error in\n    # flash_fwd_launch_template.h:188). Fixed in 0.0.33 with `<= (9, 0)`.\n    # See https://github.com/facebookresearch/xformers/issues/1329\n    if DEVICE_TYPE == \"cuda\":\n        major_version, minor_version = torch.cuda.get_device_capability()\n        if (f\"{major_version}.{minor_version}\" in (\"10.0\", \"11.0\", \"12.0\")) and (\n            Version(xformers_version) <= Version(\"0.0.32.post2\")\n        ):\n            raise NotImplementedError(\n                f\"Unsloth: Xformers {xformers_version} has a broken FA3 dispatch on \"\n                f\"SM {major_version}.{minor_version} GPUs. Please upgrade to >= 0.0.33 or build from source via\\n\"\n                \"```\\n\"\n                \"pip install ninja\\n\"\n                \"pip install -v --no-build-isolation -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers\\n\"\n                \"```\\n\"\n            )\n\n    # Temporarily disable 0.0.27 and higher - inference issues\n    if False:  # Version(xformers_version) >= Version(\"0.0.27\"):\n        raise ImportError(\n            \"Unsloth: If you are in Colab, we updated the top cell install instructions - please change it to below \"\n            \"then press Disconnect Runtime and then Restart it.\\n\"\n            \"\\n\"\n            \"%%capture\\n\"\n            \"# Installs Unsloth, Xformers (Flash Attention) and all other packages!\\n\"\n            '!pip install \"unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git\"\\n'\n            '!pip install --no-deps \"xformers<=0.0.27\" trl peft accelerate bitsandbytes\\n'\n            \"\\n\"\n            f\"Otherwise in local machines, your xformers version of {xformers_version} is too new.\\n\"\n            'Please downgrade xformers via `pip install --force-reinstall \"xformers<=0.0.27\"'\n        )\n\n    if Version(torch_version) < Version(\"2.2.0\") and Version(\n        xformers_version\n    ) >= Version(\"0.0.24\"):\n        raise ImportError(\n            f\"Unsloth: You have torch = {torch_version} but xformers = {xformers_version}.\\n\"\n            f\"Please install xformers < 0.0.24 for torch = {torch_version}.\"\n        )\n    elif Version(torch_version) < Version(\"2.3.0\") and Version(\n        xformers_version\n    ) >= Version(\"0.0.26\"):\n        raise ImportError(\n            f\"Unsloth: You have torch = {torch_version} but xformers = {xformers_version}.\\n\"\n            f\"Please install xformers < 0.0.26 for torch = {torch_version}.\"\n        )\n    elif Version(torch_version) < Version(\"2.4.0\") and Version(\n        xformers_version\n    ) > Version(\"0.0.27\"):\n        raise ImportError(\n            f\"Unsloth: You have torch = {torch_version} but xformers = {xformers_version}.\\n\"\n            f\"Please install xformers <= 0.0.27 for torch = {torch_version}.\"\n        )\n\n    from xformers._cpp_lib import _register_extensions\n\n    try:\n        _register_extensions()  # Check if C++ modules are loaded correctly\n    except Exception as error:\n        raise ImportError(\n            \"Unsloth: Xformers was not installed correctly.\\n\"\n            \"Please install xformers separately first.\\n\"\n            \"Then confirm if it's correctly installed by running:\\n\"\n            \"python -m xformers.info\\n\\n\"\n            \"Longer error message:\\n\" + str(error)\n        )\n    import xformers.ops.fmha as xformers\n\n    xformers_attention = xformers.memory_efficient_attention\nexcept ModuleNotFoundError:\n    xformers = None\n    xformers_attention = None\n    xformers_version = None\nexcept Exception as e:\n    if UNSLOTH_ENABLE_LOGGING:\n        print(\n            \"========\\nSwitching to PyTorch attention since your Xformers is broken.\\n========\\n\"\n        )\n        print(str(e))\n    xformers = None\n    xformers_attention = None\n    xformers_version = None\n\n# Check TRL version\nfrom trl import __version__ as trl_version\n\n# Unsloth now supports all TRL versions!\nif False:  # Version(trl_version) >= Version(\"0.9.0\"):\n    raise ImportError(\n        \"Unsloth: If you are in Colab, we updated the top cell install instructions - please change it to below \"\n        \"then press Disconnect Runtime and then Restart it.\\n\"\n        \"\\n\"\n        \"%%capture\\n\"\n        \"# Installs Unsloth, Xformers (Flash Attention) and all other packages!\\n\"\n        '!pip install \"unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git\"\\n'\n        '!pip install --no-deps \"xformers<=0.0.27\" trl peft accelerate bitsandbytes\\n'\n        \"\\n\"\n        f\"Otherwise in local machines, your TRL version of {trl_version} is too new.\\n\"\n        \"Please downgrade TRL via `pip install --force-reinstall trl\"\n    )\n\n# =============================================\n# Fix new Xformers versions TypeError: Multiple dispatch failed for 'torch._ops.aten.to.dtype_layout'\n# accelerate_old_send_to_device = None\n# accelerate_new_send_to_device = None\n# if xformers_version is not None and Version(xformers_version) >= Version(\"0.0.27\"):\n#     import accelerate.utils.operations\n#     if hasattr(accelerate.utils.operations, \"send_to_device\") and \\\n#         accelerate.utils.operations.send_to_device.__name__ != \"_fixed_send_to_device\":\n#         accelerate_old_send_to_device = accelerate.utils.operations.send_to_device\n#         from accelerate.utils.operations import *\n#         send_to_device = inspect.getsource(accelerate.utils.operations.send_to_device)\n#         send_to_device = re.sub(\n#             r\"([ ]{4,})return tensor\\.to\\(device\\)\",\n#             r\"\\1try: return tensor.to(device)\\n\\1except: return tensor\",\n#             send_to_device,\n#         ).replace(\"def send_to_device\", \"def _fixed_send_to_device\")\n#         exec(send_to_device)\n#         # accelerate.utils.operations.send_to_device = _fixed_send_to_device\n#         accelerate_new_send_to_device = _fixed_send_to_device\n#     pass\n# pass\n\n# Transformers 4.46 breaks dynamic caching. This is a hack\nimport transformers.generation.configuration_utils\n\nif hasattr(transformers.generation.configuration_utils, \"ALL_CACHE_IMPLEMENTATIONS\"):\n    if (\n        type(transformers.generation.configuration_utils.ALL_CACHE_IMPLEMENTATIONS)\n        is list\n    ):\n        if (\n            \"dynamic\"\n            not in transformers.generation.configuration_utils.ALL_CACHE_IMPLEMENTATIONS\n        ):\n            transformers.generation.configuration_utils.ALL_CACHE_IMPLEMENTATIONS.append(\n                \"dynamic\"\n            )\n# =============================================\n\n# =============================================\n# Torch compile settings\nUNSLOTH_COMPILE_DEBUG = os.environ.get(\"UNSLOTH_COMPILE_DEBUG\", \"0\") == \"1\"\nUNSLOTH_COMPILE_MAXIMUM = os.environ.get(\"UNSLOTH_COMPILE_MAXIMUM\", \"0\") == \"1\"\nUNSLOTH_COMPILE_IGNORE_ERRORS = (\n    os.environ.get(\"UNSLOTH_COMPILE_IGNORE_ERRORS\", \"1\") == \"1\"\n)\n# Just remove max_autotune_gemm warning\nfrom torch._inductor.runtime.hints import DeviceProperties\n\n\n@functools.lru_cache(None)\ndef is_big_gpu(index) -> bool:\n    if DEVICE_TYPE == \"xpu\":\n        prop = DeviceProperties.create(\n            torch.device(\"xpu\", index) if type(index) is int else index\n        )\n        min_sms = 16\n    else:\n        prop = DeviceProperties.create(\n            torch.device(\"cuda\", index) if type(index) is int else index\n        )\n        min_sms = 80\n\n    avail_sms = prop.multi_processor_count\n    if avail_sms < min_sms:\n        return False\n    return True\n\n\nimport torch._inductor.utils\n\ntorch._inductor.utils.is_big_gpu = is_big_gpu\npatch_torch_compile(\n    debug = UNSLOTH_COMPILE_DEBUG,\n    O3 = UNSLOTH_COMPILE_MAXIMUM,\n    ignore_errors = UNSLOTH_COMPILE_IGNORE_ERRORS,\n)\n\ntorch_compile_options = {\n    \"epilogue_fusion\": True,\n    \"max_autotune\": True,\n    \"shape_padding\": True,\n    \"trace.enabled\": UNSLOTH_COMPILE_DEBUG,\n    \"triton.cudagraphs\": False,\n}\n\nimport accelerate\n\n\ndef torch_compile_kwargs(*args, **kwargs):\n    print(\"Unsloth: Enabled auto compiling\")\n    return {\n        \"dynamic\": True,\n        \"fullgraph\": False,\n        \"options\": torch_compile_options,\n    }\n\n\naccelerate.utils.dataclasses.TorchDynamoPlugin.to_kwargs = torch_compile_kwargs\naccelerate.utils.TorchDynamoPlugin.to_kwargs = torch_compile_kwargs\naccelerate.accelerator.TorchDynamoPlugin.to_kwargs = torch_compile_kwargs\ndel accelerate\n\n\ndef patch_regional_compilation():\n    # Regional torch 2.5 Recompilation - weirdly very slow??\n    if torch.nn.ModuleList.__name__ == \"UnslothModuleList\":\n        return\n    # Only works for torch 2.5\n    if Version(torch.__version__) < Version(\"2.5.0\"):\n        return\n\n    old_module_list = torch.nn.ModuleList\n    os.environ[\"UNSLOTH_PATCHED\"] = \"1\"\n\n    def UnslothModuleList(*args, **kwargs):\n        if len(args) == 1 and len(kwargs) == 0 and type(args[0]) is list:\n            args = [\n                old_module_list(\n                    [\n                        torch.compile(\n                            x,\n                            dynamic = True,\n                            options = torch_compile_options,\n                            fullgraph = False,\n                        )\n                        for x in args[0]\n                    ]\n                )\n            ]\n        return old_module_list(*args, **kwargs)\n\n    UnslothModuleList.__doc__ = old_module_list.__doc__\n\n    torch.nn.ModuleList = UnslothModuleList\n    return\n\n\n# =============================================\n\n\ndef prepare_model_for_kbit_training(\n    model: Any,\n    use_gradient_checkpointing: Optional = True,\n    use_reentrant: Optional[bool] = True,\n) -> Any:\n    return prepare_model_for_training(\n        model = model,\n        use_gradient_checkpointing = use_gradient_checkpointing,\n        use_reentrant = use_reentrant,\n        full_finetuning = False,\n        train_layernorms = False,\n        train_embedding = False,\n        train_lm_head = False,\n        float32_mixed_precision = True,\n    )\n\n\n# =============================================\n# Weirdly LoraLayer.update_layer downcasts PEFT layers to float16??\n# For mixed precision, we need it to be in float32 not float16.\nfrom peft import __version__ as peft_version\nfrom peft.utils.integrations import dequantize_module_weight\n\nif Version(peft_version) < Version(\"0.12.0\"):\n    from peft.tuners.lora.layer import LoraLayer\n\n    try:\n        source = inspect.getsource(LoraLayer.update_layer)\n        text = \"if weight is not None:\\n\"\n        start = source.find(text) + len(text)\n        end = source.find(\"self.to(weight.device)\", start)\n        spaces = re.findall(r\"^([ ]{1,})break\", source, flags = re.MULTILINE)[0]\n        source = source.replace(source[start:end], spaces)\n        spaces = len(re.match(r\"[\\s]{1,}\", source).group(0))\n        lines = source.split(\"\\n\")\n        source = \"\\n\".join(x[spaces:] for x in lines)\n        source = re.sub(r\"([^\\.])nn\\.\", r\"\\1torch.nn.\", source)\n        source = source.replace(\"def update_layer\", \"def LoraLayer_update_layer\")\n        exec(source, globals())\n\n        # Fix up incorrect downcasting of LoRA weights\n        from peft.tuners.lora.layer import LoraLayer\n\n        LoraLayer.update_layer = LoraLayer_update_layer\n        from peft.tuners.lora import LoraLayer\n\n        LoraLayer.update_layer = LoraLayer_update_layer\n    except:\n        logger.warning_once(\n            \"Unsloth unsuccessfully patched LoraLayer.update_layer. Please file a bug report.\\n\"\n            \"Luckily, your training run will still work in the meantime!\"\n        )\n\n# =============================================\nimport importlib\n\nglobal USE_MODELSCOPE\nUSE_MODELSCOPE = os.environ.get(\"UNSLOTH_USE_MODELSCOPE\", \"0\") == \"1\"\nif USE_MODELSCOPE:\n    if importlib.util.find_spec(\"modelscope\") is None:\n        raise ImportError(\n            f\"You are using the modelscope hub, please install modelscope by `pip install modelscope -U`\"\n        )\n\nimport socket\n\n\n@functools.lru_cache(1)\ndef has_internet(host = \"8.8.8.8\", port = 53, timeout = 3):\n    if os.environ.get(\"TRANSFORMERS_OFFLINE\", \"0\") == \"1\":\n        return False\n    try:\n        socket.setdefaulttimeout(timeout)\n        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n        try:\n            sock.connect((host, port))\n            return True\n        finally:\n            sock.close()\n    except socket.error as ex:\n        return False\n\n\nimport psutil\n\n\ndef _get_statistics(statistics = None, force_download = True):\n    # We log some basic stats about which environment is being used.\n    # We simply download a README.md file from HF - all data is made public.\n    # This is simply so we can check if some envs are broken or not.\n    # You can disable this by commenting the below out\n    n_cpus = psutil.cpu_count(logical = False)\n    keynames = \"\\n\" + \"\\n\".join(os.environ.keys())\n    # Check modelscope for down detection\n    global USE_MODELSCOPE\n    USE_MODELSCOPE = os.environ.get(\"UNSLOTH_USE_MODELSCOPE\", \"0\") == \"1\"\n\n    if statistics is None:\n        # Prefer filesystem markers (harder to misidentify) before env-key matching\n        try:\n            from pathlib import Path\n\n            if Path(\"/kaggle/working\").exists():\n                statistics = \"kaggle\"\n            elif Path(\"/content\").exists() and Path(\"/opt/colab\").exists():\n                statistics = \"colab\" if n_cpus == 1 else \"colabpro\"\n            elif Path(\"/runpod-volume\").exists():\n                statistics = \"runpod\"\n        except Exception:\n            pass\n\n        # Fallback to env-key detection\n        if statistics is None:\n            if \"\\nKAGGLE_\" in keynames:\n                statistics = \"kaggle\"\n            elif \"\\nCOLAB_\" in keynames and n_cpus == 1:\n                statistics = \"colab\"\n            elif \"\\nCOLAB_\" in keynames:\n                statistics = \"colabpro\"\n            elif \"\\nRUNPOD_\" in keynames:\n                statistics = \"runpod\"\n            elif \"\\nAWS_\" in keynames:\n                statistics = \"aws\"\n            elif \"\\nAZURE_\" in keynames:\n                statistics = \"azure\"\n            # elif \"\\nK_\" in keynames or \"\\nFUNCTION_\" in keynames: statistics = \"gcp\"\n            elif \"\\nINVOCATION_ID\" in keynames:\n                statistics = \"lambda\"\n            # else: statistics = \"other\"\n            else:\n\n                def try_vllm_check():\n                    vendor_files = (\n                        \"/sys/class/dmi/id/product_version\",\n                        \"/sys/class/dmi/id/bios_vendor\",\n                        \"/sys/class/dmi/id/product_name\",\n                        \"/sys/class/dmi/id/chassis_asset_tag\",\n                        \"/sys/class/dmi/id/sys_vendor\",\n                    )\n\n                    for vendor_file in vendor_files:\n                        path = Path(vendor_file)\n                        if path.is_file():\n                            file_content = path.read_text().lower()\n                            if \"amazon\" in file_content:\n                                return \"aws\"\n                            elif \"microsoft corporation\" in file_content:\n                                return \"azure\"\n                            elif \"google\" in file_content:\n                                return \"gcp\"\n                    return \"other\"\n\n                try:\n                    statistics = try_vllm_check()\n                except Exception:\n                    statistics = \"other\"\n\n    if statistics is not None:\n        import tempfile\n        from huggingface_hub import snapshot_download\n        from unsloth_zoo.rl_environments import execute_with_time_limit\n\n        if has_internet():\n\n            def stats_check():\n                with tempfile.TemporaryDirectory(ignore_cleanup_errors = True) as f:\n                    snapshot_download(\n                        f\"unslothai/{statistics}\",\n                        force_download = True,\n                        cache_dir = f,\n                        local_dir = f,\n                    )\n\n            time_limited_stats_check = execute_with_time_limit(120)(stats_check)\n            try:\n                time_limited_stats_check()\n            except TimeoutError:\n                raise TimeoutError(\n                    \"Unsloth: HuggingFace seems to be down after trying for 120 seconds :(\\n\"\n                    \"Check https://status.huggingface.co/ for more details.\\n\"\n                    \"As a temporary measure, use modelscope with the same model name ie:\\n\"\n                    \"```\\n\"\n                    \"pip install modelscope\\n\"\n                    \"import os; os.environ['UNSLOTH_USE_MODELSCOPE'] = '1'\\n\"\n                    \"from unsloth import FastLanguageModel\\n\"\n                    \"model = FastLanguageModel.from_pretrained('unsloth/gpt-oss-20b')\\n\"\n                    \"```\"\n                )\n            except Exception:\n                # Try no time limit check\n                stats_check()\n\n\ndef get_statistics(local_files_only = False):\n    # We log some basic stats about which environment is being used.\n    # This is also to check if HuggingFace is down or not!\n    # We simply download a README.md file from HF - all data is made public.\n    # This is simply so we can check if some envs are broken or not.\n    # You can disable this by setting UNSLOTH_DISABLE_STATISTICS\n    import os\n\n    if (\n        \"UNSLOTH_DISABLE_STATISTICS\" in os.environ\n        or os.environ.get(\"UNSLOTH_USE_MODELSCOPE\", \"0\") == \"1\"\n    ):\n        return\n    if local_files_only:\n        return\n    from huggingface_hub.utils import (\n        disable_progress_bars,\n        enable_progress_bars,\n        are_progress_bars_disabled,\n    )\n\n    disabled = False\n    if not are_progress_bars_disabled():\n        disable_progress_bars()\n        disabled = True\n    _get_statistics(None)\n    _get_statistics(\"repeat\", force_download = False)\n    total_memory = (\n        torch.xpu.get_device_properties(0).total_memory\n        if DEVICE_TYPE == \"xpu\"\n        else torch.cuda.get_device_properties(0).total_memory\n    )\n    vram = total_memory / 1024 / 1024 / 1024\n    if vram <= 8:\n        vram = 8\n    elif vram <= 16:\n        vram = 16\n    elif vram <= 20:\n        vram = 20\n    elif vram <= 24:\n        vram = 24\n    elif vram <= 40:\n        vram = 40\n    elif vram <= 48:\n        vram = 48\n    elif vram <= 80:\n        vram = 80\n    else:\n        vram = 96\n    _get_statistics(f\"vram-{vram}\")\n    _get_statistics(f\"{DEVICE_COUNT if DEVICE_COUNT <= 8 else 9}\")\n    if disabled:\n        enable_progress_bars()\n\n\n# =============================================\n# Fixes Bitsandbytes to remove missing warnings\nfrom transformers.utils.quantization_config import (\n    BitsAndBytesConfig,\n    QuantizationMethod,\n)\n\nBitsAndBytesConfig__init__ = inspect.getsource(BitsAndBytesConfig.__init__)\nBitsAndBytesConfig__init__ = re.sub(\n    r\"if[\\s]{1,}kwargs\\:[\\s]{1,}.+?\\n\",\n    \"\",\n    BitsAndBytesConfig__init__,\n    flags = re.MULTILINE,\n)\nBitsAndBytesConfig__init__ = BitsAndBytesConfig__init__.split(\"\\n\")\nlength_spaces = len(re.match(r\"[\\s]{1,}\", BitsAndBytesConfig__init__[0]).group(0))\nBitsAndBytesConfig__init__ = \"\\n\".join(\n    x[length_spaces:] for x in BitsAndBytesConfig__init__\n)\nBitsAndBytesConfig__init__ = BitsAndBytesConfig__init__.replace(\n    \"__init__\",\n    \"_BitsAndBytesConfig__init__\",\n)\nexec(BitsAndBytesConfig__init__, globals())\n\nif DEVICE_COUNT == 1 and int(os.environ.get(\"WORLD_SIZE\", \"1\")) <= 1:\n    from accelerate.utils.dataclasses import DistributedType\n\n    def _prepare_backend(self, *args, **kwargs):\n        return None, DistributedType.NO\n\n    import accelerate.state\n\n    accelerate.state.PartialState._prepare_backend = _prepare_backend\n    accelerate.accelerator.Accelerator.distributed_type = (\n        lambda *args, **kwargs: DistributedType.NO\n    )\n\n\n# to move multiple tensors to the same device\ndef move_to_device(target_device, *tensors):\n    \"\"\"\n    Move multiple tensors to target device if they're not already there.\n\n    Args:\n        target_device: The target device to move tensors to\n        *tensors: Variable number of tensors to potentially move\n\n    Returns:\n        tuple: The tensors on the target device (same objects if already on device, new if moved)\n    \"\"\"\n    if isinstance(target_device, int):\n        target_device = torch.device(target_device)\n    elif isinstance(target_device, str):\n        # if string we expect it to be a device name like \"cuda:0\"\n        target_device = torch.device(target_device)\n    elif isinstance(target_device, torch.device):\n        pass\n    else:\n        raise ValueError(f\"Invalid target device: {target_device}\")\n    moved_tensors = []\n    for tensor in tensors:\n        if tensor.device != target_device:\n            moved_tensors.append(tensor.to(target_device))\n        else:\n            moved_tensors.append(tensor)\n    return tuple(moved_tensors) if len(moved_tensors) > 1 else moved_tensors[0]\n\n\nimport transformers.utils.quantization_config\n\ntransformers.utils.quantization_config.BitsAndBytesConfig.__init__ = (\n    _BitsAndBytesConfig__init__\n)\n# =============================================\n\n# Offloading to disk for modules (lm_head, embed_tokens)\nimport pickle\n\n\ndef offload_to_disk(\n    W, model, name, temporary_location: str = \"_unsloth_temporary_saved_buffers\"\n):\n    file_location = os.path.join(temporary_location, model.config._name_or_path)\n    if not os.path.exists(file_location):\n        os.makedirs(file_location)\n\n    filename = os.path.join(file_location, f\"{name}.pt\")\n    W = W.weight if hasattr(W, \"weight\") else W\n    torch.save(\n        W,\n        filename,\n        pickle_module = pickle,\n        pickle_protocol = pickle.HIGHEST_PROTOCOL,\n    )\n    # We must use weights_only = False due to pickling\n    offloaded_W = torch.load(\n        filename, map_location = \"cpu\", mmap = True, weights_only = False\n    )\n    offloaded_W._offloaded_file_location = filename\n    return offloaded_W\n\n\ndef offload_input_embeddings(\n    model, temporary_location: str = \"_unsloth_temporary_saved_buffers\"\n):\n    offloaded_W = offload_to_disk(\n        model.get_input_embeddings(), model, \"input_embeddings\", temporary_location\n    )\n    new_input_embeddings = torch.nn.Embedding.from_pretrained(offloaded_W)\n    new_input_embeddings._offloaded_file_location = offloaded_W._offloaded_file_location\n    model.set_input_embeddings(new_input_embeddings)\n    return\n\n\ndef offload_output_embeddings(\n    model, temporary_location: str = \"_unsloth_temporary_saved_buffers\"\n):\n    offloaded_W = offload_to_disk(\n        model.get_output_embeddings(), model, \"output_embeddings\", temporary_location\n    )\n\n    new_output_embeddings = torch.nn.Linear(1, 1, bias = None)\n    del new_output_embeddings.weight\n    new_output_embeddings.weight = offloaded_W\n    new_output_embeddings.in_features = offloaded_W.shape[1]\n    new_output_embeddings.out_features = offloaded_W.shape[0]\n\n    new_output_embeddings._offloaded_file_location = (\n        offloaded_W._offloaded_file_location\n    )\n    model.set_output_embeddings(new_output_embeddings)\n    return\n\n\n# Fixes a weird Torch 2.3 bug which says T4s have bfloat16\ndef is_bfloat16_supported():\n    return SUPPORTS_BFLOAT16\n\n\ndef is_vLLM_available():\n    return _is_package_available(\"vllm\")\n\n\n# Patches models to add RoPE Scaling\ndef patch_linear_scaling(\n    model_name = \"gemma2\",\n    rope_module = None,\n    scaled_rope_module = None,\n    attention_module = None,\n):\n    assert rope_module is not None and scaled_rope_module is not None\n    assert attention_module is not None\n\n    rope_name = rope_module.__name__\n    scaled_rope_name = scaled_rope_module.__name__\n    model_filepath = f\"transformers.models.{model_name}.modeling_{model_name}\"\n    exec_code = (\n        f\"import torch.nn as nn\\n\"\n        f\"from typing import Union, Optional, List, Any, Callable, Tuple\\n\"\n        f\"from {model_filepath} import logger, \"\n        f\"{model_name.title()}Attention, {model_name.title()}Config\"\n    )\n\n    try:\n        function = inspect.getsource(attention_module.__init__)\n    except:\n        # Most likely already patched!\n        return None, None\n    where = function.find(\"def\")\n    function = function.split(\"\\n\")\n    function = \"\\n\".join(x[where:] for x in function)\n    init_name = f\"{model_name.title()}Attention__init__\"\n    function = function.replace(\"def __init__\", f\"def {init_name}\")\n    function = function.replace(\n        \"super().__init__()\",\n        f\"super({model_name.title()}Attention, self).__init__()\",\n    )\n    fix_rope_function = \"\"\"\n    if getattr(self.config, \"rope_scaling\", None) is None:\n        self.rotary_emb = {rope_function}(\n            dim = self.head_dim,\n            max_position_embeddings=self.max_position_embeddings,\n            base=self.rope_theta,\n        )\n    else:\n        scaling_type = self.config.rope_scaling[\"type\"]\n        scaling_factor = self.config.rope_scaling[\"factor\"]\n        if scaling_type == \"linear\":\n            self.rotary_emb = {scaled_rope_function}(\n                dim = self.head_dim,\n                max_position_embeddings=self.max_position_embeddings,\n                scaling_factor=scaling_factor,\n                base=self.rope_theta,\n            )\n        else:\n            raise ValueError(f\"Unknown RoPE scaling type {{scaling_type}}\")\n    pass\n    \"\"\"\n    fix_rope_function = fix_rope_function.format(\n        rope_function = rope_module.__name__,\n        scaled_rope_function = scaled_rope_module.__name__,\n    )\n    rotary_emb = re.findall(\n        r\"self\\.rotary\\_emb \\= .+?\\)\",\n        function,\n        flags = re.DOTALL | re.MULTILINE,\n    )\n    if len(rotary_emb) == 0:\n        return None, exec_code + \"\\n\\n\" + function\n\n    rotary_emb = rotary_emb[0]\n    function = function.replace(rotary_emb, fix_rope_function, 1)\n    function = exec_code + \"\\n\\n\" + function\n    return init_name, function\n\n\n# Patches for Llama-3 LlamaExtendedRotaryEmbedding\ndef patch_llama_rope_scaling(\n    model_name = \"llama\",\n    rope_module = None,\n    scaled_rope_module = None,\n    extended_rope_module = None,\n    attention_module = None,\n    longrope_module = None,\n):\n    assert (\n        rope_module is not None\n        and scaled_rope_module is not None\n        and extended_rope_module is not None\n    )\n    assert attention_module is not None\n\n    rope_name = rope_module.__name__\n    scaled_rope_name = scaled_rope_module.__name__\n    model_filepath = f\"transformers.models.{model_name}.modeling_{model_name}\"\n    exec_code = (\n        f\"import torch.nn as nn\\n\"\n        f\"from typing import Union, Optional, List, Any, Callable, Tuple\\n\"\n        f\"from {model_filepath} import logger, \"\n        f\"{model_name.title()}Attention, {model_name.title()}Config\"\n    )\n\n    try:\n        function = inspect.getsource(attention_module.__init__)\n    except:\n        # Most likely already patched!\n        return None, None\n    where = function.find(\"def\")\n    function = function.split(\"\\n\")\n    function = \"\\n\".join(x[where:] for x in function)\n    init_name = f\"{model_name.title()}Attention__init__\"\n    function = function.replace(\"def __init__\", f\"def {init_name}\")\n    function = function.replace(\n        \"super().__init__()\",\n        f\"super({model_name.title()}Attention, self).__init__()\",\n    )\n    fix_rope_function = \"\"\"\n    if getattr(self.config, \"rope_scaling\", None) is None:\n        self.rotary_emb = {rope_function}(\n            dim = self.head_dim,\n            max_position_embeddings=self.max_position_embeddings,\n            base=self.rope_theta,\n        )\n    else:\n        scaling_type1 = self.config.rope_scaling.get(\"type\", None)\n        scaling_type2 = self.config.rope_scaling.get(\"rope_type\", None)\n        scaling_type = scaling_type1 if scaling_type1 is not None else scaling_type2\n        scaling_factor = self.config.rope_scaling.get(\"factor\")\n\n        if scaling_type == \"linear\":\n            self.rotary_emb = {scaled_rope_function}(\n                dim = self.head_dim,\n                max_position_embeddings=self.max_position_embeddings,\n                scaling_factor=scaling_factor,\n                base=self.rope_theta,\n            )\n        elif scaling_type == \"llama3\":\n            self.rotary_emb = {extended_rope_function}(\n                dim = self.head_dim,\n                max_position_embeddings=self.max_position_embeddings,\n                base=self.rope_theta,\n            )\n        elif scaling_type == \"longrope\":\n            self.rotary_emb = {longrope_rope_function}(\n                dim = self.head_dim,\n                max_position_embeddings = self.max_position_embeddings,\n                original_max_position_embeddings = self.config.original_max_position_embeddings,\n                base = self.rope_theta,\n                short_factor = self.config.rope_scaling['short_factor'],\n                long_factor  = self.config.rope_scaling['long_factor' ],\n            )\n        else:\n            raise ValueError(f\"Unknown RoPE scaling type {{scaling_type}}\")\n    pass\n    \"\"\"\n\n    fix_rope_function = fix_rope_function.format(\n        rope_function = rope_module.__name__,\n        scaled_rope_function = scaled_rope_module.__name__,\n        extended_rope_function = extended_rope_module.__name__,\n        longrope_rope_function = (\n            longrope_module if longrope_module is not None else rope_module\n        ).__name__,\n    )\n    rotary_emb = re.findall(\n        r\"self\\.rotary\\_emb \\= .+?\\)\",\n        function,\n        flags = re.DOTALL | re.MULTILINE,\n    )\n    if len(rotary_emb) == 0:\n        return None, function\n    rotary_emb = rotary_emb[0]\n    function = function.replace(rotary_emb, fix_rope_function, 1)\n    function = exec_code + \"\\n\\n\" + function\n    return init_name, function\n\n\ndef create_boolean_mask(n = 4096, sliding_window = 2048):\n    # Creates a boolean mask for attention\n    mask = torch.ones(n, n, dtype = torch.bool)\n    if sliding_window == 0:\n        return torch.triu(mask, diagonal = 1, out = mask)\n    torch.triu(mask, diagonal = 0, out = mask)\n    torch.triu(mask.T, diagonal = -sliding_window, out = mask.T)\n    mask = mask.T\n    torch.logical_not(mask, out = mask)\n    return mask\n\n\ndef test_mask_creation():\n    from transformers.modeling_attn_mask_utils import AttentionMaskConverter\n\n    for n in range(2, 23):\n        for s in range(1, 23):\n            correct_mask = (\n                AttentionMaskConverter(\n                    is_causal = True,\n                    sliding_window = s,\n                )\n                .to_causal_4d(\n                    1,\n                    n,\n                    n,\n                    dtype = torch.float16,\n                )\n                .squeeze(0)\n                .squeeze(0)\n            )\n            correct_mask = correct_mask == correct_mask.min()\n            our_mask = create_boolean_mask(n = n, sliding_window = s)\n            assert torch.all(correct_mask == our_mask)\n        correct_mask = (\n            AttentionMaskConverter(\n                is_causal = True,\n                sliding_window = None,\n            )\n            .to_causal_4d(\n                1,\n                n,\n                n,\n                dtype = torch.float16,\n            )\n            .squeeze(0)\n            .squeeze(0)\n        )\n        correct_mask = correct_mask == correct_mask.min()\n        our_mask = create_boolean_mask(n = n, sliding_window = 0)\n        assert torch.all(correct_mask == our_mask)\n\n\ndef _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs):\n    num_items_in_batch = None\n\n    if \"num_items_in_batch\" in kwargs:\n        num_items_in_batch = kwargs[\"num_items_in_batch\"]\n        if num_items_in_batch is None:\n            # Remove it since the model does not support it!\n            kwargs.pop(\"num_items_in_batch\")\n        elif \"num_items_in_batch\" not in inputs:\n            inputs[\"num_items_in_batch\"] = num_items_in_batch\n\n    # Get gradient accumulation steps if possible\n    if (\n        num_items_in_batch is None\n        and getattr(getattr(self, \"args\", self), \"gradient_accumulation_steps\", 1) != 1\n    ):\n        inner_model = model\n        if hasattr(inner_model, \"base_model\"):\n            inner_model = inner_model.base_model\n        if hasattr(inner_model, \"model\"):\n            inner_model = inner_model.model\n        name = inner_model.__class__.__name__\n\n        logger.warning_once(\n            f\"Unsloth: Not an error, but {name} does not accept `num_items_in_batch`.\\n\"\n            \"Using gradient accumulation will be very slightly less accurate.\\n\"\n            \"Read more on gradient accumulation issues here: https://unsloth.ai/blog/gradient\"\n        )\n    # Gemma3 multimodal models in transformers 5.x require token_type_ids during training.\n    # For text-only SFT, token_type_ids should be all zeros (no image tokens).\n    if \"token_type_ids\" not in inputs and \"input_ids\" in inputs:\n        _inner = model\n        for _attr in (\"base_model\", \"model\", \"model\"):\n            _inner = getattr(_inner, _attr, _inner)\n        if getattr(getattr(_inner, \"config\", None), \"model_type\", \"\") in (\"gemma3\",):\n            import sys as _sys\n\n            _mod = _sys.modules.get(type(_inner).__module__)\n            _has_ccm = _mod is not None and hasattr(_mod, \"create_causal_mask_mapping\")\n            if _has_ccm and _inner.training:\n                inputs[\"token_type_ids\"] = torch.zeros_like(inputs[\"input_ids\"])\n\n    outputs = self._old_compute_loss(model, inputs, *args, **kwargs)\n    return outputs\n\n\ndef patch_gradient_accumulation_fix(Trainer):\n    # Fixes gradient accumulation\n    # Fixes Output 0 of UnslothFusedLossBackward is a view and is being modified inplace.\n    import inspect\n\n    if hasattr(Trainer, \"get_batch_samples\"):\n        if Trainer.get_batch_samples.__name__ == \"_unsloth_get_batch_samples\":\n            return\n        if (\n            not inspect.getsource(Trainer.get_batch_samples)\n            .strip()\n            .endswith(\"return batch_samples, num_items_in_batch\")\n        ):\n            raise NotImplementedError(\n                \"Unsloth: Please make a Github issue immediately!!\"\n            )\n        else:\n            if Trainer.get_batch_samples.__name__ != \"_unsloth_get_batch_samples\":\n                Trainer.get_batch_samples = _unsloth_get_batch_samples\n\n            # Also fix passing in num_items_in_batch\n            if not hasattr(Trainer, \"_old_compute_loss\"):\n                # Fix transformers 4.57.0 causing `Output 0 of UnslothFusedLossBackward is a view and is being modified inplace.`\n                function = inspect.getsource(Trainer.compute_loss)\n                if \"loss *=\" in function or \"loss*=\" in function:\n                    where = function.find(\"def\")\n                    function = function.split(\"\\n\")\n                    function = \"\\n\".join(x[where:] for x in function)\n\n                    # Import all variables that need importing\n                    import transformers.trainer\n\n                    items_in_trainer = dir(transformers.trainer)\n                    good_items = []\n                    for item in items_in_trainer:\n                        if item in function:\n                            good_items.append(item)\n                    exec(\n                        \"from transformers.trainer import (\"\n                        + \", \".join(x for x in good_items)\n                        + \")\",\n                        globals(),\n                    )\n\n                    # Replace loss*= with loss = loss *\n                    function = re.sub(\n                        r\"loss[\\s]{0,}\\*\\=\",\n                        \"loss = loss *\",\n                        function,\n                    )\n                    exec(function, globals())\n                    Trainer.compute_loss = compute_loss\n                Trainer._old_compute_loss = Trainer.compute_loss\n                Trainer.compute_loss = _unsloth_pre_compute_loss\n    else:\n        logger.warning_once(\n            \"Unsloth: We fixed a gradient accumulation bug, \"\n            \"but it seems like you don't have the latest transformers version!\\n\"\n            \"Please update transformers, TRL and unsloth via:\\n\"\n            \"`pip install --upgrade --no-cache-dir --no-deps unsloth transformers git+https://github.com/huggingface/trl.git`\"\n        )\n\n    # Also fix up loss scaling ie negate loss *= self.args.gradient_accumulation_steps\n    if not (\n        Trainer.training_step.__name__ == \"_unsloth_training_step\"\n        or \"num_items_in_batch\"\n        not in inspect.signature(Trainer.training_step).parameters\n    ):\n        function = inspect.getsource(Trainer.training_step)\n        where = function.find(\"def\")\n        function = function.split(\"\\n\")\n        function = \"\\n\".join(x[where:] for x in function)\n\n        # Import all variables that need importing\n        import transformers.trainer\n\n        items_in_trainer = dir(transformers.trainer)\n        good_items = []\n        for item in items_in_trainer:\n            if item in function:\n                good_items.append(item)\n        exec(\n            \"from transformers.trainer import (\"\n            + \", \".join(x for x in good_items)\n            + \")\",\n            globals(),\n        )\n\n        # Accelerate does / self.args.gradient_accumulation_steps internally, so if we already\n        # summed it up and did the division before hand, we have to negate it.\n        function = function.replace(\n            \"loss *= self.args.gradient_accumulation_steps\",\n            \"if num_items_in_batch is not None: loss *= self.args.gradient_accumulation_steps\",\n        )\n        function = function.replace(\n            \"def training_step\", \"def _unsloth_training_step\", 1\n        )\n\n        # Fix 4.47.0 issue where num_items_in_batch was removed\n        # See https://github.com/huggingface/transformers/pull/35121\n        function = function.replace(\n            \"if self.model_accepts_loss_kwargs:\",\n            \"if False:\",\n        )\n\n        # Fix when num_items_in_batch is nothing\n        # https://github.com/huggingface/transformers/pull/35207\n        function = re.sub(\n            r\"else:\\n\"\n            r\"([\\s]{4,})self\\.accelerator\\.backward\\(loss, \\*\\*kwargs\\)\\n\"\n            r\"(.+?)if num_items_in_batch is None\\:\\n\"\n            r\"(.+?)return loss\\.detach\\(\\) \\/ self\\.args\\.gradient_accumulation_steps\",\n            \"else:\\n\"\n            \"\\2if num_items_in_batch is None:\\n\"\n            \"\\3loss = loss / self.args.gradient_accumulation_steps\\n\"\n            \"\\1self.accelerator.backward(loss, **kwargs)\",\n            function,\n        )\n\n        exec(function, globals())\n        Trainer.training_step = _unsloth_training_step\n\n    # Prevent double scaling gradient accumulation\n    # https://github.com/huggingface/transformers/pull/37208\n    # Patch model_accepts_loss_kwargs detection in Trainer.__init__\n    if Trainer.__init__.__name__ != \"_unsloth___init__\":\n        try:\n            init_function = inspect.getsource(Trainer.__init__)\n        except Exception:\n            init_function = \"\"\n        if init_function is not None:\n            init_function = textwrap.dedent(init_function)\n\n            # Import all variables that need importing\n            import transformers.trainer\n\n            items_in_trainer = dir(transformers.trainer)\n            good_items = []\n            for item in items_in_trainer:\n                if item in init_function:\n                    good_items.append(item)\n            exec(\n                \"from transformers.trainer import (\"\n                + \", \".join(x for x in good_items)\n                + \")\",\n                globals(),\n            )\n\n            init_function = init_function.replace(\n                \"def __init__\", \"def _unsloth___init__\", 1\n            )\n\n            # Force else branch\n            init_function = re.sub(\n                r'if[\\s]+hasattr\\(\\s*unwrapped_model\\s*,\\s*\"accepts_loss_kwargs\"\\s*\\)\\s*:',\n                'if hasattr(unwrapped_model, \"accepts_loss_kwargs\") and False:',\n                init_function,\n            )\n            exec(init_function, globals())\n            Trainer.__init__ = _unsloth___init__\n\n\ndef patch_tokenizer(model, tokenizer):\n    model, tokenizer = _patch_tokenizer(model, tokenizer)\n    if model is not None:\n        model.config.update({\"unsloth_version\": __version__})\n    return model, tokenizer\n\n\ndef patch_fast_lora():\n    import peft.tuners.lora.bnb\n\n    peft.tuners.lora.bnb.Linear4bit.forward = fast_lora_forward\n\n\ndef unsloth_compile_transformers(\n    dtype,\n    model_name,\n    model_types,\n    token = None,\n    revision = None,\n    trust_remote_code = False,\n    sdpa_dynamic_mask = True,\n    sdpa_bool_masks = True,\n    sdpa_gqa_replace = True,\n    sdpa_dynamic_compile = True,\n    compile_attention = True,\n    disable_causal_masks = True,\n    compile_torch_modules = True,\n    compile_custom_modules = True,\n    compile_function_calls = True,\n    fuse_lm_head = True,\n    gradient_checkpointing = True,\n    manual_replacements = True,\n    fast_lora_forwards = True,\n    fast_residual_stream = True,\n    accurate_accumulation = True,\n    epilogue_fusion = True,\n    max_autotune = False,\n    shape_padding = True,\n    cudagraphs = False,\n    debug = False,\n    fullgraph = True,\n    import_from_cache = False,\n    disable = False,\n    return_logits = False,\n    unsloth_force_compile = False,\n):\n    if Version(torch_version) < Version(\"2.4.0\"):\n        print(\n            \"=\"\n            * 30\n            + \"Unsloth: Unfortunately Unsloth vision and other newer optimized models need Torch 2.4 or later.\\n\"\n            f\"You have Torch version {torch_version}. Please upgrade your Torch version by visiting https://pytorch.org/\\n\"\n            \"For now your models will not get optimized, but will still work for now!\"\n        )\n        return\n    if trust_remote_code and unsloth_force_compile == False:\n        print(\n            \"Unsloth: We can't trace models if `trust_remote_code = True`, \"\n            \"so turning off some optimizations!\"\n        )\n        return model_types, False\n    model_types = list(dict().fromkeys(model_types).keys())\n    if disable:\n        return model_types, False\n\n    supports_sdpa = [True]\n\n    # Run patches BEFORE compiler so class replacements (e.g. GptOssTopKRouter,\n    # GptOssExperts) are in place before the compiler caches references to them.\n    _run_temporary_patches(\"pre_compile\")\n\n    for model_type in model_types:\n        _unsloth_compile_transformers(\n            model_type,\n            sdpa_dynamic_mask = sdpa_dynamic_mask,\n            sdpa_bool_masks = sdpa_bool_masks,\n            sdpa_gqa_replace = sdpa_gqa_replace,\n            sdpa_dynamic_compile = sdpa_dynamic_compile,\n            compile_attention = compile_attention,\n            disable_causal_masks = disable_causal_masks,\n            compile_torch_modules = compile_torch_modules,\n            compile_custom_modules = compile_custom_modules,\n            compile_function_calls = compile_function_calls,\n            fuse_lm_head = fuse_lm_head,\n            gradient_checkpointing = gradient_checkpointing,\n            manual_replacements = manual_replacements,\n            fast_lora_forwards = fast_lora_forwards,\n            fast_residual_stream = fast_residual_stream,\n            accurate_accumulation = accurate_accumulation,\n            epilogue_fusion = epilogue_fusion,\n            max_autotune = max_autotune,\n            shape_padding = shape_padding,\n            cudagraphs = cudagraphs,\n            debug = debug,\n            fullgraph = fullgraph,\n            import_from_cache = import_from_cache,\n            disable = disable,\n            return_logits = return_logits,\n            supports_sdpa = supports_sdpa,\n        )\n    # Redo patches which override compiler\n    _run_temporary_patches(\"post_compile\")\n    return model_types, supports_sdpa[0]\n\n\n# We need an empty logits flag to warn people logits will not be returned anymore unless asked ie\n# os.environ['UNSLOTH_RETURN_LOGITS'] = '1'\nLOGITS_ERROR_STRING = (\n    \"Unsloth: Logits are empty from 2024.11 onwards. To get raw logits again, please \"\n    'set the environment variable `UNSLOTH_RETURN_LOGITS` to `\"1\" BEFORE starting to train ie before `trainer.train()`. For example:\\n'\n    \"```\\nimport os\\n\"\n    \"os.environ['UNSLOTH_RETURN_LOGITS'] = '1'\\n\"\n    \"trainer.train()\\n```\\n\"\n    \"No need to restart your console - just add `os.environ['UNSLOTH_RETURN_LOGITS'] = '1'` before trainer.train() and re-run the cell!\"\n)\n\n\ndef raise_logits_error(*args, **kwargs):\n    raise NotImplementedError(LOGITS_ERROR_STRING)\n\n\ndef return_none(*args, **kwargs):\n    return None\n\n\nclass EmptyLogits:\n    def __init__(self):\n        return\n\n    def raise_getattr_error(self, attr):\n        return return_none if attr == \"to\" else raise_logits_error\n\n    __getitem__ = raise_logits_error\n    __getattr__ = raise_getattr_error\n\n    def __repr__(self):\n        return LOGITS_ERROR_STRING\n\n    def __str__(self):\n        return LOGITS_ERROR_STRING\n\n\nEMPTY_LOGITS = EmptyLogits()\nfunctions = dir(torch.Tensor)\nfor j, function in enumerate(functions):\n    if function.startswith(\"__\") and function.endswith(\"__\"):\n        exec(\n            f\"def raise_{j}(*args, **kwargs): print('{function}')\", globals(), locals()\n        )\n        try:\n            exec(f\"EMPTY_LOGITS.{function} = raise_{j}\", globals(), locals())\n        except:\n            continue\n\n\ndef validate_loftq_config(loftq_config, lora_dropout, bias, init_lora_weights, model):\n    from peft import LoraConfig\n\n    if loftq_config is None:\n        loftq_config = {}\n\n    signature = str(inspect.signature(LoraConfig))\n    SUPPORTS_LOFTQ = \"loftq_config\" in signature\n\n    if lora_dropout != 0:\n        logger.warning_once(\n            f\"Unsloth: Dropout = 0 is supported for fast patching. You are using dropout = {lora_dropout}.\\n\"\n            f\"Unsloth will patch all other layers, except LoRA matrices, causing a performance hit.\"\n        )\n\n    if bias != \"none\":\n        logger.warning_once(\n            f\"Unsloth: bias = `none` is supported for fast patching. You are using bias = {bias}.\\n\"\n            f\"Unsloth will patch all other layers, except LoRA matrices, causing a performance hit.\"\n        )\n\n    if not (\n        type(init_lora_weights) is bool\n        or init_lora_weights == \"gaussian\"\n        or init_lora_weights == \"loftq\"\n        or init_lora_weights == \"corda\"\n    ):\n        raise ValueError(\n            'Unsloth: `init_lora_weights` must be either [True, False, \"gaussian\", \"loftq\", \"corda\"].'\n        )\n\n    if init_lora_weights == \"loftq\":\n        if not SUPPORTS_LOFTQ:\n            import peft\n\n            raise RuntimeError(\n                f\"Unsloth: Your PEFT version of {peft.__version__} does not support LoftQ init.\\n\"\n                \"Please install PEFT 0.7.2 or higher.\\n\"\n                \"You can also install from source: `pip install git+https://github.com/huggingface/peft.git\"\n            )\n\n        if loftq_config == {}:\n            from peft import LoftQConfig\n\n            logger.warning_once(\n                \"Unsloth: init_lora_weights = `loftq` is set, but `loftq_config` is None.\\n\"\n                \"We shall use `loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1)`.\"\n            )\n            loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1)\n\n        if hasattr(model.config, \"quantization_config\"):\n            raise ValueError(\n                \"Unsloth: You are using `loftq` init, yet `load_in_4bit = True` was set.\\n\"\n                \"Reload your model without any quantization by setting `load_in_4bit = False`.\"\n            )\n\n    return loftq_config\n\n\ndef fast_inference_setup(model_name, model_config):\n    fast_inference = True\n    if not is_vLLM_available():\n        logger.warning_once(\n            \"Unsloth: vLLM is not installed! Will use Unsloth inference!\"\n        )\n        fast_inference = False\n    from unsloth_zoo.vllm_utils import (\n        patch_vllm,\n        vllm_dynamic_quant_supported,\n    )\n\n    patch_vllm()\n    if model_name.endswith(\"unsloth-bnb-4bit\"):\n        if not vllm_dynamic_quant_supported(model_name, model_config):\n            # Instead use -bnb-4bit variant\n            logger.warning_once(\n                f\"Unsloth: Switching from Unsloth dynamic quant to normal quant since\\n\"\n                f\"we do not yet support fast inference for {model_name}\"\n            )\n            model_name = model_name[: -len(\"unsloth-bnb-4bit\")] + \"bnb-4bit\"\n    return fast_inference, model_name\n\n\ndef patch_peft_fast_inference(model):\n    vllm_engine = getattr(model.model, \"vllm_engine\", None)\n    if vllm_engine is not None:\n        model.vllm_engine = model.model.vllm_engine\n        model.fast_generate = model.model.fast_generate\n        model.fast_generate_batches = model.model.fast_generate_batches\n\n        # Also saving and loading LoRA\n        from unsloth_zoo.vllm_utils import save_lora, load_lora\n\n        model.save_lora = functools.partial(save_lora, model)\n        model.load_lora = functools.partial(load_lora, model)\n\n\ndef error_out_no_vllm(*args, **kwargs):\n    raise NotImplementedError(\n        \"Unsloth: vLLM is not yet supported for fast inference for this model! Please use `.generate` instead\"\n    )\n\n\ntry:\n    from torchao.core.config import AOBaseConfig\n\n    try:\n        from torchao.quantization import Int4WeightOnlyConfig\n    except:\n        print(\"Unsloth: TorchAO changed `torchao.quantization.Int4WeightOnlyConfig`\")\n        Int4WeightOnlyConfig = None\nexcept:\n    AOBaseConfig = None\n    Int4WeightOnlyConfig = None\n\n\n@dataclass\nclass TorchAOConfig:\n    qat_scheme: Optional[str] = \"int4\"\n\n    # Each (config, filter_fn) pair defines a quantization rule\n    base_config_and_filter_fns: List[\n        Tuple[\"AOBaseConfig\", Optional[Callable[[torch.nn.Module, str], bool]]]\n    ] = field(\n        default_factory = lambda: [\n            (\n                Int4WeightOnlyConfig(group_size = 128),\n                lambda m, _: isinstance(m, torch.nn.Linear)\n                and getattr(m, \"in_features\", 0) >= 128,\n            ),\n        ]\n    )\n\n    # Optional transformation to apply before quantization setup\n    prequantization_transform: Optional[Callable[[torch.nn.Module], None]] = None\n\n\ndef _untie_input_output_embeddings(model: torch.nn.Module) -> None:\n    \"\"\"\n    Utility to untie input/output embeddings in a HuggingFace model.\n    This is useful if we want to quantize the input/ouput embeddings differently.\n    Model is modified in-place.\n    \"\"\"\n\n    # 1) Persist setting in config\n    if hasattr(model.config, \"tie_word_embeddings\"):\n        model.config.tie_word_embeddings = False\n\n    # 2) Find input and output embeddings\n    in_emb = model.get_input_embeddings()\n    out_proj = model.get_output_embeddings() or getattr(model, \"lm_head\", None)\n    if out_proj is None:\n        raise AttributeError(\"Couldn't locate output projection (lm_head).\")\n\n    # (Optional) sanity: shapes should match [vocab, hidden]\n    assert (\n        out_proj.weight.shape == in_emb.weight.shape\n    ), f\"Shape mismatch: out_proj {out_proj.weight.shape} vs in_emb {in_emb.weight.shape}\"\n\n    # 3) Only clone if they are actually tied (shared storage)\n    if out_proj.weight.data_ptr() == in_emb.weight.data_ptr():\n        with torch.no_grad():\n            W = in_emb.weight.detach().clone()\n        out_proj.weight = torch.nn.Parameter(W)  # new storage, keeps dtype/device\n\n    # 4) Prevent future automatic re-tying\n    def _no_tie(self):\n        return\n\n    model.tie_weights = _no_tie.__get__(model, model.__class__)\n\n    # 5) Verify no shared storage\n    assert (\n        out_proj.weight.data_ptr() != in_emb.weight.data_ptr()\n    ), \"Embeddings still tied!\"\n\n\ndef _filter_fn_to_fqns(\n    model: torch.nn.Module,\n    filter_fn: Callable[[torch.nn.Module, str], bool],\n) -> Iterator[str]:\n    \"\"\"\n    Given a model and a filter function (m, fqn) -> bool,\n    yield fully qualified names (FQNs) of modules that match.\n    \"\"\"\n    for fqn, module in model.named_modules():\n        if filter_fn(module, fqn):\n            yield fqn\n\n\ndef _convert_torchao_model(model):\n    from transformers import TorchAoConfig\n    from torchao.quantization import quantize_, ModuleFqnToConfig\n    from torchao.quantization.qat import QATConfig\n    from torchao.utils import TorchAOBaseTensor\n\n    module_to_fqn_dict = {}\n    for base_config, filter_fn in model._torchao_config.base_config_and_filter_fns:\n        quantize_(model, QATConfig(base_config, step = \"convert\"), filter_fn = filter_fn)\n\n        # Default filter function used for quantize_\n        if filter_fn is None:\n            if \"_default\" in module_to_fqn_dict:\n                raise ValueError(\"Cannot use multiple default quantization configs\")\n            module_to_fqn_dict[\"_default\"] = base_config\n        else:\n            for fqn in _filter_fn_to_fqns(model, filter_fn):\n                if fqn in module_to_fqn_dict:\n                    raise ValueError(f\"Found multiple quantization configs for {fqn}\")\n                module_to_fqn_dict[fqn] = base_config\n\n    in_emb = model.get_input_embeddings()\n    out_proj = model.get_output_embeddings() or getattr(model, \"lm_head\", None)\n    kwargs = {}\n    if isinstance(in_emb.weight, TorchAOBaseTensor) or (\n        out_proj is not None and isinstance(out_proj.weight, TorchAOBaseTensor)\n    ):\n        kwargs[\"include_input_output_embeddings\"] = True\n        kwargs[\"modules_to_not_convert\"] = []\n\n    quant_config = ModuleFqnToConfig(module_to_fqn_dict)\n    quantization_config = TorchAoConfig(quant_type = quant_config, **kwargs)\n    model.config.quantization_config = quantization_config\n\n\ndef _prepare_model_for_qat(\n    model: torch.nn.Module, qat_scheme: Union[str, TorchAOConfig]\n) -> torch.nn.Module:\n    \"\"\"\n    Transform a model for Quantization-Aware Training (QAT) during fine-tuning.\n\n    On a high level, this means fake quantizing the base (frozen) model during training.\n    Fake quantization refers to simulating quantization numerics in high precision (e.g. bf16).\n    This helps mitigate quantization degradations when the model is quantized after training.\n\n    QAT can be optionally combined with LoRA fine-tuning to for additional throughput improvement.\n    For more details: https://dev-discuss.pytorch.org/t/speeding-up-qat-by-1-89x-with-lora/2700\n    \"\"\"\n    try:\n        from torchao.quantization import PerRow, quantize_\n        from torchao.quantization.granularity import PerGroup, PerAxis\n        from torchao.quantization.qat import QATConfig\n    except ImportError:\n        raise ImportError(TORCHAO_MSG)\n\n    # Gemma3 models have issues with int8 embedding quantization due to their\n    # large vocabulary size (262144). Auto-switch to int4 weight-only instead.\n    if qat_scheme == \"int8-int4\":\n        model_types = get_transformers_model_type(model.config)\n        is_gemma3 = any(\"gemma3\" in mt or \"gemma_3\" in mt for mt in model_types)\n        if is_gemma3:\n            print(\n                \"Unsloth: Gemma3 has a large vocabulary causing int8 embedding issues. \"\n                \"Switching to int4 weight-only QAT for training stability.\"\n            )\n            qat_scheme = \"int4\"\n\n    if not isinstance(qat_scheme, TorchAOConfig):\n        torchao_config: Optional[TorchAOConfig] = None\n        if qat_scheme == \"fp8-int4\":\n            try:\n                from torchao.quantization import Float8DynamicActivationInt4WeightConfig\n            except ImportError:\n                raise ImportError(TORCHAO_MSG)\n            group_size = 128\n            base_config = Float8DynamicActivationInt4WeightConfig()\n            filter_fn = (\n                lambda m, _: isinstance(m, torch.nn.Linear)\n                and m.in_features >= group_size\n            )\n            torchao_config = TorchAOConfig(\n                qat_scheme = qat_scheme,\n                base_config_and_filter_fns = [(base_config, filter_fn)],\n            )\n        elif qat_scheme == \"fp8-fp8\":\n            try:\n                from torchao.quantization import (\n                    Float8DynamicActivationFloat8WeightConfig,\n                )\n            except ImportError:\n                raise ImportError(TORCHAO_MSG)\n            base_config = Float8DynamicActivationFloat8WeightConfig(\n                granularity = PerRow()\n            )\n            torchao_config = TorchAOConfig(\n                qat_scheme = qat_scheme, base_config_and_filter_fns = [(base_config, None)]\n            )\n        elif qat_scheme == \"int8-int4\":\n            try:\n                from torchao.quantization import (\n                    Int8DynamicActivationIntxWeightConfig,\n                    IntxWeightOnlyConfig,\n                )\n            except ImportError:\n                raise ImportError(TORCHAO_MSG)\n            torchao_config = TorchAOConfig(\n                qat_scheme = qat_scheme,\n                base_config_and_filter_fns = [\n                    (\n                        IntxWeightOnlyConfig(\n                            weight_dtype = torch.int8, granularity = PerAxis(0)\n                        ),\n                        lambda m, fqn: isinstance(m, torch.nn.Embedding),\n                    ),\n                    (\n                        Int8DynamicActivationIntxWeightConfig(\n                            weight_dtype = torch.int4, weight_granularity = PerGroup(32)\n                        ),\n                        None,\n                    ),\n                ],\n                prequantization_transform = _untie_input_output_embeddings,\n            )\n        elif qat_scheme == \"int4\":\n            try:\n                from torchao.quantization import Int4WeightOnlyConfig\n            except ImportError:\n                raise ImportError(TORCHAO_MSG)\n            group_size = 128\n            base_config = Int4WeightOnlyConfig(group_size = group_size)\n            filter_fn = (\n                lambda m, _: isinstance(m, torch.nn.Linear)\n                and m.in_features >= group_size\n            )\n            torchao_config = TorchAOConfig(\n                qat_scheme = qat_scheme,\n                base_config_and_filter_fns = [(base_config, filter_fn)],\n            )\n        elif qat_scheme == \"int8\":\n            try:\n                from torchao.quantization import IntxWeightOnlyConfig\n                from torchao.quantization.granularity import PerAxis\n            except ImportError:\n                raise ImportError(TORCHAO_MSG)\n\n            base_config = IntxWeightOnlyConfig(\n                weight_dtype = torch.int8,\n                granularity = PerAxis(0),\n            )\n            filter_fn = lambda m, _: isinstance(m, torch.nn.Linear)\n            torchao_config = TorchAOConfig(\n                qat_scheme = qat_scheme,\n                base_config_and_filter_fns = [(base_config, filter_fn)],\n            )\n        else:\n            raise ValueError(f\"Unexpected QAT scheme {qat_scheme}\")\n        assert torchao_config is not None, f\"TorchAOConfig was not set for {qat_scheme}\"\n    else:\n        torchao_config = qat_scheme\n\n    # Save Torchao metadata everywhere\n    inner_model = model\n    while hasattr(inner_model, \"model\"):\n        inner_model._torchao_config = torchao_config\n        inner_model = inner_model.model\n    inner_model._torchao_config = torchao_config\n\n    if torchao_config.prequantization_transform is not None:\n        torchao_config.prequantization_transform(model)\n    for base_config, filter_fn in torchao_config.base_config_and_filter_fns:\n        quantize_(model, QATConfig(base_config, step = \"prepare\"), filter_fn = filter_fn)\n\n    return model\n\n\ndef patch_hf_quantizer():\n    # To tell hf trainer that the quantized model is trainable\n    def make_trainable(self):\n        return True\n\n    try:\n        from transformers.quantizers.quantizer_finegrained_fp8 import (\n            FineGrainedFP8HfQuantizer,\n        )\n\n        FineGrainedFP8HfQuantizer.is_trainable = property(make_trainable)\n        FineGrainedFP8HfQuantizer.is_qat_trainable = property(make_trainable)\n    except Exception as e:\n        logger.warning(f\"Failed to patch FineGrainedFP8HfQuantizer. Error {e}\")\n\n    try:\n        from transformers.quantizers.quantizer_fbgemm_fp8 import FbgemmFp8HfQuantizer\n\n        FbgemmFp8HfQuantizer.is_trainable = property(make_trainable)\n        FbgemmFp8HfQuantizer.is_qat_trainable = property(make_trainable)\n    except Exception as e:\n        logger.warning(f\"Failed to patch FbgemmFp8HfQuantizer. Error {e}\")\n\n    try:\n        from transformers.quantizers.quantizer_torchao import TorchAoHfQuantizer\n\n        TorchAoHfQuantizer.is_trainable = property(make_trainable)\n        TorchAoHfQuantizer.is_qat_trainable = property(make_trainable)\n    except Exception as e:\n        logger.warning(f\"Failed to patch TorchAoHfQuantizer. Error {e}\")\n\n\npatch_hf_quantizer()\n\n\ndef verify_fp8_support_if_applicable(model_config):\n    quant_method = get_quant_type(model_config)\n    if quant_method in [\"fbgemm_fp8\", \"fp8\"] and DEVICE_TYPE != \"cuda\":\n        raise ValueError(\n            f\"Unsloth: FP8 quantization is only supported on CUDA GPUs. You are using {DEVICE_TYPE}.\"\n        )\n\n    # [TODO] Need to add FP8 support for Intel XPUs\n    if DEVICE_TYPE == \"cuda\":\n        major_version, minor_version = torch.cuda.get_device_capability()\n        if quant_method == \"fbgemm_fp8\" and major_version < 9:\n            # While L4 does support FP8 as data type, it doesn't have fbgemm (package) support yet. So we restrict it.\n            raise ValueError(\n                f\"Unsloth: FBGEMM FP8 quantization is only supported on H100 and higher GPUs. L4 is not supported. You are using {torch.cuda.get_device_name()}. Refer to https://developer.nvidia.com/cuda-gpus for more details.\"\n            )\n        if quant_method == \"fp8\" and major_version * 10 + minor_version < 89:\n            # In case of block quantized, we allow L4 because we fall back to torchao kernels.\n            raise ValueError(\n                f\"Unsloth: FP8 quantization is only supported on L4 and higher GPUs with compute capability 8.9 or higher. You are using {torch.cuda.get_device_name()}. Refer to https://developer.nvidia.com/cuda-gpus for more details.\"\n            )\n\n\ndef _get_inference_mode_context_manager(model: torch.nn.Module):\n    \"\"\"\n    If the state dict was quantized using torchao, we will run into\n    the following error when calling ops like aten.t() in inference mode.\n    This is a bug in PyTorch that affects all tensor subclasses.\n\n        Cannot set version_counter for inference tensor\n\n    For now, we work around this issue by using `torch.no_grad()` in this case.\n    See https://github.com/pytorch/pytorch/issues/164872 for more details.\n    Otherwise, just return `torch.inference_mode()`.\n    \"\"\"\n    torchao_config = getattr(model, \"torchao_config\", None)\n    if torchao_config is not None and torchao_config.qat_scheme is None:\n        return torch.no_grad()\n    else:\n        return torch.inference_mode()\n\n\ndef hf_login(token: Optional[str] = None) -> Optional[str]:\n    if token is None:\n        try:\n            from huggingface_hub import get_token\n\n            token = get_token()\n            if token is None:\n                return None\n        except:\n            return None\n    try:\n        from huggingface_hub import login\n\n        login(token = token)\n        return token\n    except Exception as e:\n        logger.info(f\"Failed to login to huggingface using token with error: {e}\")\n    return token\n\n\n# =============================================\n# MoE (Mixture of Experts) Detection and LoRA Utilities\n\n\ndef is_moe_model(model) -> bool:\n    \"\"\"\n    Detect if a model is a Mixture of Experts (MoE) model.\n\n    Args:\n        model: The model to check (can be HF model or config)\n\n    Returns:\n        True if the model is an MoE model, False otherwise\n    \"\"\"\n    config = getattr(model, \"config\", model)\n\n    # Different MoE models use different config attribute names:\n    # - Qwen3-MoE: num_experts\n    # - GLM4-MoE: n_routed_experts, num_local_experts\n    # - Mixtral: num_local_experts\n    num_experts = None\n    for attr in (\"num_experts\", \"n_routed_experts\", \"num_local_experts\"):\n        num_experts = getattr(config, attr, None)\n        if num_experts is not None:\n            break\n\n    # Check text_config for VL models\n    if num_experts is None and hasattr(config, \"text_config\"):\n        for attr in (\"num_experts\", \"n_routed_experts\", \"num_local_experts\"):\n            num_experts = getattr(config.text_config, attr, None)\n            if num_experts is not None:\n                break\n\n    return num_experts is not None and num_experts > 0\n\n\ndef get_moe_target_parameters(model, target_modules = None) -> Optional[List[str]]:\n    \"\"\"\n    Get the target_parameters for MoE expert layers if applicable.\n\n    For MoE models, returns the parameter paths for expert weights\n    (gate_up_proj, down_proj) that should be targeted by PEFT's\n    target_parameters for LoRA on nn.Parameter.\n\n    Only includes MoE parameters that match what's in target_modules:\n    - If \"down_proj\" is in target_modules -> includes \"mlp.experts.down_proj\"\n    - If \"gate_proj\" or \"up_proj\" is in target_modules -> includes \"mlp.experts.gate_up_proj\"\n\n    Args:\n        model: The model to get target parameters for\n        target_modules: List/tuple of target module names to match against\n\n    Returns:\n        List of parameter paths for MoE experts, or None if not an MoE model\n    \"\"\"\n    if not is_moe_model(model):\n        return None\n\n    config = getattr(model, \"config\", model)\n    # Get num_experts from various possible config attributes\n    num_experts = None\n    for attr in (\"num_experts\", \"n_routed_experts\", \"num_local_experts\"):\n        num_experts = getattr(config, attr, None)\n        if num_experts is not None:\n            break\n    if num_experts is None and hasattr(config, \"text_config\"):\n        for attr in (\"num_experts\", \"n_routed_experts\", \"num_local_experts\"):\n            num_experts = getattr(config.text_config, attr, None)\n            if num_experts is not None:\n                break\n    if num_experts is None:\n        num_experts = 0\n\n    # Determine which MoE parameters to include based on target_modules\n    moe_params = []\n\n    # Normalize target_modules to a set for efficient lookup\n    if target_modules is None:\n        # If no target_modules specified, include all MoE params\n        target_set = {\"gate_proj\", \"up_proj\", \"down_proj\", \"gate_up_proj\"}\n    elif isinstance(target_modules, str):\n        target_set = {target_modules}\n        # Heuristic for regex matching MLPs\n        if \"proj\" in target_modules and (\n            \"mlp\" in target_modules or \"ffn\" in target_modules\n        ):\n            target_set.update({\"gate_proj\", \"up_proj\", \"down_proj\", \"gate_up_proj\"})\n    else:\n        target_set = set(target_modules) if target_modules else set()\n\n    # gate_up_proj combines both gate_proj and up_proj in MoE\n    # Also match \"gate_up_proj\" directly since users may specify the fused name\n    if (\n        \"gate_proj\" in target_set\n        or \"up_proj\" in target_set\n        or \"gate_up_proj\" in target_set\n    ):\n        moe_params.append(\"mlp.experts.gate_up_proj\")\n\n    if \"down_proj\" in target_set:\n        moe_params.append(\"mlp.experts.down_proj\")\n\n    if moe_params:\n        print(\n            f\"Unsloth: Detected MoE model with {num_experts = } and {target_modules = }. Enabling LoRA on MoE parameters: {moe_params}\"\n        )\n        return moe_params\n\n    return None\n\n\ndef make_fast_generate_wrapper(original_generate):\n    \"\"\"\n    Creates a wrapper around model.generate that checks for incorrect\n    vLLM-style usage when fast_inference=False.\n    \"\"\"\n\n    @functools.wraps(original_generate)\n    def _fast_generate_wrapper(*args, **kwargs):\n        # Check for vLLM-specific arguments\n        if \"sampling_params\" in kwargs:\n            raise ValueError(\n                \"Unsloth: `sampling_params` is only supported when `fast_inference=True` (vLLM). \"\n                \"Since `fast_inference=False`, use HuggingFace generate arguments instead:\\n\"\n                \"  model.fast_generate(**tokens.to('cuda'), max_new_tokens=64, temperature=1.0, top_p=0.95)\"\n            )\n\n        if \"lora_request\" in kwargs:\n            raise ValueError(\n                \"Unsloth: `lora_request` is only supported when `fast_inference=True` (vLLM). \"\n                \"Since `fast_inference=False`, LoRA weights are already merged into the model.\"\n            )\n\n        # Check if first positional argument is a string or list of strings\n        if len(args) > 0:\n            first_arg = args[0]\n            is_string_input = False\n\n            if isinstance(first_arg, str):\n                is_string_input = True\n            elif isinstance(first_arg, (list, tuple)) and len(first_arg) > 0:\n                if isinstance(first_arg[0], str):\n                    is_string_input = True\n\n            if is_string_input:\n                raise ValueError(\n                    \"Unsloth: Passing text strings to `fast_generate` is only supported \"\n                    \"when `fast_inference=True` (vLLM). Since `fast_inference=False`, you must \"\n                    \"tokenize the input first:\\n\\n\"\n                    \"  messages = tokenizer.apply_chat_template(\\n\"\n                    '      [{\"role\": \"user\", \"content\": \"Your prompt here\"}],\\n'\n                    \"      tokenize=True, add_generation_prompt=True,\\n\"\n                    '      return_tensors=\"pt\", return_dict=True\\n'\n                    \"  )\\n\"\n                    \"  output = model.fast_generate(\\n\"\n                    \"      **messages.to('cuda'),\\n\"\n                    \"      max_new_tokens=64,\\n\"\n                    \"      temperature=1.0,\\n\"\n                    \"  )\"\n                )\n\n        # Call original generate\n        return original_generate(*args, **kwargs)\n\n    return _fast_generate_wrapper\n\n\n# Fix llm_int8_skip_modules not being respected for VLMs with dynamic quantization.\n# Dynamic quant checkpoints (eg gemma-3-4b-it-unsloth-bnb-4bit) encode skip paths as\n# \"language_model.model.layers.*\", but the live module tree surfaces them as\n# \"model.language_model.layers.*\". This prefix mismatch causes should_convert_module\n# to miss the skip list, so modules meant to stay in 16-bit get wrapped in Linear4bit\n# without a quant_state, producing \"Skipping ... no quant_state found\" warnings.\n# We patch should_convert_module to expand both the module name and the skip patterns\n# into all equivalent alias forms before delegating to the original matcher.\n# Ref: https://github.com/unslothai/unsloth/issues/4208\nimport transformers.quantizers.quantizers_utils as _quantizers_utils\n\nif (\n    hasattr(_quantizers_utils, \"should_convert_module\")\n    and getattr(_quantizers_utils.should_convert_module, \"__name__\", \"\")\n    != \"patched_should_convert_module\"\n):\n    _original_should_convert_module = _quantizers_utils.should_convert_module\n\n    def _get_full_name_aliases(full_name):\n        aliases = {full_name}\n        if not isinstance(full_name, str):\n            return aliases\n\n        if full_name.startswith(\"model.language_model.\"):\n            aliases.add(full_name[len(\"model.\") :])\n        if \"language_model.model.\" in full_name:\n            aliases.add(full_name.replace(\"language_model.model.\", \"language_model.\"))\n        if full_name.startswith(\"model.language_model.model.\"):\n            aliases.add(\n                full_name[len(\"model.\") :].replace(\n                    \"language_model.model.\", \"language_model.\"\n                )\n            )\n        return aliases\n\n    def _get_pattern_aliases(pattern):\n        aliases = {pattern}\n        if not isinstance(pattern, str):\n            return aliases\n\n        if \"language_model.model.\" in pattern:\n            aliases.add(pattern.replace(\"language_model.model.\", \"language_model.\"))\n        return aliases\n\n    def _expand_patterns(patterns):\n        expanded = set()\n        for pattern in patterns:\n            expanded.update(_get_pattern_aliases(pattern))\n        return expanded\n\n    def patched_should_convert_module(full_name, patterns = None):\n        if patterns is None:\n            return _original_should_convert_module(full_name, patterns)\n\n        expanded_patterns = _expand_patterns(patterns)\n        return all(\n            _original_should_convert_module(candidate, expanded_patterns)\n            for candidate in _get_full_name_aliases(full_name)\n        )\n\n    patched_should_convert_module._original_should_convert_module = (\n        _original_should_convert_module\n    )\n    _quantizers_utils.should_convert_module = patched_should_convert_module\n\n    try:\n        import transformers.integrations.bitsandbytes\n\n        transformers.integrations.bitsandbytes.should_convert_module = (\n            patched_should_convert_module\n        )\n    except Exception:\n        pass\n"
  },
  {
    "path": "unsloth/models/cohere.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .llama import *\nfrom ._utils import __version__\nfrom unsloth_zoo.hf_utils import dtype_from_config\nfrom unsloth_zoo.utils import _get_dtype, Version\nfrom ..utils.packing import get_packed_info_from_kwargs\nfrom ..utils.attention_dispatch import (\n    AttentionConfig,\n    AttentionContext,\n    run_attention,\n    select_attention_backend,\n)\n\ntry:\n    from transformers.models.cohere.modeling_cohere import (\n        CohereAttention,\n        CohereDecoderLayer,\n        CohereModel,\n        CohereForCausalLM,\n        CohereRotaryEmbedding,\n        apply_rotary_pos_emb,\n        repeat_kv,\n    )\nexcept:\n    transformers_version = Version(transformers_version)\n    if not transformers_version >= Version(\"4.42\"):\n        raise ImportError(\n            f\"Unsloth: Your transformers version of {transformers_version} does not support Cohere.\\n\"\n            f\"The minimum required version is 4.42.3.\\n\"\n            f'Try `pip install --upgrade \"transformers>=4.42.3\"`\\n'\n            f\"to obtain the latest transformers build, then restart this session.\"\n        )\n\nfrom transformers.modeling_attn_mask_utils import (\n    _prepare_4d_causal_attention_mask_for_sdpa,\n)\n\n# For Pytorch 2.1.1\ntry:\n    from transformers.models.cohere.modeling_cohere import (\n        CohereSdpaAttention,\n        CohereFlashAttention2,\n    )\nexcept:\n    CohereSdpaAttention = CohereAttention\n    CohereFlashAttention2 = CohereAttention\n\n\ndef fast_layernorm_inference(self, X, out_weight = None):\n    XX = X.to(torch.float32, copy = True)\n    XX -= X.mean(-1, keepdim = True)\n    variance = XX.square().mean(-1, keepdim = True)\n    variance += self.variance_epsilon\n    XX *= variance.rsqrt_()\n    out_weight[:] = self.weight\n    XX *= out_weight\n    return XX.to(X.dtype)\n\n\n# QK norm in Cohere\ndef CohereAttention_fast_forward(\n    self,\n    hidden_states: torch.Tensor,\n    causal_mask: Optional[BlockDiagonalCausalMask] = None,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_value: Optional[Tuple[torch.Tensor]] = None,\n    output_attentions: bool = False,\n    use_cache: bool = False,\n    padding_mask: Optional[torch.LongTensor] = None,\n    position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n    *args,\n    **kwargs,\n) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n    # Clear inference\n    if hasattr(self, \"paged_attention\"):\n        del self.paged_attention_K\n        del self.paged_attention_V\n        del self.paged_attention\n        del self.temp_QA\n        del self.temp_KV\n        del self.RH_Q\n        del self.attention\n        del self.q_norm_out_weight\n        del self.k_norm_out_weight\n\n    bsz, q_len, _ = hidden_states.size()\n\n    n_heads = self.config.num_attention_heads\n    n_groups = self.num_key_value_groups\n    n_kv_heads = self.config.num_key_value_heads\n    head_dim = self.head_dim\n    assert n_kv_heads * n_groups == n_heads\n\n    Q, K, V = self.apply_qkv(self, hidden_states)\n    Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)\n    K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)\n    V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)\n    seq_info = get_packed_info_from_kwargs(kwargs, Q.device)\n    if self.use_qk_norm:\n        Q = fast_layernorm_compiled(self.q_norm, Q)\n        K = fast_layernorm_compiled(self.k_norm, K)\n\n    kv_seq_len = K.shape[-2]\n    if past_key_value is not None:\n        kv_seq_len += past_key_value[0].shape[-2]\n\n    # Extend RoPE dynamically to fit in VRAM\n    if position_embeddings:\n        cos, sin = position_embeddings\n    else:\n        cos, sin = self.rotary_emb.get_cached(kv_seq_len, Q.device.index)\n\n    rope_position_ids = (\n        position_ids if position_ids is not None else kwargs.get(\"position_ids\")\n    )\n    # Useful for LongRoPE\n    Q, K = fast_rope_embedding(Q, K, cos, sin, rope_position_ids)\n\n    if past_key_value is not None:\n        K = torch.cat([past_key_value[0], K], dim = 2)\n        V = torch.cat([past_key_value[1], V], dim = 2)\n    past_key_value = (K, V) if use_cache else None\n\n    # Attention module\n    use_varlen = seq_info is not None and past_key_value is None\n    backend = select_attention_backend(use_varlen)\n    attention_config = AttentionConfig(\n        backend = backend,\n        n_kv_heads = n_kv_heads,\n        n_groups = n_groups,\n        flash_dense_kwargs = {\"causal\": True},\n        flash_varlen_kwargs = {\n            \"dropout_p\": 0.0,\n            \"causal\": True,\n            \"softmax_scale\": getattr(self, \"softmax_scale\", None),\n        },\n    )\n    context = AttentionContext(\n        bsz = bsz,\n        q_len = q_len,\n        kv_seq_len = kv_seq_len,\n        n_heads = n_heads,\n        head_dim = head_dim,\n        requires_grad = hidden_states.requires_grad,\n        seq_info = seq_info,\n        attention_mask = attention_mask,\n        causal_mask = causal_mask,\n    )\n\n    A = run_attention(config = attention_config, context = context, Q = Q, K = K, V = V)\n\n    attn_output = A.reshape(bsz, q_len, n_heads * head_dim)\n    attn_output = self.apply_o(self, attn_output)\n    attn_weights = None\n    return attn_output, attn_weights, past_key_value\n\n\n# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L590\ndef CohereDecoderLayer_fast_forward(\n    self,\n    hidden_states: torch.Tensor,\n    causal_mask: Optional[BlockDiagonalCausalMask] = None,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_value: Optional[Tuple[torch.Tensor]] = None,\n    output_attentions: Optional[bool] = False,\n    use_cache: Optional[bool] = False,\n    padding_mask: Optional[torch.LongTensor] = None,\n    position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n    *args,\n    **kwargs,\n):\n    if use_cache and hasattr(\n        self, \"_flag_for_generation\"\n    ):  # past_key_value is not None:\n        out_weight = torch.empty(\n            self.input_layernorm.weight.shape,\n            dtype = torch.float32,\n            device = f\"{DEVICE_TYPE_TORCH}:0\",\n        )\n\n        # Self Attention\n        residual = hidden_states\n        hidden_states = fast_layernorm_inference(\n            self.input_layernorm, hidden_states, out_weight\n        )\n        hidden_states_attention, self_attn_weights, present_key_value = self.self_attn(\n            hidden_states = hidden_states,\n            causal_mask = causal_mask,\n            attention_mask = attention_mask,\n            position_ids = position_ids,\n            past_key_value = past_key_value,\n            output_attentions = output_attentions,\n            use_cache = use_cache,\n            padding_mask = padding_mask,\n            **kwargs,\n        )\n\n        # Fully Connected\n        hidden_states_mlp = fast_swiglu_inference(self.mlp, hidden_states)\n        residual += hidden_states_attention\n        residual += hidden_states_mlp\n        hidden_states = residual\n    else:\n        residual = hidden_states\n        hidden_states = fast_layernorm_compiled(self.input_layernorm, hidden_states)\n        hidden_states_attention, self_attn_weights, present_key_value = self.self_attn(\n            hidden_states = hidden_states,\n            causal_mask = causal_mask,\n            attention_mask = attention_mask,\n            position_ids = position_ids,\n            past_key_value = past_key_value,\n            output_attentions = output_attentions,\n            use_cache = use_cache,\n            padding_mask = padding_mask,\n            **kwargs,\n        )\n\n        # Fully Connected\n        hidden_states_mlp = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states_attention + hidden_states_mlp\n\n    outputs = (hidden_states,)\n    if output_attentions:\n        outputs += (self_attn_weights,)\n    if use_cache:\n        outputs += (present_key_value,)\n    return outputs\n\n\nfrom math import sqrt as math_sqrt\n\nKV_CACHE_INCREMENT = 256  # KV Cache update size\ntorch_nn_functional_softmax = torch.nn.functional.softmax\ntorch_matmul = torch.matmul\n\n\ndef CohereAttention_fast_forward_inference(\n    self,\n    hidden_states: torch.Tensor,\n    past_key_value: Optional[Tuple[torch.Tensor]],\n    position_ids,\n    do_prefill = False,\n    attention_mask = None,\n    **kwargs,\n):\n    Xn = hidden_states\n    bsz, _, hd = hidden_states.size()\n    K1, V1 = past_key_value\n    dtype = Xn.dtype\n\n    n_heads = self.config.num_attention_heads\n    n_groups = self.num_key_value_groups\n    n_kv_heads = self.config.num_key_value_heads\n    head_dim = self.head_dim\n    # assert(n_kv_heads * n_groups == n_heads)\n\n    hidden_size = self.config.hidden_size\n    attention_size = n_heads * head_dim\n    seq_len = K1.shape[-2]\n    kv_seq_len = seq_len + 1\n\n    # Prefill phase\n    # if not hasattr(self, \"paged_attention\"):\n    if do_prefill:\n        self.paged_attention = torch.empty(\n            (KV_CACHE_INCREMENT + seq_len + 1, 2, bsz, n_kv_heads, head_dim),\n            dtype = dtype,\n            device = f\"{DEVICE_TYPE_TORCH}:0\",\n        )\n        self.paged_attention_K = self.paged_attention[:, 0]\n        self.paged_attention_V = self.paged_attention[:, 1]\n        self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)\n        self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)\n        self.temp_QA = torch.empty(\n            (2, bsz, 1, attention_size), dtype = dtype, device = f\"{DEVICE_TYPE_TORCH}:0\"\n        )\n        self.temp_KV = torch.empty(\n            (2, bsz, 1, n_kv_heads * head_dim),\n            dtype = dtype,\n            device = f\"{DEVICE_TYPE_TORCH}:0\",\n        )\n        self.RH_Q = torch.empty(\n            (bsz, n_heads, 1, head_dim), dtype = dtype, device = f\"{DEVICE_TYPE_TORCH}:0\"\n        )\n\n        # Mistral Nemo 12b has weird dimensions\n        if attention_size != hidden_size:\n            self.temp_O = torch.empty(\n                (bsz, 1, hidden_size), dtype = dtype, device = f\"{DEVICE_TYPE_TORCH}:0\"\n            )\n        else:\n            self.temp_O = self.temp_QA[1][:, :, :hidden_size]\n\n        self.attention = torch.empty(\n            (bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len),\n            dtype = dtype,\n            device = f\"{DEVICE_TYPE_TORCH}:0\",\n        )\n        self.scalar = 1.0 / math_sqrt(self.head_dim)\n        self.half_head_dim = head_dim // 2\n        # Cohere has QK layernorms\n        if self.use_qk_norm:\n            self.q_norm_out_weight = torch.empty(\n                self.q_norm.weight.shape,\n                dtype = torch.float32,\n                device = f\"{DEVICE_TYPE_TORCH}:0\",\n            )\n            self.k_norm_out_weight = torch.empty(\n                self.k_norm.weight.shape,\n                dtype = torch.float32,\n                device = f\"{DEVICE_TYPE_TORCH}:0\",\n            )\n        else:\n            self.q_norm_out_weight = None\n            self.k_norm_out_weight = None\n    elif kv_seq_len >= self.paged_attention.shape[0]:\n        self.paged_attention.resize_(\n            (\n                self.paged_attention.shape[0] + KV_CACHE_INCREMENT,\n                2,\n                bsz,\n                n_kv_heads,\n                head_dim,\n            )\n        )\n        self.paged_attention_K = self.paged_attention[:, 0]\n        self.paged_attention_V = self.paged_attention[:, 1]\n        self.attention.resize_(\n            (bsz, n_heads, 1, self.attention.shape[-1] + KV_CACHE_INCREMENT)\n        )\n\n    Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])\n    Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0])\n    Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1])\n    Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2)\n    Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)\n    Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)\n    if self.use_qk_norm:\n        Qn = fast_layernorm_inference(self.q_norm, Qn, self.q_norm_out_weight)\n        Kn = fast_layernorm_inference(self.k_norm, Kn, self.k_norm_out_weight)\n\n    # cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len)\n    # Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids)\n    cos, sin = self.rotary_emb.get_cached(kv_seq_len, Qn.device.index)\n    cos = cos[position_ids].unsqueeze(1)\n    sin = sin[position_ids].unsqueeze(1)\n    h = self.half_head_dim\n\n    RH_Q = self.RH_Q\n    RH_Q[:, :, :, :h] = Qn[:, :, :, h:]\n    RH_Q[:, :, :, h:] = Qn[:, :, :, :h]\n    RH_Q[:, :, :, :h].neg_()\n    Qn *= cos\n    Qn.addcmul_(RH_Q, sin)\n\n    RH_K = RH_Q[\n        :, :n_kv_heads, :, :\n    ]  # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = \"cuda:0\")\n    RH_K[:, :, :, :h] = Kn[:, :, :, h:]\n    RH_K[:, :, :, h:] = Kn[:, :, :, :h]\n    RH_K[:, :, :, :h].neg_()\n    Kn *= cos\n    Kn.addcmul_(RH_K, sin)\n\n    # New KV cache\n    # Kn = torch.cat([K1, Kn], dim = 2)\n    # Vn = torch.cat([V1, Vn], dim = 2)\n    self.paged_attention_K[seq_len] = Kn.permute(2, 0, 1, 3)\n    self.paged_attention_V[seq_len] = Vn.permute(2, 0, 1, 3)\n    Kn = self.paged_attention_K[:kv_seq_len].permute(1, 2, 0, 3)\n    Vn = self.paged_attention_V[:kv_seq_len].permute(1, 2, 0, 3)\n\n    # Handle sliding windows\n    sliding_window = getattr(self.config, \"sliding_window\", None)\n    if sliding_window is not None and kv_seq_len > sliding_window:\n        start = kv_seq_len - sliding_window\n        Knn = Kn[:, :, start:, :]  # .contiguous()\n        Vnn = Vn[:, :, start:, :]  # .contiguous()\n        if attention_mask is not None:\n            attention_mask = attention_mask[..., start:]\n    else:\n        Knn, Vnn = Kn, Vn\n\n    # Grouped query attention\n    _, _, cached_len, _ = Knn.shape\n    if n_groups != 1:\n        Knn = Knn[:, :, None, :, :].expand(\n            bsz, n_kv_heads, n_groups, cached_len, head_dim\n        )\n        Vnn = Vnn[:, :, None, :, :].expand(\n            bsz, n_kv_heads, n_groups, cached_len, head_dim\n        )\n        Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim)\n        Vnn = Vnn.reshape(bsz, n_heads, cached_len, head_dim)\n\n    # Attention\n    if bsz == 1:\n        Qn *= self.scalar  # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963\n        # It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows\n        A = torch_matmul(\n            Qn, Knn.transpose(2, 3), out = self.attention[:, :, :, :cached_len]\n        )\n        A[:] = torch_nn_functional_softmax(\n            A, dim = -1, dtype = torch.float32\n        )  # .to(A.dtype)\n        A = torch_matmul(A, Vnn, out = Qn)\n    else:\n        A = scaled_dot_product_attention(\n            Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False\n        )\n    A = A.transpose(1, 2)\n    A = A.reshape(bsz, 1, attention_size)\n    A = fast_linear_forward(self.o_proj, A, out = self.temp_O)\n    return A, (Kn, Vn)\n\n\n# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825\n# @torch.inference_mode\ndef CohereModel_fast_forward_inference(\n    self,\n    input_ids,\n    past_key_values,\n    position_ids,\n    attention_mask = None,\n):\n    out_weights = tuple(\n        torch.empty_like(\n            self.model.layers[0].input_layernorm.weight,\n            dtype = torch.float32,\n            device = torch.device(x),\n        )\n        for x in range(DEVICE_COUNT)\n    )\n    input_ids = input_ids[:, : self.max_seq_length]\n    hidden_states = self.model.embed_tokens(input_ids)\n    hidden_states = hidden_states.to(_get_dtype(dtype_from_config(self.config)))\n    bsz, q_len, hd = hidden_states.shape\n    seq_len = past_key_values[0][0].shape[-2]\n    if bsz != 1:\n        attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(\n            attention_mask,\n            (bsz, q_len),\n            hidden_states,\n            seq_len,\n            sliding_window = getattr(self.config, \"sliding_window\", None),\n        )\n        # Pre-convert to bool once for all layers (avoids per-layer .eq(0))\n        if attention_mask is not None and attention_mask.dtype != torch.bool:\n            attention_mask = attention_mask.eq(0)\n    else:\n        attention_mask = None\n\n    next_decoder_cache = []\n    for idx, decoder_layer in enumerate(self.model.layers):\n        device_index = getattr(decoder_layer, \"_per_layer_device_index\", 0)\n        hidden_states, position_ids = move_to_device(\n            device_index, hidden_states, position_ids\n        )\n        residual = hidden_states\n        hidden_states = fast_layernorm_inference(\n            decoder_layer.input_layernorm, hidden_states, out_weights[device_index]\n        )\n        hidden_states_attention, present_key_value = (\n            CohereAttention_fast_forward_inference(\n                decoder_layer.self_attn,\n                hidden_states = hidden_states,\n                past_key_value = past_key_values[idx],\n                position_ids = position_ids,\n                attention_mask = attention_mask,\n                do_prefill = not hasattr(decoder_layer.self_attn, \"paged_attention\"),\n            )\n        )\n\n        hidden_states_mlp = fast_swiglu_inference(decoder_layer.mlp, hidden_states)\n        residual += hidden_states_attention\n        residual += hidden_states_mlp\n        hidden_states = residual\n\n        next_decoder_cache.append(present_key_value)\n    hidden_states = fast_layernorm_inference(\n        self.model.norm, hidden_states, out_weights[device_index]\n    )\n\n    return BaseModelOutputWithPast(\n        last_hidden_state = hidden_states,\n        past_key_values = next_decoder_cache,\n        hidden_states = [],\n        attentions = [],\n    )\n\n\nclass FastCohereModel(FastLlamaModel):\n    @staticmethod\n    def pre_patch():\n        init_name, function = patch_linear_scaling(\n            model_name = \"cohere\",\n            rope_module = LlamaRotaryEmbedding,\n            scaled_rope_module = LlamaLinearScalingRotaryEmbedding,\n            attention_module = CohereAttention,\n        )\n        if init_name is not None:\n            exec(function, globals())\n            CohereAttention.__init__ = eval(init_name)\n        CohereAttention.forward = CohereAttention_fast_forward\n        CohereSdpaAttention.forward = CohereAttention_fast_forward\n        CohereFlashAttention2.forward = CohereAttention_fast_forward\n        CohereDecoderLayer.forward = CohereDecoderLayer_fast_forward\n        CohereModel.forward = LlamaModel_fast_forward\n        CohereForCausalLM.forward = CausalLM_fast_forward(\n            CohereModel_fast_forward_inference\n        )\n        PeftModelForCausalLM.forward = PeftModel_fast_forward\n        fix_prepare_inputs_for_generation(CohereForCausalLM)\n\n        import transformers.models.cohere.modeling_cohere\n\n        transformers.models.cohere.modeling_cohere.CohereRotaryEmbedding = (\n            LlamaRotaryEmbedding\n        )\n        return\n"
  },
  {
    "path": "unsloth/models/dpo.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n__all__ = [\n    \"PatchDPOTrainer\",\n    \"PatchKTOTrainer\",\n]\n\n\ndef PatchDPOTrainer():\n    return\n\n\ndef PatchKTOTrainer():\n    return\n"
  },
  {
    "path": "unsloth/models/falcon_h1.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .llama import *\nimport os\nfrom ._utils import __version__\nfrom unsloth_zoo.utils import Version, _get_dtype\nfrom unsloth_zoo.hf_utils import dtype_from_config\nfrom ..utils.packing import get_packed_info_from_kwargs\nfrom ..utils.attention_dispatch import (\n    AttentionConfig,\n    AttentionContext,\n    run_attention,\n    select_attention_backend,\n    SDPA,\n)\nfrom .llama import (\n    LlamaRotaryEmbedding,\n    LlamaLinearScalingRotaryEmbedding,\n    _LlamaModel_fast_forward_inference,\n)\n\ntry:\n    from transformers.models.falcon_h1.modeling_falcon_h1 import (\n        FalconH1Attention,\n        FalconH1DecoderLayer,\n        FalconH1Model,\n        FalconH1ForCausalLM,\n        FalconHybridMambaAttentionDynamicCache,\n    )\nexcept:\n    from transformers import __version__ as transformers_version\n\n    transformers_version = Version(transformers_version)\n    if not transformers_version >= Version(\n        \"4.53.0\"\n    ):  # TODO: Update when transformers is updated\n        raise ImportError(\n            f\"Unsloth: Your transformers version of {transformers_version} does not support FalconH1.\\n\"\n            f\"The minimum required version is 4.53.0.\\n\"\n            f'Try `pip install --upgrade \"transformers>=4.53.0\"`\\n'\n            f\"to obtain the latest transformers build, then restart this session.\"\n        )\nfrom transformers.modeling_attn_mask_utils import (\n    _prepare_4d_causal_attention_mask_for_sdpa,\n)\nfrom transformers.utils import (\n    is_torchdynamo_compiling,\n)\n\n# For Pytorch 2.1.1\ntry:\n    from transformers.models.falcon_h1.modeling_falcon_h1 import (\n        FalconH1Attention,\n    )\nexcept ModuleNotFoundError:\n    # if we are on an old version of transformers technically it should fail in the try except above\n    # but if somehow we make it here, we need to raise an error since FalconH1Attention is not available\n    # or renamed\n    raise ImportError(\n        \"Unsloth: Could not import FalconH1Attention from transformers.models.falcon_h1.modeling_falcon_h1.\"\n    )\n\n\ndef FalconH1Attention_fast_forward(\n    self,\n    hidden_states: torch.Tensor,\n    causal_mask: Optional[BlockDiagonalCausalMask] = None,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_value: Optional[Tuple[torch.Tensor]] = None,\n    output_attentions: bool = False,\n    use_cache: bool = False,\n    padding_mask: Optional[torch.LongTensor] = None,\n    position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n    *args,\n    **kwargs,\n) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n    # Clear inference\n    if hasattr(self, \"paged_attention\"):\n        del self.paged_attention_K\n        del self.paged_attention_V\n        del self.paged_attention\n        del self.temp_QA\n        del self.temp_KV\n        del self.RH_Q\n        del self.attention\n\n    bsz, q_len, _ = hidden_states.size()\n\n    n_heads = self.config.num_attention_heads\n    n_groups = self.num_key_value_groups\n    n_kv_heads = self.config.num_key_value_heads\n    head_dim = self.head_dim\n    assert n_kv_heads * n_groups == n_heads\n\n    Q, K, V = self.apply_qkv(self, hidden_states)\n    Q = Q.view(bsz, q_len, n_heads, head_dim)\n    K = K.view(bsz, q_len, n_kv_heads, head_dim)\n    V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)\n    seq_info = get_packed_info_from_kwargs(kwargs, hidden_states.device)\n\n    # Falcon H1 multiplies key states by a multiplier\n    K = K * self.config.key_multiplier\n\n    Q = Q.transpose(1, 2)\n    K = K.transpose(1, 2)\n\n    kv_seq_len = K.shape[-2]\n    if past_key_value is not None:\n        kv_seq_len += past_key_value[0].shape[-2]\n\n    # Extend RoPE dynamically to fit in VRAM\n    if position_embeddings and kv_seq_len <= position_embeddings[0].shape[0]:\n        cos, sin = position_embeddings\n    else:\n        rotary_emb = self.rotary_emb\n        rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len)\n        cos, sin = rotary_emb.get_cached(kv_seq_len, Q.device.index)\n\n    rope_position_ids = (\n        position_ids if position_ids is not None else kwargs.get(\"position_ids\")\n    )\n    # Useful for LongRoPE\n    Q, K = fast_rope_embedding(Q, K, cos, sin, rope_position_ids)\n\n    if past_key_value is not None:\n        K = torch.cat([past_key_value[0], K], dim = 2)\n        V = torch.cat([past_key_value[1], V], dim = 2)\n    past_key_value = (K, V) if use_cache else None\n\n    # Attention module\n    window = (-1, -1)\n    use_varlen = (\n        attention_mask is None\n        and seq_info is not None\n        and past_key_value is None\n        and window == (-1, -1)\n    )\n\n    backend = (\n        SDPA if attention_mask is not None else select_attention_backend(use_varlen)\n    )\n    attention_config = AttentionConfig(\n        backend = backend,\n        n_kv_heads = n_kv_heads,\n        n_groups = n_groups,\n        flash_dense_kwargs = {\n            \"causal\": True,\n            \"window_size\": (kv_seq_len, kv_seq_len),\n        },\n        flash_varlen_kwargs = {\n            \"dropout_p\": 0.0,\n            \"softmax_scale\": None,\n            \"causal\": True,\n        },\n        sdpa_kwargs = {} if attention_mask is None else {\"attn_mask\": attention_mask},\n    )\n    context = AttentionContext(\n        bsz = bsz,\n        q_len = q_len,\n        kv_seq_len = kv_seq_len,\n        n_heads = n_heads,\n        head_dim = head_dim,\n        requires_grad = hidden_states.requires_grad,\n        seq_info = seq_info,\n        attention_mask = attention_mask,\n        causal_mask = causal_mask,\n    )\n\n    A = run_attention(config = attention_config, context = context, Q = Q, K = K, V = V)\n\n    attn_output = A.reshape(bsz, q_len, n_heads * head_dim)\n    attn_output = self.apply_o(self, attn_output)\n    attn_weights = None\n    return attn_output, attn_weights, past_key_value\n\n\ntorch_matmul = torch.matmul\n\n\ndef FalconH1Attention_fast_forward_inference(\n    self,\n    hidden_states: torch.Tensor,\n    past_key_value: Optional[Tuple[torch.Tensor]],\n    position_ids,\n    do_prefill = False,\n    attention_mask = None,\n    **kwargs,\n):\n    \"\"\"\n    https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L406\n    Fast inference using KV cache.\n    QK^T can be computed in 4 chunks\n\n    [Q, q] @ [K, k].T where q, k are the new tokens.\n    [QK^T, Qk^T]\n    [qK^T, qk^T]\n\n    Since the attention mask wipes Qk^T, we just get\n    [QK^T,    0]\n    [qK^T, qk^T]\n\n    Since softmax is row-wise, we get\n    softmax([QK^T,    0])\n    softmax([qK^T, qk^T])\n\n    We then multiply by   [V]\n                          [v]\n    softmax([QK^T,    0]) [softmax(QK^T)V] *\n    softmax([qK^T, qk^T]) [softmax([qK^T, qk^T]) @ [V, v]]\n\n    But notice * [softmax(QK^T)V] is just the last attention.\n    We just need to compute the last final row.\n\n    This means we can pass in a row of Q, but we need to\n    remember K and V, which are called the KV cache.\n    \"\"\"\n    Xn = hidden_states\n    bsz, _, hd = hidden_states.size()\n    K1, V1 = past_key_value\n    dtype = Xn.dtype\n\n    n_heads = self.config.num_attention_heads\n    n_groups = self.num_key_value_groups\n    n_kv_heads = self.config.num_key_value_heads\n    head_dim = self.head_dim\n    # assert(n_kv_heads * n_groups == n_heads)\n\n    hidden_size = self.config.hidden_size\n    attention_size = n_heads * head_dim\n    seq_len = K1.shape[-2]\n    kv_seq_len = seq_len + 1\n\n    # Prefill phase\n    # if not hasattr(self, \"paged_attention\"):\n    device = hidden_states.device\n    if do_prefill:\n        self.paged_attention = torch.empty(\n            (KV_CACHE_INCREMENT + seq_len + 1, 2, bsz, n_kv_heads, head_dim),\n            dtype = dtype,\n            device = device,\n        )\n        self.paged_attention_K = self.paged_attention[:, 0]\n        self.paged_attention_V = self.paged_attention[:, 1]\n        self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)\n        self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)\n        self.temp_QA = torch.empty(\n            (2, bsz, 1, attention_size), dtype = dtype, device = device\n        )\n        self.temp_KV = torch.empty(\n            (2, bsz, 1, n_kv_heads * head_dim), dtype = dtype, device = device\n        )\n        self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = device)\n\n        # Mistral Nemo 12b has weird dimensions\n        if attention_size != hidden_size:\n            self.temp_O = torch.empty((bsz, 1, hidden_size), dtype = dtype, device = device)\n        else:\n            self.temp_O = self.temp_QA[1][:, :, :hidden_size]\n\n        self.attention = torch.empty(\n            (bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype = dtype, device = device\n        )\n        self.scalar = 1.0 / math_sqrt(self.head_dim)\n        self.half_head_dim = head_dim // 2\n    elif kv_seq_len >= self.paged_attention.shape[0]:\n        self.paged_attention.resize_(\n            (\n                self.paged_attention.shape[0] + KV_CACHE_INCREMENT,\n                2,\n                bsz,\n                n_kv_heads,\n                head_dim,\n            )\n        )\n        self.paged_attention_K = self.paged_attention[:, 0]\n        self.paged_attention_V = self.paged_attention[:, 1]\n        self.attention.resize_(\n            (bsz, n_heads, 1, self.attention.shape[-1] + KV_CACHE_INCREMENT)\n        )\n\n    Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])\n    Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0])\n    Kn.mul_(self.config.key_multiplier)\n    Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1])\n    Qn = Qn.view(\n        bsz, 1, n_heads, head_dim\n    )  # .transpose(1, 2) # we will transpose after normalisation\n    Kn = Kn.view(\n        bsz, 1, n_kv_heads, head_dim\n    )  # .transpose(1, 2) # we will transpose after normalisation\n    Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)\n\n    Qn = Qn.transpose(1, 2)\n    Kn = Kn.transpose(1, 2)\n\n    # cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len)\n    # Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids)\n\n    # Need to do it prior 2 steps before hitting full on short KV cache\n    # or else error\n    self.rotary_emb.extend_rope_embedding(Vn, seq_len + 2)\n    cos, sin = self.rotary_emb.get_cached(kv_seq_len, Qn.device.index)\n    cos = cos[position_ids].unsqueeze(1)\n    sin = sin[position_ids].unsqueeze(1)\n    h = self.half_head_dim\n\n    RH_Q = self.RH_Q\n    RH_Q[:, :, :, :h] = Qn[:, :, :, h:]\n    RH_Q[:, :, :, h:] = Qn[:, :, :, :h]\n    RH_Q[:, :, :, :h].neg_()  # torch.neg(RH_Q[:,:,:,:h], out = RH_Q[:,:,:,:h])\n    Qn *= cos\n    Qn.addcmul_(RH_Q, sin)\n\n    RH_K = RH_Q[\n        :, :n_kv_heads, :, :\n    ]  # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = \"cuda:0\")\n    RH_K[:, :, :, :h] = Kn[:, :, :, h:]\n    RH_K[:, :, :, h:] = Kn[:, :, :, :h]\n    RH_K[:, :, :, :h].neg_()  # torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h])\n    Kn *= cos\n    Kn.addcmul_(RH_K, sin)\n\n    # New KV cache\n    # Kn = torch.cat([K1, Kn], dim = 2)\n    # Vn = torch.cat([V1, Vn], dim = 2)\n    self.paged_attention_K[seq_len] = Kn.permute(2, 0, 1, 3)\n    self.paged_attention_V[seq_len] = Vn.permute(2, 0, 1, 3)\n    Kn = self.paged_attention_K[:kv_seq_len].permute(1, 2, 0, 3)\n    Vn = self.paged_attention_V[:kv_seq_len].permute(1, 2, 0, 3)\n\n    # Handle sliding windows\n    sliding_window = getattr(self.config, \"sliding_window\", None)\n    if sliding_window is not None and kv_seq_len > sliding_window:\n        start = kv_seq_len - sliding_window\n        Knn = Kn[:, :, start:, :]  # .contiguous()\n        Vnn = Vn[:, :, start:, :]  # .contiguous()\n        if attention_mask is not None:\n            attention_mask = attention_mask[..., start:]\n    else:\n        Knn, Vnn = Kn, Vn\n\n    # Grouped query attention\n    _, _, cached_len, _ = Knn.shape\n    if bsz == 1 or not SDPA_HAS_GQA and n_groups != 1:\n        Knn = Knn[:, :, None, :, :].expand(\n            bsz, n_kv_heads, n_groups, cached_len, head_dim\n        )\n        Vnn = Vnn[:, :, None, :, :].expand(\n            bsz, n_kv_heads, n_groups, cached_len, head_dim\n        )\n        Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim)\n        Vnn = Vnn.reshape(bsz, n_heads, cached_len, head_dim)\n\n    # Attention\n    if bsz == 1:\n        Qn *= self.scalar  # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963\n        # It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows\n        A = torch_matmul(\n            Qn, Knn.transpose(2, 3), out = self.attention[:, :, :, :cached_len]\n        )\n        A[:] = torch_nn_functional_softmax(\n            A, dim = -1, dtype = torch.float32\n        )  # .to(A.dtype)\n        A = torch_matmul(A, Vnn, out = Qn)\n    else:\n        if SDPA_HAS_GQA:\n            A = scaled_dot_product_attention(\n                Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False, enable_gqa = True\n            )\n        else:\n            A = scaled_dot_product_attention(\n                Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False\n            )\n    A = A.transpose(1, 2)\n    A = A.reshape(bsz, 1, attention_size)\n    A = fast_linear_forward(self.o_proj, A, out = self.temp_O)\n    return A, (Kn, Vn)\n\n\n# https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon_h1/modeling_falcon_h1.py\ndef FalconH1DecoderLayer_fast_forward(\n    self,\n    hidden_states: torch.Tensor,\n    causal_mask = None,\n    attention_mask: Optional[torch.Tensor] = None,\n    mamba_attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    cache_position: Optional[torch.LongTensor] = None,\n    past_key_value: Optional[Tuple[torch.Tensor]] = None,\n    output_attentions: Optional[bool] = False,\n    use_cache: Optional[bool] = False,\n    padding_mask: Optional[torch.LongTensor] = None,\n    position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n    *args,\n    **kwargs,\n) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:\n    \"\"\"\n    Args:\n        hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n        attention_mask (`torch.FloatTensor`, *optional*): attention mask of size\n            `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.\n        output_attentions (`bool`, *optional*):\n            Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n            returned tensors for more detail.\n        use_cache (`bool`, *optional*):\n            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n            (see `past_key_values`).\n        past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states\n    \"\"\"\n    if use_cache and hasattr(self, \"_flag_for_generation\"):\n        residual = hidden_states\n        hidden_states = fast_rms_layernorm_inference(\n            self.input_layernorm, hidden_states\n        )\n        attention_hidden_states, self_attn_weights, present_key_value = self.self_attn(\n            hidden_states = hidden_states,\n            causal_mask = causal_mask,\n            attention_mask = attention_mask,\n            position_ids = position_ids,\n            past_key_value = past_key_value,\n            output_attentions = output_attentions,\n            use_cache = use_cache,\n            padding_mask = padding_mask,\n            position_embeddings = position_embeddings,\n            **kwargs,\n        )\n        attention_hidden_states = attention_hidden_states * self.attn_out_multiplier\n\n        mamba_hidden_states = self.mamba(\n            hidden_states = hidden_states,\n            cache_params = past_key_value,\n            cache_position = cache_position,\n            attention_mask = mamba_attention_mask,\n        )\n        mamba_hidden_states = mamba_hidden_states * self.ssm_out_multiplier\n\n        hidden_states = mamba_hidden_states + attention_hidden_states\n\n        hidden_states += residual\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = fast_rms_layernorm_inference(\n            self.pre_ff_layernorm, hidden_states\n        )\n        hidden_states = fast_swiglu_inference(self.feed_forward, hidden_states)\n        hidden_states += residual\n    else:\n        residual = hidden_states\n        hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)\n\n        mamba_hidden_states = self.mamba(\n            hidden_states = hidden_states,\n            cache_params = past_key_value,\n            cache_position = cache_position,\n            attention_mask = mamba_attention_mask,\n        )\n        mamba_hidden_states = mamba_hidden_states * self.ssm_out_multiplier\n\n        attention_hidden_states, self_attn_weights, present_key_value = self.self_attn(\n            hidden_states = hidden_states,\n            causal_mask = causal_mask,\n            attention_mask = attention_mask,\n            position_ids = position_ids,\n            past_key_value = past_key_value,\n            output_attentions = output_attentions,\n            use_cache = use_cache,\n            padding_mask = padding_mask,\n            position_embeddings = position_embeddings,\n            **kwargs,\n        )\n        attention_hidden_states = attention_hidden_states * self.attn_out_multiplier\n\n        hidden_states = mamba_hidden_states + attention_hidden_states\n\n        # residual connection after attention + Mamba\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = fast_rms_layernorm(self.pre_ff_layernorm, hidden_states)\n        hidden_states = self.feed_forward(hidden_states)\n        hidden_states = residual + hidden_states\n\n    outputs = (hidden_states,)\n    if output_attentions:\n        outputs += (self_attn_weights,)\n    if use_cache:\n        outputs += (present_key_value,)\n    return outputs\n\n\ndef _FalconH1_fast_forward_inference(\n    attention_fast_forward_inference = FalconH1Attention_fast_forward_inference,\n    mlp_fast_forward_inference = fast_swiglu_inference,\n):\n    # This makes the attention and MLP customisable.\n    # Now for models like qwen3 or cohere which use custom attention operations, we can use this function\n    def FalconH1Model_fast_forward_inference_custom(\n        self,\n        input_ids,\n        past_key_values,\n        position_ids,\n        cache_position = None,\n        attention_mask = None,\n        mamba_attention_mask = None,\n    ):\n        input_ids = input_ids[:, : self.max_seq_length]\n        bsz, q_len = input_ids.shape\n        hd = self.config.hidden_size\n        mlp_size = self.config.intermediate_size\n        gate_multiplier, down_multiplier = self.config.mlp_multipliers\n\n        X = self.model.embed_tokens(input_ids)\n        X = X * self.config.embedding_multiplier\n\n        X = X.to(_get_dtype(dtype_from_config(self.config)))\n        bsz, q_len, hd = X.shape\n        assert q_len == 1\n        # Get saved buffers to reduce memory movement\n        residual = torch.empty(\n            (bsz, q_len, hd), dtype = torch.float32, device = f\"{DEVICE_TYPE_TORCH}:0\"\n        )\n        _XX = torch.empty(\n            (2, bsz, q_len, hd), dtype = torch.float32, device = f\"{DEVICE_TYPE_TORCH}:0\"\n        )\n        XX, XX2 = _XX[0], _XX[1]\n        variance = torch.empty(\n            (bsz, q_len, 1), dtype = torch.float32, device = f\"{DEVICE_TYPE_TORCH}:0\"\n        )\n        temp_mlp = torch.empty(\n            (2, bsz, 1, mlp_size), dtype = X.dtype, device = f\"{DEVICE_TYPE_TORCH}:0\"\n        )\n        temp_gate, temp_up = temp_mlp[0], temp_mlp[1]\n        seq_len = past_key_values[0][0].shape[-2]\n        if bsz != 1:\n            attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(\n                attention_mask,\n                (bsz, q_len),\n                X,\n                seq_len,\n                sliding_window = getattr(self.config, \"sliding_window\", None),\n            )\n        else:\n            attention_mask = None\n\n        next_decoder_cache = []\n\n        for idx, decoder_layer in enumerate(self.model.layers):\n            residual.copy_(X)  # residual = X\n            X = fast_rms_layernorm_inference(\n                decoder_layer.input_layernorm,\n                X,\n                XX = XX,\n                XX2 = XX2,\n                variance = variance,\n            )\n            attention_hidden_states, present_key_value = (\n                attention_fast_forward_inference(\n                    decoder_layer.self_attn,\n                    hidden_states = X * decoder_layer.attention_in_multiplier,\n                    past_key_value = past_key_values[idx],\n                    position_ids = position_ids,\n                    attention_mask = attention_mask,\n                    do_prefill = not hasattr(decoder_layer.self_attn, \"paged_attention\"),\n                )\n            )\n            attention_hidden_states = (\n                attention_hidden_states * decoder_layer.attn_out_multiplier\n            )\n            mamba_hidden_states = decoder_layer.mamba(\n                hidden_states = X,\n                cache_params = present_key_value,\n                cache_position = cache_position,\n                attention_mask = mamba_attention_mask,\n            )\n            mamba_hidden_states = mamba_hidden_states * decoder_layer.ssm_out_multiplier\n            X = mamba_hidden_states + attention_hidden_states\n\n            X += residual\n\n            residual.copy_(X)  # residual = X\n            X = fast_rms_layernorm_inference(\n                decoder_layer.pre_ff_layernorm,\n                X,\n                XX = XX,\n                XX2 = XX2,\n                variance = variance,\n            )\n            X = mlp_fast_forward_inference(\n                decoder_layer.feed_forward,\n                X,\n                temp_gate = temp_gate,\n                temp_up = temp_up,\n                gate_multiplier = gate_multiplier,\n                down_multiplier = down_multiplier,\n            )\n            X += residual\n\n            next_decoder_cache.append(present_key_value)\n        X = fast_rms_layernorm_inference(\n            self.model.final_layernorm,\n            X,\n            XX = XX,\n            XX2 = XX2,\n            variance = variance,\n        )\n\n        return BaseModelOutputWithPast(\n            last_hidden_state = X,\n            past_key_values = next_decoder_cache,\n            hidden_states = [],\n            attentions = [],\n        )\n\n    return FalconH1Model_fast_forward_inference_custom\n\n\n# Separate prepare_inputs_for_generation for Hybrid FalconH1\ndef _fast_prepare_inputs_for_generation(\n    self,\n    input_ids,\n    past_key_values = None,\n    attention_mask = None,\n    inputs_embeds = None,\n    cache_position = None,\n    position_ids = None,\n    use_cache = True,\n    **kwargs,\n):\n    # Overwritten -- has a unique cache type, `FalconHybridMambaAttentionDynamicCache`\n    empty_past_kv = past_key_values is None\n\n    # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens\n    # Exception 1: when passing input_embeds, input_ids may be missing entries\n    # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here\n    # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.\n    #              (we can't check exception 3 while compiling)\n    if not empty_past_kv:\n        if (\n            inputs_embeds is not None  # Exception 1\n            or (\n                is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]\n            )  # Exception 3\n        ):\n            input_ids = input_ids[:, -cache_position.shape[0] :]\n        elif (\n            input_ids.shape[1] != cache_position.shape[0]\n        ):  # Default case (the \"else\", a no op, is Exception 2)\n            input_ids = input_ids[:, cache_position]\n    # TODO: Wire up Cache to work for inference.\n    # else:\n    #     past_key_values = FalconHybridMambaAttentionDynamicCache(\n    #         self.config,\n    #         input_ids.shape[0],\n    #         self.dtype,\n    #         devices=[\n    #             self.model.layers[i].mamba.conv1d.weight.device for i in range(self.config.num_hidden_layers)\n    #         ],\n    #     )\n\n    if attention_mask is not None and position_ids is None:\n        # create position_ids on the fly for batch generation\n        position_ids = attention_mask.long().cumsum(-1) - 1\n        position_ids.masked_fill_(attention_mask == 0, 1)\n        if not empty_past_kv:\n            position_ids = position_ids[:, -input_ids.shape[1] :]\n\n    # if `inputs_embeds` are passed, we only want to use them in the 1st generation step\n    if inputs_embeds is not None and empty_past_kv:\n        model_inputs = {\"inputs_embeds\": inputs_embeds}\n    else:\n        model_inputs = {\n            \"input_ids\": input_ids.contiguous()\n        }  # `contiguous()` needed for compilation use cases\n\n    model_inputs.update(\n        {\n            \"position_ids\": position_ids,\n            \"past_key_values\": past_key_values,\n            \"use_cache\": use_cache,\n            \"attention_mask\": attention_mask,\n            \"logits_to_keep\": self.config.num_logits_to_keep,\n            \"cache_position\": cache_position,\n        }\n    )\n    return model_inputs\n\n\ndef fix_prepare_inputs_for_generation(module):\n    # Fix prepare_inputs_for_generation\n    if hasattr(module, \"prepare_inputs_for_generation\"):\n        module.prepare_inputs_for_generation = _fast_prepare_inputs_for_generation\n\n\nclass FastFalconH1Model(FastLlamaModel):\n    @staticmethod\n    def pre_patch():\n        init_name, function = patch_linear_scaling(\n            model_name = \"FalconH1\",\n            rope_module = LlamaRotaryEmbedding,\n            scaled_rope_module = LlamaLinearScalingRotaryEmbedding,\n            attention_module = FalconH1Attention,\n        )\n        if init_name is not None:\n            exec(function, globals())\n            FalconH1Attention.__init__ = eval(init_name)\n        FalconH1Attention.forward = FalconH1Attention_fast_forward\n        FalconH1DecoderLayer.forward = FalconH1DecoderLayer_fast_forward\n        FalconH1Model.forward = LlamaModel_fast_forward\n        FalconH1ForCausalLM.forward = CausalLM_fast_forward(\n            _FalconH1_fast_forward_inference(FalconH1Attention_fast_forward_inference)\n        )\n        PeftModelForCausalLM.forward = PeftModel_fast_forward\n        fix_prepare_inputs_for_generation(FalconH1ForCausalLM)\n\n        # Solves https://github.com/unslothai/unsloth/issues/168\n        # Static KV Cache was introduced in 4.38.0, causing training to be much slower.\n        # Inference can now be CUDAGraphed, but we shall retain the old rotary embeddings.\n        # https://github.com/huggingface/transformers/pull/27931\n        # https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py\n        import transformers.models.falcon_h1.modeling_falcon_h1\n\n        transformers.models.falcon_h1.modeling_falcon_h1.FalconH1RotaryEmbedding = (\n            LlamaRotaryEmbedding\n        )\n        return\n\n    @staticmethod\n    def from_pretrained(  # TODO: Change after release\n        model_name = \"Qwen/FalconH1-7B\",\n        max_seq_length = 4096,\n        dtype = None,\n        load_in_4bit = True,\n        token = None,\n        device_map = \"sequential\",\n        rope_scaling = None,\n        fix_tokenizer = True,\n        model_patcher = None,\n        tokenizer_name = None,\n        trust_remote_code = False,\n        **kwargs,\n    ):\n        return FastLlamaModel.from_pretrained(\n            model_name = model_name,\n            max_seq_length = max_seq_length,\n            dtype = dtype,\n            load_in_4bit = load_in_4bit,\n            token = token,\n            device_map = device_map,\n            rope_scaling = rope_scaling,\n            fix_tokenizer = fix_tokenizer,\n            model_patcher = FastFalconH1Model,\n            tokenizer_name = tokenizer_name,\n            trust_remote_code = trust_remote_code,\n            **kwargs,\n        )\n"
  },
  {
    "path": "unsloth/models/gemma.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .llama import *\nfrom .llama import _get_rope_theta\nfrom ._utils import __version__\nfrom unsloth_zoo.utils import _get_dtype, Version\nfrom unsloth_zoo.hf_utils import dtype_from_config\nfrom ..utils.packing import (\n    build_sdpa_packed_attention_mask,\n    build_xformers_block_causal_mask,\n    get_packed_info_from_kwargs,\n)\nimport math\n\ntry:\n    from transformers.models.gemma.modeling_gemma import (\n        GemmaAttention,\n        GemmaDecoderLayer,\n        GemmaModel,\n        GemmaForCausalLM,\n        GemmaRotaryEmbedding,\n        apply_rotary_pos_emb,\n        repeat_kv,\n    )\nexcept:\n    transformers_version = Version(transformers_version)\n    if not transformers_version >= Version(\"4.38\"):\n        raise ImportError(\n            f\"Unsloth: Your transformers version of {transformers_version} does not support Gemma.\\n\"\n            f\"The minimum required version is 4.38.\\n\"\n            f'Try `pip install --upgrade \"transformers>=4.38\"`\\n'\n            f\"to obtain the latest transformers build, then restart this session.\"\n        )\n\nfrom transformers.modeling_attn_mask_utils import (\n    _prepare_4d_causal_attention_mask_for_sdpa,\n)\n\n# For Pytorch 2.1.1\ntry:\n    from transformers.models.gemma.modeling_gemma import (\n        GemmaSdpaAttention,\n        GemmaFlashAttention2,\n    )\nexcept:\n    GemmaSdpaAttention = GemmaAttention\n    GemmaFlashAttention2 = GemmaAttention\n\n\ntorch_nn_functional_gelu = torch.nn.functional.gelu\n\n\ndef fast_geglu_inference(self, X):\n    # gate = self.gate_proj(X)\n    # up   = self.up_proj(X)\n    bsz, _, hd = X.shape\n    # mlp_size = self.config.intermediate_size\n    # temp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = \"cuda:0\")\n\n    gate = fast_linear_forward(self.gate_proj, X)  # , out = temp[0])\n    up = fast_linear_forward(self.up_proj, X)  # , out = temp[1])\n    gate = torch_nn_functional_gelu(gate, approximate = \"tanh\")\n    gate *= up\n\n    # X = self.down_proj(gate)\n    down = fast_linear_forward(self.down_proj, gate, out = up[:, :, :hd])\n    return down\n\n\n# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L590\ndef GemmaDecoderLayer_fast_forward(\n    self,\n    hidden_states: torch.Tensor,\n    causal_mask: Optional[BlockDiagonalCausalMask] = None,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_value: Optional[Tuple[torch.Tensor]] = None,\n    output_attentions: Optional[bool] = False,\n    use_cache: Optional[bool] = False,\n    padding_mask: Optional[torch.LongTensor] = None,\n    *args,\n    **kwargs,\n):\n    if use_cache and hasattr(\n        self, \"_flag_for_generation\"\n    ):  # past_key_value is not None:\n        out_weight = torch.empty(\n            self.input_layernorm.weight.shape,\n            dtype = torch.float32,\n            device = f\"{DEVICE_TYPE_TORCH}:0\",\n        )\n\n        # Self Attention\n        residual = hidden_states\n        hidden_states = fast_rms_layernorm_inference_gemma(\n            self.input_layernorm, hidden_states, out_weight\n        )\n        hidden_states, self_attn_weights, present_key_value = self.self_attn(\n            hidden_states = hidden_states,\n            causal_mask = causal_mask,\n            attention_mask = attention_mask,\n            position_ids = position_ids,\n            past_key_value = past_key_value,\n            output_attentions = output_attentions,\n            use_cache = use_cache,\n            padding_mask = padding_mask,\n            **kwargs,\n        )\n        hidden_states += residual\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = fast_rms_layernorm_inference_gemma(\n            self.post_attention_layernorm, hidden_states, out_weight\n        )\n        hidden_states = fast_geglu_inference(self.mlp, hidden_states)\n        hidden_states += residual\n    else:\n        residual = hidden_states\n        hidden_states = fast_rms_layernorm(\n            self.input_layernorm, hidden_states, gemma = True\n        )\n        hidden_states, self_attn_weights, present_key_value = self.self_attn(\n            hidden_states = hidden_states,\n            causal_mask = causal_mask,\n            attention_mask = attention_mask,\n            position_ids = position_ids,\n            past_key_value = past_key_value,\n            output_attentions = output_attentions,\n            use_cache = use_cache,\n            padding_mask = padding_mask,\n            **kwargs,\n        )\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = fast_rms_layernorm(\n            self.post_attention_layernorm, hidden_states, gemma = True\n        )\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n\n    outputs = (hidden_states,)\n    if output_attentions:\n        outputs += (self_attn_weights,)\n    if use_cache:\n        outputs += (present_key_value,)\n    return outputs\n\n\nfrom math import sqrt as math_sqrt\n\n\n# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825\n# @torch.inference_mode\ndef GemmaModel_fast_forward_inference(\n    self,\n    input_ids,\n    past_key_values,\n    position_ids,\n    attention_mask = None,\n    **kwargs,\n):\n    out_weights = tuple(\n        torch.empty_like(\n            self.model.layers[0].input_layernorm.weight,\n            dtype = torch.float32,\n            device = torch.device(x),\n        )\n        for x in range(DEVICE_COUNT)\n    )\n    input_ids = input_ids[:, : self.max_seq_length]\n    hidden_states = self.model.embed_tokens(input_ids)\n    hidden_states = hidden_states.to(_get_dtype(dtype_from_config(self.config)))\n    # 3072**0.5 = 55.5000 in bfloat16, whilst 55.4256 in float32\n    # 2048**0.5 = 45.2500 in bfloat16, whilst 45.2548 in float32\n    hidden_states *= torch.tensor(\n        math_sqrt(self.config.hidden_size), dtype = hidden_states.dtype\n    )\n\n    bsz, q_len, hd = hidden_states.shape\n    seq_len = past_key_values[0][0].shape[-2]\n    kv_seq_len = seq_len + 1\n    if bsz != 1:\n        attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(\n            attention_mask,\n            (bsz, q_len),\n            hidden_states,\n            seq_len,\n        )\n        # Pre-convert to bool once for all layers (avoids per-layer .eq(0))\n        if attention_mask is not None and attention_mask.dtype != torch.bool:\n            attention_mask = attention_mask.eq(0)\n\n    # Compute rotary_seq_len once to avoid per-layer GPU-CPU sync from .item()\n    rotary_seq_len = max(kv_seq_len, int(position_ids.max().item()) + 1)\n\n    next_decoder_cache = []\n    for idx, decoder_layer in enumerate(self.model.layers):\n        device_index = getattr(decoder_layer, \"_per_layer_device_index\", 0)\n        hidden_states, position_ids = move_to_device(\n            device_index, hidden_states, position_ids\n        )\n\n        residual = hidden_states\n        hidden_states = fast_rms_layernorm_inference_gemma(\n            decoder_layer.input_layernorm, hidden_states, out_weights[device_index]\n        )\n        hidden_states, present_key_value = LlamaAttention_fast_forward_inference(\n            decoder_layer.self_attn,\n            hidden_states = hidden_states,\n            past_key_value = past_key_values[idx],\n            position_ids = position_ids,\n            attention_mask = attention_mask,\n            do_prefill = not hasattr(decoder_layer.self_attn, \"paged_attention\"),\n            rotary_seq_len = rotary_seq_len,\n        )\n        hidden_states += residual\n\n        residual = hidden_states\n        hidden_states = fast_rms_layernorm_inference_gemma(\n            decoder_layer.post_attention_layernorm,\n            hidden_states,\n            out_weights[device_index],\n        )\n        hidden_states = fast_geglu_inference(decoder_layer.mlp, hidden_states)\n        hidden_states += residual\n\n        next_decoder_cache.append(present_key_value)\n    hidden_states = fast_rms_layernorm_inference_gemma(\n        self.model.norm, hidden_states, out_weights[device_index]\n    )\n\n    return BaseModelOutputWithPast(\n        last_hidden_state = hidden_states,\n        past_key_values = next_decoder_cache,\n        hidden_states = [],\n        attentions = [],\n    )\n\n\n# Follows line by line https://github.com/google-deepmind/gemma/blob/main/gemma/positional_embeddings.py#L45\n# Formulates cos and sin differently from Llama!\nclass GemmaFixedRotaryEmbedding(torch.nn.Module):\n    # Fixes https://github.com/huggingface/transformers/pull/28837\n    # https://github.com/microsoft/DeepSpeed/issues/4932\n    # The precision of RoPE buffers is not correct, so we cast to int64.\n    def __init__(\n        self,\n        dim = None,\n        max_position_embeddings = 2048,\n        base = 10000,\n        device = None,\n        config = None,  # [TODO] Hack to pass in config - need to remove later\n    ):\n        super().__init__()\n        # In transformers 5.0+, RotaryEmbedding(config) passes config as first positional arg (dim)\n        if (\n            config is None\n            and dim is not None\n            and hasattr(dim, \"max_position_embeddings\")\n        ):\n            config = dim\n            dim = None\n        if config is not None:\n            # [TODO] Hack to pass in config - need to remove later\n            base = _get_rope_theta(config, default = base)\n            partial_rotary_factor = (\n                config.partial_rotary_factor\n                if hasattr(config, \"partial_rotary_factor\")\n                else 1.0\n            )\n            dim = getattr(config, \"head_dim\", None)\n            if dim is None:\n                dim = int((config.hidden_size // config.num_attention_heads))\n            device = \"cuda\"\n            max_position_embeddings = config.max_position_embeddings\n        self.dim = dim\n        self.max_position_embeddings = max_position_embeddings\n        self.base = base\n        # Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this\n        self.current_rope_size = min(4 * 8192, self.max_position_embeddings)\n        self.multi_gpu_cos_cached = [None] * DEVICE_COUNT\n        self.multi_gpu_sin_cached = [None] * DEVICE_COUNT\n\n        # Build here to make `torch.jit.trace` work.\n        for device in range(DEVICE_COUNT):\n            self._set_cos_sin_cache(\n                seq_len = self.current_rope_size,\n                device = torch.device(device),\n                dtype = torch.get_default_dtype(),\n            )\n\n        # dummy so that patch_utils doesn't fail for now\n        self.cos_cached = torch.empty(\n            1, device = torch.cuda.current_device(), dtype = torch.get_default_dtype()\n        )\n        self.sin_cached = torch.empty(\n            1, device = torch.cuda.current_device(), dtype = torch.get_default_dtype()\n        )\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        # Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and\n        # in FP32. They are applied (multiplied) in FP32 as well.\n        self.current_rope_size = seq_len\n\n        # The difference is we do division explicitly instead of t * (1/x) ie we do t/x.\n        freq_exponents = (2.0 / self.dim) * (\n            torch.arange(self.dim // 2, dtype = torch.int64, device = \"cpu\").float()\n        )\n        timescale = self.base**freq_exponents\n        positions = torch.arange(\n            self.current_rope_size, device = \"cpu\", dtype = torch.int64\n        ).float()\n        radians_new = positions[..., None] / timescale[None, None, :]\n        radians_new = radians_new.squeeze(0)\n\n        emb = torch.cat((radians_new, radians_new), dim = -1)\n        # We must do RoPE in float32!\n        cos = emb.cos().to(device = device, non_blocking = True)  # , dtype = dtype)\n        sin = emb.sin().to(device = device, non_blocking = True)  # , dtype = dtype)\n        self.multi_gpu_cos_cached[device.index] = cos\n        self.multi_gpu_sin_cached[device.index] = sin\n        return cos, sin\n\n    def forward(self, x, position_ids = None, seq_len = None):\n        # x: [bs, num_attention_heads, seq_len, head_size]\n        if seq_len is not None and seq_len > self.current_rope_size:\n            self._set_cos_sin_cache(seq_len = seq_len, device = x.device, dtype = x.dtype)\n\n        device_index = x.device.index\n\n        return (\n            self.multi_gpu_cos_cached[device_index][:seq_len],\n            self.multi_gpu_sin_cached[device_index][:seq_len],\n        )\n\n    def get_cached(self, seq_len = None, device_index = None):\n        if device_index is None:\n            device_index = torch.cuda.current_device()\n        return self.multi_gpu_cos_cached[device_index], self.multi_gpu_sin_cached[\n            device_index\n        ]\n\n    def extend_rope_embedding(self, x, seq_len):\n        if seq_len <= self.current_rope_size:\n            return\n        # Iteratively grow by increments of 8192\n        self.current_rope_size = math.ceil(seq_len / 8192) * 8192\n        for device in range(DEVICE_COUNT):\n            self._set_cos_sin_cache(\n                self.current_rope_size, device = torch.device(device), dtype = x.dtype\n            )\n\n\nclass GemmaFixedLinearScalingRotaryEmbedding(GemmaFixedRotaryEmbedding):\n    \"\"\"LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev\"\"\"\n\n    # Fixes https://github.com/huggingface/transformers/pull/28837\n    # https://github.com/microsoft/DeepSpeed/issues/4932\n    # The precision of RoPE buffers is not correct, so we cast to int64.\n    def __init__(\n        self,\n        dim = None,\n        max_position_embeddings = 2048,\n        base = 10000,\n        device = None,\n        scaling_factor = 1.0,\n        config = None,  # [TODO] Hack to pass in config - need to remove later\n    ):\n        self.scaling_factor = scaling_factor\n        super().__init__(\n            dim = dim,\n            max_position_embeddings = max_position_embeddings,\n            base = base,\n            device = device,\n            config = config,\n        )\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        # Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and\n        # in FP32. They are applied (multiplied) in FP32 as well.\n        self.current_rope_size = seq_len\n\n        # The difference is we do division explicitly instead of t * (1/x) ie we do t/x.\n        freq_exponents = (2.0 / self.dim) * (\n            torch.arange(self.dim // 2, dtype = torch.int64, device = \"cpu\").float()\n        )\n        timescale = self.base**freq_exponents\n        positions = torch.arange(\n            self.current_rope_size, device = \"cpu\", dtype = torch.int64\n        ).float()\n        positions = positions / self.scaling_factor\n        radians_new = positions[..., None] / timescale[None, None, :]\n        radians_new = radians_new.squeeze(0)\n\n        emb = torch.cat((radians_new, radians_new), dim = -1)\n        # We must do RoPE in float32!\n        cos = emb.cos().to(device = device, non_blocking = True)  # , dtype = dtype)\n        sin = emb.sin().to(device = device, non_blocking = True)  # , dtype = dtype)\n        self.multi_gpu_cos_cached[device.index] = cos\n        self.multi_gpu_sin_cached[device.index] = sin\n        return cos, sin\n\n\nclass FastGemmaModel(FastLlamaModel):\n    @staticmethod\n    def pre_patch():\n        init_name, function = patch_linear_scaling(\n            model_name = \"gemma\",\n            rope_module = GemmaFixedRotaryEmbedding,\n            scaled_rope_module = GemmaFixedLinearScalingRotaryEmbedding,\n            attention_module = GemmaAttention,\n        )\n        if init_name is not None:\n            exec(function, globals())\n            GemmaAttention.__init__ = eval(init_name)\n        GemmaAttention.forward = LlamaAttention_fast_forward\n        GemmaSdpaAttention.forward = LlamaAttention_fast_forward\n        GemmaFlashAttention2.forward = LlamaAttention_fast_forward\n        GemmaDecoderLayer.forward = GemmaDecoderLayer_fast_forward\n        GemmaModel.forward = LlamaModel_fast_forward\n        GemmaForCausalLM.forward = CausalLM_fast_forward(\n            GemmaModel_fast_forward_inference\n        )\n        PeftModelForCausalLM.forward = PeftModel_fast_forward\n        fix_prepare_inputs_for_generation(GemmaForCausalLM)\n\n        # Solves https://github.com/unslothai/unsloth/issues/168\n        # Static KV Cache was introduced in 4.38.0, causing training to be much slower.\n        # Inference can now be CUDAGraphed, but we shall retain the old rotary embeddings.\n        # https://github.com/huggingface/transformers/pull/27931\n        # https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py\n        import transformers.models.gemma.modeling_gemma\n\n        transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding = (\n            GemmaFixedRotaryEmbedding\n        )\n        return\n\n    @staticmethod\n    def post_patch(model, tokenizer, correct_dtype = None):\n        # Gemma does not downcast RoPE\n        model, tokenizer = patch_model_and_tokenizer(\n            model, tokenizer, downcast_rope = False, correct_dtype = correct_dtype\n        )\n\n        # Add 1 to weight\n        # return output * (1 + self.weight)\n        # https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma/modeling_gemma.py#L89\n        from transformers.models.gemma.modeling_gemma import GemmaRMSNorm\n\n        # Freeze all parameters except LoRA\n        # We do this first since += 1 seems to not be liked by requires_grad = True\n        for name, param in model.named_parameters():\n            if \".lora_A.\" in name or \".lora_B.\" in name:\n                param.requires_grad_(True)\n            else:\n                param.requires_grad_(False)\n\n        # Patch RMS Layernorm\n        for name, module in model.named_modules():\n            if isinstance(module, GemmaRMSNorm):\n                # Must be in float32\n                # https://github.com/keras-team/keras-nlp/blob/v0.8.2/keras_nlp/models/gemma/rms_normalization.py#L36\n                # module = module.to(torch.float32)\n                # Leave + 1 to Triton kernel itself\n                # module.weight += 1.0 # return output * (1 + self.weight)\n                if not hasattr(module, \"variance_epsilon\"):\n                    module.variance_epsilon = (\n                        module.eps\n                    )  # Gemma doesn't use variance_epsilon\n\n        # Clear deleted GPU items\n        import gc\n\n        for _ in range(3):\n            gc.collect()\n            torch.cuda.empty_cache()\n        return model, tokenizer\n"
  },
  {
    "path": "unsloth/models/gemma2.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .llama import *\nfrom ._utils import __version__\nfrom unsloth_zoo.utils import _get_dtype, Version\nfrom unsloth_zoo.hf_utils import dtype_from_config\nfrom ..utils.packing import get_packed_info_from_kwargs\nfrom ..utils.attention_dispatch import (\n    AttentionConfig,\n    AttentionContext,\n    run_attention,\n    select_attention_backend,\n    SDPA,\n)\nfrom .gemma import (\n    GemmaFixedRotaryEmbedding,\n    GemmaFixedLinearScalingRotaryEmbedding,\n    fast_geglu_inference,\n)\n\ntry:\n    from transformers.models.gemma2.modeling_gemma2 import (\n        Gemma2Attention,\n        Gemma2DecoderLayer,\n        Gemma2Model,\n        Gemma2ForCausalLM,\n        Gemma2RotaryEmbedding,\n        apply_rotary_pos_emb,\n        repeat_kv,\n    )\nexcept:\n    transformers_version = Version(transformers_version)\n    if not transformers_version >= Version(\"4.42\"):\n        raise ImportError(\n            f\"Unsloth: Your transformers version of {transformers_version} does not support Gemma2.\\n\"\n            f\"The minimum required version is 4.42.3.\\n\"\n            f'Try `pip install --upgrade \"transformers>=4.42.3\"`\\n'\n            f\"to obtain the latest transformers build, then restart this session.\"\n        )\n\nfrom transformers.modeling_attn_mask_utils import (\n    _prepare_4d_causal_attention_mask_for_sdpa,\n)\n\n# For Pytorch 2.1.1\ntry:\n    from transformers.models.gemma2.modeling_gemma2 import (\n        Gemma2SdpaAttention,\n        Gemma2FlashAttention2,\n    )\nexcept:\n    Gemma2SdpaAttention = Gemma2Attention\n    Gemma2FlashAttention2 = Gemma2Attention\n\nif HAS_FLASH_ATTENTION_SOFTCAPPING:\n    from flash_attn import flash_attn_func\n\n\n# Logit softcapping\ndef Gemma2Attention_fast_forward(\n    self,\n    hidden_states: torch.Tensor,\n    causal_mask: Optional[BlockDiagonalCausalMask] = None,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_value: Optional[Tuple[torch.Tensor]] = None,\n    output_attentions: bool = False,\n    use_cache: bool = False,\n    padding_mask: Optional[torch.LongTensor] = None,\n    *args,\n    **kwargs,\n) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n    # Clear inference\n    if hasattr(self, \"paged_attention\"):\n        del self.paged_attention_K\n        del self.paged_attention_V\n        del self.paged_attention\n        del self.temp_QA\n        del self.temp_KV\n        del self.RH_Q\n        del self.attention\n\n    bsz, q_len, _ = hidden_states.size()\n\n    n_heads = self.config.num_attention_heads\n    n_groups = self.num_key_value_groups\n    n_kv_heads = self.config.num_key_value_heads\n    head_dim = self.head_dim\n    assert n_kv_heads * n_groups == n_heads\n\n    Q, K, V = self.apply_qkv(self, hidden_states)\n    Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)\n    K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)\n    V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)\n    seq_info = get_packed_info_from_kwargs(kwargs, Q.device)\n\n    kv_seq_len = K.shape[-2]\n    if past_key_value is not None:\n        kv_seq_len += past_key_value[0].shape[-2]\n\n    device_index = Q.device.index\n    cos = self.rotary_emb.multi_gpu_cos_cached[device_index]\n    sin = self.rotary_emb.multi_gpu_sin_cached[device_index]\n\n    rope_position_ids = (\n        position_ids if position_ids is not None else kwargs.get(\"position_ids\")\n    )\n    if rope_position_ids is not None:\n        # Useful for LongRoPE\n        cos_var, sin_var = self.rotary_emb.get_cached(kv_seq_len, device_index)\n        Q, K = fast_rope_embedding(Q, K, cos_var, sin_var, rope_position_ids)\n    else:\n        Q, K = fast_rope_embedding(Q, K, cos, sin)\n\n    if past_key_value is not None:\n        K = torch.cat([past_key_value[0], K], dim = 2)\n        V = torch.cat([past_key_value[1], V], dim = 2)\n    past_key_value = (K, V) if use_cache else None\n\n    # Only enable if the attention_mask is True\n    use_sliding_window = kwargs.get(\"use_sliding_window\")\n    has_sliding_window = (\n        use_sliding_window\n        if use_sliding_window is not None\n        else isinstance(causal_mask, bool) and causal_mask is True\n    )\n\n    use_flash = HAS_FLASH_ATTENTION_SOFTCAPPING and attention_mask is None\n\n    if use_flash:\n        window = (-1, -1)\n        sliding_window = getattr(self.config, \"sliding_window\", None)\n        if has_sliding_window:\n            sliding_window = (\n                sliding_window if sliding_window is not None else kv_seq_len\n            )\n            window = (\n                (-1, -1)\n                if kv_seq_len <= sliding_window\n                else (sliding_window, sliding_window)\n            )\n\n        if not hasattr(self, \"_flash_attention_softmax_scale\"):\n            self._flash_attention_softmax_scale = 1.0 / (\n                self.config.query_pre_attn_scalar**0.5\n            )\n\n        use_varlen = seq_info is not None and past_key_value is None\n\n        attention_config = AttentionConfig(\n            backend = select_attention_backend(use_varlen),\n            n_kv_heads = n_kv_heads,\n            n_groups = n_groups,\n            flash_dense_kwargs = {\n                \"causal\": True,\n                \"softcap\": self.config.attn_logit_softcapping,\n                \"softmax_scale\": self._flash_attention_softmax_scale,\n                \"window_size\": window,\n            },\n            flash_varlen_kwargs = {\n                \"dropout_p\": 0.0,\n                \"softmax_scale\": self._flash_attention_softmax_scale,\n                \"causal\": True,\n                \"softcap\": self.config.attn_logit_softcapping,\n                \"window_size\": window,\n            },\n        )\n\n        context = AttentionContext(\n            bsz = bsz,\n            q_len = q_len,\n            kv_seq_len = kv_seq_len,\n            n_heads = n_heads,\n            head_dim = head_dim,\n            requires_grad = hidden_states.requires_grad,\n            seq_info = seq_info,\n            attention_mask = attention_mask,\n            causal_mask = causal_mask,\n            sliding_window = sliding_window,\n        )\n\n        A = run_attention(config = attention_config, context = context, Q = Q, K = K, V = V)\n        A = A.reshape(bsz, q_len, n_heads * head_dim)\n    else:\n        fx = (\n            slow_inference_attention_softcapping\n            if \"_flag_for_generation\" in kwargs\n            else slow_attention_softcapping\n        )\n        A = fx(Q, K, V, causal_mask, self, bsz, kv_seq_len)\n    A = self.apply_o(self, A)\n    return A, None, past_key_value\n\n\n# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L590\ndef Gemma2DecoderLayer_fast_forward(\n    self,\n    hidden_states: torch.Tensor,\n    causal_mask: Optional[BlockDiagonalCausalMask] = None,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_value: Optional[Tuple[torch.Tensor]] = None,\n    output_attentions: Optional[bool] = False,\n    use_cache: Optional[bool] = False,\n    padding_mask: Optional[torch.LongTensor] = None,\n    *args,\n    **kwargs,\n):\n    if use_cache and hasattr(\n        self, \"_flag_for_generation\"\n    ):  # past_key_value is not None:\n        out_weight = torch.empty(\n            self.input_layernorm.weight.shape,\n            dtype = torch.float32,\n            device = f\"{DEVICE_TYPE_TORCH}:0\",\n        )\n\n        # Self Attention\n        residual = hidden_states\n        hidden_states = fast_rms_layernorm_inference_gemma(\n            self.input_layernorm, hidden_states, out_weight\n        )\n        hidden_states, self_attn_weights, present_key_value = self.self_attn(\n            hidden_states = hidden_states,\n            causal_mask = causal_mask,\n            attention_mask = attention_mask,\n            position_ids = position_ids,\n            past_key_value = past_key_value,\n            output_attentions = output_attentions,\n            use_cache = use_cache,\n            padding_mask = padding_mask,\n            _flag_for_generation = self._flag_for_generation,\n            **kwargs,\n        )\n        hidden_states = fast_rms_layernorm_inference_gemma(\n            self.post_attention_layernorm, hidden_states, out_weight\n        )\n        hidden_states += residual\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = fast_rms_layernorm_inference_gemma(\n            self.pre_feedforward_layernorm, hidden_states, out_weight\n        )\n        hidden_states = fast_geglu_inference(self.mlp, hidden_states)\n        hidden_states = fast_rms_layernorm_inference_gemma(\n            self.post_feedforward_layernorm, hidden_states, out_weight\n        )\n        hidden_states += residual\n    else:\n        residual = hidden_states\n        hidden_states = fast_rms_layernorm(\n            self.input_layernorm, hidden_states, gemma = True\n        )\n        hidden_states, self_attn_weights, present_key_value = self.self_attn(\n            hidden_states = hidden_states,\n            causal_mask = causal_mask,\n            attention_mask = attention_mask,\n            position_ids = position_ids,\n            past_key_value = past_key_value,\n            output_attentions = output_attentions,\n            use_cache = use_cache,\n            padding_mask = padding_mask,\n            **kwargs,\n        )\n        hidden_states = fast_rms_layernorm(\n            self.post_attention_layernorm, hidden_states, gemma = True\n        )\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = fast_rms_layernorm(\n            self.pre_feedforward_layernorm, hidden_states, gemma = True\n        )\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = fast_rms_layernorm(\n            self.post_feedforward_layernorm, hidden_states, gemma = True\n        )\n        hidden_states = residual + hidden_states\n\n    outputs = (hidden_states,)\n    if output_attentions:\n        outputs += (self_attn_weights,)\n    if use_cache:\n        outputs += (present_key_value,)\n    return outputs\n\n\nfrom math import sqrt as math_sqrt\n\nKV_CACHE_INCREMENT = 256  # KV Cache update size\ntorch_nn_functional_softmax = torch.nn.functional.softmax\ntorch_matmul = torch.matmul\ntorch_tanh = torch.tanh\n\n\ndef Gemma2Attention_fast_forward_inference(\n    self,\n    hidden_states: torch.Tensor,\n    past_key_value: Optional[Tuple[torch.Tensor]],\n    position_ids,\n    do_prefill = False,\n    attention_mask = None,\n    use_sliding_window = False,\n    **kwargs,\n):\n    Xn = hidden_states\n    bsz, _, hd = hidden_states.size()\n    K1, V1 = past_key_value\n    dtype = Xn.dtype\n\n    n_heads = self.config.num_attention_heads\n    n_groups = self.num_key_value_groups\n    n_kv_heads = self.config.num_key_value_heads\n    head_dim = self.head_dim\n    # assert(n_kv_heads * n_groups == n_heads)\n\n    hidden_size = self.config.hidden_size\n    attention_size = n_heads * head_dim\n    seq_len = K1.shape[-2]\n    kv_seq_len = seq_len + 1\n    device = hidden_states.device\n\n    # Prefill phase\n    # if not hasattr(self, \"paged_attention\"):\n    if do_prefill:\n        self.paged_attention = torch.empty(\n            (KV_CACHE_INCREMENT + seq_len + 1, 2, bsz, n_kv_heads, head_dim),\n            dtype = dtype,\n            device = device,\n        )\n        self.paged_attention_K = self.paged_attention[:, 0]\n        self.paged_attention_V = self.paged_attention[:, 1]\n        self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)\n        self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)\n        self.temp_QA = torch.empty(\n            (2, bsz, 1, attention_size), dtype = dtype, device = device\n        )\n        self.temp_KV = torch.empty(\n            (2, bsz, 1, n_kv_heads * head_dim), dtype = dtype, device = device\n        )\n        self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = device)\n        # Only for Gemma2\n        self.temp_O = torch.empty((bsz, 1, hidden_size), dtype = dtype, device = device)\n        self.attention = torch.empty(\n            (bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype = dtype, device = device\n        )\n\n        # See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e\n        # Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below\n        # We default to using the config file itself\n        # s = self.config.hidden_size // self.config.num_attention_heads\n        self.scalar = 1.0 / math_sqrt(self.config.query_pre_attn_scalar)\n        # self.scalar = 1.0 / math_sqrt(self.config.hidden_size // self.config.num_attention_heads)\n        self.half_head_dim = head_dim // 2\n        self.t = self.config.attn_logit_softcapping\n        self.reciprocal_t = 1.0 / self.config.attn_logit_softcapping\n    elif kv_seq_len >= self.paged_attention.shape[0]:\n        self.paged_attention.resize_(\n            (\n                self.paged_attention.shape[0] + KV_CACHE_INCREMENT,\n                2,\n                bsz,\n                n_kv_heads,\n                head_dim,\n            )\n        )\n        self.paged_attention_K = self.paged_attention[:, 0]\n        self.paged_attention_V = self.paged_attention[:, 1]\n        self.attention.resize_(\n            (bsz, n_heads, 1, self.attention.shape[-1] + KV_CACHE_INCREMENT)\n        )\n\n    Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])\n    Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0])\n    Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1])\n    Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2)\n    Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)\n    Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)\n\n    # cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len)\n    # Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids)\n    cos, sin = self.rotary_emb.get_cached(kv_seq_len, Qn.device.index)\n    cos = cos[position_ids].unsqueeze(1)\n    sin = sin[position_ids].unsqueeze(1)\n    h = self.half_head_dim\n\n    RH_Q = self.RH_Q\n    RH_Q[:, :, :, :h] = Qn[:, :, :, h:]\n    RH_Q[:, :, :, h:] = Qn[:, :, :, :h]\n    RH_Q[:, :, :, :h].neg_()\n    Qn *= cos\n    Qn.addcmul_(RH_Q, sin)\n\n    RH_K = RH_Q[\n        :, :n_kv_heads, :, :\n    ]  # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = \"cuda:0\")\n    RH_K[:, :, :, :h] = Kn[:, :, :, h:]\n    RH_K[:, :, :, h:] = Kn[:, :, :, :h]\n    RH_K[:, :, :, :h].neg_()\n    Kn *= cos\n    Kn.addcmul_(RH_K, sin)\n\n    # New KV cache\n    # Kn = torch.cat([K1, Kn], dim = 2)\n    # Vn = torch.cat([V1, Vn], dim = 2)\n    self.paged_attention_K[seq_len] = Kn.permute(2, 0, 1, 3)\n    self.paged_attention_V[seq_len] = Vn.permute(2, 0, 1, 3)\n    Kn = self.paged_attention_K[:kv_seq_len].permute(1, 2, 0, 3)\n    Vn = self.paged_attention_V[:kv_seq_len].permute(1, 2, 0, 3)\n\n    # Handle sliding windows\n    sliding_window = self.config.sliding_window\n    if use_sliding_window and kv_seq_len > sliding_window:\n        start = kv_seq_len - sliding_window\n        Knn = Kn[:, :, start:, :]  # .contiguous()\n        Vnn = Vn[:, :, start:, :]  # .contiguous()\n    else:\n        Knn, Vnn = Kn, Vn\n\n    # Grouped query attention\n    _, _, cached_len, _ = Knn.shape\n    if n_groups != 1:\n        Knn = Knn[:, :, None, :, :].expand(\n            bsz, n_kv_heads, n_groups, cached_len, head_dim\n        )\n        Vnn = Vnn[:, :, None, :, :].expand(\n            bsz, n_kv_heads, n_groups, cached_len, head_dim\n        )\n        Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim)\n        Vnn = Vnn.reshape(bsz, n_heads, cached_len, head_dim)\n\n    # Attention\n    # [TODO] Gemma2 uses manual matmul for all batch sizes because SDPA does\n    # not support softcapping (tanh logit scaling). If a future PyTorch adds\n    # a softcap param to scaled_dot_product_attention, consider using SDPA\n    # for bsz > 1 to match the llama/qwen3 pattern.\n    Qn *= (\n        self.scalar\n    )  # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963\n    # It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows\n    A = torch_matmul(Qn, Knn.transpose(2, 3), out = self.attention[:, :, :, :cached_len])\n\n    # Softcapping must happen BEFORE the mask is applied.\n    # Reference: google-deepmind/gemma _modules.py and transformers gemma2 eager_attention_forward\n    A *= self.reciprocal_t\n    A.tanh_()\n    A *= self.t  # Logit softcapping\n\n    if attention_mask is not None and isinstance(attention_mask, torch.Tensor):\n        # Slice mask to match K/V when sliding window is active\n        if attention_mask.shape[-1] != A.shape[-1]:\n            attention_mask = attention_mask[:, :, :, -A.shape[-1] :]\n        A += attention_mask\n\n    A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)  # .to(A.dtype)\n    A = torch_matmul(A, Vnn, out = Qn)\n    A = A.transpose(1, 2)\n    A = A.reshape(bsz, 1, attention_size)\n    A = fast_linear_forward(self.o_proj, A, out = self.temp_O)\n    return A, (Kn, Vn)\n\n\n# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825\n# @torch.inference_mode\ndef Gemma2Model_fast_forward_inference(\n    self,\n    input_ids,\n    past_key_values,\n    position_ids,\n    attention_mask = None,\n    **kwargs,\n):\n    out_weights = tuple(\n        torch.empty_like(\n            self.model.layers[0].input_layernorm.weight,\n            dtype = torch.float32,\n            device = torch.device(x),\n        )\n        for x in range(DEVICE_COUNT)\n    )\n    input_ids = input_ids[:, : self.max_seq_length]\n    hidden_states = self.model.embed_tokens(input_ids)\n    hidden_states = hidden_states.to(_get_dtype(dtype_from_config(self.config)))\n    # 3072**0.5 = 55.5000 in bfloat16, whilst 55.4256 in float32\n    # 2048**0.5 = 45.2500 in bfloat16, whilst 45.2548 in float32\n    hidden_states *= torch.tensor(\n        math_sqrt(self.config.hidden_size), dtype = hidden_states.dtype\n    )\n\n    bsz, q_len, hd = hidden_states.shape\n    seq_len = past_key_values[0][0].shape[-2]\n    if bsz != 1:\n        if HAS_FLASH_ATTENTION_SOFTCAPPING:\n            SWA = True\n            GA = False\n        else:\n            SWA = _prepare_4d_causal_attention_mask_for_sdpa(\n                attention_mask,\n                (bsz, q_len),\n                hidden_states,\n                seq_len,\n                sliding_window = self.config.sliding_window,\n            )\n            GA = _prepare_4d_causal_attention_mask_for_sdpa(\n                attention_mask,\n                (bsz, q_len),\n                hidden_states,\n                seq_len,\n            )\n    else:\n        SWA = attention_mask\n        GA = attention_mask\n    next_decoder_cache = []\n    for idx, decoder_layer in enumerate(self.model.layers):\n        # For pipeline parallelism, we need to move all tensors to the same device\n        # note that this movement is once per GPU in PP\n        device_index = getattr(decoder_layer, \"_per_layer_device_index\", 0)\n        hidden_states, position_ids = move_to_device(\n            device_index, hidden_states, position_ids\n        )\n\n        use_sliding_window = idx % 2 == 0\n\n        residual = hidden_states\n        hidden_states = fast_rms_layernorm_inference_gemma(\n            decoder_layer.input_layernorm, hidden_states, out_weights[device_index]\n        )\n        hidden_states, present_key_value = Gemma2Attention_fast_forward_inference(\n            decoder_layer.self_attn,\n            hidden_states = hidden_states,\n            past_key_value = past_key_values[idx],\n            position_ids = position_ids,\n            attention_mask = SWA if use_sliding_window else GA,\n            do_prefill = not hasattr(decoder_layer.self_attn, \"paged_attention\"),\n            use_sliding_window = use_sliding_window,\n        )\n        hidden_states = fast_rms_layernorm_inference_gemma(\n            decoder_layer.post_attention_layernorm,\n            hidden_states,\n            out_weights[device_index],\n        )\n        hidden_states += residual\n\n        residual = hidden_states\n        hidden_states = fast_rms_layernorm_inference_gemma(\n            decoder_layer.pre_feedforward_layernorm,\n            hidden_states,\n            out_weights[device_index],\n        )\n        hidden_states = fast_geglu_inference(decoder_layer.mlp, hidden_states)\n        hidden_states = fast_rms_layernorm_inference_gemma(\n            decoder_layer.post_feedforward_layernorm,\n            hidden_states,\n            out_weights[device_index],\n        )\n        hidden_states += residual\n\n        next_decoder_cache.append(present_key_value)\n    hidden_states = fast_rms_layernorm_inference_gemma(\n        self.model.norm, hidden_states, out_weights[device_index]\n    )\n\n    return BaseModelOutputWithPast(\n        last_hidden_state = hidden_states,\n        past_key_values = next_decoder_cache,\n        hidden_states = [],\n        attentions = [],\n    )\n\n\nclass FastGemma2Model(FastLlamaModel):\n    @staticmethod\n    def pre_patch():\n        init_name, function = patch_linear_scaling(\n            model_name = \"gemma2\",\n            rope_module = GemmaFixedRotaryEmbedding,\n            scaled_rope_module = GemmaFixedLinearScalingRotaryEmbedding,\n            attention_module = Gemma2Attention,\n        )\n        if init_name is not None:\n            exec(function, globals())\n            Gemma2Attention.__init__ = eval(init_name)\n        Gemma2Attention.forward = Gemma2Attention_fast_forward\n        Gemma2SdpaAttention.forward = Gemma2Attention_fast_forward\n        Gemma2FlashAttention2.forward = Gemma2Attention_fast_forward\n        Gemma2DecoderLayer.forward = Gemma2DecoderLayer_fast_forward\n        Gemma2Model.forward = LlamaModel_fast_forward\n        Gemma2ForCausalLM.forward = CausalLM_fast_forward(\n            Gemma2Model_fast_forward_inference\n        )\n        PeftModelForCausalLM.forward = PeftModel_fast_forward\n        fix_prepare_inputs_for_generation(Gemma2ForCausalLM)\n\n        # Solves https://github.com/unslothai/unsloth/issues/168\n        # Static KV Cache was introduced in 4.38.0, causing training to be much slower.\n        # Inference can now be CUDAGraphed, but we shall retain the old rotary embeddings.\n        # https://github.com/huggingface/transformers/pull/27931\n        # https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py\n        import transformers.models.gemma2.modeling_gemma2\n\n        transformers.models.gemma2.modeling_gemma2.Gemma2RotaryEmbedding = (\n            GemmaFixedRotaryEmbedding\n        )\n        return\n\n    @staticmethod\n    def post_patch(model, tokenizer, correct_dtype = None):\n        # Gemma does not downcast RoPE\n        model, tokenizer = patch_model_and_tokenizer(\n            model, tokenizer, downcast_rope = False, correct_dtype = correct_dtype\n        )\n\n        # Add 1 to weight\n        # return output * (1 + self.weight)\n        # https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma/modeling_gemma.py#L89\n        from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm\n\n        # Freeze all parameters except LoRA\n        # We do this first since += 1 seems to not be liked by requires_grad = True\n        for name, param in model.named_parameters():\n            if \".lora_A.\" in name or \".lora_B.\" in name:\n                param.requires_grad_(True)\n            else:\n                param.requires_grad_(False)\n\n        # Patch RMS Layernorm\n        for name, module in model.named_modules():\n            if isinstance(module, Gemma2RMSNorm):\n                # Must be in float32\n                # https://github.com/keras-team/keras-nlp/blob/v0.8.2/keras_nlp/models/gemma/rms_normalization.py#L36\n                # module = module.to(torch.float32)\n                # Leave + 1 to Triton kernel itself\n                # module.weight += 1.0 # return output * (1 + self.weight)\n                if not hasattr(module, \"variance_epsilon\"):\n                    module.variance_epsilon = (\n                        module.eps\n                    )  # Gemma doesn't use variance_epsilon\n\n        # Clear deleted GPU items\n        import gc\n\n        for _ in range(3):\n            gc.collect()\n            torch.cuda.empty_cache()\n        return model, tokenizer\n"
  },
  {
    "path": "unsloth/models/glm4_moe.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nGLM-4.7 Flash (GLM4 MoE Lite) optimized implementation using grouped GEMM.\n\nKey architecture differences from Qwen3 MoE:\n- Router uses sigmoid activation (not softmax)\n- Has routed_scaling_factor of 1.8\n- Has 1 shared expert that processes all tokens\n- Uses group-based selection before topk\n- Uses MLA (Multi-head Latent Attention)\n\"\"\"\n\nfrom .llama import *\nimport os\nfrom ._utils import __version__\nfrom .llama import (\n    LlamaRotaryEmbedding,\n    LlamaLinearScalingRotaryEmbedding,\n    fix_prepare_inputs_for_generation,\n    fast_rms_layernorm_inference,\n    fast_swiglu_inference,\n    LlamaModel_fast_forward,\n    LlamaModel_fast_forward_inference,\n    CausalLM_fast_forward,\n    PeftModel_fast_forward,\n)\nimport torch\nimport torch.nn.functional as F\nfrom typing import Optional, Tuple\nfrom ..kernels import fast_rms_layernorm\n\n# Import the grouped gemm utilities from unsloth kernels\n# The grouped_gemm module expects its parent directory to be in sys.path\nHAS_GROUPED_GEMM = False\ntry:\n    import sys\n    import os\n\n    # Add the moe directory (parent of grouped_gemm) to sys.path\n    _moe_path = os.path.join(\n        os.path.dirname(os.path.dirname(os.path.abspath(__file__))), \"kernels\", \"moe\"\n    )\n    if _moe_path not in sys.path:\n        sys.path.insert(0, _moe_path)\n\n    # Import grouped_gemm package first to apply TMA compatibility shim\n    # This patches triton.language to support both old and new TMA API names\n    import grouped_gemm  # noqa: F401 - triggers TMA compatibility shim\n\n    from grouped_gemm.interface import grouped_gemm\n    from grouped_gemm.reference.moe_ops import (\n        get_routing_indices,\n        permute,\n        unpermute,\n    )\n\n    HAS_GROUPED_GEMM = True\nexcept ImportError as e:\n    import warnings\n\n    warnings.warn(\n        f\"Grouped GEMM not available: {e}. MoE will use fallback implementation.\"\n    )\n\n\n# Import transformers GLM4 MoE Lite classes\ntry:\n    from transformers.models.glm4_moe_lite.modeling_glm4_moe_lite import (\n        Glm4MoeLiteAttention,\n        Glm4MoeLiteMoE,\n        Glm4MoeLiteMLP,\n        Glm4MoeLiteNaiveMoe,\n        Glm4MoeLiteTopkRouter,\n        Glm4MoeLiteDecoderLayer,\n        Glm4MoeLiteModel,\n        Glm4MoeLiteForCausalLM,\n        Glm4MoeLiteRMSNorm,\n    )\n\n    HAS_GLM4_MOE = True\nexcept ImportError:\n    HAS_GLM4_MOE = False\n\n    # Create dummy classes for type checking\n    class Glm4MoeLiteAttention:\n        pass\n\n    class Glm4MoeLiteMoE:\n        pass\n\n    class Glm4MoeLiteMLP:\n        pass\n\n    class Glm4MoeLiteNaiveMoe:\n        pass\n\n    class Glm4MoeLiteTopkRouter:\n        pass\n\n    class Glm4MoeLiteDecoderLayer:\n        pass\n\n    class Glm4MoeLiteModel:\n        pass\n\n    class Glm4MoeLiteForCausalLM:\n        pass\n\n\ntorch_nn_functional_silu = torch.nn.functional.silu\n\n\ndef Glm4MoeLiteMoE_fast_forward(self, hidden_states):\n    \"\"\"\n    Optimized MoE forward pass using grouped GEMM.\n\n    GLM4 MoE specifics:\n    - Uses sigmoid router activation (not softmax)\n    - Has routed_scaling_factor of 1.8\n    - Has 1 shared expert that always processes all tokens\n    - Uses group-based selection with topk_group\n    \"\"\"\n    residuals = hidden_states\n    orig_shape = hidden_states.shape\n    batch_size, seq_len, hidden_dim = orig_shape\n    num_tokens = batch_size * seq_len\n\n    # Flatten hidden states for routing\n    hidden_states = hidden_states.view(-1, hidden_dim)\n\n    # Router computation\n    router_logits = self.gate(hidden_states)  # [num_tokens, n_routed_experts]\n    topk_indices, topk_weights = self.route_tokens_to_experts(router_logits)\n    # Cast routing weights to match hidden_states dtype (Qwen3 pattern)\n    # Sigmoid router returns fp32, but hidden_states may be bf16\n    topk_weights = topk_weights.to(hidden_states.dtype)\n\n    # Get routing indices for grouped GEMM\n    with torch.no_grad():\n        token_counts_by_expert, gather_indices = get_routing_indices(\n            topk_indices, self.n_routed_experts\n        )\n\n    # Use grouped GEMM for expert computation\n    if HAS_GROUPED_GEMM:\n        # Cast hidden_states to match expert weights dtype\n        # Under autocast, hidden_states may be fp32 while weights are bf16\n        hidden_states = hidden_states.to(self.experts.gate_up_proj.dtype)\n\n        # First grouped GEMM: gate_up_proj with permute_x\n        # Input: [num_tokens, hidden_dim] -> Output: [total_tokens, 2*intermediate_dim]\n        intermediate = grouped_gemm(\n            X = hidden_states,\n            W = self.experts.gate_up_proj,\n            m_sizes = token_counts_by_expert.int(),\n            topk = self.top_k,\n            gather_indices = gather_indices,\n            permute_x = True,\n            permute_y = False,\n            autotune = True,\n            is_first_gemm = True,\n        )\n\n        # Activation: SiLU(gate) * up\n        gate, up = intermediate.chunk(2, dim = -1)\n        intermediate = torch_nn_functional_silu(gate) * up\n\n        # Second grouped GEMM: down_proj with permute_y\n        # Input: [total_tokens, intermediate_dim] -> Output: [total_tokens, hidden_dim]\n        expert_output = grouped_gemm(\n            X = intermediate,\n            W = self.experts.down_proj,\n            m_sizes = token_counts_by_expert.int(),\n            topk = self.top_k,\n            gather_indices = gather_indices,\n            permute_x = False,\n            permute_y = True,\n            autotune = True,\n            is_first_gemm = False,\n        )\n\n        # Merge topk weights: [num_tokens, top_k, hidden_dim] -> [num_tokens, hidden_dim]\n        hidden_states = (\n            expert_output.view(num_tokens, self.top_k, hidden_dim)\n            * topk_weights.unsqueeze(-1)\n        ).sum(dim = 1)\n    else:\n        # Fallback to naive implementation\n        hidden_states = self.experts(hidden_states, topk_indices, topk_weights)\n\n    # Add shared expert output\n    hidden_states = hidden_states + self.shared_experts(residuals.view(-1, hidden_dim))\n\n    return hidden_states.view(*orig_shape)\n\n\ndef Glm4MoeLiteNaiveMoe_fast_forward(\n    self,\n    hidden_states: torch.Tensor,\n    top_k_index: torch.Tensor,\n    top_k_weights: torch.Tensor,\n) -> torch.Tensor:\n    \"\"\"\n    Optimized expert forward using grouped GEMM.\n\n    Args:\n        hidden_states: [num_tokens, hidden_dim]\n        top_k_index: [num_tokens, top_k] indices of selected experts\n        top_k_weights: [num_tokens, top_k] weights for selected experts\n\n    Returns:\n        [num_tokens, hidden_dim] output after weighted sum of expert outputs\n    \"\"\"\n    num_tokens, hidden_dim = hidden_states.shape\n    top_k = top_k_index.shape[1]\n    # Cast routing weights to match hidden_states dtype (Qwen3 pattern)\n    top_k_weights = top_k_weights.to(hidden_states.dtype)\n\n    if not HAS_GROUPED_GEMM:\n        # Fallback to original naive implementation\n        final_hidden_states = torch.zeros_like(hidden_states)\n        with torch.no_grad():\n            expert_mask = torch.nn.functional.one_hot(\n                top_k_index, num_classes = self.num_experts\n            )\n            expert_mask = expert_mask.permute(2, 1, 0)\n            expert_hit = torch.greater(expert_mask.sum(dim = (-1, -2)), 0).nonzero()\n\n        for expert_idx in expert_hit:\n            expert_idx = expert_idx[0]\n            if expert_idx == self.num_experts:\n                continue\n            top_k_pos, token_idx = torch.where(expert_mask[expert_idx])\n            current_state = hidden_states[token_idx]\n            gate, up = torch.nn.functional.linear(\n                current_state, self.gate_up_proj[expert_idx]\n            ).chunk(2, dim = -1)\n            current_hidden_states = self.act_fn(gate) * up\n            current_hidden_states = torch.nn.functional.linear(\n                current_hidden_states, self.down_proj[expert_idx]\n            )\n            current_hidden_states = (\n                current_hidden_states * top_k_weights[token_idx, top_k_pos, None]\n            )\n            final_hidden_states.index_add_(\n                0, token_idx, current_hidden_states.to(final_hidden_states.dtype)\n            )\n\n        return final_hidden_states\n\n    # Get routing indices for grouped GEMM\n    with torch.no_grad():\n        token_counts_by_expert, gather_indices = get_routing_indices(\n            top_k_index, self.num_experts\n        )\n\n    # Cast hidden_states to match expert weights dtype\n    # Under autocast, hidden_states may be fp32 while weights are bf16\n    hidden_states = hidden_states.to(self.gate_up_proj.dtype)\n\n    # First grouped GEMM: gate_up_proj\n    intermediate = grouped_gemm(\n        X = hidden_states,\n        W = self.gate_up_proj,\n        m_sizes = token_counts_by_expert.int(),\n        topk = top_k,\n        gather_indices = gather_indices,\n        permute_x = True,\n        permute_y = False,\n        autotune = True,\n        is_first_gemm = True,\n    )\n\n    # Activation: SiLU(gate) * up\n    gate, up = intermediate.chunk(2, dim = -1)\n    intermediate = self.act_fn(gate) * up\n\n    # Second grouped GEMM: down_proj\n    expert_output = grouped_gemm(\n        X = intermediate,\n        W = self.down_proj,\n        m_sizes = token_counts_by_expert.int(),\n        topk = top_k,\n        gather_indices = gather_indices,\n        permute_x = False,\n        permute_y = True,\n        autotune = True,\n        is_first_gemm = False,\n    )\n\n    # Merge topk weights\n    final_hidden_states = (\n        expert_output.view(num_tokens, top_k, hidden_dim) * top_k_weights.unsqueeze(-1)\n    ).sum(dim = 1)\n\n    return final_hidden_states\n\n\ndef Glm4MoeLiteDecoderLayer_fast_forward(\n    self,\n    hidden_states: torch.Tensor,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_values = None,\n    use_cache: bool = False,\n    cache_position: Optional[torch.LongTensor] = None,\n    position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n    **kwargs,\n) -> torch.Tensor:\n    \"\"\"\n    Optimized decoder layer forward with fast RMS layernorm.\n    \"\"\"\n    # Check if we're in inference mode\n    is_inference = use_cache and hasattr(self, \"_flag_for_generation\")\n\n    if is_inference:\n        # Self-attention with fast inference path\n        residual = hidden_states\n        hidden_states = fast_rms_layernorm_inference(\n            self.input_layernorm, hidden_states\n        )\n        hidden_states, _ = self.self_attn(\n            hidden_states = hidden_states,\n            attention_mask = attention_mask,\n            position_ids = position_ids,\n            past_key_values = past_key_values,\n            use_cache = use_cache,\n            cache_position = cache_position,\n            position_embeddings = position_embeddings,\n            **kwargs,\n        )\n        hidden_states = residual + hidden_states\n\n        # MLP/MoE\n        residual = hidden_states\n        hidden_states = fast_rms_layernorm_inference(\n            self.post_attention_layernorm, hidden_states\n        )\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n    else:\n        # Training path\n        residual = hidden_states\n        hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)\n        hidden_states, _ = self.self_attn(\n            hidden_states = hidden_states,\n            attention_mask = attention_mask,\n            position_ids = position_ids,\n            past_key_values = past_key_values,\n            use_cache = use_cache,\n            cache_position = cache_position,\n            position_embeddings = position_embeddings,\n            **kwargs,\n        )\n        hidden_states = residual + hidden_states\n\n        # MLP/MoE\n        residual = hidden_states\n        hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n\n    return hidden_states\n\n\ndef Glm4MoeLiteMLP_fast_forward(self, x):\n    \"\"\"\n    Optimized MLP forward using fused SwiGLU.\n    \"\"\"\n    return fast_swiglu_inference(self, x)\n\n\nclass FastGLM47Model(FastLlamaModel):\n    \"\"\"\n    Fast GLM-4.7 Flash (GLM4 MoE Lite) model with grouped GEMM optimization.\n\n    This provides 2-3x throughput improvement for MoE layers by:\n    - Replacing sequential expert loops with grouped GEMM operations\n    - Fusing permutation operations into the GEMM kernels\n    - Using optimized RMS LayerNorm and SwiGLU implementations\n    \"\"\"\n\n    @staticmethod\n    def pre_patch():\n        if not HAS_GLM4_MOE:\n            raise ImportError(\n                \"Unsloth: GLM4 MoE Lite support requires transformers >= 5.0.0. \"\n                \"Please upgrade with: pip install --upgrade transformers\"\n            )\n\n        # Patch MoE forward with grouped GEMM optimization\n        # TMA compatibility is handled by grouped_gemm/__init__.py which patches\n        # triton.language to support both old (_experimental_make_tensor_descriptor)\n        # and new (make_tensor_descriptor) API names\n        if HAS_GROUPED_GEMM:\n            Glm4MoeLiteNaiveMoe.forward = Glm4MoeLiteNaiveMoe_fast_forward\n            Glm4MoeLiteMoE.forward = Glm4MoeLiteMoE_fast_forward\n\n        # Note: We don't patch the following for GLM4 MoE because:\n        # - GLM4 uses MLA (Multi-head Latent Attention) which has different projection names\n        # - Glm4MoeLiteRotaryEmbedding doesn't have extend_rope_embedding method\n        # - The decoder layer and model forward functions assume Llama-compatible infrastructure\n\n        return\n\n    @staticmethod\n    def from_pretrained(\n        model_name = \"unsloth/GLM-4.7-Flash\",\n        max_seq_length = 4096,\n        dtype = None,\n        load_in_4bit = True,\n        token = None,\n        device_map = \"sequential\",\n        rope_scaling = None,\n        fix_tokenizer = True,\n        model_patcher = None,\n        tokenizer_name = None,\n        trust_remote_code = False,\n        **kwargs,\n    ):\n        # Pop kwargs that are used by loader but not passed to model\n        kwargs.pop(\"unsloth_force_compile\", None)\n\n        return FastLlamaModel.from_pretrained(\n            model_name = model_name,\n            max_seq_length = max_seq_length,\n            dtype = dtype,\n            load_in_4bit = load_in_4bit,\n            token = token,\n            device_map = device_map,\n            rope_scaling = rope_scaling,\n            fix_tokenizer = fix_tokenizer,\n            model_patcher = FastGLM47Model,\n            tokenizer_name = tokenizer_name,\n            trust_remote_code = trust_remote_code,\n            **kwargs,\n        )\n"
  },
  {
    "path": "unsloth/models/granite.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .llama import *\nimport os\nfrom ._utils import __version__\nfrom unsloth_zoo.utils import _get_dtype, Version\nfrom unsloth_zoo.hf_utils import dtype_from_config\nfrom ..utils.packing import get_packed_info_from_kwargs\nfrom ..utils.attention_dispatch import (\n    AttentionConfig,\n    AttentionContext,\n    run_attention,\n    select_attention_backend,\n    SDPA,\n)\nfrom .llama import (\n    LlamaRotaryEmbedding,\n    LlamaLinearScalingRotaryEmbedding,\n)\nfrom .mistral import *\nfrom bitsandbytes.nn import Linear4bit as Bnb_Linear4bit\nfrom peft.tuners.lora import Linear4bit as Peft_Linear4bit\n\ntry:\n    from transformers.models.granite.modeling_granite import (\n        GraniteAttention,\n        GraniteDecoderLayer,\n        GraniteModel,\n        GraniteForCausalLM,\n    )\nexcept:\n    transformers_version = Version(transformers_version)\n    if not transformers_version >= Version(\"4.45.0\"):\n        raise ImportError(\n            f\"Unsloth: Your transformers version of {transformers_version} does not support Granite.\\n\"\n            f\"The minimum required version is 4.45.0.\\n\"\n            f'Try `pip install --upgrade \"transformers>=4.45.0\"`\\n'\n            f\"to obtain the latest transformers build, then restart this session.\"\n        )\n\nfrom transformers.modeling_attn_mask_utils import (\n    _prepare_4d_causal_attention_mask_for_sdpa,\n)\n\n# For Pytorch 2.1.1\ntry:\n    from transformers.models.granite.modeling_granite import (\n        GraniteSdpaAttention,\n        GraniteFlashAttention2,\n    )\nexcept:\n    GraniteSdpaAttention = GraniteAttention\n    GraniteFlashAttention2 = GraniteAttention\n\n\ndef GraniteAttention_fast_forward(\n    self,\n    hidden_states: torch.Tensor,\n    causal_mask: Optional[BlockDiagonalCausalMask] = None,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_value: Optional[Tuple[torch.Tensor]] = None,\n    output_attentions: bool = False,\n    use_cache: bool = False,\n    padding_mask: Optional[torch.LongTensor] = None,\n    position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n    *args,\n    **kwargs,\n) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n    # Clear inference\n    if hasattr(self, \"paged_attention\"):\n        del self.paged_attention_K\n        del self.paged_attention_V\n        del self.paged_attention\n        del self.temp_QA\n        del self.temp_KV\n        del self.RH_Q\n        del self.attention\n\n    bsz, q_len, _ = hidden_states.size()\n\n    n_heads = self.config.num_attention_heads\n    n_groups = self.num_key_value_groups\n    n_kv_heads = self.config.num_key_value_heads\n    head_dim = self.head_dim\n    dropout_p = self.config.attention_dropout if self.training else 0\n    assert n_kv_heads * n_groups == n_heads\n\n    Q, K, V = self.apply_qkv(self, hidden_states)\n    Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)\n    K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)\n    V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)\n    seq_info = get_packed_info_from_kwargs(kwargs, Q.device)\n\n    kv_seq_len = K.shape[-2]\n    if past_key_value is not None:\n        kv_seq_len += past_key_value[0].shape[-2]\n\n    assert position_embeddings is not None\n    cos, sin = position_embeddings\n    rope_position_ids = (\n        position_ids if position_ids is not None else kwargs.get(\"position_ids\")\n    )\n    if rope_position_ids is not None:\n        # Useful for LongRoPE\n        Q, K = fast_rope_embedding(Q, K, cos, sin, rope_position_ids)\n    else:\n        Q, K = fast_rope_embedding(Q, K, cos, sin)\n\n    if past_key_value is not None:\n        K = torch.cat([past_key_value[0], K], dim = 2)\n        V = torch.cat([past_key_value[1], V], dim = 2)\n    past_key_value = (K, V) if use_cache else None\n\n    # Attention module\n    use_varlen = (\n        attention_mask is None and seq_info is not None and past_key_value is None\n    )\n\n    backend = (\n        SDPA if attention_mask is not None else select_attention_backend(use_varlen)\n    )\n\n    window = (kv_seq_len, kv_seq_len)\n    softmax_scale = getattr(self, \"scaling\", None)\n    attention_config = AttentionConfig(\n        backend = backend,\n        n_kv_heads = n_kv_heads,\n        n_groups = n_groups,\n        flash_dense_kwargs = {\n            \"causal\": True,\n            \"softmax_scale\": softmax_scale,\n            \"dropout_p\": dropout_p,\n            \"window_size\": window,\n        },\n        flash_varlen_kwargs = {\n            \"dropout_p\": 0.0,\n            \"softmax_scale\": softmax_scale,\n            \"causal\": True,\n        },\n        sdpa_kwargs = {\n            k: v\n            for k, v in {\n                \"attn_mask\": attention_mask,\n                \"scale\": softmax_scale,\n                \"dropout_p\": dropout_p,\n            }.items()\n            if v is not None\n        },\n        xformers_kwargs = {\n            \"scale\": softmax_scale,\n            \"p\": dropout_p,\n        },\n    )\n\n    context = AttentionContext(\n        bsz = bsz,\n        q_len = q_len,\n        kv_seq_len = kv_seq_len,\n        n_heads = n_heads,\n        head_dim = head_dim,\n        requires_grad = hidden_states.requires_grad,\n        seq_info = seq_info,\n        attention_mask = attention_mask,\n        causal_mask = causal_mask,\n    )\n\n    A = run_attention(config = attention_config, context = context, Q = Q, K = K, V = V)\n\n    attn_output = A.reshape(bsz, q_len, n_heads * head_dim)\n    attn_output = self.apply_o(self, attn_output)\n    attn_weights = None\n    return attn_output, attn_weights, past_key_value\n\n\ndef GraniteDecoderLayer_fast_forward(\n    self,\n    hidden_states: torch.Tensor,\n    causal_mask: Optional[BlockDiagonalCausalMask] = None,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_value: Optional[Tuple[torch.Tensor]] = None,\n    output_attentions: Optional[bool] = False,\n    use_cache: Optional[bool] = False,\n    padding_mask: Optional[torch.LongTensor] = None,\n    position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n    *args,\n    **kwargs,\n):\n    residual_multiplier = (\n        self.residual_multiplier\n        if hasattr(self, \"residual_multiplier\")\n        else self.config.residual_multiplier\n    )\n\n    if use_cache and hasattr(\n        self, \"_flag_for_generation\"\n    ):  # past_key_value is not None:\n        residual = hidden_states\n        hidden_states = fast_rms_layernorm_inference(\n            self.input_layernorm, hidden_states\n        )\n        hidden_states, self_attn_weights, present_key_value = self.self_attn(\n            hidden_states = hidden_states,\n            causal_mask = causal_mask,\n            attention_mask = attention_mask,\n            position_ids = position_ids,\n            past_key_value = past_key_value,\n            output_attentions = output_attentions,\n            use_cache = use_cache,\n            padding_mask = padding_mask,\n            position_embeddings = position_embeddings,\n            _flag_for_generation = self._flag_for_generation,\n            **kwargs,\n        )\n        hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = fast_rms_layernorm_inference(\n            self.post_attention_layernorm, hidden_states\n        )\n        hidden_states = fast_swiglu_inference(self.mlp, hidden_states)\n        hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)\n    else:\n        residual = hidden_states\n        hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)\n        hidden_states, self_attn_weights, present_key_value = self.self_attn(\n            hidden_states = hidden_states,\n            causal_mask = causal_mask,\n            attention_mask = attention_mask,\n            position_ids = position_ids,\n            past_key_value = past_key_value,\n            output_attentions = output_attentions,\n            use_cache = use_cache,\n            padding_mask = padding_mask,\n            position_embeddings = position_embeddings,\n            **kwargs,\n        )\n        hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)\n\n    outputs = (hidden_states,)\n    if output_attentions:\n        outputs += (self_attn_weights,)\n    if use_cache:\n        outputs += (present_key_value,)\n    return outputs\n\n\nfrom math import sqrt as math_sqrt\n\nKV_CACHE_INCREMENT = 256  # KV Cache update size\ntorch_nn_functional_softmax = torch.nn.functional.softmax\ntorch_matmul = torch.matmul\ntorch_tanh = torch.tanh\n\n\ndef GraniteAttention_fast_forward_inference(\n    self,\n    hidden_states: torch.Tensor,\n    past_key_value: Optional[Tuple[torch.Tensor]],\n    position_ids,\n    do_prefill = False,\n    attention_mask = None,\n    use_sliding_window = False,\n    position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n):\n    assert (\n        position_embeddings is not None\n    ), f\"Granite model requires position embeddings to be specified\"\n\n    Xn = hidden_states\n    bsz, _, hd = hidden_states.size()\n    K1, V1 = past_key_value\n    dtype = Xn.dtype\n\n    n_heads = self.config.num_attention_heads\n    n_groups = self.num_key_value_groups\n    n_kv_heads = self.config.num_key_value_heads\n    head_dim = self.head_dim\n    # assert(n_kv_heads * n_groups == n_heads)\n\n    hidden_size = self.config.hidden_size\n    attention_size = n_heads * head_dim\n    seq_len = K1.shape[-2]\n    kv_seq_len = seq_len + 1\n    device = hidden_states.device\n\n    # Prefill phase\n    # if not hasattr(self, \"paged_attention\"):\n    if do_prefill:\n        self.paged_attention = torch.empty(\n            (KV_CACHE_INCREMENT + seq_len + 1, 2, bsz, n_kv_heads, head_dim),\n            dtype = dtype,\n            device = device,\n        )\n        self.paged_attention_K = self.paged_attention[:, 0]\n        self.paged_attention_V = self.paged_attention[:, 1]\n        self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)\n        self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)\n        self.temp_QA = torch.empty(\n            (2, bsz, 1, attention_size), dtype = dtype, device = device\n        )\n        self.temp_KV = torch.empty(\n            (2, bsz, 1, n_kv_heads * head_dim), dtype = dtype, device = device\n        )\n        self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = device)\n        self.temp_O = torch.empty((bsz, 1, hidden_size), dtype = dtype, device = device)\n        self.attention = torch.empty(\n            (bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype = dtype, device = device\n        )\n\n        self.half_head_dim = head_dim // 2\n    elif kv_seq_len >= self.paged_attention.shape[0]:\n        self.paged_attention.resize_(\n            (\n                self.paged_attention.shape[0] + KV_CACHE_INCREMENT,\n                2,\n                bsz,\n                n_kv_heads,\n                head_dim,\n            )\n        )\n        self.paged_attention_K = self.paged_attention[:, 0]\n        self.paged_attention_V = self.paged_attention[:, 1]\n        self.attention.resize_(\n            (bsz, n_heads, 1, self.attention.shape[-1] + KV_CACHE_INCREMENT)\n        )\n\n    Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])\n    Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0])\n    Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1])\n    Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2)\n    Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)\n    Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)\n\n    # cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len)\n    # Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids)\n    cos, sin = position_embeddings\n    cos, sin = cos[position_ids], sin[position_ids]\n    h = self.half_head_dim\n\n    RH_Q = self.RH_Q\n    RH_Q[:, :, :, :h] = Qn[:, :, :, h:]\n    RH_Q[:, :, :, h:] = Qn[:, :, :, :h]\n    RH_Q[:, :, :, :h].neg_()\n    Qn *= cos\n    Qn.addcmul_(RH_Q, sin)\n\n    RH_K = RH_Q[\n        :, :n_kv_heads, :, :\n    ]  # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = \"cuda:0\")\n    RH_K[:, :, :, :h] = Kn[:, :, :, h:]\n    RH_K[:, :, :, h:] = Kn[:, :, :, :h]\n    RH_K[:, :, :, :h].neg_()\n    Kn *= cos\n    Kn.addcmul_(RH_K, sin)\n\n    # New KV cache\n    # Kn = torch.cat([K1, Kn], dim = 2)\n    # Vn = torch.cat([V1, Vn], dim = 2)\n    self.paged_attention_K[seq_len] = Kn.permute(2, 0, 1, 3)\n    self.paged_attention_V[seq_len] = Vn.permute(2, 0, 1, 3)\n    Kn = self.paged_attention_K[:kv_seq_len].permute(1, 2, 0, 3)\n    Vn = self.paged_attention_V[:kv_seq_len].permute(1, 2, 0, 3)\n\n    # Grouped query attention\n    _, _, cached_len, _ = Kn.shape\n    if bsz == 1 or ((not SDPA_HAS_GQA) and n_groups != 1):\n        Kn = Kn[:, :, None, :, :].expand(\n            bsz, n_kv_heads, n_groups, cached_len, head_dim\n        )\n        Vn = Vn[:, :, None, :, :].expand(\n            bsz, n_kv_heads, n_groups, cached_len, head_dim\n        )\n        Kn = Kn.reshape(bsz, n_heads, cached_len, head_dim)\n        Vn = Vn.reshape(bsz, n_heads, cached_len, head_dim)\n\n    # Attention\n    if bsz == 1:\n        Qn *= self.scaling\n        A = torch_matmul(\n            Qn, Kn.transpose(2, 3), out = self.attention[:, :, :, :cached_len]\n        )\n        A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)\n        A = torch_matmul(A, Vn, out = Qn)\n    else:\n        if (\n            attention_mask is not None\n            and attention_mask.dim() == 4\n            and attention_mask.dtype != torch.bool\n        ):\n            attention_mask = attention_mask.eq(0)\n        if SDPA_HAS_GQA:\n            A = scaled_dot_product_attention(\n                Qn,\n                Kn,\n                Vn,\n                attn_mask = attention_mask,\n                scale = self.scaling,\n                enable_gqa = True,\n            )\n        else:\n            A = scaled_dot_product_attention(\n                Qn,\n                Kn,\n                Vn,\n                attn_mask = attention_mask,\n                scale = self.scaling,\n            )\n    A = A.transpose(1, 2)\n    A = A.reshape(bsz, 1, attention_size)\n    A = fast_linear_forward(self.o_proj, A, out = self.temp_O)\n    return A, (Kn, Vn)\n\n\n# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825\n# @torch.inference_mode\ndef GraniteModel_fast_forward_inference(\n    self,\n    input_ids,\n    past_key_values,\n    position_ids,\n    attention_mask = None,\n):\n    input_ids = input_ids[:, : self.max_seq_length]\n    hidden_states = self.model.embed_tokens(input_ids)\n    hidden_states = hidden_states.to(_get_dtype(dtype_from_config(self.config)))\n    hidden_states *= self.model.embedding_multiplier\n    residual_multiplier = (\n        self.residual_multiplier\n        if hasattr(self, \"residual_multiplier\")\n        else self.config.residual_multiplier\n    )\n\n    bsz, q_len, hd = hidden_states.shape\n    seq_len = past_key_values[0][0].shape[-2]\n    if bsz != 1:\n        attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(\n            attention_mask,\n            (bsz, q_len),\n            hidden_states,\n            seq_len,\n        )\n        # Pre-convert to bool once for all layers (avoids per-layer .eq(0))\n        if attention_mask is not None and attention_mask.dtype != torch.bool:\n            attention_mask = attention_mask.eq(0)\n    else:\n        attention_mask = None\n\n    position_embeddings = self.model.rotary_emb.get_cached(\n        self.max_seq_length, hidden_states.device.index\n    )\n\n    next_decoder_cache = []\n    for idx, decoder_layer in enumerate(self.model.layers):\n        device_index = getattr(decoder_layer, \"_per_layer_device_index\", 0)\n        hidden_states, position_ids = move_to_device(\n            device_index, hidden_states, position_ids\n        )\n\n        residual = hidden_states\n        hidden_states = fast_rms_layernorm_inference(\n            decoder_layer.input_layernorm, hidden_states\n        )\n        hidden_states, present_key_value = GraniteAttention_fast_forward_inference(\n            decoder_layer.self_attn,\n            hidden_states = hidden_states,\n            past_key_value = past_key_values[idx],\n            position_ids = position_ids,\n            attention_mask = attention_mask,\n            do_prefill = not hasattr(decoder_layer.self_attn, \"paged_attention\"),\n            position_embeddings = position_embeddings,\n        )\n\n        hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)\n\n        residual = hidden_states\n        hidden_states = fast_rms_layernorm_inference(\n            decoder_layer.post_attention_layernorm, hidden_states\n        )\n        hidden_states = fast_swiglu_inference(decoder_layer.mlp, hidden_states)\n        hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)\n\n        next_decoder_cache.append(present_key_value)\n    hidden_states = fast_rms_layernorm_inference(self.model.norm, hidden_states)\n\n    return BaseModelOutputWithPast(\n        last_hidden_state = hidden_states,\n        past_key_values = next_decoder_cache,\n        hidden_states = [],\n        attentions = [],\n    )\n\n\nclass GraniteRotaryEmbedding(LlamaRotaryEmbedding):\n    def __init__(self, config):\n        super().__init__(config = config)\n\n\ndef patched_init(original_init):\n    def new_init(self, *args, **kwargs):\n        # we can use self.residual_multiplier arg in GraniteDecoderLayer_fast_forward as mentioned here\n        # https://github.com/huggingface/transformers/blob/e5fd865ebae062b7cf03a81b8c6affeb39f30bec/src/transformers/models/granite/modeling_granite.py#L243\n        # The problem is, we don't have access to either the value or config in GraniteModel_fast_forward_inference\n        # So we need a way to pass this value around. It is probably better to pass on entire config just in case we need it later\n        config = kwargs.get(\"config\", args[0] if args else None)\n        if config is not None:\n            self.config = config\n        original_init(self, *args, **kwargs)\n\n    return new_init\n\n\nclass FastGraniteModel(FastLlamaModel):\n    @staticmethod\n    def pre_patch():\n        init_name, function = patch_linear_scaling(\n            model_name = \"granite\",\n            rope_module = GraniteRotaryEmbedding,\n            scaled_rope_module = LlamaLinearScalingRotaryEmbedding,\n            attention_module = GraniteAttention,\n        )\n        if init_name is not None:\n            exec(function, globals())\n            GraniteAttention.__init__ = eval(init_name)\n        GraniteAttention.forward = GraniteAttention_fast_forward\n        GraniteSdpaAttention.forward = GraniteAttention_fast_forward\n        GraniteFlashAttention2.forward = GraniteAttention_fast_forward\n        GraniteDecoderLayer.forward = GraniteDecoderLayer_fast_forward\n        GraniteModel.forward = LlamaModel_fast_forward\n        GraniteForCausalLM.forward = CausalLM_fast_forward(\n            GraniteModel_fast_forward_inference\n        )\n        GraniteForCausalLM.__init__ = patched_init(GraniteForCausalLM.__init__)\n        PeftModelForCausalLM.forward = PeftModel_fast_forward\n        fix_prepare_inputs_for_generation(GraniteForCausalLM)\n\n        import transformers.models.granite.modeling_granite\n\n        transformers.models.granite.modeling_granite.GraniteRotaryEmbedding = (\n            GraniteRotaryEmbedding\n        )\n\n        return\n\n    @staticmethod\n    def post_patch(model, tokenizer, correct_dtype = None):\n        # Torch.compile fails on embedding matrix??\n        # Workaround randomnly fixes it for torch versions < 2.2\n        model.model.embed_tokens = torch.nn.Embedding.from_pretrained(\n            model.model.embed_tokens.weight\n        )\n        model.config.update({\"unsloth_version\": __version__})\n\n        # We also do this for the lm_head\n        lm_head = torch.nn.Linear(1, 1, bias = None)\n        del lm_head.weight\n        lm_head.weight = model.lm_head.weight\n        lm_head.in_features = lm_head.weight.shape[1]\n        lm_head.out_features = lm_head.weight.shape[0]\n        model.lm_head = lm_head\n\n        # Granite has tied weights! This means lm_head == embed_tokens\n        if (\n            model.model.embed_tokens.weight.data_ptr()\n            != model.lm_head.weight.data_ptr()\n        ):\n            lm_head = torch.nn.Linear(1, 1, bias = None)\n            del lm_head.weight\n            lm_head.weight = model.model.embed_tokens.weight\n            lm_head.in_features = lm_head.weight.shape[1]\n            lm_head.out_features = lm_head.weight.shape[0]\n            model.lm_head = lm_head\n\n        # Also patch all dtypes - BnB seems to not allocate the correct type?\n        # BnB default dtype seems to be float16!\n        correct_dtype = lm_head.weight.dtype\n\n        for name, module in model.named_modules():\n            if isinstance(module, (Bnb_Linear4bit, Peft_Linear4bit)):\n                weight = module.weight\n                quant_state = weight.quant_state\n\n                if type(quant_state) is list:\n                    # BnB seems to have float16 as default!\n                    module.weight.quant_state[2] = (\n                        correct_dtype  # Cast to correct dtype\n                    )\n                else:\n                    # https://github.com/TimDettmers/bitsandbytes/pull/763/files\n                    quant_state.dtype = correct_dtype\n            # Downcast RoPE embedding to correct data type\n            if name.endswith(\"rotary_emb\") or hasattr(module, \"cos_cached\"):\n                if hasattr(module, \"cos_cached\") and (\n                    module.cos_cached.dtype != correct_dtype\n                ):\n                    module.cos_cached = module.cos_cached.to(correct_dtype)\n                    module.sin_cached = module.sin_cached.to(correct_dtype)\n\n                elif hasattr(module, \"short_cos_cached\") and (\n                    module.short_cos_cached.dtype != correct_dtype\n                ):\n                    module.short_cos_cached = module.short_cos_cached.to(correct_dtype)\n                    module.short_sin_cached = module.short_sin_cached.to(correct_dtype)\n\n        # Clear deleted GPU items\n        import gc\n\n        for _ in range(3):\n            gc.collect()\n            torch.cuda.empty_cache()\n        return model, tokenizer\n"
  },
  {
    "path": "unsloth/models/llama.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nimport gc\nimport math\nimport functools\nfrom typing import Optional, Tuple, List, Union\n\nfrom ._utils import *\nfrom ._utils import apply_unsloth_gradient_checkpointing\nfrom ._utils import __version__, importlib_version\nfrom ._utils import move_to_device\nfrom ._utils import (\n    _get_inference_mode_context_manager,\n    _prepare_model_for_qat,\n    is_bfloat16_supported,\n    get_quant_type,\n)\nfrom .loader_utils import _get_fp8_mode_and_check_settings\nfrom ..utils.packing import (\n    get_packed_info_from_kwargs,\n    mask_packed_sequence_boundaries,\n)\nfrom ..utils.attention_dispatch import (\n    AttentionConfig,\n    AttentionContext,\n    run_attention,\n    SDPA,\n    select_attention_backend,\n)\nfrom torch.nn.functional import scaled_dot_product_attention\nfrom transformers import __version__ as transformers_version\nfrom unsloth_zoo.utils import Version, _get_dtype\nfrom unsloth_zoo.hf_utils import (\n    dtype_from_config,\n    add_dtype_kwargs,\n    fix_lora_auto_mapping,\n)\nfrom unsloth_zoo.peft_utils import SKIP_QUANTIZATION_MODULES\nfrom ..device_type import (\n    is_hip,\n    get_device_type,\n    DEVICE_TYPE,\n    DEVICE_TYPE_TORCH,\n    DEVICE_COUNT,\n    ALLOW_PREQUANTIZED_MODELS,\n)\n\ntransformers_version = Version(transformers_version)\n# Transformers moved rotary embeddings out of all attention layers\nIS_ATTENTION_REFACTOR = transformers_version > Version(\"4.47.1\")\ntry:\n    from transformers.modeling_layers import GradientCheckpointingLayer\nexcept:\n    GradientCheckpointingLayer = type(None)\n\nfrom transformers.models.llama.modeling_llama import (\n    logger,\n    BaseModelOutputWithPast,\n    CausalLMOutputWithPast,\n)\nfrom transformers.modeling_attn_mask_utils import (\n    _prepare_4d_causal_attention_mask_for_sdpa,\n)\nfrom ..kernels import *\nfrom ..tokenizer_utils import *\nfrom .vision import FastBaseModel\n\n# Final patching code\nfrom transformers.models.llama.modeling_llama import (\n    LlamaAttention,\n    LlamaDecoderLayer,\n    LlamaModel,\n    LlamaForCausalLM,\n)\n\n# For Pytorch 2.1.1\ntry:\n    from transformers.models.llama.modeling_llama import (\n        LlamaSdpaAttention,\n        LlamaFlashAttention2,\n    )\nexcept:\n    LlamaSdpaAttention = LlamaAttention\n    LlamaFlashAttention2 = LlamaAttention\n\nfrom transformers import (\n    AutoTokenizer,\n    AutoModelForCausalLM,\n    AutoModelForSequenceClassification,\n    BitsAndBytesConfig,\n    AutoConfig,\n)\nfrom transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING\nfrom transformers import set_seed as transformers_set_seed\nfrom peft import LoraConfig, TaskType, get_peft_model as _get_peft_model\nfrom peft import PeftModelForCausalLM, PeftModelForSequenceClassification\nfrom ..save import patch_saving_functions\nimport re, os, inspect, math, sys\nimport types\n\ntry:\n    from huggingface_hub.utils import get_token\nexcept:\n    # Old HF Hub versions <= 0.0.25\n    from huggingface_hub.utils._token import get_token\nfrom triton import __version__ as triton_version\n\nHAS_XFORMERS = xformers is not None\nBlockDiagonalCausalMask = (\n    xformers.attn_bias.BlockDiagonalCausalMask if HAS_XFORMERS else None\n)\n\nif DEVICE_TYPE == \"xpu\":\n    clean_gpu_cache = torch.xpu.empty_cache\n    get_current_device = torch.xpu.current_device\nelse:\n    clean_gpu_cache = torch.cuda.empty_cache\n    get_current_device = torch.cuda.current_device\n\n\ndef original_apply_qkv(self, X):\n    Q = self.q_proj(X)\n    K = self.k_proj(X)\n    V = self.v_proj(X)\n    return Q, K, V\n\n\ndef original_apply_o(self, X):\n    O = self.o_proj(X)\n    return O\n\n\nfrom math import sqrt as math_sqrt\n\nKV_CACHE_INCREMENT = 512  # KV Cache update size\ntorch_nn_functional_softmax = torch.nn.functional.softmax\n# SDPA has GQA internally\nSDPA_HAS_GQA = \"enable_gqa\" in scaled_dot_product_attention.__doc__\n\nfrom peft.utils.other import ModulesToSaveWrapper\n\n\ndef _offload_frozen_module_for_training(\n    module: ModulesToSaveWrapper,\n    device_type: str,\n    offload_device: Optional[str] = \"cpu\",\n) -> None:\n    \"\"\"\n    Offload frozen module to CPU and configure trainable copy for mixed precision training.\n\n    This function optimizes memory usage by:\n    1. Moving the trainable copy to the target device with appropriate precision\n    2. Optionally offloading the original frozen module to CPU/disk to free VRAM\n    3. Converting float16 to float32 for compatibility with certain GPUs (e.g., Tesla T4)\n\n    Args:\n        module: The module to configure. Must be a ModulesToSaveWrapper with a\n            `modules_to_save` attribute containing trainable and original modules.\n        device_type: Target device string for training (e.g., \"cuda:0\", \"xpu:0\")\n        offload_device: Device to offload frozen parameters (default: \"cpu\").\n            If None, the original frozen module remains on its current device.\n            Note: Currently only \"cpu\" is supported; disk offloading is planned.\n\n    Returns:\n        None (modifies module in-place)\n\n    Note:\n        - Float16 weights are automatically promoted to float32 for GPU compatibility\n        - When offload_device is specified, frozen parameters are moved to free VRAM\n        - Future versions will support disk-based offloading for even larger models\n\n    See Also:\n        - https://github.com/unslothai/unsloth/pull/1200 (Tesla T4 float32 requirement)\n    \"\"\"\n    # Early return with explicit None if module doesn't support mixed precision training\n    if not hasattr(module, \"modules_to_save\"):\n        return None\n\n    new_dtype = module.modules_to_save.default.weight.dtype\n    if new_dtype == torch.float16:\n        # See https://github.com/unslothai/unsloth/pull/1200\n        # Tesla T4 must use float32 and not float16\n        new_dtype = torch.float32\n\n    module.modules_to_save.default.to(\n        device = device_type, dtype = new_dtype, non_blocking = True\n    )\n    module.modules_to_save.default.requires_grad_(True)\n\n    # [TODO] Move old module to CPU - should be disk!\n    if offload_device is not None:\n        module.original_module.to(device = offload_device, non_blocking = True)\n    module.original_module.requires_grad_(False)\n\n\n# Fix new HF's inference code\ndef _fast_prepare_inputs_for_generation(\n    self,\n    input_ids,\n    attention_mask = None,\n    inputs_embeds = None,\n    **kwargs,\n):\n    past_key_values = kwargs.get(\"past_key_values\", None)\n    original_attention_mask = attention_mask\n\n    # Handle inputs_embeds - only use on FIRST generation step (no cache)\n    # This fixes GitHub issue #3798: inputs_embeds was ignored\n    use_inputs_embeds = inputs_embeds is not None and past_key_values is None\n\n    if input_ids is not None and input_ids.numel() > 0:\n        bs, seq_length = input_ids.shape\n        device = input_ids.device\n    elif inputs_embeds is not None:\n        bs, seq_length, _ = inputs_embeds.shape\n        device = inputs_embeds.device\n    else:\n        bs, seq_length = 1, 0\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n    if past_key_values is not None:\n        # Check for uninitialized DynamicCache\n        if len(past_key_values) == 0:\n            past_key_values = None\n            kwargs[\"past_key_values\"] = None\n            use_inputs_embeds = inputs_embeds is not None\n        # New since 4.56\n        elif (\n            hasattr(past_key_values, \"get_seq_length\")\n            and past_key_values.get_seq_length() == 0\n        ):\n            past_key_values = None\n            kwargs[\"past_key_values\"] = None\n            use_inputs_embeds = inputs_embeds is not None\n        else:\n            if input_ids is not None and input_ids.numel() > 0:\n                bs = input_ids.shape[0]\n                input_ids = input_ids[:, [-1]]\n                device = input_ids.device\n                seq_length = 1\n            elif inputs_embeds is not None:\n                bs, seq_length, _ = inputs_embeds.shape\n                device = inputs_embeds.device\n            else:\n                bs, seq_length = 1, 0\n                device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n            if hasattr(past_key_values, \"get_seq_length\"):\n                past_len = int(past_key_values.get_seq_length())\n            else:\n                # legacy tuple cache: (layer, (K,V))\n                past_len = int(past_key_values[0][0].shape[-2])\n\n            max_cache_len = None\n            if hasattr(past_key_values, \"get_max_cache_shape\"):\n                m = past_key_values.get_max_cache_shape()\n                max_cache_len = int(m) if m is not None and m > 0 else None\n            elif hasattr(past_key_values, \"get_max_length\"):\n                m = past_key_values.get_max_length()\n                max_cache_len = int(m) if m is not None else None\n\n            # ensure cache_position\n            cache_position = kwargs.get(\"cache_position\", None)\n            if cache_position is None:\n                kwargs[\"cache_position\"] = torch.arange(\n                    past_len,\n                    past_len + seq_length,\n                    device = device,\n                    dtype = torch.long,\n                )\n            else:\n                if (\n                    hasattr(cache_position, \"device\")\n                    and cache_position.device != device\n                ):\n                    kwargs[\"cache_position\"] = cache_position.to(device)\n\n            # Get to the base model\n            base_model = self\n            if hasattr(base_model, \"base_model_prefix\"):\n                base_model = getattr(base_model, base_model.base_model_prefix)\n\n            if hasattr(\n                base_model, \"_prepare_4d_causal_attention_mask_with_cache_position\"\n            ):\n                if not hasattr(base_model, \"_unsloth_mask_needs_device\"):\n\n                    def _check_needs_device(fn) -> bool:\n                        try:\n                            sig = inspect.signature(inspect.unwrap(fn))\n                            return \"device\" in sig.parameters\n                        except:\n                            # transformers <= 4.51.3 includes device arg but > 4.51.3 does not\n                            return transformers_version < Version(\"4.52.0\")\n\n                    base_model._unsloth_mask_needs_device = _check_needs_device(\n                        base_model._prepare_4d_causal_attention_mask_with_cache_position\n                    )\n\n                if max_cache_len is not None:\n                    target_length = max_cache_len\n                elif (\n                    original_attention_mask is not None\n                    and original_attention_mask.dim() == 2\n                ):\n                    target_length = original_attention_mask.shape[-1]\n                else:\n                    target_length = past_len + seq_length\n\n                mask_kwargs = {\n                    \"sequence_length\": seq_length,\n                    \"target_length\": target_length,\n                    \"dtype\": self.dtype,\n                    \"cache_position\": kwargs[\"cache_position\"],\n                    \"batch_size\": bs,\n                    \"config\": self.config,\n                    \"past_key_values\": past_key_values,\n                }\n                if base_model._unsloth_mask_needs_device:\n                    mask_kwargs[\"device\"] = device\n\n                attention_mask = (\n                    base_model._prepare_4d_causal_attention_mask_with_cache_position(\n                        attention_mask,\n                        **mask_kwargs,\n                    )\n                )\n            else:\n                if transformers_version <= Version(\"4.52.4\"):\n                    logger.warning_once(\n                        f\"{self.__class__.__name__} has no `_prepare_4d_causal_attention_mask_with_cache_position` method \"\n                        \"defined in its base modeling class. Compiled forward passes will be sub-optimal. If you're \"\n                        \"writing code, see Llama for an example implementation. If you're a user, please report this \"\n                        \"issue on GitHub.\"\n                    )\n\n    if kwargs.get(\"position_ids\", None) is None:\n        if original_attention_mask is not None and original_attention_mask.dim() == 2:\n            position_ids = original_attention_mask.long().cumsum(-1) - 1\n            position_ids.masked_fill_(original_attention_mask == 0, 1)\n            position_ids = position_ids[:, -seq_length:]\n            kwargs[\"position_ids\"] = position_ids\n        elif kwargs.get(\"cache_position\", None) is not None:\n            cp = kwargs[\"cache_position\"]\n            if cp.dim() == 1:\n                cp = cp.unsqueeze(0).expand(bs, -1)\n            kwargs[\"position_ids\"] = cp\n\n    result = {\n        \"attention_mask\": attention_mask,\n        **kwargs,\n    }\n    if use_inputs_embeds:\n        result[\"inputs_embeds\"] = inputs_embeds\n        result[\"input_ids\"] = None\n    else:\n        result[\"input_ids\"] = input_ids\n    return result\n\n\ndef fix_prepare_inputs_for_generation(module):\n    # Fix prepare_inputs_for_generation\n    if hasattr(module, \"prepare_inputs_for_generation\"):\n        module.prepare_inputs_for_generation = _fast_prepare_inputs_for_generation\n\n\ntorch_matmul = torch.matmul\n\n\ndef LlamaAttention_fast_forward_inference(\n    self,\n    hidden_states: torch.Tensor,\n    past_key_value: Optional[Tuple[torch.Tensor]],\n    position_ids,\n    do_prefill = False,\n    attention_mask = None,\n    rotary_seq_len = None,\n):\n    \"\"\"\n    https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L406\n    Fast inference using KV cache.\n    QK^T can be computed in 4 chunks\n\n    [Q, q] @ [K, k].T where q, k are the new tokens.\n    [QK^T, Qk^T]\n    [qK^T, qk^T]\n\n    Since the attention mask wipes Qk^T, we just get\n    [QK^T,    0]\n    [qK^T, qk^T]\n\n    Since softmax is row-wise, we get\n    softmax([QK^T,    0])\n    softmax([qK^T, qk^T])\n\n    We then multiply by   [V]\n                          [v]\n    softmax([QK^T,    0]) [softmax(QK^T)V] *\n    softmax([qK^T, qk^T]) [softmax([qK^T, qk^T]) @ [V, v]]\n\n    But notice * [softmax(QK^T)V] is just the last attention.\n    We just need to compute the last final row.\n\n    This means we can pass in a row of Q, but we need to\n    remember K and V, which are called the KV cache.\n    \"\"\"\n    Xn = hidden_states\n    bsz, _, hd = hidden_states.size()\n    K1, V1 = past_key_value\n    dtype = Xn.dtype\n\n    n_heads = self.config.num_attention_heads\n    n_groups = self.num_key_value_groups\n    n_kv_heads = self.config.num_key_value_heads\n    head_dim = self.head_dim\n    # assert(n_kv_heads * n_groups == n_heads)\n\n    hidden_size = self.config.hidden_size\n    attention_size = n_heads * head_dim\n    seq_len = K1.shape[-2]\n    kv_seq_len = seq_len + 1\n\n    # Prefill phase\n    # if not hasattr(self, \"paged_attention\"):\n    device = hidden_states.device\n    if do_prefill:\n        self.paged_attention = torch.empty(\n            (KV_CACHE_INCREMENT + seq_len + 1, 2, bsz, n_kv_heads, head_dim),\n            dtype = dtype,\n            device = device,\n        )\n        self.paged_attention_K = self.paged_attention[:, 0]\n        self.paged_attention_V = self.paged_attention[:, 1]\n        self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)\n        self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)\n        self.temp_QA = torch.empty(\n            (2, bsz, 1, attention_size), dtype = dtype, device = device\n        )\n        self.temp_KV = torch.empty(\n            (2, bsz, 1, n_kv_heads * head_dim), dtype = dtype, device = device\n        )\n        self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = device)\n\n        # Mistral Nemo 12b has weird dimensions\n        if attention_size != hidden_size:\n            self.temp_O = torch.empty((bsz, 1, hidden_size), dtype = dtype, device = device)\n        else:\n            self.temp_O = self.temp_QA[1][:, :, :hidden_size]\n\n        self.attention = torch.empty(\n            (bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype = dtype, device = device\n        )\n        self.scalar = 1.0 / math_sqrt(self.head_dim)\n        self.half_head_dim = head_dim // 2\n    elif kv_seq_len >= self.paged_attention.shape[0]:\n        self.paged_attention.resize_(\n            (\n                self.paged_attention.shape[0] + KV_CACHE_INCREMENT,\n                2,\n                bsz,\n                n_kv_heads,\n                head_dim,\n            )\n        )\n        self.paged_attention_K = self.paged_attention[:, 0]\n        self.paged_attention_V = self.paged_attention[:, 1]\n        self.attention.resize_(\n            (bsz, n_heads, 1, self.attention.shape[-1] + KV_CACHE_INCREMENT)\n        )\n\n    Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])\n    Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0])\n    Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1])\n    Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2)\n    Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)\n    Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)\n\n    # cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len)\n    # Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids)\n\n    # Need to do it prior 2 steps before hitting full on short KV cache\n    # or else error\n    # ensure correct shape\n    if position_ids.dim() == 1:\n        position_ids = position_ids[:, None]\n    position_ids = position_ids.to(Qn.device)\n\n    if rotary_seq_len is None:\n        rotary_seq_len = max(kv_seq_len, int(position_ids.max().item()) + 1)\n    self.rotary_emb.extend_rope_embedding(Vn, rotary_seq_len + 1)  # +1 slack\n    cos, sin = self.rotary_emb.get_cached(rotary_seq_len, Qn.device.index or 0)\n\n    cos = cos[position_ids].unsqueeze(1).to(device = Qn.device, dtype = Qn.dtype)\n    sin = sin[position_ids].unsqueeze(1).to(device = Qn.device, dtype = Qn.dtype)\n\n    h = self.half_head_dim\n\n    RH_Q = self.RH_Q\n    RH_Q[:, :, :, :h] = Qn[:, :, :, h:]\n    RH_Q[:, :, :, h:] = Qn[:, :, :, :h]\n    RH_Q[:, :, :, :h].neg_()  # torch.neg(RH_Q[:,:,:,:h], out = RH_Q[:,:,:,:h])\n    Qn *= cos\n    Qn.addcmul_(RH_Q, sin)\n\n    RH_K = RH_Q[\n        :, :n_kv_heads, :, :\n    ]  # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = \"cuda:0\")\n    RH_K[:, :, :, :h] = Kn[:, :, :, h:]\n    RH_K[:, :, :, h:] = Kn[:, :, :, :h]\n    RH_K[:, :, :, :h].neg_()  # torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h])\n    Kn *= cos\n    Kn.addcmul_(RH_K, sin)\n\n    # New KV cache\n    # Kn = torch.cat([K1, Kn], dim = 2)\n    # Vn = torch.cat([V1, Vn], dim = 2)\n    self.paged_attention_K[seq_len] = Kn.permute(2, 0, 1, 3)\n    self.paged_attention_V[seq_len] = Vn.permute(2, 0, 1, 3)\n    Kn = self.paged_attention_K[:kv_seq_len].permute(1, 2, 0, 3)\n    Vn = self.paged_attention_V[:kv_seq_len].permute(1, 2, 0, 3)\n\n    # Handle sliding windows\n    sliding_window = getattr(self.config, \"sliding_window\", None)\n    if sliding_window is not None and kv_seq_len > sliding_window:\n        # From https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L193\n        start = kv_seq_len - sliding_window\n        Knn = Kn[:, :, start:, :]  # .contiguous()\n        Vnn = Vn[:, :, start:, :]  # .contiguous()\n        if attention_mask is not None:\n            attention_mask = attention_mask[..., start:]\n    else:\n        Knn, Vnn = Kn, Vn\n\n    # Grouped query attention\n    _, _, cached_len, _ = Knn.shape\n    if bsz == 1 or ((not SDPA_HAS_GQA) and n_groups != 1):\n        Knn = Knn[:, :, None, :, :].expand(\n            bsz, n_kv_heads, n_groups, cached_len, head_dim\n        )\n        Vnn = Vnn[:, :, None, :, :].expand(\n            bsz, n_kv_heads, n_groups, cached_len, head_dim\n        )\n        Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim)\n        Vnn = Vnn.reshape(bsz, n_heads, cached_len, head_dim)\n\n    # when qlen==vlen and attn_mask is None, we should use causal attention\n    Q_len = Qn.shape[-2]\n    K_len = Knn.shape[-2]\n    if attention_mask is None and Q_len == K_len:\n        is_causal = True\n    else:\n        is_causal = False\n    # Attention\n    if bsz == 1:\n        Qn *= self.scalar  # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963\n        # It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows\n        A = torch_matmul(\n            Qn, Knn.transpose(2, 3), out = self.attention[:, :, :, :cached_len]\n        )\n        A[:] = torch_nn_functional_softmax(\n            A, dim = -1, dtype = torch.float32\n        )  # .to(A.dtype)\n        A = torch_matmul(A, Vnn, out = Qn)\n    # --- attention_mask fixup for SDPA if user passes 2D padding mask\n    else:\n        if attention_mask is not None and attention_mask.dim() == 2:\n            attention_mask = attention_mask[:, None, None, :].to(torch.bool)\n            # is it more appropriate to use _prepare_4d_causal_attention_mask_for_sdpa?\n        elif (\n            attention_mask is not None\n            and attention_mask.dim() == 4\n            and attention_mask.dtype != torch.bool\n        ):\n            # Decode is more stable with boolean keep masks than additive bf16 masks.\n            attention_mask = attention_mask.eq(0)\n\n        if SDPA_HAS_GQA:\n            A = scaled_dot_product_attention(\n                Qn,\n                Knn,\n                Vnn,\n                attn_mask = attention_mask,\n                is_causal = is_causal,\n                enable_gqa = True,\n            )\n        else:\n            A = scaled_dot_product_attention(\n                Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = is_causal\n            )\n    A = A.transpose(1, 2)\n    A = A.reshape(bsz, 1, attention_size)\n    A = fast_linear_forward(self.o_proj, A, out = self.temp_O)\n    return A, (Kn, Vn)\n\n\ntorch_nn_functional_silu = torch.nn.functional.silu\n\n\ndef fast_swiglu_inference(\n    self, X, temp_gate = None, temp_up = None, gate_multiplier = None, down_multiplier = None\n):\n    # gate = self.gate_proj(X)\n    # up   = self.up_proj(X)\n    bsz, _, hd = X.shape\n    # mlp_size = self.config.intermediate_size\n    # temp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = \"cuda:0\")\n\n    gate = fast_linear_forward(self.gate_proj, X, out = temp_gate)\n\n    if gate_multiplier is not None:\n        gate *= gate_multiplier\n\n    up = fast_linear_forward(self.up_proj, X, out = temp_up)\n\n    gate = torch_nn_functional_silu(gate, inplace = True)\n    gate *= up\n\n    # X = self.down_proj(gate)\n    down = fast_linear_forward(self.down_proj, gate, out = up[:, :, :hd])\n\n    if down_multiplier is not None:\n        down *= down_multiplier\n\n    return down\n\n\ntorch_square = torch.square\ntorch_mean = torch.mean\n\n\ndef fast_rms_layernorm_inference(self, X, XX = None, XX2 = None, variance = None):\n    old_dtype = X.dtype\n    if XX is None:\n        XX = X.to(torch.float32)\n        variance = XX.square().mean(-1, keepdim = True)\n    else:\n        XX.copy_(X)\n        torch_mean(torch_square(XX, out = XX2), -1, keepdim = True, out = variance)\n    variance += self.variance_epsilon\n    XX *= variance.rsqrt_()\n\n    if XX is None:\n        X = XX.to(old_dtype)\n    else:\n        X.copy_(XX)\n\n    X *= self.weight\n    return X\n\n\ndef fast_rms_layernorm_inference_gemma(self, X, out_weight = None):\n    XX = X.to(torch.float32)\n    variance = XX.square().mean(-1, keepdim = True)\n    variance += self.variance_epsilon\n    XX *= variance.rsqrt_()\n\n    if out_weight is None:\n        out_weight = self.weight + 1.0\n    else:\n        out_weight[:] = self.weight\n        out_weight += 1.0\n\n    XX *= out_weight\n    return XX.to(X.dtype)\n\n\n# Normal layernorm with mean removal\n@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options)\ndef fast_layernorm_compiled(layernorm, X):\n    old_dtype = X.dtype\n    X = X.float()\n    mean = X.mean(-1, keepdim = True)\n    Xbar = X - mean\n    X = (\n        Xbar\n        * torch.rsqrt(Xbar.square().mean(-1, keepdim = True) + layernorm.variance_epsilon)\n        * layernorm.weight.float()\n    )\n    return X.to(old_dtype)\n\n\n# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L320\ndef LlamaAttention_fast_forward(\n    self,\n    hidden_states: torch.Tensor,\n    causal_mask: Optional[BlockDiagonalCausalMask] = None,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_value: Optional[Tuple[torch.Tensor]] = None,\n    output_attentions: bool = False,\n    use_cache: bool = False,\n    padding_mask: Optional[torch.LongTensor] = None,\n    position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n    *args,\n    **kwargs,\n) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n    # Clear inference\n    if hasattr(self, \"paged_attention\"):\n        del self.paged_attention_K\n        del self.paged_attention_V\n        del self.paged_attention\n        del self.temp_QA\n        del self.temp_KV\n        del self.RH_Q\n        del self.attention\n    bsz, q_len, _ = hidden_states.size()\n\n    n_heads = self.config.num_attention_heads\n    n_groups = self.num_key_value_groups\n    n_kv_heads = self.config.num_key_value_heads\n    head_dim = self.head_dim\n    assert n_kv_heads * n_groups == n_heads\n\n    Q, K, V = self.apply_qkv(self, hidden_states)\n    Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)\n    K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)\n    V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)\n    seq_info = get_packed_info_from_kwargs(kwargs, Q.device)\n\n    kv_seq_len = K.shape[-2]\n    if past_key_value is not None:\n        kv_seq_len += past_key_value[0].shape[-2]\n\n    if position_embeddings and kv_seq_len <= position_embeddings[0].shape[0]:\n        cos, sin = position_embeddings\n    else:\n        rotary_emb = self.rotary_emb\n        rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len)\n        cos, sin = rotary_emb.get_cached(kv_seq_len, Q.device.index)\n        cos = cos.to(device = Q.device, dtype = Q.dtype)\n        sin = sin.to(device = Q.device, dtype = Q.dtype)\n\n    rope_position_ids = position_ids\n    if rope_position_ids is None and seq_info is not None:\n        rope_position_ids = kwargs.get(\"position_ids\")\n\n    # Q, K = (\n    #     fast_rope_embedding(Q, K, cos, sin)\n    #     if rope_position_ids is None\n    #     else inplace_rope_embedding(Q, K, cos, sin, rope_position_ids)\n    # )\n    Q, K = fast_rope_embedding(Q, K, cos, sin, rope_position_ids)\n\n    if past_key_value is not None:\n        K = torch.cat([past_key_value[0], K], dim = 2)\n        V = torch.cat([past_key_value[1], V], dim = 2)\n    past_key_value = (K, V) if use_cache else None\n\n    # Attention module\n    use_varlen = seq_info is not None and past_key_value is None\n    backend = (\n        SDPA if attention_mask is not None else select_attention_backend(use_varlen)\n    )\n\n    # should dropout be hardcoded to 0.0?\n    config = AttentionConfig(\n        backend = backend,\n        n_kv_heads = n_kv_heads,\n        n_groups = n_groups,\n        flash_dense_kwargs = {\"causal\": True},\n        flash_varlen_kwargs = {\"dropout_p\": 0.0, \"causal\": True},\n    )\n    context = AttentionContext(\n        bsz = bsz,\n        q_len = q_len,\n        kv_seq_len = kv_seq_len,\n        n_heads = n_heads,\n        head_dim = head_dim,\n        requires_grad = hidden_states.requires_grad,\n        seq_info = seq_info,\n        attention_mask = attention_mask,\n        causal_mask = causal_mask,\n    )\n\n    A = run_attention(config = config, context = context, Q = Q, K = K, V = V)\n    attn_output = A.reshape(bsz, q_len, n_heads * head_dim)\n    attn_output = self.apply_o(self, attn_output)\n    attn_weights = None\n    return attn_output, attn_weights, past_key_value\n\n\n# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L590\ndef LlamaDecoderLayer_fast_forward(\n    self,\n    hidden_states: torch.Tensor,\n    causal_mask = None,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_value: Optional[Tuple[torch.Tensor]] = None,\n    output_attentions: Optional[bool] = False,\n    use_cache: Optional[bool] = False,\n    padding_mask: Optional[torch.LongTensor] = None,\n    position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n    *args,\n    **kwargs,\n) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:\n    \"\"\"\n    Args:\n        hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n        attention_mask (`torch.FloatTensor`, *optional*): attention mask of size\n            `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.\n        output_attentions (`bool`, *optional*):\n            Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n            returned tensors for more detail.\n        use_cache (`bool`, *optional*):\n            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n            (see `past_key_values`).\n        past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states\n    \"\"\"\n    if use_cache and hasattr(self, \"_flag_for_generation\"):\n        residual = hidden_states\n        hidden_states = fast_rms_layernorm_inference(\n            self.input_layernorm, hidden_states\n        )\n        hidden_states, self_attn_weights, present_key_value = self.self_attn(\n            hidden_states = hidden_states,\n            causal_mask = causal_mask,\n            attention_mask = attention_mask,\n            position_ids = position_ids,\n            past_key_value = past_key_value,\n            output_attentions = output_attentions,\n            use_cache = use_cache,\n            padding_mask = padding_mask,\n            position_embeddings = position_embeddings,\n            **kwargs,\n        )\n        hidden_states += residual\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = fast_rms_layernorm_inference(\n            self.post_attention_layernorm, hidden_states\n        )\n        hidden_states = fast_swiglu_inference(self.mlp, hidden_states)\n        hidden_states += residual\n    else:\n        residual = hidden_states\n        hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)\n        hidden_states, self_attn_weights, present_key_value = self.self_attn(\n            hidden_states = hidden_states,\n            causal_mask = causal_mask,\n            attention_mask = attention_mask,\n            position_ids = position_ids,\n            past_key_value = past_key_value,\n            output_attentions = output_attentions,\n            use_cache = use_cache,\n            padding_mask = padding_mask,\n            position_embeddings = position_embeddings,\n            **kwargs,\n        )\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n\n    outputs = (hidden_states,)\n    if output_attentions:\n        outputs += (self_attn_weights,)\n    if use_cache:\n        outputs += (present_key_value,)\n    return outputs\n\n\n# https://github.com/unslothai/unsloth/issues/404#issuecomment-2323473452\n__DTYPE_MAP = {\n    \"float32\": torch.float32,\n    torch.float32: torch.float32,\n    \"float16\": torch.float16,\n    torch.float16: torch.float16,\n    \"bfloat16\": torch.bfloat16,\n    torch.bfloat16: torch.bfloat16,\n}\n\n\n# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825\ndef LlamaModel_fast_forward(\n    self,\n    input_ids: Optional[torch.LongTensor] = None,\n    causal_mask: Optional[BlockDiagonalCausalMask] = None,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_values: Optional[List[torch.FloatTensor]] = None,\n    inputs_embeds: Optional[torch.FloatTensor] = None,\n    use_cache: Optional[bool] = None,\n    output_attentions: Optional[bool] = None,\n    output_hidden_states: Optional[bool] = None,\n    return_dict: Optional[bool] = None,\n    *args,\n    **kwargs,\n) -> Union[Tuple, BaseModelOutputWithPast]:\n    output_attentions = (\n        output_attentions\n        if output_attentions is not None\n        else self.config.output_attentions\n    )\n    assert output_attentions is False\n    output_hidden_states = (\n        output_hidden_states\n        if output_hidden_states is not None\n        else self.config.output_hidden_states\n    )\n    use_cache = use_cache if use_cache is not None else self.config.use_cache\n\n    return_dict = (\n        return_dict if return_dict is not None else self.config.use_return_dict\n    )\n\n    # retrieve input_ids and inputs_embeds\n    if input_ids is not None and inputs_embeds is not None:\n        raise ValueError(\n            \"Unsloth: You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time\"\n        )\n    elif input_ids is not None:\n        batch_size, seq_length = input_ids.shape\n    elif inputs_embeds is not None:\n        batch_size, seq_length, _ = inputs_embeds.shape\n    else:\n        raise ValueError(\n            \"Unsloth: You have to specify either decoder_input_ids or decoder_inputs_embeds\"\n        )\n\n    seq_length_with_past = seq_length\n\n    # Fix out of bounds tokenization unless we were given packed metadata\n    allow_overlength = getattr(self, \"_unsloth_allow_packed_overlength\", False) or (\n        \"packed_seq_lengths\" in kwargs\n    )\n    if hasattr(self, \"max_seq_length\") and not allow_overlength:\n        if seq_length > self.max_seq_length:\n            shape = input_ids.shape if input_ids is not None else inputs_embeds.shape\n            logger.warning_once(\n                f\"Unsloth: Input IDs of shape {shape} with length {seq_length} > the model's max sequence length of {self.max_seq_length}.\\n\"\n                \"We shall truncate it ourselves. It's imperative if you correct this issue first.\"\n            )\n        if input_ids is not None:\n            input_ids = input_ids[:, : self.max_seq_length]\n        elif inputs_embeds is not None:\n            inputs_embeds = inputs_embeds[:, : self.max_seq_length, :]\n        if (\n            attention_mask is not None\n            and attention_mask.shape[-1] > self.max_seq_length\n        ):\n            attention_mask = attention_mask[:, : self.max_seq_length]\n\n    past_key_values_length = 0\n\n    if past_key_values is not None:\n        past_key_values_length = past_key_values[0][0].shape[2]\n        seq_length_with_past = seq_length_with_past + past_key_values_length\n\n    # We already handle KV cache position_ids ourselves.\n    if False:  # (past_key_values_length != 0):\n        position_ids = torch.arange(\n            past_key_values_length,\n            seq_length + past_key_values_length,\n            dtype = torch.int32,\n            device = f\"{DEVICE_TYPE_TORCH}:0\",\n        )\n        position_ids = position_ids.unsqueeze(0).view(-1, seq_length)\n    elif position_ids is not None:\n        position_ids = position_ids.view(-1, seq_length).to(torch.int32)  # .long()\n    else:\n        position_ids = None\n\n    if position_ids is not None:\n        if position_ids.shape[0] != batch_size:\n            position_ids = position_ids.repeat((batch_size, 1))\n\n    # Embed positions\n    if inputs_embeds is None:\n        inputs_embeds = self.embed_tokens(input_ids)\n\n    inputs_embeds = inputs_embeds.to(_get_dtype(dtype_from_config(self.config)))\n\n    # Normalized from Gemma\n    IS_GEMMA = self.config.model_type.startswith(\"gemma\")\n    IS_GEMMA2 = self.config.model_type.startswith(\"gemma2\")\n    IS_COHERE = self.config.model_type.startswith(\"cohere\")\n    IS_GRANITE = self.config.model_type.startswith(\"granite\")\n    IS_FALCON_H1 = self.config.model_type.startswith(\"falcon_h1\")\n\n    train_embed_tokens = self.embed_tokens.weight.requires_grad\n\n    if IS_GEMMA:\n        # Match Gemma exactly by casting to bfloat16 / float16\n        # inputs_embeds *= math_sqrt(self.config.hidden_size)\n        # Ie 3072**0.5 = 55.5000 in bfloat16, whilst 55.4256 in float32\n        # &  2048**0.5 = 45.2500 in bfloat16, whilst 45.2548 in float32\n        normalizer = torch.tensor(\n            math_sqrt(self.config.hidden_size), dtype = inputs_embeds.dtype\n        )\n\n        if train_embed_tokens:\n            # Careful we must not do an inplace op!\n            inputs_embeds = inputs_embeds * normalizer\n        else:\n            inputs_requires_grad = inputs_embeds.requires_grad\n            if not inputs_embeds.is_leaf:\n                inputs_embeds = inputs_embeds.detach()\n                inputs_requires_grad = True\n            elif inputs_requires_grad:\n                inputs_embeds.requires_grad_(False)\n            inputs_embeds *= normalizer\n            # inputs_embeds *= math_sqrt(self.config.hidden_size)\n            if inputs_requires_grad:\n                inputs_embeds.requires_grad_(True)\n\n    # Fix up attention mask by setting elements to 0\n    # Specifically for DPO\n    if (\n        getattr(self, \"_has_no_labels\", False) is True\n        and (attention_mask is not None)\n        and (past_key_values is None)\n        and (not train_embed_tokens)\n        and self.training\n    ):\n        # Careful for inference the attention_mask is size (1, kv_seq_len)\n        # Whilst the input_embeds is size (1, 1, 4096)\n        inputs_requires_grad = inputs_embeds.requires_grad\n        if not inputs_embeds.is_leaf:\n            inputs_embeds = inputs_embeds.detach()\n            inputs_requires_grad = True\n        elif inputs_requires_grad:\n            inputs_embeds.requires_grad_(False)\n        attention_mask = attention_mask[:, : self.max_seq_length]  # Must resize!\n        inputs_embeds *= attention_mask.unsqueeze(0).transpose(0, 1).transpose(1, 2)\n        if inputs_requires_grad:\n            inputs_embeds.requires_grad_(True)\n\n    # Ignore attention_mask\n    if attention_mask is None:\n        padding_mask = None\n    elif self.training:\n        attention_mask = None\n        padding_mask = None\n    else:\n        # if 0 in attention_mask:\n        #     padding_mask = attention_mask\n        # else:\n        padding_mask = None\n\n        attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(\n            attention_mask,\n            (batch_size, seq_length),\n            inputs_embeds,\n            past_key_values_length,\n            sliding_window = getattr(self.config, \"sliding_window\", None),\n        )\n        # Must NOT convert to bool - weirdly this causes stuff to error out!\n        # if attention_mask is not None:\n        #     attention_mask = attention_mask.to(torch.bool)\n\n    hidden_states = inputs_embeds\n    if IS_GRANITE or IS_FALCON_H1:  # granite has embedding multiplier\n        hidden_states = self.config.embedding_multiplier * hidden_states\n\n    if past_key_values is None and self.training:\n        use_cache = False\n        # if use_cache:\n        #     logger.warning_once(\n        #         \"Unsloth: `use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`\"\n        #     )\n        #     use_cache = False\n\n    # decoder layers\n    all_hidden_states = () if output_hidden_states else None\n    all_self_attns = () if output_attentions else None\n    next_decoder_cache = () if use_cache else None\n\n    # Gradient checkpointing methods (ie sqrt)\n    if hasattr(self, \"_gradient_checkpointing_boundaries\"):\n        boundaries = self._gradient_checkpointing_boundaries\n    else:\n        boundaries = None\n\n    # Check checkpointing method\n    gradient_checkpointing = False\n\n    if self.gradient_checkpointing and self.training and not use_cache:\n        gradient_checkpointing = True\n\n    # Gemma2 has alternating SWA and global attn\n    use_static_mask = True\n    dynamic_SWA_mask = None\n    dynamic_GA_mask = None\n    if IS_GEMMA2:\n        if HAS_FLASH_ATTENTION_SOFTCAPPING and attention_mask is None:\n            self.SWA_mask = True\n            self.GA_mask = False\n        elif attention_mask is not None:\n            # Fixes https://github.com/unslothai/unsloth/issues/853\n            # Unsloth needs a 2D mask, not a [2, 1, n, n] mask!\n\n            # https://github.com/pytorch/pytorch/issues/103749\n            # Need to convert to float and not using bool\n            # attention_mask = (1.0 - attention_mask.float()) * torch.finfo(inputs_embeds.dtype).min\n            dynamic_SWA_mask = _prepare_4d_causal_attention_mask_for_sdpa(\n                attention_mask,\n                (batch_size, seq_length),\n                inputs_embeds,\n                past_key_values_length,\n                sliding_window = self.config.sliding_window,\n            )\n            dynamic_GA_mask = _prepare_4d_causal_attention_mask_for_sdpa(\n                attention_mask,\n                (batch_size, seq_length),\n                inputs_embeds,\n                past_key_values_length,\n                sliding_window = None,\n            )\n            use_static_mask = False\n\n        elif not hasattr(self, \"SWA_mask\"):\n            if HAS_FLEX_ATTENTION:\n                # Use Flex Attention instead!\n                self.SWA_mask = create_flex_attention_sliding_window_mask(\n                    self.max_seq_length, self.config.sliding_window\n                )\n                self.GA_mask = create_flex_attention_causal_mask(self.max_seq_length)\n            else:\n                n = self.max_seq_length  # self.config.max_position_embeddings\n                # masked_fill is making stuff slower!\n                # self. GA_mask = create_boolean_mask(n = n, sliding_window = 0)\n                # self.SWA_mask = create_boolean_mask(n = n, sliding_window = self.config.sliding_window)\n                from transformers.modeling_attn_mask_utils import AttentionMaskConverter\n\n                self.SWA_mask = (\n                    AttentionMaskConverter(\n                        is_causal = True,\n                        sliding_window = self.config.sliding_window,\n                    )\n                    .to_causal_4d(\n                        1,\n                        n,\n                        n,\n                        dtype = inputs_embeds.dtype,\n                        device = DEVICE_TYPE_TORCH,\n                    )\n                    .squeeze(0)\n                    .squeeze(0)\n                )\n\n                self.GA_mask = (\n                    AttentionMaskConverter(\n                        is_causal = True,\n                    )\n                    .to_causal_4d(\n                        1,\n                        n,\n                        n,\n                        dtype = inputs_embeds.dtype,\n                        device = DEVICE_TYPE_TORCH,\n                    )\n                    .squeeze(0)\n                    .squeeze(0)\n                )\n            pass\n\n    if (\n        IS_ATTENTION_REFACTOR\n        and (\n            hasattr(self, \"rotary_emb\")\n            or not hasattr(self.layers[0].self_attn, \"rotary_emb\")\n        )\n    ) or IS_GRANITE:\n        # Transformers main has made it mandatory to pass position_embeddings\n        # https://github.com/huggingface/transformers/pull/34858\n        # Also, transformers 4.45.0 supports granite but with the attention refactor (it always had the refactor)\n        # unsloth's check for granite too has \"version >= 4.45.0 (rightly so)\".\n        # so let granite always use the attention refactor implementation.\n\n        self.rotary_emb.extend_rope_embedding(\n            hidden_states, self.config.max_position_embeddings\n        )\n        position_embeddings = self.rotary_emb.get_cached(\n            self.config.max_position_embeddings, hidden_states.device.index\n        )\n    else:\n        position_embeddings = None\n\n    # Go through every layer!\n    for idx, decoder_layer in enumerate(self.layers):\n        if output_hidden_states:\n            all_hidden_states += (hidden_states,)\n        past_key_value = past_key_values[idx] if past_key_values is not None else None\n\n        mask = causal_mask\n        if IS_GEMMA2:\n            use_sliding_window = idx % 2 == 0\n            if use_sliding_window:\n                mask = self.SWA_mask if use_static_mask else dynamic_SWA_mask\n            else:\n                mask = self.GA_mask if use_static_mask else dynamic_GA_mask\n            kwargs[\"use_sliding_window\"] = use_sliding_window\n\n        if gradient_checkpointing and not isinstance(\n            decoder_layer, GradientCheckpointingLayer\n        ):\n\n            def create_custom_forward(module):\n                def custom_forward(*inputs):\n                    return module(\n                        *inputs,\n                        past_key_value,\n                        output_attentions,\n                        padding_mask = padding_mask,\n                        position_embeddings = position_embeddings,\n                        **kwargs,\n                    )\n\n                return custom_forward\n\n            layer_outputs = torch.utils.checkpoint.checkpoint(\n                create_custom_forward(decoder_layer),\n                hidden_states,\n                mask,\n                attention_mask,\n                position_ids,\n                use_reentrant = True,\n                preserve_rng_state = False,\n            )\n            hidden_states = layer_outputs[0]\n\n        else:\n            layer_outputs = decoder_layer(\n                hidden_states,\n                causal_mask = mask,\n                attention_mask = attention_mask,\n                position_ids = position_ids,\n                past_key_value = past_key_value,\n                output_attentions = output_attentions,\n                use_cache = use_cache,\n                padding_mask = padding_mask,\n                position_embeddings = position_embeddings,\n                **kwargs,\n            )\n            hidden_states = layer_outputs[0]\n\n        if use_cache:\n            next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)\n        if output_attentions:\n            all_self_attns += (layer_outputs[1],)\n\n    # Final layernorm\n    if use_cache:\n        if IS_FALCON_H1:\n            hidden_states = fast_rms_layernorm_inference(\n                self.final_layernorm, hidden_states\n            )\n        else:\n            hidden_states = (\n                fast_rms_layernorm_inference_gemma\n                if IS_GEMMA\n                else fast_rms_layernorm_inference\n            )(self.norm, hidden_states)\n    elif IS_COHERE:\n        hidden_states = self.norm(hidden_states)\n    elif IS_FALCON_H1:\n        hidden_states = fast_rms_layernorm(\n            self.final_layernorm, hidden_states, gemma = IS_GEMMA\n        )\n    else:\n        hidden_states = fast_rms_layernorm(self.norm, hidden_states, gemma = IS_GEMMA)\n\n    if output_hidden_states:\n        all_hidden_states += (hidden_states,)\n    next_cache = next_decoder_cache if use_cache else None\n\n    if not return_dict:\n        return tuple(\n            v\n            for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]\n            if v is not None\n        )\n    return BaseModelOutputWithPast(\n        last_hidden_state = hidden_states,\n        past_key_values = next_cache,\n        hidden_states = all_hidden_states,\n        attentions = all_self_attns,\n    )\n\n\n# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825\ndef _LlamaModel_fast_forward_inference(\n    attention_fast_forward_inference = LlamaAttention_fast_forward_inference,\n    mlp_fast_forward_inference = fast_swiglu_inference,\n):\n    # This makes the attention and MLP customisable.\n    # Now for models like qwen3 or cohere which use custom attention operations, we can use this function\n    def LlamaModel_fast_forward_inference_custom(\n        self,\n        input_ids,\n        past_key_values,\n        position_ids,\n        attention_mask = None,\n        **kwargs,\n    ):\n        input_ids = input_ids[:, : self.max_seq_length]\n        bsz, q_len = input_ids.shape\n        hd = self.config.hidden_size\n        mlp_size = self.config.intermediate_size\n\n        X = self.model.embed_tokens(input_ids)\n        X = X.to(_get_dtype(dtype_from_config(self.config)))\n        bsz, q_len, hd = X.shape\n        assert q_len == 1\n        # Get saved buffers to reduce memory movement\n        residual = torch.empty(\n            (bsz, q_len, hd), dtype = torch.float32, device = f\"{DEVICE_TYPE_TORCH}:0\"\n        )\n        _XX = torch.empty(\n            (2, bsz, q_len, hd), dtype = torch.float32, device = f\"{DEVICE_TYPE_TORCH}:0\"\n        )\n        XX, XX2 = _XX[0], _XX[1]\n        variance = torch.empty(\n            (bsz, q_len, 1), dtype = torch.float32, device = f\"{DEVICE_TYPE_TORCH}:0\"\n        )\n        temp_mlp = torch.empty(\n            (2, bsz, 1, mlp_size), dtype = X.dtype, device = f\"{DEVICE_TYPE_TORCH}:0\"\n        )\n        temp_gates, temp_ups = (\n            tuple(temp_mlp[0].to(torch.device(x)) for x in range(DEVICE_COUNT)),\n            tuple(temp_mlp[1].to(torch.device(x)) for x in range(DEVICE_COUNT)),\n        )\n\n        seq_len = past_key_values[0][0].shape[-2]\n        kv_seq_len = seq_len + 1\n        if attention_mask is not None:\n            attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(\n                attention_mask,\n                (bsz, q_len),\n                X,\n                seq_len,\n                sliding_window = getattr(self.config, \"sliding_window\", None),\n            )\n            # Pre-convert to bool once for all layers (avoids per-layer .eq(0))\n            if attention_mask is not None and attention_mask.dtype != torch.bool:\n                attention_mask = attention_mask.eq(0)\n        else:\n            attention_mask = None\n\n        # Compute rotary_seq_len once to avoid per-layer GPU-CPU sync from .item()\n        rotary_seq_len = max(kv_seq_len, int(position_ids.max().item()) + 1)\n\n        next_decoder_cache = []\n\n        for idx, decoder_layer in enumerate(self.model.layers):\n            device_index = getattr(decoder_layer, \"_per_layer_device_index\", 0)\n            X, residual, position_ids = move_to_device(\n                device_index, X, residual, position_ids\n            )\n            residual.copy_(X)  # residual = X\n            X = fast_rms_layernorm_inference(\n                decoder_layer.input_layernorm,\n                X,\n                XX = XX,\n                XX2 = XX2,\n                variance = variance,\n            )\n            X, present_key_value = attention_fast_forward_inference(\n                decoder_layer.self_attn,\n                hidden_states = X,\n                past_key_value = past_key_values[idx],\n                position_ids = position_ids,\n                attention_mask = attention_mask,\n                do_prefill = not hasattr(decoder_layer.self_attn, \"paged_attention\"),\n                rotary_seq_len = rotary_seq_len,\n            )\n            X += residual\n\n            residual.copy_(X)  # residual = X\n            X = fast_rms_layernorm_inference(\n                decoder_layer.post_attention_layernorm,\n                X,\n                XX = XX,\n                XX2 = XX2,\n                variance = variance,\n            )\n            X = mlp_fast_forward_inference(\n                decoder_layer.mlp,\n                X,\n                temp_gate = temp_gates[device_index],\n                temp_up = temp_ups[device_index],\n            )\n            X += residual\n\n            next_decoder_cache.append(present_key_value)\n        X = fast_rms_layernorm_inference(\n            self.model.norm,\n            X,\n            XX = XX,\n            XX2 = XX2,\n            variance = variance,\n        )\n\n        return BaseModelOutputWithPast(\n            last_hidden_state = X,\n            past_key_values = next_decoder_cache,\n            hidden_states = [],\n            attentions = [],\n        )\n\n    return LlamaModel_fast_forward_inference_custom\n\n\n# For ensuring backwards compatibility, we create LlamaModel_fast_forward_inference that is consumed by other models\nLlamaModel_fast_forward_inference = _LlamaModel_fast_forward_inference()\n\n\ndef CausalLM_fast_forward(fast_forward_inference):\n    def _CausalLM_fast_forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        causal_mask: Optional[BlockDiagonalCausalMask] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        num_logits_to_keep: Optional[int] = 0,\n        logits_to_keep: Optional[int] = 0,\n        *args,\n        **kwargs,\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        if past_key_values is not None:\n            outputs = fast_forward_inference(\n                self,\n                input_ids,\n                past_key_values,\n                position_ids = position_ids,\n                attention_mask = attention_mask,\n                **kwargs,\n            )\n        else:\n            causal_mask = (\n                xformers.attn_bias.LowerTriangularMask() if HAS_XFORMERS else None\n            )\n\n            output_attentions = (\n                output_attentions\n                if output_attentions is not None\n                else self.config.output_attentions\n            )\n            output_hidden_states = (\n                output_hidden_states\n                if output_hidden_states is not None\n                else self.config.output_hidden_states\n            )\n            return_dict = (\n                return_dict if return_dict is not None else self.config.use_return_dict\n            )\n            # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n            self.model._has_no_labels = labels is None\n            outputs = self.model(\n                input_ids = input_ids,\n                causal_mask = causal_mask,\n                attention_mask = attention_mask,\n                position_ids = position_ids,\n                past_key_values = past_key_values,\n                inputs_embeds = inputs_embeds,\n                use_cache = use_cache,\n                output_attentions = output_attentions,\n                output_hidden_states = output_hidden_states,\n                return_dict = return_dict,\n                **kwargs,\n            )\n        hidden_states = outputs[0]\n\n        bsz, q_len, hd = hidden_states.shape\n        lm_head = self.lm_head.weight\n        lm_head_device = lm_head.device\n\n        logit_softcapping = getattr(self.config, \"final_logit_softcapping\", 0)\n        logit_scaling = getattr(self.config, \"logit_scale\", 0)\n        dtype = lm_head.dtype\n        num_logits_to_keep = max(num_logits_to_keep, logits_to_keep)\n\n        # Move items to same device as lm_head\n        hidden_states = hidden_states.to(lm_head_device)\n        if labels is not None:\n            labels = labels.to(lm_head_device)\n\n        # Output last hidden states without logits if asked\n        if os.environ.get(\"UNSLOTH_RETURN_HIDDEN_STATES\", \"0\") == \"1\":\n            if num_logits_to_keep != 0:\n                hidden_states = hidden_states[:, -num_logits_to_keep:, :]\n            return CausalLMOutputWithPast(\n                loss = None,\n                logits = hidden_states,\n                past_key_values = outputs.past_key_values,\n                hidden_states = outputs.hidden_states,\n                attentions = outputs.attentions,\n            )\n\n        if bsz == 1 and q_len == 1:\n            logits = torch.mv(lm_head, hidden_states.ravel().to(dtype))\n            logits = logits.unsqueeze(0).unsqueeze(0)\n        elif num_logits_to_keep != 0:\n            logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :].to(dtype))\n        else:\n            RETURN_LOGITS = os.environ.get(\"UNSLOTH_RETURN_LOGITS\", \"0\") == \"1\"\n            # < 1024 Normal Unsloth uses less VRAM!\n            if bsz * q_len <= 1024 and not RETURN_LOGITS:\n                # Use unsloth_fused_ce_loss which actually calculates the best chunk size to reduce VRAM usage\n                RETURN_LOGITS = False\n\n            if not RETURN_LOGITS and labels is not None:\n                n_items = kwargs.get(\"num_items_in_batch\", None)\n                if n_items is None:\n                    n_items = kwargs.get(\"n_items\", None)\n\n                if self.config.model_type == \"falcon_h1\":\n                    hidden_states = hidden_states * self.config.lm_head_multiplier\n\n                ### DISABLED since T4 breaks\n                # OutOfResources: out of resource: shared memory, Required: 98304, Hardware limit: 65536. Reducing block sizes or `num_stages` may help.\n                # loss = fused_linear_cross_entropy(\n                #     hidden_states      = hidden_states,\n                #     lm_weight          = lm_head,\n                #     labels             = labels,\n                #     num_items_in_batch = n_items,\n                #     logit_softcapping  = logit_softcapping,\n                # )\n                loss = unsloth_fused_ce_loss(\n                    trainer = None,\n                    hidden_states = hidden_states,\n                    lm_head_weight = lm_head,\n                    lm_head_bias = None,\n                    labels = labels,\n                    mask = None,\n                    n_items = n_items,\n                    scaling = getattr(self, \"accelerator_scaler\", None),\n                    target_gb = None,\n                    torch_compile = True,\n                    logit_softcapping = logit_softcapping,\n                )\n                if not return_dict:\n                    output = (logits,) + outputs[1:]\n                    return (loss,) + output if loss is not None else output\n\n                output = CausalLMOutputWithPast(\n                    loss = loss,\n                    logits = EMPTY_LOGITS,\n                    past_key_values = outputs.past_key_values,\n                    hidden_states = outputs.hidden_states,\n                    attentions = outputs.attentions,\n                )\n                return output\n            pass\n            logits = self.lm_head(hidden_states.to(dtype))\n\n        logits = logits.to(_get_dtype(dtype_from_config(self.config)))\n        loss = None\n        logit_softcapping = getattr(self.config, \"final_logit_softcapping\", 0)\n        logit_scaling = getattr(self.config, \"logit_scale\", 0)\n        if self.config.model_type == \"granite\":\n            # granite uses logit_scaling as key and they divide by the scale unlike cohere\n            # notice that for granite, logits_scale is 16 and for cohere it is 0.125 (aka 1/8) in their respective configs\n            # granite: https://github.com/huggingface/transformers/blob/4d1d0f29a493098e6bc6b904b82e29cb331827f5/src/transformers/models/granite/modeling_granite.py#L1103\n            # cohere: https://github.com/huggingface/transformers/blob/4d1d0f29a493098e6bc6b904b82e29cb331827f5/src/transformers/models/cohere/modeling_cohere.py#L1176\n            logit_scaling = 1 / getattr(self.config, \"logits_scaling\", 1)\n        elif self.config.model_type == \"falcon_h1\":\n            logit_scaling = self.config.lm_head_multiplier\n\n        if labels is not None:\n            shift_logits = logits\n            # if not hasattr(self, \"extra_ignored_labels\"):\n            #     # Fixes https://github.com/unslothai/unsloth/issues/10\n            #     self.extra_ignored_labels = torch.full((self.max_seq_length, 1), -100, device = \"cuda:0\")\n            # pass\n            shift_labels = torch.empty_like(labels)\n            shift_labels[..., :-1] = labels[..., 1:]\n            shift_labels[..., -1] = -100\n            mask_packed_sequence_boundaries(\n                shift_labels,\n                kwargs.get(\"packed_seq_lengths\"),\n            )\n            # shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))\n            n_items = kwargs.get(\"num_items_in_batch\", None)\n            if n_items is None:\n                n_items = kwargs.get(\"n_items\", None)\n            loss = fast_cross_entropy_loss(\n                logits = shift_logits,\n                labels = shift_labels,\n                logit_softcapping = logit_softcapping,\n                logit_scaling = logit_scaling,\n                n_items = n_items,\n            )\n        else:\n            if logit_scaling != 0:\n                if logits.requires_grad:\n                    logits = logit_scaling * logits\n                else:\n                    logits *= logit_scaling\n            if logit_softcapping != 0:\n                if logits.requires_grad:\n                    logits = (1.0 / logit_softcapping) * logits\n                    logits = torch.tanh(logits)\n                    logits = logit_softcapping * logits\n                else:\n                    logits *= 1.0 / logit_softcapping\n                    logits.tanh_()\n                    logits *= logit_softcapping\n\n        if not return_dict:\n            output = (logits,) + outputs[1:]\n            return (loss,) + output if loss is not None else output\n        return CausalLMOutputWithPast(\n            loss = loss,\n            logits = logits,\n            past_key_values = outputs.past_key_values,\n            hidden_states = outputs.hidden_states,\n            attentions = outputs.attentions,\n        )\n\n    return _CausalLM_fast_forward\n\n\n@torch._disable_dynamo\ndef PeftModel_fast_forward(\n    self,\n    input_ids = None,\n    causal_mask = None,\n    attention_mask = None,\n    inputs_embeds = None,\n    labels = None,\n    output_attentions = None,\n    output_hidden_states = None,\n    return_dict = None,\n    task_ids = None,\n    num_logits_to_keep = 0,\n    logits_to_keep = 0,\n    **kwargs,\n):\n    is_classification = \"Classification\" in str(type(self.base_model.model))\n    if is_classification:\n        return self.base_model(\n            input_ids = input_ids,\n            attention_mask = attention_mask,\n            inputs_embeds = inputs_embeds,\n            labels = labels,\n            output_attentions = output_attentions,\n            output_hidden_states = output_hidden_states,\n            return_dict = return_dict,\n            **kwargs,\n        )\n    else:\n        return self.base_model(\n            input_ids = input_ids,\n            causal_mask = causal_mask,\n            attention_mask = attention_mask,\n            inputs_embeds = inputs_embeds,\n            labels = labels,\n            output_attentions = output_attentions,\n            output_hidden_states = output_hidden_states,\n            return_dict = return_dict,\n            num_logits_to_keep = num_logits_to_keep,\n            logits_to_keep = logits_to_keep,\n            **kwargs,\n        )\n\n\ndef _get_rope_theta(config, default = 10000.0):\n    \"\"\"Get rope_theta from config, handling both transformers 4.x and 5.x.\"\"\"\n    try:\n        return config.rope_theta\n    except (AttributeError, KeyError):\n        pass\n    rp = getattr(config, \"rope_parameters\", None)\n    if isinstance(rp, dict):\n        return rp.get(\"rope_theta\", default)\n    return default\n\n\n# Solves https://github.com/unslothai/unsloth/issues/168\n# Static KV Cache was introduced in 4.38.0, causing training to be much slower.\n# Inference can now be CUDAGraphed, but we shall retain the old rotary embeddings.\n# https://github.com/huggingface/transformers/pull/27931\n# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py\nclass LlamaRotaryEmbedding(torch.nn.Module):\n    # Fixes https://github.com/huggingface/transformers/pull/28837\n    # https://github.com/microsoft/DeepSpeed/issues/4932\n    # The precision of RoPE buffers is not correct, so we cast to int64.\n    def __init__(\n        self,\n        dim = None,\n        max_position_embeddings = 2048,\n        base = 10000,\n        device = None,\n        config = None,  # [TODO] Hack to pass in config - need to remove later\n    ):\n        super().__init__()\n        if config is not None:\n            # [TODO] Hack to pass in config - need to remove later\n            base = _get_rope_theta(config, default = base)\n            partial_rotary_factor = (\n                config.partial_rotary_factor\n                if hasattr(config, \"partial_rotary_factor\")\n                else 1.0\n            )\n            dim = getattr(config, \"head_dim\", None)\n            if dim is None:\n                dim = int((config.hidden_size // config.num_attention_heads))\n            device = DEVICE_TYPE_TORCH\n            max_position_embeddings = config.max_position_embeddings\n\n        self.dim = dim\n        self.max_position_embeddings = max_position_embeddings\n        self.base = base\n        # Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this\n        self.current_rope_size = min(4 * 8192, self.max_position_embeddings)\n        self.multi_gpu_cos_cached = [None] * DEVICE_COUNT\n        self.multi_gpu_sin_cached = [None] * DEVICE_COUNT\n\n        # Normal Llama-3 RoPE\n        inv_freq = 1.0 / (\n            self.base\n            ** (\n                torch.arange(0, self.dim, 2, dtype = torch.int64, device = \"cpu\").float()\n                / self.dim\n            )\n        )\n        inv_freq = self._apply_inv_freq_scaling(inv_freq)\n        self.register_buffer(\"inv_freq\", inv_freq, persistent = False)\n\n        # Build here to make `torch.jit.trace` work.\n        for device_idx in range(DEVICE_COUNT):\n            self._set_cos_sin_cache(\n                seq_len = self.current_rope_size,\n                device = torch.device(device_idx),\n                dtype = torch.get_default_dtype(),\n            )\n\n        # dummy so that patch_utils doesn't fail for now\n        self.cos_cached = torch.empty(\n            1, device = get_current_device(), dtype = torch.get_default_dtype()\n        )\n        self.sin_cached = torch.empty(\n            1, device = get_current_device(), dtype = torch.get_default_dtype()\n        )\n\n    def _apply_inv_freq_scaling(self, inv_freq):\n        \"\"\"Override to apply custom inv_freq scaling (e.g., extended RoPE).\"\"\"\n        return inv_freq\n\n    def _apply_time_scaling(self, t):\n        \"\"\"Override to apply custom time scaling (e.g., linear scaling).\"\"\"\n        return t\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        # Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and\n        # in FP32. They are applied (multiplied) in FP32 as well.\n        self.current_rope_size = seq_len\n        t = torch.arange(\n            self.current_rope_size, device = self.inv_freq.device, dtype = torch.int64\n        ).float()\n        t = self._apply_time_scaling(t)\n\n        freqs = torch.outer(t, self.inv_freq)\n        # Different from paper, but it uses a different permutation in order to obtain the same calculation\n        emb = torch.cat((freqs, freqs), dim = -1)\n        cos = emb.cos().to(dtype = dtype, device = device, non_blocking = True)\n        sin = emb.sin().to(dtype = dtype, device = device, non_blocking = True)\n        self.multi_gpu_cos_cached[device.index] = cos\n        self.multi_gpu_sin_cached[device.index] = sin\n        return cos, sin\n\n    def forward(self, x, position_ids = None, seq_len = None):\n        # x: [bs, num_attention_heads, seq_len, head_size]\n        if seq_len is not None and seq_len > self.current_rope_size:\n            self._set_cos_sin_cache(seq_len = seq_len, device = x.device, dtype = x.dtype)\n\n        device_index = x.device.index\n        return (\n            self.multi_gpu_cos_cached[device_index][:seq_len],\n            self.multi_gpu_sin_cached[device_index][:seq_len],\n        )\n\n    def get_cached(self, seq_len = None, device_index = None):\n        if device_index is None:\n            device_index = get_current_device()\n        return self.multi_gpu_cos_cached[device_index], self.multi_gpu_sin_cached[\n            device_index\n        ]\n\n    def extend_rope_embedding(self, x, seq_len):\n        if seq_len <= self.current_rope_size:\n            return\n        # Iteratively grow by increments of 8192\n        self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192\n        for device_idx in range(DEVICE_COUNT):\n            self._set_cos_sin_cache(\n                self.current_rope_size, device = torch.device(device_idx), dtype = x.dtype\n            )\n\n\nclass LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):\n    \"\"\"LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev\"\"\"\n\n    # Fixes https://github.com/huggingface/transformers/pull/28837\n    # https://github.com/microsoft/DeepSpeed/issues/4932\n    # The precision of RoPE buffers is not correct, so we cast to int64.\n    def __init__(\n        self,\n        dim = None,\n        max_position_embeddings = 2048,\n        base = 10000,\n        device = None,\n        scaling_factor = 1.0,\n        config = None,  # [TODO] Hack to pass in config - need to remove later\n    ):\n        self.scaling_factor = scaling_factor\n        super().__init__(\n            dim = dim,\n            max_position_embeddings = max_position_embeddings,\n            base = base,\n            device = device,\n            config = config,\n        )\n\n    def _apply_time_scaling(self, t):\n        \"\"\"Apply linear scaling to time indices.\"\"\"\n        return t / self.scaling_factor\n\n\n# See https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding.py#L736\n# For Llama 3.1\nclass LlamaExtendedRotaryEmbedding(LlamaRotaryEmbedding):\n    def __init__(\n        self,\n        dim = None,\n        max_position_embeddings = 2048,\n        base = 10000,\n        device = None,\n        config = None,  # [TODO] Hack to pass in config - need to remove later\n    ):\n        super().__init__(\n            dim = dim,\n            max_position_embeddings = max_position_embeddings,\n            base = base,\n            device = device,\n            config = config,\n        )\n\n    # From https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/model.py#L41\n    def _apply_inv_freq_scaling(self, freqs: torch.Tensor):\n        # Values obtained from grid search\n        scale_factor = 8\n        low_freq_factor = 1\n        high_freq_factor = 4\n        old_context_len = 8192  # original llama3 length\n\n        low_freq_wavelen = old_context_len / low_freq_factor\n        high_freq_wavelen = old_context_len / high_freq_factor\n        new_freqs = []\n        for freq in freqs:\n            wavelen = 2 * math.pi / freq\n            if wavelen < high_freq_wavelen:\n                new_freqs.append(freq)\n            elif wavelen > low_freq_wavelen:\n                new_freqs.append(freq / scale_factor)\n            else:\n                assert low_freq_wavelen != high_freq_wavelen\n                smooth = (old_context_len / wavelen - low_freq_factor) / (\n                    high_freq_factor - low_freq_factor\n                )\n                new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)\n        return torch.tensor(new_freqs, dtype = freqs.dtype, device = freqs.device)\n\n\nclass LongRopeRotaryEmbedding(torch.nn.Module):\n    # For Phi 3.5 128K https://huggingface.co/microsoft/Phi-3.5-mini-instruct/blob/main/modeling_phi3.py\n    def __init__(\n        self,\n        dim = None,\n        max_position_embeddings = 131072,\n        original_max_position_embeddings = 4096,\n        base = 10000,\n        short_factor = None,\n        long_factor = None,\n        device = None,\n        config = None,  # [TODO] Hack to pass in config - need to remove later\n    ):\n        super().__init__()\n        assert short_factor is not None\n        assert long_factor is not None\n        assert type(original_max_position_embeddings) is int\n\n        if config is not None:\n            # [TODO] Hack to pass in config - need to remove later\n            base = _get_rope_theta(config, default = base)\n            partial_rotary_factor = (\n                config.partial_rotary_factor\n                if hasattr(config, \"partial_rotary_factor\")\n                else 1.0\n            )\n            dim = int((config.hidden_size // config.num_attention_heads))\n            device = DEVICE_TYPE_TORCH\n            max_position_embeddings = config.max_position_embeddings\n\n        self.dim = dim\n        self.max_position_embeddings = max_position_embeddings\n        self.original_max_position_embeddings = original_max_position_embeddings\n        self.base = base\n        # Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this\n        self.current_rope_size = min(\n            original_max_position_embeddings, self.max_position_embeddings\n        )\n        self.multi_gpu_short_cos_cached = [None] * DEVICE_COUNT\n        self.multi_gpu_short_sin_cached = [None] * DEVICE_COUNT\n        self.multi_gpu_long_cos_cached = [None] * DEVICE_COUNT\n        self.multi_gpu_long_sin_cached = [None] * DEVICE_COUNT\n\n        # Long RoPE similar to RoPE except short sequences have 1 cos / sin\n        # and long sequences have another cos / sin\n        inv_freq_shape = (\n            torch.arange(0, self.dim, 2, dtype = torch.int64, device = \"cpu\").float()\n            / self.dim\n        )\n        short_factor = torch.tensor(short_factor, device = \"cpu\", dtype = torch.float32)\n        long_factor = torch.tensor(long_factor, device = \"cpu\", dtype = torch.float32)\n        short_inv_freq = 1.0 / (short_factor * self.base**inv_freq_shape)\n        long_inv_freq = 1.0 / (long_factor * self.base**inv_freq_shape)\n\n        # Phi-3 Scale factor\n        scale = self.max_position_embeddings / self.original_max_position_embeddings\n        if scale <= 1.0:\n            scaling_factor = 1.0\n        else:\n            scaling_factor = math.sqrt(\n                1 + math.log(scale) / math.log(self.original_max_position_embeddings)\n            )\n        self.scaling_factor = scaling_factor\n\n        # Short and long inv_freq\n        self.register_buffer(\"short_inv_freq\", short_inv_freq, persistent = False)\n        self.register_buffer(\"long_inv_freq\", long_inv_freq, persistent = False)\n\n        # Build here to make `torch.jit.trace` work.\n        # Initialize short sequences cache for all devices\n        dtype = torch.bfloat16 if is_bfloat16_supported() else torch.float16\n        t = torch.arange(\n            original_max_position_embeddings,\n            device = self.short_inv_freq.device,\n            dtype = torch.int64,\n        ).float()\n        freqs = torch.outer(t, self.short_inv_freq)\n        emb = torch.cat((freqs, freqs), dim = -1)\n\n        for device_idx in range(DEVICE_COUNT):\n            device_obj = torch.device(device_idx)\n            cos_cached = (emb.cos() * self.scaling_factor).to(\n                dtype = dtype, device = device_obj, non_blocking = True\n            )\n            sin_cached = (emb.sin() * self.scaling_factor).to(\n                dtype = dtype, device = device_obj, non_blocking = True\n            )\n            self.multi_gpu_short_cos_cached[device_idx] = cos_cached\n            self.multi_gpu_short_sin_cached[device_idx] = sin_cached\n\n        # dummy so that patch_utils doesn't fail for now\n        self.short_cos_cached = torch.empty(\n            1, device = get_current_device(), dtype = torch.get_default_dtype()\n        )\n        self.short_sin_cached = torch.empty(\n            1, device = get_current_device(), dtype = torch.get_default_dtype()\n        )\n        self.long_cos_cached = torch.empty(\n            1, device = get_current_device(), dtype = torch.get_default_dtype()\n        )\n        self.long_sin_cached = torch.empty(\n            1, device = get_current_device(), dtype = torch.get_default_dtype()\n        )\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        # Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and\n        # in FP32. They are applied (multiplied) in FP32 as well.\n        self.current_rope_size = seq_len\n\n        t = torch.arange(\n            self.current_rope_size, device = self.long_inv_freq.device, dtype = torch.int64\n        ).float()\n        # Long sequences\n        freqs = torch.outer(t, self.long_inv_freq)\n        emb = torch.cat((freqs, freqs), dim = -1)\n        cos_cached = (emb.cos() * self.scaling_factor).to(\n            dtype = dtype, device = device, non_blocking = True\n        )\n        sin_cached = (emb.sin() * self.scaling_factor).to(\n            dtype = dtype, device = device, non_blocking = True\n        )\n        self.multi_gpu_long_cos_cached[device.index] = cos_cached\n        self.multi_gpu_long_sin_cached[device.index] = sin_cached\n        return cos_cached, sin_cached\n\n    def forward(self, x, position_ids = None, seq_len = None):\n        # x: [bs, num_attention_heads, seq_len, head_size]\n        if seq_len is not None and seq_len > self.current_rope_size:\n            self._set_cos_sin_cache(seq_len = seq_len, device = x.device, dtype = x.dtype)\n\n        device_index = x.device.index\n\n        if seq_len is not None and seq_len < self.original_max_position_embeddings:\n            return (\n                self.multi_gpu_short_cos_cached[device_index][:seq_len],\n                self.multi_gpu_short_sin_cached[device_index][:seq_len],\n            )\n        else:\n            return (\n                self.multi_gpu_long_cos_cached[device_index][:seq_len],\n                self.multi_gpu_long_sin_cached[device_index][:seq_len],\n            )\n\n    def get_cached(self, seq_len = None, device_index = None):\n        if device_index is None:\n            device_index = get_current_device()\n        if seq_len is not None and seq_len < self.original_max_position_embeddings:\n            return self.multi_gpu_short_cos_cached[\n                device_index\n            ], self.multi_gpu_short_sin_cached[device_index]\n        return self.multi_gpu_long_cos_cached[\n            device_index\n        ], self.multi_gpu_long_sin_cached[device_index]\n\n    def extend_rope_embedding(self, x, seq_len):\n        if seq_len <= self.current_rope_size:\n            return\n        # Iteratively grow by increments of 8192\n        self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192\n        for device_idx in range(DEVICE_COUNT):\n            self._set_cos_sin_cache(\n                self.current_rope_size, device = torch.device(device_idx), dtype = x.dtype\n            )\n\n\ndef unsloth_fast_generate(\n    self,\n    *args,\n    **kwargs,\n):\n    # If the model starts out in training mode, restore training mode after generation\n    restore_training_mode = self.training\n\n    FastLlamaModel.for_inference(self)\n\n    # Unpack BatchEncoding passed as input_ids for backwards compatibility.\n    # Old notebooks do model.generate(input_ids=tokenizer(...)) where the tokenizer\n    # output is a BatchEncoding (dict-like). Transformers v5 generate() calls\n    # .shape on it directly and crashes. Unpack into separate kwargs so both\n    # v4 and v5 work transparently.\n    _maybe_encoding = kwargs.get(\"input_ids\", None)\n    if (\n        _maybe_encoding is not None\n        and not isinstance(_maybe_encoding, torch.Tensor)\n        and hasattr(_maybe_encoding, \"items\")\n    ):\n        batch_data = kwargs.pop(\"input_ids\")\n        for key, val in batch_data.items():\n            kwargs.setdefault(key, val)\n\n    dtype = _get_dtype(dtype_from_config(self.config))\n\n    if hasattr(self, \"config\") and hasattr(self.config, \"max_position_embeddings\"):\n        if (\n            \"input_ids\" in kwargs\n            and kwargs[\"input_ids\"] is not None\n            and \"max_new_tokens\" in kwargs\n        ):\n            _ids = kwargs[\"input_ids\"]\n            if hasattr(_ids, \"shape\") and (\n                _ids.shape[-1] + kwargs[\"max_new_tokens\"]\n                > self.config.max_position_embeddings\n            ):\n                raise ValueError(\n                    f\"Unsloth: input length {_ids.shape[-1]} + max_new_tokens {kwargs['max_new_tokens']} exceeds the maximum sequence length of {self.config.max_position_embeddings}!\\n\"\n                    \"You will need to do long context extension by increasing the `max_seq_length` in `FastLanguageModel.from_pretrained`.\"\n                )\n\n    # Must patch accelerate for Xformers\n    # if accelerate_new_send_to_device is not None:\n    #     import accelerate.utils.operations\n    #     accelerate.utils.operations.send_to_device = accelerate_new_send_to_device\n    # pass\n\n    # For newer HF\n    kwargs[\"cache_implementation\"] = \"dynamic\"\n    # For num_logits_to_keep\n    num_logits_to_keep = kwargs.get(\"num_logits_to_keep\", None)\n    logits_to_keep = kwargs.get(\"logits_to_keep\", None)\n    if num_logits_to_keep is None and logits_to_keep is None:\n        kwargs[\"num_logits_to_keep\"] = 1\n\n    # Remove token_type_ids\n    kwargs.pop(\"token_type_ids\", None)\n\n    # Check pad_token\n    model_eos_token_id = getattr(self.config, \"eos_token_id\", None)\n    if model_eos_token_id is not None and hasattr(model_eos_token_id, \"__iter__\"):\n        model_eos_token_id = model_eos_token_id[0]\n\n    kwargs[\"pad_token_id\"] = kwargs.pop(\"pad_token_id\", model_eos_token_id)\n\n    # Mixed precision autocast\n    with (\n        _get_inference_mode_context_manager(self),\n        torch.autocast(device_type = DEVICE_TYPE_TORCH, dtype = dtype),\n    ):\n        output = self._old_generate(*args, **kwargs)\n\n    # Return accelerate back\n    # if accelerate_new_send_to_device is not None:\n    #     accelerate.utils.operations.send_to_device = accelerate_old_send_to_device\n    # pass\n\n    if restore_training_mode:\n        FastLlamaModel.for_training(self)\n\n    return output\n\n\nclass FastLlamaModel:\n    @staticmethod\n    def _prepare_for_qat(model, qat_scheme):\n        model = _prepare_model_for_qat(model, qat_scheme)\n        return model\n\n    @staticmethod\n    def pre_patch():\n        init_name, function = patch_llama_rope_scaling(\n            model_name = \"llama\",\n            rope_module = LlamaRotaryEmbedding,\n            scaled_rope_module = LlamaLinearScalingRotaryEmbedding,\n            extended_rope_module = LlamaExtendedRotaryEmbedding,\n            attention_module = LlamaAttention,\n            longrope_module = LongRopeRotaryEmbedding,\n        )\n        if init_name is not None:\n            exec(function, globals())\n            LlamaAttention.__init__ = eval(init_name)\n        LlamaAttention.forward = LlamaAttention_fast_forward\n        LlamaSdpaAttention.forward = LlamaAttention_fast_forward\n        LlamaFlashAttention2.forward = LlamaAttention_fast_forward\n        LlamaDecoderLayer.forward = LlamaDecoderLayer_fast_forward\n        LlamaModel.forward = LlamaModel_fast_forward\n        LlamaForCausalLM.forward = CausalLM_fast_forward(\n            LlamaModel_fast_forward_inference\n        )\n        PeftModelForCausalLM.forward = PeftModel_fast_forward\n        fix_prepare_inputs_for_generation(LlamaForCausalLM)\n\n        # Solves https://github.com/unslothai/unsloth/issues/168\n        # Static KV Cache was introduced in 4.38.0, causing training to be much slower.\n        # Inference can now be CUDAGraphed, but we shall retain the old rotary embeddings.\n        # https://github.com/huggingface/transformers/pull/27931\n        # https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py\n        import transformers.models.llama.modeling_llama\n\n        transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = (\n            LlamaRotaryEmbedding\n        )\n        transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding = (\n            LlamaLinearScalingRotaryEmbedding\n        )\n        return\n\n    @staticmethod\n    def from_pretrained(\n        model_name = \"unsloth/llama-3-8b-bnb-4bit\",\n        max_seq_length = None,\n        dtype = None,\n        load_in_4bit = True,\n        token = None,\n        device_map = \"sequential\",\n        rope_scaling = None,\n        fix_tokenizer = True,\n        model_patcher = None,\n        tokenizer_name = None,\n        trust_remote_code = False,\n        revision = None,\n        fast_inference = False,  # uses vLLM\n        gpu_memory_utilization = 0.5,\n        float8_kv_cache = False,\n        random_state = 3407,\n        max_lora_rank = 16,\n        disable_log_stats = False,\n        unsloth_vllm_standby = False,\n        num_labels = None,\n        qat_scheme = None,\n        load_in_fp8 = False,  # fp8 LoRA (True, False, 'block')\n        **kwargs,\n    ):\n        os.environ[\"UNSLOTH_USE_NEW_MODEL\"] = \"0\"\n        if trust_remote_code:\n            if fast_inference:\n                raise NotImplementedError(\n                    \"Unsloth: Fast inference does not support `trust_remote_code` yet.\"\n                )\n            print(\n                \"Unsloth: WARNING `trust_remote_code` is True.\\n\"\n                \"Are you certain you want to do remote code execution?\"\n            )\n        if fast_inference:\n            if not is_vLLM_available():\n                print(\"Unsloth: vLLM is not installed! Will use Unsloth inference!\")\n                fast_inference = False\n            if DEVICE_TYPE == \"cuda\":\n                major_version, minor_version = torch.cuda.get_device_capability()\n                if major_version < 7:\n                    print(\n                        \"Unsloth: vLLM does not work on older GPUs - will switch to Unsloth inference!\"\n                    )\n                    fast_inference = False\n            elif DEVICE_TYPE == \"hip\":\n                fast_inference = True\n            if (\n                unsloth_vllm_standby\n                and os.environ.get(\"UNSLOTH_VLLM_STANDBY\", \"0\") == \"0\"\n            ):\n                raise RuntimeError(\n                    \"Unsloth: `unsloth_vllm_standby` is True, but  environment variable `UNSLOTH_VLLM_STANDBY` is not set to 1!\"\n                )\n\n        token = hf_login(token)\n        if model_patcher is None:\n            model_patcher = FastLlamaModel\n        SUPPORTS_BFLOAT16 = is_bfloat16_supported()\n\n        if DEVICE_TYPE == \"cuda\":\n            gpu_stats = torch.cuda.get_device_properties(0)\n            gpu_stats_name = (\n                gpu_stats.name + \". \" if gpu_stats.name != \"\" else \"NVIDIA GPU Device. \"\n            )\n            gpu_version = torch.version.cuda\n            gpu_stats_snippet = f\"CUDA: {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit: {gpu_version}.\"\n            try:\n                vllm_version = f\" vLLM: {importlib_version('vllm')}.\"\n            except:\n                vllm_version = \"\"\n        elif DEVICE_TYPE == \"hip\":\n            gpu_stats = torch.cuda.get_device_properties(0)\n            gpu_stats_name = resolve_hip_gpu_stats_name(gpu_stats)\n            gpu_version = torch.version.hip\n            gpu_stats_snippet = f\"ROCm Toolkit: {gpu_version}.\"\n            try:\n                vllm_version = f\" vLLM: {importlib_version('vllm')}.\"\n            except:\n                vllm_version = \"\"\n        elif DEVICE_TYPE == \"xpu\":\n            gpu_stats = torch.xpu.get_device_properties(0)\n            gpu_stats_name = (\n                gpu_stats.name + \". \" if gpu_stats.name != \"\" else \"Intel XPU Device. \"\n            )\n            gpu_version = torch.version.xpu\n            gpu_stats_snippet = f\"Intel Toolkit: {gpu_version}.\"\n            try:\n                vllm_version = f\" vLLM: {importlib_version('vllm')}.\"\n            except:\n                vllm_version = \"\"\n        else:\n            raise ValueError(f\"Unsloth: Unsupported device type: {DEVICE_TYPE}\")\n\n        max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n\n        statistics = (\n            f\"==((====))==  Unsloth {__version__}: Fast {model_patcher.__name__[4:-5]} patching. Transformers: {transformers_version}.{vllm_version}\\n\"\n            f\"   {chr(92)}{chr(92)}   /|    {gpu_stats_name}Num GPUs = {DEVICE_COUNT}. Max memory: {max_memory} GB. Platform: {platform_system}.\\n\"\n            f\"O^O/ {chr(92)}_/ {chr(92)}    Torch: {torch.__version__}. {gpu_stats_snippet} Triton: {triton_version}\\n\"\n            f\"{chr(92)}        /    Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\\n\"\n            f' \"-____-\"     Free license: http://github.com/unslothai/unsloth'\n        )\n\n        print(statistics)\n\n        # Warn about fast transfers\n        if \"HF_HUB_ENABLE_HF_TRANSFER\" in os.environ:\n            old_hf_transfer = os.environ[\"HF_HUB_ENABLE_HF_TRANSFER\"]\n            if old_hf_transfer in (\"False\", \"false\"):\n                old_hf_transfer = \"0\"\n            if old_hf_transfer in (\"True\", \"true\"):\n                old_hf_transfer = \"1\"\n        else:\n            old_hf_transfer = \"0\"\n        if old_hf_transfer == \"1\":\n            print(\n                \"Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!\"\n            )\n        if old_hf_transfer != \"0\":\n            os.environ[\"HF_HUB_ENABLE_HF_TRANSFER\"] = \"1\"\n\n        model_patcher.pre_patch()\n        # For debugging - we use a download counter to see if environments are not breaking or if HF is down\n        get_statistics(kwargs.get(\"local_files_only\", False))\n\n        if dtype is None:\n            dtype = torch.float16 if not SUPPORTS_BFLOAT16 else torch.bfloat16\n        elif dtype == torch.bfloat16 and not SUPPORTS_BFLOAT16:\n            logger.warning_once(\n                \"Device does not support bfloat16. Will change to float16.\"\n            )\n            dtype = torch.float16\n        # elif dtype == torch.float16 and SUPPORTS_BFLOAT16:\n        #     logger.warning_once(\"Device supports bfloat16 but you selected float16. Will change to bfloat16.\")\n        #     dtype = torch.bfloat16\n\n        assert (\n            dtype == torch.float16 or dtype == torch.bfloat16 or dtype == torch.float32\n        )\n\n        # RoPE Scaling\n        model_config = AutoConfig.from_pretrained(\n            model_name,\n            token = token,\n            attn_implementation = \"sdpa\",\n        )\n        model_config.model_name = model_name\n        model_max_seq_length = model_config.max_position_embeddings\n\n        verify_fp8_support_if_applicable(model_config)\n\n        # Check if RoPE Scaling is even allowed\n        model_function = MODEL_FOR_CAUSAL_LM_MAPPING[model_config.__class__]\n        IS_FALCON_H1 = model_config.model_type.startswith(\"falcon_h1\")\n\n        preferred_attn_impl = (\n            prefer_flex_attn_if_supported(model_function, model_config) or \"eager\"\n        )\n\n        has_rope_scaling = False\n        try:\n            with open(inspect.getfile(model_function), \"r\", encoding = \"utf-8\") as file:\n                has_rope_scaling = \"self.config.rope_scaling\" in file.read()\n        except:\n            pass\n        has_rope_scaling = True\n\n        # If max_seq_length is not specified, use maximum from config\n        if max_seq_length is None:\n            max_seq_length = model_max_seq_length\n\n        if (rope_scaling is None) and (max_seq_length > model_max_seq_length):\n            rope_scaling = max_seq_length / model_max_seq_length\n\n            if fast_inference:\n                raise NotImplementedError(\n                    \"Unsloth: Fast inference does not yet work with RoPE Scaling.\"\n                )\n\n            logger.warning_once(\n                f\"Unsloth: {model_name} can only handle sequence lengths of at most \"\n                f\"{model_max_seq_length}.\\nBut with kaiokendev's RoPE scaling of \"\n                f\"{round(rope_scaling, 3)}, it can be magically be extended to \"\n                f\"{max_seq_length}!\"\n            )\n\n            # Warn RoPE scaling isn't allowed\n            if not has_rope_scaling:\n                raise RuntimeError(\n                    f\"However, {model_name} doesn't support RoPE Scaling!\\n\"\n                    \"Please file a feature request at https://github.com/unslothai/unsloth.\"\n                )\n\n            rope_scaling = {\n                \"type\": \"linear\",\n                \"factor\": rope_scaling,\n            }\n\n            # Add to kwargs\n            kwargs[\"rope_scaling\"] = rope_scaling\n\n        bnb_config = None\n        if load_in_4bit:\n            llm_int8_skip_modules = SKIP_QUANTIZATION_MODULES.copy()\n            if IS_FALCON_H1:\n                # we cannot quantize out_proj layer due to mamba kernels: https://github.com/tiiuae/Falcon-H1/issues/13#issuecomment-2918671274\n                llm_int8_skip_modules.append(\"out_proj\")\n            bnb_config = BitsAndBytesConfig(\n                load_in_4bit = True,\n                bnb_4bit_use_double_quant = True,\n                bnb_4bit_quant_type = \"nf4\",\n                bnb_4bit_compute_dtype = dtype,\n                llm_int8_skip_modules = llm_int8_skip_modules,\n            )\n\n        # https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/discussions/12\n        # RoPE Scaling's max_position_embeddings must be updated\n        max_position_embeddings = max(max_seq_length, model_max_seq_length)\n        kwargs.pop(\"attn_implementation\", None)  # No need since we auto call it\n\n        # Cannot be None, since HF now checks for the config\n        if load_in_4bit:\n            kwargs[\"quantization_config\"] = bnb_config\n\n        kwargs = add_dtype_kwargs(dtype, kwargs)\n\n        raise_handler = RaiseUninitialized()\n        if num_labels is not None:\n            model = AutoModelForSequenceClassification.from_pretrained(\n                model_name,\n                device_map = device_map,\n                # torch_dtype             = dtype, # transformers changed torch_dtype to dtype\n                num_labels = num_labels,\n                # quantization_config     = bnb_config,\n                token = token,\n                max_position_embeddings = max_position_embeddings,\n                trust_remote_code = trust_remote_code,\n                attn_implementation = preferred_attn_impl,\n                **kwargs,\n            )\n        elif not fast_inference:\n            model = AutoModelForCausalLM.from_pretrained(\n                model_name,\n                device_map = device_map,\n                # torch_dtype             = dtype, # transformers changed torch_dtype to dtype\n                # quantization_config     = bnb_config,\n                token = token,\n                max_position_embeddings = max_position_embeddings,\n                trust_remote_code = trust_remote_code,\n                attn_implementation = preferred_attn_impl,\n                **kwargs,\n            )\n            model.fast_generate = make_fast_generate_wrapper(model.generate)\n            model.fast_generate_batches = None\n        else:\n            from unsloth_zoo.vllm_utils import (\n                load_vllm,\n                get_vllm_state_dict,\n                convert_vllm_to_huggingface,\n                generate_batches,\n            )\n\n            fp8_mode = None\n            if load_in_fp8 != False:\n                fp8_mode = _get_fp8_mode_and_check_settings(\n                    load_in_fp8,\n                    fast_inference,\n                )\n\n            allowed_args = inspect.getfullargspec(load_vllm).args\n            load_vllm_kwargs = dict(\n                model_name = model_name,\n                config = model_config,\n                gpu_memory_utilization = gpu_memory_utilization,\n                max_seq_length = max_seq_length,\n                dtype = dtype,\n                float8_kv_cache = float8_kv_cache,\n                enable_lora = True,\n                max_lora_rank = max_lora_rank,\n                disable_log_stats = disable_log_stats,\n                use_bitsandbytes = load_in_4bit,\n                unsloth_vllm_standby = unsloth_vllm_standby,\n                fp8_mode = fp8_mode,\n            )\n            for allowed_arg in allowed_args:\n                if allowed_arg not in load_vllm_kwargs and allowed_arg in kwargs:\n                    load_vllm_kwargs[allowed_arg] = kwargs[allowed_arg]\n            pass\n\n            # Load vLLM first\n            llm = load_vllm(**load_vllm_kwargs)\n\n            # Convert to HF format\n            _, quant_state_dict = get_vllm_state_dict(\n                llm,\n                config = model_config,\n                load_in_fp8 = load_in_fp8,\n            )\n            model = convert_vllm_to_huggingface(\n                quant_state_dict, model_config, dtype, bnb_config\n            )\n            model.vllm_engine = llm\n            model.fast_generate = model.vllm_engine.generate\n            model.fast_generate_batches = functools.partial(\n                generate_batches, model.vllm_engine\n            )\n        raise_handler.remove()\n        # Return old flag\n        os.environ[\"HF_HUB_ENABLE_HF_TRANSFER\"] = old_hf_transfer\n\n        # Counteract saved tokenizers\n        tokenizer_name = model_name if tokenizer_name is None else tokenizer_name\n        tokenizer = load_correct_tokenizer(\n            tokenizer_name = tokenizer_name,\n            model_max_length = max_position_embeddings,\n            padding_side = \"right\",\n            token = token,\n            trust_remote_code = trust_remote_code,\n            fix_tokenizer = fix_tokenizer,\n        )\n\n        model, tokenizer = patch_tokenizer(model, tokenizer)\n        model, tokenizer = model_patcher.post_patch(\n            model, tokenizer, correct_dtype = dtype\n        )\n\n        # Patch up QKV / O and MLP\n        for idx, layer in enumerate(model.model.layers):\n            layer.self_attn.apply_qkv = original_apply_qkv\n            layer.self_attn.apply_o = original_apply_o\n\n        # Patch Trainer\n        from transformers.trainer import Trainer\n\n        try:\n            if Trainer._inner_training_loop.__name__ != \"_fast_inner_training_loop\":\n                inner_training_loop = inspect.getsource(Trainer._inner_training_loop)\n                Trainer._original_training_loop = inner_training_loop\n            else:\n                inner_training_loop = Trainer._original_training_loop\n        except:\n            raise RuntimeError(\"Unsloth: Unsuccessfully patched inner_training_loop\")\n\n        import transformers.trainer\n\n        items_in_trainer = dir(transformers.trainer)\n        good_items = []\n        for item in items_in_trainer:\n            if item in inner_training_loop:\n                good_items.append(item)\n        exec(\n            \"from transformers.trainer import (\"\n            + \", \".join(x for x in good_items)\n            + \")\",\n            globals(),\n        )\n\n        start = re.search(\n            r\"logger\\.info\\([\\\"\\'].+?Running training\", inner_training_loop\n        ).span(0)[0]\n        end = inner_training_loop.find(\"\\n\\n\", start)\n        original_debug = inner_training_loop[start:end]\n        spaces = re.search(r\"\\n([\\s\\t]{1,})\", original_debug).group(0)[1:]\n        front_spaces = re.match(r\"([\\s\\t]{1,})\", inner_training_loop).group(0)\n\n        # Cannot use \\\\ since it will cause a SyntaxWarning in Python 3.12\n        # Instead use chr(92) == \\\\\n        debug_info = \"\"\"debug_info = \\\\\n        f\"==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = {len(set(p.device for p in model.parameters()))}\\\\n\"\\\\\n        f\"   {chr(92)}{chr(92)}   /|    Num examples = {num_examples:,} | Num Epochs = {num_train_epochs:,} | Total steps = {max_steps:,}\\\\n\"\\\\\n        f\"O^O/ {chr(92)}_/ {chr(92)}    Batch size per device = {self._train_batch_size:,} | Gradient accumulation steps = {args.gradient_accumulation_steps}\\\\n\"\\\\\n        f\"{chr(92)}        /    Data Parallel GPUs = {args.world_size} | Total batch size ({self._train_batch_size} x {args.gradient_accumulation_steps} x {args.world_size}) = {total_train_batch_size:,}\\\\n\"\\\\\n        f' \"-____-\"     Trainable parameters = {get_model_param_count(model, trainable_only=True):,} of {get_model_param_count(model):,} ({get_model_param_count(model, trainable_only=True)/get_model_param_count(model)*100:.2f}% trained)'\n        logger.warning(debug_info)\n        import gc\n        for _ in range(3):\n            gc.collect()\n            if DEVICE_TYPE == \"xpu\":\n                torch.xpu.empty_cache()\n            else:\n                torch.cuda.empty_cache()\"\"\"\n\n        debug_info = debug_info.split(\"\\n\")\n        debug_info = \"\\n\".join(\n            [debug_info[0]] + [spaces + x[8:] for x in debug_info[1:]]\n        )\n        inner_training_loop = inner_training_loop.replace(original_debug, debug_info)\n\n        debug_info = \"\"\"n_total_devices = total_train_batch_size // \\\\\n            args.gradient_accumulation_steps // self._train_batch_size\n        if n_total_devices > 1:\n            logger.warning_once('Unsloth is running with multi GPUs - the effective batch size is multiplied by ' + str(n_total_devices))\n        debug_info =\"\"\"\n        debug_info = debug_info.split(\"\\n\")\n        debug_info = \"\\n\".join(\n            [debug_info[0]] + [spaces + x[8:] for x in debug_info[1:]]\n        )\n        inner_training_loop = inner_training_loop.replace(\"debug_info =\", debug_info, 1)\n\n        front_spaces = re.match(r\"[\\t\\s]{1,}\", inner_training_loop).group(0)\n        inner_training_loop = re.sub(\n            r\"^\" + front_spaces, \"\", inner_training_loop, flags = re.MULTILINE\n        )\n        inner_training_loop = inner_training_loop.replace(\n            \"train_dataloader = tpu_spmd_dataloader(train_dataloader)\",\n            \"raise RuntimeError('Unsloth: TPUs are not yet supported!')\",\n        )\n        inner_training_loop = inner_training_loop.replace(\n            \"_inner_training_loop\",\n            \"_fast_inner_training_loop\",\n            1,\n        )\n        inner_training_loop = inner_training_loop.replace(\n            \"is_torch_tpu_available()\",\n            \"False\",\n        )\n        exec(inner_training_loop, globals())\n        Trainer._inner_training_loop = _fast_inner_training_loop\n\n        # Save max_seq_length\n        model.max_seq_length = max_seq_length\n        m = model\n        while hasattr(m, \"model\"):\n            m.max_seq_length = max_seq_length\n            m = m.model\n        m.max_seq_length = max_seq_length\n        # Save to modules as well\n        for module in model.modules():\n            module.max_seq_length = max_seq_length\n\n        # We check the tokenizer first for errors\n        if fix_tokenizer:\n            tokenizer = check_tokenizer(\n                model = model,\n                tokenizer = tokenizer,\n                model_name = model_name,\n                model_max_length = max_position_embeddings,\n                padding_side = \"right\",\n                token = token,\n            )\n        patch_saving_functions(tokenizer)\n\n        # Fix up config for transformers uploading PEFT\n        # Not necessary anymore since we require transformers>=4.37!\n        if False:\n            name = model.config._name_or_path\n            if name.startswith(\"unsloth/\") and name.endswith(\"-bnb-4bit\"):\n                name = name[: len(name) - len(\"-bnb-4bit\")]\n                model.config.update({\"_name_or_path\": name})\n\n        # Log Unsloth version for future fastpaths for inference\n        model.config.update({\"unsloth_version\": __version__})\n\n        # Add save modules\n        patch_saving_functions(model)\n        Trainer._inner_training_loop = _fast_inner_training_loop\n\n        # Fix gradient accumulation\n        patch_gradient_accumulation_fix(Trainer)\n\n        # Save tokenizer for inference purposes\n        tokenizer.padding_side = \"left\"  # Force inference\n        internal_model = model\n        while hasattr(internal_model, \"model\"):\n            internal_model._saved_temp_tokenizer = tokenizer\n            # Also set is_loaded_in_8bit to disable incorrect DDP\n            internal_model.is_loaded_in_8bit = True\n\n            internal_model = internal_model.model\n        internal_model._saved_temp_tokenizer = tokenizer\n        # Also set is_loaded_in_8bit to disable incorrect DDP\n        internal_model.is_loaded_in_8bit = True\n\n        # For transformers > 4.47.1, we need to add rotary_emb to all attention layers\n        if IS_ATTENTION_REFACTOR or hasattr(model.model, \"rotary_emb\"):\n            rotary_emb = model.model.rotary_emb\n            for layer in model.model.layers:\n                layer.self_attn.rotary_emb = rotary_emb\n\n        # Add for_inference and for_training\n        model.for_training = functools.partial(FastLlamaModel.for_training, model)\n        model.for_inference = functools.partial(FastLlamaModel.for_inference, model)\n        m = model\n        while hasattr(m, \"model\"):\n            m.for_training = functools.partial(FastBaseModel.for_training, m)\n            m.for_inference = functools.partial(FastBaseModel.for_inference, m)\n            m = m.model\n\n        # Patch generate\n        is_classification = \"Classification\" in str(type(model))\n        if not is_classification and model.generate.__name__ != \"unsloth_fast_generate\":\n            model._old_generate = model.generate\n            unsloth_fast_generate.__doc__ = model._old_generate.__doc__\n            model.generate = types.MethodType(unsloth_fast_generate, model)\n        # Set weight[padding_idx] = 0 for embeddings that are NOT tied with the\n        # lm_head. When weights are tied, zeroing the padding row also zeros\n        # the corresponding lm_head row, forcing logit = 0 for the pad token.\n        # This is higher than the (negative) logits for real tokens in models\n        # like Gemma, causing the decoder to emit <pad> and produce gibberish.\n        # Skip entirely if eos_token == pad_token to avoid zeroing EOS embedding.\n        eos_token_id = (\n            getattr(tokenizer, \"eos_token_id\", None) if tokenizer is not None else None\n        )\n        pad_token_id = (\n            getattr(tokenizer, \"pad_token_id\", None) if tokenizer is not None else None\n        )\n        if tokenizer is not None and eos_token_id != pad_token_id:\n            lm_head = getattr(model, \"lm_head\", None)\n            lm_head_weight = (\n                getattr(lm_head, \"weight\", None) if lm_head is not None else None\n            )\n            with torch.no_grad():\n                for name, module in model.named_modules():\n                    if type(module) is torch.nn.Embedding:\n                        if (\n                            getattr(module, \"weight\", None) is not None\n                            and getattr(module, \"padding_idx\", None) is not None\n                        ):\n                            if module.padding_idx < module.weight.shape[0]:\n                                # Skip if tied to lm_head\n                                if (\n                                    lm_head_weight is not None\n                                    and module.weight.data_ptr()\n                                    == lm_head_weight.data_ptr()\n                                ):\n                                    continue\n                                module.weight[module.padding_idx] = 0\n        return model, tokenizer\n\n    @staticmethod\n    def post_patch(model, tokenizer, correct_dtype = None):\n        model, tokenizer = patch_model_and_tokenizer(\n            model, tokenizer, downcast_rope = True, correct_dtype = correct_dtype\n        )\n        return model, tokenizer\n\n    @staticmethod\n    def get_peft_model(\n        model,\n        r = 16,\n        target_modules = [\n            \"q_proj\",\n            \"k_proj\",\n            \"v_proj\",\n            \"o_proj\",\n            \"gate_proj\",\n            \"up_proj\",\n            \"down_proj\",\n        ],\n        lora_alpha = 16,\n        lora_dropout = 0.0,\n        bias = \"none\",\n        layers_to_transform = None,\n        layers_pattern = None,\n        use_gradient_checkpointing = \"unsloth\",\n        random_state = 3407,\n        max_seq_length = 2048,  # not used anymore\n        use_rslora = False,\n        modules_to_save = None,\n        init_lora_weights = True,\n        loftq_config = {},\n        temporary_location = \"_unsloth_temporary_saved_buffers\",\n        qat_scheme = None,\n        target_parameters = None,  # For MoE expert layers (nn.Parameter)\n        ensure_weight_tying = False,\n        **kwargs,\n    ):\n        if os.environ.get(\"UNSLOTH_USE_NEW_MODEL\", \"0\") == \"1\":\n            # Check for other PEFT args in kwargs\n            for peft_arg, flag in (\n                (\"finetune_vision_layers\", False),\n                (\"finetune_language_layers\", True),\n                (\"finetune_attention_modules\", True),\n                (\"finetune_mlp_modules\", True),\n            ):\n                if peft_arg not in kwargs:\n                    kwargs[peft_arg] = flag\n            return FastBaseModel.get_peft_model(\n                model = model,\n                r = r,\n                target_modules = target_modules,\n                lora_alpha = lora_alpha,\n                lora_dropout = lora_dropout,\n                bias = bias,\n                layers_to_transform = layers_to_transform,\n                layers_pattern = layers_pattern,\n                use_gradient_checkpointing = use_gradient_checkpointing,\n                random_state = random_state,\n                max_seq_length = max_seq_length,\n                use_rslora = use_rslora,\n                modules_to_save = modules_to_save,\n                init_lora_weights = init_lora_weights,\n                loftq_config = loftq_config,\n                temporary_location = temporary_location,\n                target_parameters = target_parameters,\n                ensure_weight_tying = ensure_weight_tying,\n                **kwargs,\n            )\n        if os.environ.get(\"UNSLOTH_ENABLE_FULL_FINETUNING\", \"0\") == \"1\":\n            print(\n                \"Unsloth: Full finetuning is enabled, so .get_peft_model has no effect\"\n            )\n            return model\n        transformers_set_seed(random_state)\n\n        # Apply gradient checkpointing with smart heuristics\n        max_seq = getattr(model, \"max_seq_length\", 512)\n        dtype = model.get_input_embeddings().weight.dtype\n        use_gradient_checkpointing = apply_unsloth_gradient_checkpointing(\n            use_gradient_checkpointing, max_seq, dtype\n        )\n\n        if type(r) is not int:\n            raise TypeError(f\"Unsloth: Rank of {str(r)} must be an integer.\")\n        if r <= 0:\n            raise TypeError(f\"Unsloth: Rank of {str(r)} must be larger than 0.\")\n\n        if isinstance(model, PeftModelForCausalLM) or isinstance(\n            model, PeftModelForSequenceClassification\n        ):\n            # Check if exactly the same and then pass through!\n            assert hasattr(model, \"peft_config\")\n\n            peft_config = model.peft_config[\"default\"].to_dict()\n            check_parameters = [\n                \"r\",\n                \"lora_alpha\",\n                \"lora_dropout\",\n                \"bias\",\n                \"layers_to_transform\",\n                \"layers_pattern\",\n                \"use_rslora\",\n                \"init_lora_weights\",\n            ]\n            check_all = True\n            for param in check_parameters:\n                check_all = check_all and (peft_config[param] == eval(param))\n\n            # Check save_modules\n            old_target_modules = list(peft_config[\"target_modules\"])\n            modules_to_save = peft_config[\"modules_to_save\"]\n            if modules_to_save is None:\n                modules_to_save = {}\n            modules_to_save = list(modules_to_save)\n            old_target_modules += modules_to_save\n\n            # Combine all\n            new_target_modules = list(target_modules) + list(\n                modules_to_save if modules_to_save is not None else []\n            )\n\n            # Now check!\n            new_target_modules = set(new_target_modules)\n            check_all = check_all and (\n                len(set(old_target_modules) ^ new_target_modules) == 0\n            )\n\n            check_all = check_all and (\n                (loftq_config == {} or loftq_config is None)\n                and (\n                    peft_config[\"loftq_config\"] == {}\n                    or peft_config[\"loftq_config\"] is None\n                )\n            )\n\n            if check_all:\n                # Simply pass through!\n                logger.warning(\n                    \"Unsloth: Already have LoRA adapters! We shall skip this step.\"\n                )\n\n                # Offload!\n                # [TODO] First offload lm_head and embed_tokens to CPU (should be disk!!)\n                if \"embed_tokens\" in new_target_modules:\n                    print(\n                        \"Unsloth: Training embed_tokens in mixed precision to save VRAM\"\n                    )\n\n                    _offload_frozen_module_for_training(\n                        model.get_input_embeddings(), DEVICE_TYPE_TORCH\n                    )\n\n                if \"lm_head\" in new_target_modules:\n                    print(\"Unsloth: Training lm_head in mixed precision to save VRAM\")\n\n                    _offload_frozen_module_for_training(\n                        model.get_output_embeddings(), DEVICE_TYPE_TORCH\n                    )\n\n                return model\n            else:\n                raise TypeError(\n                    \"Unsloth: Your model already has LoRA adapters. Your new parameters are different.\"\n                )\n\n        if loftq_config is None:\n            loftq_config = {}\n\n        signature = str(inspect.signature(LoraConfig))\n        SUPPORTS_LOFTQ = \"loftq_config\" in signature\n        SUPPORTS_RSLORA = \"use_rslora\" in signature\n\n        if lora_dropout != 0:\n            logger.warning_once(\n                f\"Unsloth: Dropout = 0 is supported for fast patching. You are using dropout = {lora_dropout}.\\n\"\n                f\"Unsloth will patch all other layers, except LoRA matrices, causing a performance hit.\"\n            )\n\n        if bias != \"none\":\n            logger.warning_once(\n                f\"Unsloth: bias = `none` is supported for fast patching. You are using bias = {bias}.\\n\"\n                f\"Unsloth will patch all other layers, except LoRA matrices, causing a performance hit.\"\n            )\n\n        if not (\n            type(init_lora_weights) is bool\n            or init_lora_weights == \"gaussian\"\n            or init_lora_weights == \"loftq\"\n            or init_lora_weights == \"corda\"\n        ):\n            raise ValueError(\n                'Unsloth: `init_lora_weights` must be either [True, False, \"gaussian\", \"loftq\", \"corda\"].'\n            )\n\n        if init_lora_weights == \"loftq\":\n            if not SUPPORTS_LOFTQ:\n                import peft\n\n                raise RuntimeError(\n                    f\"Unsloth: Your PEFT version of {peft.__version__} does not support LoftQ init.\\n\"\n                    \"Please install PEFT 0.7.2 or higher.\\n\"\n                    \"You can also install from source: `pip install git+https://github.com/huggingface/peft.git\"\n                )\n\n            if loftq_config == {}:\n                from peft import LoftQConfig\n\n                logger.warning_once(\n                    \"Unsloth: init_lora_weights = `loftq` is set, but `loftq_config` is None.\\n\"\n                    \"We shall use `loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1)`.\"\n                )\n                loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1)\n\n            if hasattr(model.config, \"quantization_config\"):\n                raise ValueError(\n                    \"Unsloth: You are using `loftq` init, yet `load_in_4bit = True` was set.\\n\"\n                    \"Reload your model without any quantization by setting `load_in_4bit = False`.\"\n                )\n\n        assert type(use_rslora) is bool\n        if use_rslora:\n            if not SUPPORTS_RSLORA:\n                # We manually check for PEFT\n                import peft\n\n                raise RuntimeError(\n                    f\"Unsloth: Your PEFT version of {peft.__version__} does not support `use_rslora`.\\n\"\n                    \"Please install PEFT 0.7.2 or higher.\\n\"\n                    \"You can also install from source: `pip install git+https://github.com/huggingface/peft.git\"\n                )\n\n        accepted_modules = frozenset(\n            (\n                \"lm_head\",\n                \"q_proj\",\n                \"k_proj\",\n                \"v_proj\",\n                \"o_proj\",\n                \"gate_proj\",\n                \"up_proj\",\n                \"down_proj\",\n            ),\n        )\n        model.config.update({\"unsloth_version\": __version__})\n\n        if type(modules_to_save) is tuple:\n            modules_to_save = list(modules_to_save)\n\n        train_lm_head = False\n        train_embed_tokens = False\n        final_modules = []\n        for module in target_modules:\n            if module == \"embed_tokens\":\n                # logger.warning_once(\n                #     \"Unsloth: `embed_tokens` should be placed in `modules_to_save` and not `target_modules`. \"\\\n                #     \"Luckily, we shall do it for you!\"\n                # )\n                train_embed_tokens = True\n                if modules_to_save is None:\n                    modules_to_save = [\"embed_tokens\"]\n                else:\n                    modules_to_save.append(\"embed_tokens\")\n\n            else:\n                try:\n                    assert module in accepted_modules\n                    final_modules.append(module)\n                except AssertionError as e:\n                    final_modules.append(module)\n                    print(\n                        \"Unsloth: You added custom modules, but Unsloth hasn't optimized for this.\\n\"\n                        \"Beware - your finetuning might be noticeably slower!\"\n                    )\n                pass\n\n        # Check if we added new tokens!\n        if hasattr(model, \"_need_to_train_embeddings\"):\n            # Check if embed_tokens/lm_head are already being trained\n            # (either as LoRA targets in final_modules or via modules_to_save)\n            _embed_already_trained = (\n                train_embed_tokens or \"embed_tokens\" in final_modules\n            )\n            _lm_head_already_trained = train_lm_head or \"lm_head\" in final_modules\n            if not _lm_head_already_trained or not _embed_already_trained:\n                print(\n                    \"Unsloth: You added new tokens but did not specify if you wanted to \"\n                    \"train the lm_head and embed_tokens.\\nWe must turn it on for you.\"\n                )\n\n                # Only add to modules_to_save if not already a LoRA target\n                if not _embed_already_trained:\n                    train_embed_tokens = True\n                    if modules_to_save is None:\n                        modules_to_save = [\"embed_tokens\"]\n                    elif \"embed_tokens\" not in modules_to_save:\n                        modules_to_save.append(\"embed_tokens\")\n\n                if not _lm_head_already_trained:\n                    train_lm_head = True\n                    if modules_to_save is None:\n                        modules_to_save = [\"lm_head\"]\n                    elif \"lm_head\" not in modules_to_save:\n                        modules_to_save.append(\"lm_head\")\n\n        # Check for Llama-3\n        # if hasattr(model._saved_temp_tokenizer, \"_using_llama3_template\"):\n        #     if not train_embed_tokens and not train_lm_head:\n        #         raise RuntimeError(\"\")\n\n        # First fix untrained tokens\n        # Wrong - can cause reserved tokens to pop out!!\n        # if train_embed_tokens or train_lm_head:\n        #     fix_untrained_tokens(model, eps = 1e-16)\n        # pass\n\n        # Check modules_to_save\n        if modules_to_save is not None:\n            for module in modules_to_save:\n                if module == \"lm_head\":\n                    train_lm_head = True\n                elif module == \"embed_tokens\":\n                    train_embed_tokens = True\n                else:\n                    raise TypeError(\n                        f\"Unsloth: Module = {module} is not allowed. Only 'lm_head' and 'embed_tokens' is allowed.\"\n                    )\n        if isinstance(modules_to_save, (tuple, list)):\n            modules_to_save = list(set(modules_to_save))\n\n        vllm_engine = None\n        if hasattr(model, \"vllm_engine\"):\n            # Fast inference!\n            vllm_engine = model.vllm_engine\n            vllm_fast_generate = model.fast_generate\n            vllm_fast_generate_batches = model.fast_generate_batches\n\n            if modules_to_save is not None:\n                raise NotImplementedError(\n                    \"Unsloth: Currently fast inference does not work with training embeddings or lm_head.\"\n                )\n\n            if bias != \"none\":\n                raise NotImplementedError(\n                    \"Unsloth: Currently fast inference does not work with using biases for LoRA.\"\n                )\n\n        # Does not get lora yet, so get name from model, not base model\n        is_classification = \"Classification\" in str(type(model))\n\n        # Auto-detect MoE models and populate target_parameters for expert layers\n        if target_parameters is None:\n            target_parameters = get_moe_target_parameters(model, target_modules)\n\n        arguments = dict(\n            r = r,\n            lora_alpha = lora_alpha,\n            target_modules = final_modules,\n            lora_dropout = lora_dropout,\n            bias = bias,\n            task_type = TaskType.CAUSAL_LM if not is_classification else TaskType.SEQ_CLS,\n            layers_to_transform = layers_to_transform,\n            init_lora_weights = init_lora_weights,\n            loftq_config = loftq_config,\n            use_rslora = use_rslora,\n            modules_to_save = modules_to_save,\n            target_parameters = target_parameters,\n            ensure_weight_tying = ensure_weight_tying,\n            **kwargs,\n        )\n        if not SUPPORTS_LOFTQ:\n            del arguments[\"loftq_config\"]\n        if not SUPPORTS_RSLORA:\n            del arguments[\"use_rslora\"]\n\n        _saved_temp_tokenizer = model._saved_temp_tokenizer\n\n        lora_config = LoraConfig(**arguments)\n        # First offload lm_head and embed_tokens to disk\n        input_embeddings_device = model.get_input_embeddings().weight.device\n        if is_classification:\n            output_embeddings_device = model.score.weight.device\n        else:\n            output_embeddings_device = model.get_output_embeddings().weight.device\n\n        if use_gradient_checkpointing == \"unsloth\":\n            if train_embed_tokens:\n                print(\"Unsloth: Offloading input_embeddings to disk to save VRAM\")\n                offload_input_embeddings(model, temporary_location)\n\n            # Remove old items to save VRAM\n            for _ in range(3):\n                gc.collect()\n                clean_gpu_cache()\n\n            if train_lm_head:\n                print(\"Unsloth: Offloading output_embeddings to disk to save VRAM\")\n                offload_output_embeddings(model, temporary_location)\n\n            # Remove old items to save VRAM\n            for _ in range(3):\n                gc.collect()\n                clean_gpu_cache()\n\n        model = _get_peft_model(model, lora_config)\n        # Fix LoraConfig.auto_mapping is None\n        fix_lora_auto_mapping(model)\n\n        # Apply QAT + LoRA if specified\n        if qat_scheme is not None:\n            print(\"Unsloth: Applying QAT to mitigate quantization degradation\")\n            model = FastLlamaModel._prepare_for_qat(model, qat_scheme)\n\n        model._saved_temp_tokenizer = _saved_temp_tokenizer\n\n        model = FastLlamaModel.patch_peft_model(model, use_gradient_checkpointing)\n\n        if ensure_weight_tying:\n            try:\n                input_embeddings = model.get_input_embeddings()\n                output_embeddings = model.get_output_embeddings()\n\n                if input_embeddings is not None and output_embeddings is not None:\n\n                    def _retie_parameter(target_module, source_module):\n                        if not hasattr(source_module, \"weight\"):\n                            return\n                        weight = source_module.weight\n                        # Remove existing registration to avoid \"attribute already exists\"\n                        if \"weight\" in getattr(target_module, \"_parameters\", {}):\n                            target_module._parameters.pop(\"weight\")\n                        if hasattr(target_module, \"weight\"):\n                            try:\n                                delattr(target_module, \"weight\")\n                            except Exception as exc:\n                                logger.warning_once(\n                                    f\"Unsloth: Could not delete existing weight attr during retie on \"\n                                    f\"{type(target_module).__name__}: {exc}\"\n                                )\n                        target_module.register_parameter(\"weight\", weight)\n\n                    # Tie trainable copies created by ModulesToSaveWrapper first (these are used in forward)\n                    if hasattr(input_embeddings, \"modules_to_save\") and hasattr(\n                        output_embeddings, \"modules_to_save\"\n                    ):\n                        if hasattr(\n                            input_embeddings.modules_to_save, \"default\"\n                        ) and hasattr(output_embeddings.modules_to_save, \"default\"):\n                            _retie_parameter(\n                                output_embeddings.modules_to_save.default,\n                                input_embeddings.modules_to_save.default,\n                            )\n\n                    # Tie original_module references as well if present\n                    if hasattr(input_embeddings, \"original_module\") and hasattr(\n                        output_embeddings, \"original_module\"\n                    ):\n                        _retie_parameter(\n                            output_embeddings.original_module,\n                            input_embeddings.original_module,\n                        )\n            except Exception as e:\n                logger.warning_once(\n                    f\"Unsloth: Failed to ensure weight tying between embeddings and lm_head: {e}\"\n                )\n\n        if train_embed_tokens:\n            print(\"Unsloth: Training embed_tokens in mixed precision to save VRAM\")\n            assert hasattr(model.get_input_embeddings(), \"modules_to_save\")\n\n            _offload_frozen_module_for_training(\n                model.get_input_embeddings(), DEVICE_TYPE_TORCH, offload_device = None\n            )\n\n        if train_lm_head:\n            print(\"Unsloth: Training lm_head in mixed precision to save VRAM\")\n            assert hasattr(model.get_output_embeddings(), \"modules_to_save\")\n\n            _offload_frozen_module_for_training(\n                model.get_output_embeddings(), DEVICE_TYPE_TORCH, offload_device = None\n            )\n\n        # Patch tokenizer to pad to the right\n        internal_model = model\n        while hasattr(internal_model, \"model\"):\n            if hasattr(internal_model, \"_saved_temp_tokenizer\"):\n                internal_model._saved_temp_tokenizer.padding_side = \"right\"\n            # Also set is_loaded_in_8bit to disable incorrect DDP\n            internal_model.is_loaded_in_8bit = True\n            internal_model = internal_model.model\n        if hasattr(internal_model, \"_saved_temp_tokenizer\"):\n            internal_model._saved_temp_tokenizer.padding_side = \"right\"\n        # Also set is_loaded_in_8bit to disable incorrect DDP\n        internal_model.is_loaded_in_8bit = True\n\n        # Clear deleted GPU items\n        for _ in range(3):\n            gc.collect()\n            clean_gpu_cache()\n\n        patch_peft_fast_inference(model)\n\n        # Add for_inference and for_training\n        model.for_training = functools.partial(FastLlamaModel.for_training, model)\n        model.for_inference = functools.partial(FastLlamaModel.for_inference, model)\n        m = model\n        while hasattr(m, \"model\"):\n            m.for_training = functools.partial(FastBaseModel.for_training, m)\n            m.for_inference = functools.partial(FastBaseModel.for_inference, m)\n            m = m.model\n        return model\n\n    @staticmethod\n    def patch_peft_model(\n        model,\n        use_gradient_checkpointing = \"unsloth\",\n    ):\n        if os.environ.get(\"UNSLOTH_USE_NEW_MODEL\", \"0\") == \"1\":\n            return FastBaseModel.patch_peft_model(\n                model = model,\n                use_gradient_checkpointing = use_gradient_checkpointing,\n            )\n        if not isinstance(model, PeftModelForCausalLM) and not isinstance(\n            model, PeftModelForSequenceClassification\n        ):\n            raise TypeError(\n                \"Unsloth: Your model needs to call `.get_peft_model` first!\"\n            )\n\n        # Get activation function\n        model_type = model.config.model_type\n\n        if model_type == \"llama\":\n            apply_lora_mlp = apply_lora_mlp_swiglu\n        elif model_type == \"mistral\":\n            apply_lora_mlp = apply_lora_mlp_swiglu\n        elif model_type == \"qwen2\":\n            apply_lora_mlp = apply_lora_mlp_swiglu\n        elif model_type == \"gemma\":\n            apply_lora_mlp = apply_lora_mlp_geglu_approx\n        elif model_type == \"gemma2\":\n            apply_lora_mlp = apply_lora_mlp_geglu_approx\n        elif model_type == \"cohere\":\n            apply_lora_mlp = apply_lora_mlp_swiglu\n        elif model_type == \"granite\":\n            apply_lora_mlp = apply_lora_mlp_swiglu\n        elif model_type == \"qwen3\":\n            apply_lora_mlp = apply_lora_mlp_swiglu\n        elif model_type == \"falcon_h1\":\n            apply_lora_mlp = apply_lora_mlp_swiglu\n        elif model_type == \"qwen3moe\":\n            apply_lora_mlp = apply_lora_mlp_swiglu\n        else:\n            raise NotImplementedError(f\"Unsloth: {model_type} is not yet implemented!\")\n\n        model = prepare_model_for_kbit_training(\n            model,\n            use_gradient_checkpointing = use_gradient_checkpointing,\n            use_reentrant = True,\n        )\n\n        # Fix up config for transformers uploading PEFT\n        for active_adapter in model.peft_config.keys():\n            # Not necessary since we requires transformers >= 4.37\n            if False:\n                name = model.peft_config[active_adapter].base_model_name_or_path\n                if name.startswith(\"unsloth/\") and name.endswith(\"-bnb-4bit\"):\n                    name = name[: len(name) - len(\"-bnb-4bit\")]\n                    model.peft_config[active_adapter].base_model_name_or_path = name\n                pass\n            # Add revision to enable future fast inference paths\n            # [TODO] Bugs out!see https://github.com/unslothai/unsloth/issues/492\n            # model.peft_config[active_adapter].revision = f\"unsloth\"\n\n        from transformers.trainer import Trainer\n\n        if Trainer._inner_training_loop.__name__ != \"_fast_inner_training_loop\":\n            raise RuntimeError(\n                \"Unsloth: Unsuccessfully patched Trainer! Please file a bug report!\"\n            )\n\n        # Fix loftq issues\n        # loftq_config must not = None, but rather {}\n        all_configs = model.peft_config\n        for key, current_config in all_configs.items():\n            if (\n                hasattr(current_config, \"loftq_config\")\n                and current_config.loftq_config is None\n            ):\n                new_args = current_config.__dict__\n                new_args[\"loftq_config\"] = {}\n                current_config = current_config.__class__(**new_args)\n                all_configs[key] = current_config\n\n        # Do patching\n        n_mlp = 0\n        n_qkv = 0\n        n_o = 0\n\n        active_adapter = (\n            model.active_adapters[0]\n            if hasattr(model, \"active_adapters\")\n            else model.active_adapter\n        )\n\n        # Get dropout and bias\n        lora_dropout = model.peft_config[active_adapter].lora_dropout\n        bias = model.peft_config[active_adapter].bias\n\n        # We also do not inplace edit QKV for Cohere!\n        _apply_lora_mlp = (\n            functools.partial(apply_lora_mlp, inplace = False)\n            if model_type == \"cohere\"\n            else apply_lora_mlp\n        )\n\n        if lora_dropout == 0 and bias == \"none\":\n            for idx, layer in enumerate(model.model.model.layers):\n                if model_type != \"falcon_h1\":\n                    # LoRAMLP.apply doesn't have functionality for gate and down multipliers yet.\n                    # Don't patch falcon h1 for the time being.\n\n                    # MLP patching\n                    mlp_module = layer.mlp\n                    gate_proj = mlp_module.gate_proj\n                    up_proj = mlp_module.up_proj\n                    down_proj = mlp_module.down_proj\n\n                    if (\n                        hasattr(gate_proj, \"lora_A\")\n                        and hasattr(up_proj, \"lora_A\")\n                        and hasattr(down_proj, \"lora_A\")\n                        and (getattr(gate_proj, \"base_layer\", gate_proj).bias is None)\n                        and (getattr(up_proj, \"base_layer\", up_proj).bias is None)\n                        and (getattr(down_proj, \"base_layer\", down_proj).bias is None)\n                        and (\n                            len(getattr(gate_proj, \"lora_magnitude_vector\", []) or [])\n                            == 0\n                        )\n                        and (\n                            len(getattr(up_proj, \"lora_magnitude_vector\", []) or [])\n                            == 0\n                        )\n                        and (\n                            len(getattr(down_proj, \"lora_magnitude_vector\", []) or [])\n                            == 0\n                        )\n                    ):\n                        # https://stackoverflow.com/questions/50599045/python-replacing-a-function-within-a-class-of-a-module\n                        if hasattr(mlp_module, \"_unsloth_forward\"):\n                            # then we've patched the mlp to use TiledMLP\n                            mlp_module._unsloth_forward = types.MethodType(\n                                _apply_lora_mlp, mlp_module\n                            )\n                        else:\n                            mlp_module.forward = types.MethodType(\n                                _apply_lora_mlp, mlp_module\n                            )\n                        n_mlp += 1\n                    else:\n                        logger.warning_once(\n                            \"Not an error, but Unsloth cannot patch MLP layers with our manual autograd engine since either LoRA adapters\\n\"\n                            \"are not enabled or a bias term (like in Qwen) is used.\"\n                        )\n\n                # QKV attention patching\n                q_proj = layer.self_attn.q_proj\n                k_proj = layer.self_attn.k_proj\n                v_proj = layer.self_attn.v_proj\n                if (\n                    hasattr(q_proj, \"lora_A\")\n                    and hasattr(k_proj, \"lora_A\")\n                    and hasattr(v_proj, \"lora_A\")\n                    and (getattr(q_proj, \"base_layer\", q_proj).bias is None)\n                    and (getattr(k_proj, \"base_layer\", k_proj).bias is None)\n                    and (getattr(v_proj, \"base_layer\", v_proj).bias is None)\n                    and (len(getattr(q_proj, \"lora_magnitude_vector\", []) or []) == 0)\n                    and (len(getattr(k_proj, \"lora_magnitude_vector\", []) or []) == 0)\n                    and (len(getattr(v_proj, \"lora_magnitude_vector\", []) or []) == 0)\n                ):\n                    layer.self_attn.apply_qkv = apply_lora_qkv\n                    n_qkv += 1\n                else:\n                    if model_type == \"qwen2\":\n                        n_qkv += 1\n                    else:\n                        logger.warning_once(\n                            \"Not an error, but Unsloth cannot patch Attention layers with our manual autograd engine since either LoRA adapters\\n\"\n                            \"are not enabled or a bias term (like in Qwen) is used.\"\n                        )\n\n                # O attention patching\n                o_proj = layer.self_attn.o_proj\n                if (\n                    hasattr(o_proj, \"lora_A\")\n                    and (getattr(o_proj, \"base_layer\", o_proj).bias is None)\n                    and (len(getattr(o_proj, \"lora_magnitude_vector\", []) or []) == 0)\n                ):\n                    layer.self_attn.apply_o = apply_lora_o\n                    n_o += 1\n                else:\n                    logger.warning_once(\n                        \"Not an error, but Unsloth cannot patch O projection layer with our manual autograd engine since either LoRA adapters\\n\"\n                        \"are not enabled or a bias term (like in Qwen) is used.\"\n                    )\n\n        logger.warning_once(\n            f\"Unsloth {__version__} patched {len(model.model.model.layers)} layers with \"\n            f\"{n_qkv} QKV layers, {n_o} O layers and {n_mlp} MLP layers.\",\n        )\n        patch_saving_functions(model)\n\n        # Patch cross entropy loss labels\n        # Fixes https://github.com/unslothai/unsloth/issues/10\n        max_seq_length = model.max_seq_length\n        # extra_ignored_labels = torch.full((max_seq_length, 1), -100, device = \"cuda:0\")\n        # model.model.extra_ignored_labels = extra_ignored_labels\n        internal_model = model\n        while hasattr(internal_model, \"model\"):\n            internal_model.max_seq_length = max_seq_length\n            internal_model = internal_model.model\n        internal_model.max_seq_length = max_seq_length\n        # Save to modules as well\n        for module in model.modules():\n            module.max_seq_length = max_seq_length\n\n        # Patch tokenizer to pad to the right\n        internal_model = model\n        while hasattr(internal_model, \"model\"):\n            if hasattr(internal_model, \"_saved_temp_tokenizer\"):\n                internal_model._saved_temp_tokenizer.padding_side = \"right\"\n            internal_model = internal_model.model\n        if hasattr(internal_model, \"_saved_temp_tokenizer\"):\n            internal_model._saved_temp_tokenizer.padding_side = \"right\"\n\n        # Clear deleted GPU items\n        for _ in range(3):\n            gc.collect()\n            clean_gpu_cache()\n\n        patch_peft_fast_inference(model)\n\n        # Add for_inference and for_training\n        model.for_training = functools.partial(FastLlamaModel.for_training, model)\n        model.for_inference = functools.partial(FastLlamaModel.for_inference, model)\n        m = model\n        while hasattr(m, \"model\"):\n            m.for_training = functools.partial(FastBaseModel.for_training, m)\n            m.for_inference = functools.partial(FastBaseModel.for_inference, m)\n            m = m.model\n        return model\n\n    @staticmethod\n    def for_inference(model):\n        if not hasattr(model, \"parameters\"):\n            raise TypeError(\n                \"Unsloth: I think you're passing a tokenizer, not the model to for_inference!\"\n            )\n\n        def _for_inference(m):\n            if hasattr(m, \"gradient_checkpointing\"):\n                m.gradient_checkpointing = False\n            if hasattr(m, \"training\"):\n                m.training = False\n            # Pad tokenizer to the left\n            if hasattr(m, \"_saved_temp_tokenizer\"):\n                m._saved_temp_tokenizer.padding_side = \"left\"\n            # Set a flag for generation!\n            m._flag_for_generation = True\n\n        m = model\n        while hasattr(m, \"model\"):\n            _for_inference(m)\n            m = m.model\n        _for_inference(m)\n        model.eval()  # to turn off training on modules deeper in\n\n        # Since transformers 4.53, must turn off explicitly\n        for module in model.modules():\n            if hasattr(module, \"gradient_checkpointing\"):\n                module.gradient_checkpointing = False\n\n        # Also disable training for embeddings for NEFTune\n        if hasattr(model, \"get_input_embeddings\"):\n            embeddings = model.get_input_embeddings()\n            if hasattr(embeddings, \"training\"):\n                embeddings.training = False\n        if hasattr(model, \"get_output_embeddings\"):\n            embeddings = model.get_output_embeddings()\n            if hasattr(embeddings, \"training\"):\n                embeddings.training = False\n        return model\n\n    @staticmethod\n    def for_training(model, use_gradient_checkpointing = True):\n        if not hasattr(model, \"parameters\"):\n            raise TypeError(\n                \"Unsloth: I think you're passing a tokenizer, not the model to for_training!\"\n            )\n\n        # Delete all fast inference loras\n        for param in model.parameters():\n            if hasattr(param, \"_fast_lora\"):\n                del param._fast_lora\n\n        def _for_training(m):\n            if hasattr(m, \"gradient_checkpointing\"):\n                m.gradient_checkpointing = use_gradient_checkpointing\n            if hasattr(m, \"training\"):\n                m.training = True\n            # Pad tokenizer to the left\n            if hasattr(m, \"_saved_temp_tokenizer\"):\n                m._saved_temp_tokenizer.padding_side = \"right\"\n            # Set a flag for generation!\n            if hasattr(m, \"_flag_for_generation\"):\n                del m._flag_for_generation\n\n        m = model\n        while hasattr(m, \"model\"):\n            _for_training(m)\n            m = m.model\n        _for_training(m)\n        model.train()  # to turn on training on modules deeper in\n\n        # Since transformers 4.53, must turn on explicitly\n        for module in model.modules():\n            if hasattr(module, \"gradient_checkpointing\"):\n                module.gradient_checkpointing = use_gradient_checkpointing\n\n        # Also re-enable training for embeddings for NEFTune\n        if hasattr(model, \"get_input_embeddings\"):\n            embeddings = model.get_input_embeddings()\n            if hasattr(embeddings, \"training\"):\n                embeddings.training = True\n        if hasattr(model, \"get_output_embeddings\"):\n            embeddings = model.get_output_embeddings()\n            if hasattr(embeddings, \"training\"):\n                embeddings.training = True\n        return model\n\n\nfrom .rl import PatchFastRL\n\nPatchFastRL(FastLanguageModel = FastLlamaModel)\n"
  },
  {
    "path": "unsloth/models/llama4.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# from unsloth_studio.models import patch_llama4\n# patch_llama4()\n"
  },
  {
    "path": "unsloth/models/loader.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom ._utils import (\n    _prepare_model_for_qat,\n    is_bfloat16_supported,\n    is_vLLM_available,\n    HAS_FLASH_ATTENTION,\n    HAS_FLASH_ATTENTION_SOFTCAPPING,\n    USE_MODELSCOPE,\n    get_transformers_model_type,\n    hf_login,\n)\nfrom .granite import FastGraniteModel\nfrom .llama import FastLlamaModel, logger\nfrom .mistral import FastMistralModel\nfrom .qwen2 import FastQwen2Model\nfrom .qwen3 import FastQwen3Model\nfrom .qwen3_moe import FastQwen3MoeModel\nfrom .cohere import FastCohereModel\nfrom transformers import AutoConfig\nfrom transformers import __version__ as transformers_version\nfrom peft import PeftConfig, PeftModel\nfrom .loader_utils import (\n    _get_fp8_mode_and_check_settings,\n    _offline_quantize_to_fp8,\n    _tag_model_with_fp8_torchao_config,\n    get_model_name,\n    prepare_device_map,\n)\nimport os, contextlib, sys\n\ntry:\n    from huggingface_hub import get_token\nexcept:\n    try:\n        from huggingface_hub.utils import get_token\n    except:\n        # For older versions of huggingface_hub\n        from huggingface_hub.utils._token import get_token\nfrom huggingface_hub import HfFileSystem\nimport importlib.util\nfrom ..device_type import (\n    is_hip,\n    get_device_type,\n    DEVICE_TYPE,\n    DEVICE_TYPE_TORCH,\n    DEVICE_COUNT,\n    ALLOW_PREQUANTIZED_MODELS,\n    ALLOW_BITSANDBYTES,\n)\n\n# https://github.com/huggingface/transformers/pull/26037 allows 4 bit loading!\nfrom unsloth_zoo.utils import Version, _get_dtype\nfrom unsloth_zoo.hf_utils import dtype_from_config\nfrom unsloth_zoo.tiled_mlp import patch_tiled_mlp\n\ntransformers_version = Version(transformers_version)\nSUPPORTS_FOURBIT = transformers_version >= Version(\"4.37\")\nSUPPORTS_GEMMA = transformers_version >= Version(\"4.38\")\nSUPPORTS_GEMMA2 = transformers_version >= Version(\"4.42\")\nSUPPORTS_LLAMA31 = transformers_version >= Version(\"4.43.2\")\nSUPPORTS_LLAMA32 = transformers_version > Version(\"4.45.0\")\nSUPPORTS_GRANITE = transformers_version >= Version(\"4.46.0\")\nSUPPORTS_QWEN3 = transformers_version >= Version(\"4.50.3\")\nSUPPORTS_QWEN3_MOE = transformers_version >= Version(\"4.50.3\")\nSUPPORTS_FALCON_H1 = transformers_version >= Version(\"4.53.0\")\nSUPPORTS_GEMMA3N = transformers_version >= Version(\"4.53.0\")\nSUPPORTS_GPTOSS = transformers_version >= Version(\"4.55.0\")\n# Transformers v5 meta-device loading corrupts non-persistent buffers (inv_freq).\n# See _fix_rope_inv_freq() below for details.\n_NEEDS_ROPE_FIX = transformers_version >= Version(\"5.0.0\")\nif SUPPORTS_GEMMA:\n    from .gemma import FastGemmaModel\nif SUPPORTS_GEMMA2:\n    from .gemma2 import FastGemma2Model\nif SUPPORTS_FALCON_H1:\n    from .falcon_h1 import FastFalconH1Model\nimport torch\nfrom ._utils import (\n    patch_compiling_bitsandbytes,\n    patch_model_and_tokenizer,\n    prepare_model_for_kbit_training,\n    apply_unsloth_gradient_checkpointing,\n    patch_compiled_autograd,\n    process_vision_info,\n    unsloth_compile_transformers,\n    fast_inference_setup,\n)\n\nglobal FORCE_FLOAT32\n# Forces float32 precision since float16 goes to infinity\nFORCE_FLOAT32 = [\n    \"gemma3,\",  # Add comma bc gemma3 will match gemma3n\n    \"gemma3text\",  # Gemma3TextModel (EmbeddingGemma, standalone text-only Gemma3)\n    \"gemma3n\",\n    \"gpt_oss\",\n    \"qwen3_5\",  # Qwen3.5 GDN layers produce NaN grad norms in float16 training\n]\n\nglobal DISABLE_COMPILE_MODEL_NAMES\n# Must be alphabetically sorted for each entry\nDISABLE_COMPILE_MODEL_NAMES = [\n    \"aya_vision\",\n    \"modernbert\",\n    \"granite,llava_next\",  # Granite-vision 3\n]\n\nglobal DISABLE_SDPA_MODEL_NAMES\n# Disables some SDPA modules since it's wrong\nDISABLE_SDPA_MODEL_NAMES = [\n    \"gemma3,\",  # Add comma bc gemma3 will match gemma3n\n    \"gemma3_text\",  # Gemma3TextModel (EmbeddingGemma) - substring match, keep underscore\n]\n\n\ndef _fix_rope_inv_freq(model):\n    \"\"\"Fix inv_freq corruption caused by transformers v5 meta-device loading.\n\n    Transformers v5 initializes models on the meta device, then\n    _move_missing_keys_from_meta_to_device() (modeling_utils.py) replaces ALL\n    non-persistent buffers with torch.empty_like() -- uninitialized memory.\n\n    Vanilla transformers restores inv_freq via _init_weights() which checks for\n    hasattr(module, \"original_inv_freq\"). Unsloth's LlamaRotaryEmbedding and\n    subclasses do not have this attribute, so inv_freq stays corrupted. This\n    produces wrong positional encodings and causes 5-11x higher training loss.\n\n    This function recomputes inv_freq from the stored base and dim, applies\n    any model-specific scaling, and rebuilds the cos/sin caches.\n\n    Only runs on transformers >= 5.0.0. No-op on v4.\n    \"\"\"\n    if not _NEEDS_ROPE_FIX:\n        return model\n\n    for name, module in model.named_modules():\n        # Unsloth's LlamaRotaryEmbedding and subclasses (Extended, LinearScaling,\n        # Granite). Native v5 rotary classes (Gemma3, etc.) have original_inv_freq\n        # which v5's _init_weights() uses to restore inv_freq, so they are fine.\n        if (\n            hasattr(module, \"inv_freq\")\n            and hasattr(module, \"base\")\n            and hasattr(module, \"dim\")\n            and hasattr(module, \"_apply_inv_freq_scaling\")\n            and hasattr(module, \"multi_gpu_cos_cached\")\n        ):\n            inv_freq = 1.0 / (\n                module.base\n                ** (\n                    torch.arange(\n                        0, module.dim, 2, dtype = torch.int64, device = \"cpu\"\n                    ).float()\n                    / module.dim\n                )\n            )\n            inv_freq = module._apply_inv_freq_scaling(inv_freq)\n            module.inv_freq = inv_freq\n            for device_idx in range(len(module.multi_gpu_cos_cached)):\n                if module.multi_gpu_cos_cached[device_idx] is not None:\n                    module._set_cos_sin_cache(\n                        seq_len = module.current_rope_size,\n                        device = torch.device(device_idx),\n                        dtype = torch.get_default_dtype(),\n                    )\n\n        # LongRopeRotaryEmbedding (Phi-3.5 style with short_inv_freq + long_inv_freq)\n        elif (\n            hasattr(module, \"short_inv_freq\")\n            and hasattr(module, \"long_inv_freq\")\n            and hasattr(module, \"base\")\n            and hasattr(module, \"dim\")\n        ):\n            config = getattr(model, \"config\", None)\n            rope_scaling = getattr(config, \"rope_scaling\", None) if config else None\n            if rope_scaling is not None:\n                short_factor = rope_scaling.get(\"short_factor\", None)\n                long_factor = rope_scaling.get(\"long_factor\", None)\n                if short_factor is not None and long_factor is not None:\n                    inv_freq_shape = (\n                        torch.arange(\n                            0, module.dim, 2, dtype = torch.int64, device = \"cpu\"\n                        ).float()\n                        / module.dim\n                    )\n                    sf = torch.tensor(short_factor, device = \"cpu\", dtype = torch.float32)\n                    lf = torch.tensor(long_factor, device = \"cpu\", dtype = torch.float32)\n                    module.short_inv_freq = 1.0 / (sf * module.base**inv_freq_shape)\n                    module.long_inv_freq = 1.0 / (lf * module.base**inv_freq_shape)\n\n                    dtype = torch.bfloat16 if is_bfloat16_supported() else torch.float16\n                    t = torch.arange(\n                        module.original_max_position_embeddings,\n                        device = module.short_inv_freq.device,\n                        dtype = torch.int64,\n                    ).float()\n                    freqs = torch.outer(t, module.short_inv_freq)\n                    emb = torch.cat((freqs, freqs), dim = -1)\n                    for device_idx in range(len(module.multi_gpu_short_cos_cached)):\n                        if module.multi_gpu_short_cos_cached[device_idx] is not None:\n                            device_obj = torch.device(device_idx)\n                            module.multi_gpu_short_cos_cached[device_idx] = (\n                                emb.cos() * module.scaling_factor\n                            ).to(dtype = dtype, device = device_obj, non_blocking = True)\n                            module.multi_gpu_short_sin_cached[device_idx] = (\n                                emb.sin() * module.scaling_factor\n                            ).to(dtype = dtype, device = device_obj, non_blocking = True)\n    return model\n\n\nclass FastLanguageModel(FastLlamaModel):\n    @staticmethod\n    def from_pretrained(\n        model_name = \"unsloth/Llama-3.2-1B-Instruct\",\n        max_seq_length = 2048,\n        dtype = None,\n        load_in_4bit = True,  # 4bit QLoRA\n        load_in_8bit = False,  # 8bit  LoRA\n        load_in_16bit = False,  # 16bit LoRA\n        full_finetuning = False,\n        token = None,\n        device_map = \"sequential\",\n        rope_scaling = None,\n        fix_tokenizer = True,\n        trust_remote_code = False,\n        use_gradient_checkpointing = \"unsloth\",\n        resize_model_vocab = None,\n        revision = None,\n        use_exact_model_name = False,\n        offload_embedding = False,\n        float32_mixed_precision = None,  # Forces float32 mixed precision\n        fast_inference = False,  # uses vLLM\n        gpu_memory_utilization = 0.5,\n        float8_kv_cache = False,\n        random_state = 3407,\n        max_lora_rank = 64,\n        disable_log_stats = True,\n        qat_scheme = None,\n        load_in_fp8 = False,  # fp8 LoRA (True, False, 'block')\n        unsloth_tiled_mlp = False,\n        *args,\n        **kwargs,\n    ):\n        # Respect user-provided quantization_config (e.g. BitsAndBytesConfig)\n        quantization_config = kwargs.get(\"quantization_config\", None)\n        if quantization_config is not None:\n            if isinstance(quantization_config, dict):\n                q_load_in_4bit = quantization_config.get(\"load_in_4bit\", False)\n                q_load_in_8bit = quantization_config.get(\"load_in_8bit\", False)\n            else:\n                q_load_in_4bit = getattr(quantization_config, \"load_in_4bit\", False)\n                q_load_in_8bit = getattr(quantization_config, \"load_in_8bit\", False)\n            if q_load_in_4bit:\n                load_in_4bit = True\n                load_in_8bit = False\n            if q_load_in_8bit:\n                load_in_8bit = True\n                load_in_4bit = False\n\n        # Login to allow private models\n        token = hf_login(token)\n        # Align dtype with bnb_4bit_compute_dtype if provided and dtype is unset.\n        if dtype is None and quantization_config is not None:\n            bnb_compute_dtype = None\n            if isinstance(quantization_config, dict):\n                if quantization_config.get(\"load_in_4bit\", False):\n                    bnb_compute_dtype = quantization_config.get(\n                        \"bnb_4bit_compute_dtype\", None\n                    )\n            else:\n                if getattr(quantization_config, \"load_in_4bit\", False):\n                    bnb_compute_dtype = getattr(\n                        quantization_config, \"bnb_4bit_compute_dtype\", None\n                    )\n            if isinstance(bnb_compute_dtype, str):\n                bnb_compute_dtype = getattr(torch, bnb_compute_dtype, None)\n            if isinstance(bnb_compute_dtype, torch.dtype):\n                dtype = bnb_compute_dtype\n\n        # Distributed-safe device placement for quantized models.\n        # In multi-GPU (torchrun), each rank must load the model on its own device\n        # to avoid Accelerate device relocation errors with quantized weights.\n        is_quantized = load_in_4bit or load_in_8bit or load_in_fp8\n        if is_quantized and isinstance(device_map, str):\n            distributed_device_map, is_dist = prepare_device_map()\n            if is_dist:\n                device_map = distributed_device_map\n\n        if load_in_8bit or full_finetuning or qat_scheme is not None:\n            return FastModel.from_pretrained(\n                model_name = model_name,\n                max_seq_length = max_seq_length,\n                dtype = dtype,\n                load_in_4bit = load_in_4bit,\n                load_in_8bit = load_in_8bit,\n                load_in_16bit = load_in_16bit,\n                full_finetuning = full_finetuning,\n                token = token,\n                device_map = device_map,\n                rope_scaling = rope_scaling,  # [TODO] No effect\n                fix_tokenizer = fix_tokenizer,  # [TODO] No effect\n                trust_remote_code = trust_remote_code,\n                use_gradient_checkpointing = use_gradient_checkpointing,\n                resize_model_vocab = resize_model_vocab,  # [TODO] No effect\n                revision = revision,\n                return_logits = False,  # Return logits\n                fullgraph = True,  # No graph breaks\n                use_exact_model_name = use_exact_model_name,\n                offload_embedding = offload_embedding,\n                float32_mixed_precision = float32_mixed_precision,\n                # Pass vLLM/inference parameters\n                fast_inference = fast_inference,\n                gpu_memory_utilization = gpu_memory_utilization,\n                float8_kv_cache = float8_kv_cache,\n                random_state = random_state,\n                max_lora_rank = max_lora_rank,\n                disable_log_stats = disable_log_stats,\n                qat_scheme = qat_scheme,\n                load_in_fp8 = load_in_fp8,\n                unsloth_tiled_mlp = unsloth_tiled_mlp,\n                *args,\n                **kwargs,\n            )\n\n        if isinstance(dtype, str) and dtype in [\"float16\", \"bfloat16\"]:\n            dtype = getattr(torch, dtype)\n        assert (\n            dtype is None\n            or dtype == torch.float16\n            or dtype == torch.bfloat16\n            or dtype == torch.float32\n        )\n\n        if fast_inference:\n            if importlib.util.find_spec(\"vllm\") is None:\n                raise ImportError(\n                    \"Unsloth: Please install vLLM before enabling `fast_inference`!\\n\"\n                    \"You can do this in a terminal via `pip install vllm`\"\n                )\n            if DEVICE_TYPE_TORCH == \"cuda\":\n                for i in range(DEVICE_COUNT):\n                    # [TODO] DGX Spark vLLM breaks\n                    if \"NVIDIA GB10\" in str(torch.cuda.get_device_name(i)).upper():\n                        print(\n                            \"Unsloth: DGX Spark detected - `fast_inference=True` is currently broken as of January 2026.\\n\"\n                            \"Defaulting to native Unsloth inference.\"\n                        )\n                        fast_inference = False\n                        break\n\n        # Check if 4bit is allowed specifically for AMD\n        if not ALLOW_BITSANDBYTES and not use_exact_model_name:\n            if load_in_4bit or load_in_8bit or model_name.lower().endswith(\"-bnb-4bit\"):\n                print(\n                    \"Unsloth: AMD currently is not stable with 4bit bitsandbytes. Disabling for now.\"\n                )\n            load_in_4bit = False\n\n        # Find FP8, BnB 4bit, other mapped names\n        old_model_name = model_name\n        fp8_mode = None\n        if not use_exact_model_name:\n            new_model_name = get_model_name(\n                model_name,\n                load_in_4bit = load_in_4bit,\n                load_in_fp8 = load_in_fp8,\n                token = token,\n                trust_remote_code = trust_remote_code,\n            )\n            if new_model_name is None and load_in_fp8 != False:\n                fp8_mode = _get_fp8_mode_and_check_settings(\n                    load_in_fp8,\n                    fast_inference,\n                    full_finetuning,\n                    load_in_4bit,\n                    load_in_8bit,\n                    load_in_16bit,\n                )\n                model_name = _offline_quantize_to_fp8(model_name, fp8_mode)\n            else:\n                assert new_model_name is not None\n                model_name = new_model_name\n                # If mapper resolved to a pre-quantized FP8 model, disable\n                # on-the-fly quantization to avoid double quantization\n                if load_in_fp8 != False and new_model_name != old_model_name:\n                    load_in_fp8 = False\n\n        # Check if pre-quantized models are allowed\n        # AMD Instinct GPUs need blocksize = 128 on bitsandbytes < 0.49.2 (our pre-quants use blocksize = 64)\n        if not ALLOW_PREQUANTIZED_MODELS and model_name.lower().endswith(\n            (\"-unsloth-bnb-4bit\", \"-bnb-4bit\")\n        ):\n            model_name = model_name.lower().removesuffix(\"-unsloth-bnb-4bit\")\n            model_name = model_name.lower().removesuffix(\"-bnb-4bit\")\n        # Change -BF16 to all False for 4bit, 8bit etc\n        if model_name.lower().endswith(\"-bf16\"):\n            load_in_4bit = False\n            load_in_8bit = False\n            load_in_fp8 = False\n            load_in_16bit = True\n\n        if USE_MODELSCOPE and not os.path.exists(model_name):\n            from modelscope import snapshot_download\n\n            model_name = snapshot_download(model_name)\n\n        # First check if it's a normal model via AutoConfig\n        from huggingface_hub.utils import (\n            disable_progress_bars,\n            enable_progress_bars,\n            are_progress_bars_disabled,\n        )\n\n        was_disabled = are_progress_bars_disabled()\n        disable_progress_bars()\n\n        autoconfig_error = None\n        peft_error = None\n        model_config = None\n        peft_config = None\n        try:\n            model_config = AutoConfig.from_pretrained(\n                model_name,\n                token = token,\n                revision = revision,\n                trust_remote_code = trust_remote_code,\n            )\n            is_model = True\n        except ImportError:\n            raise\n        except Exception as error:\n            autoconfig_error = str(error)\n            if \"architecture\" in autoconfig_error:\n                if \"qwen3_5\" in autoconfig_error:\n                    raise ImportError(\n                        f\"Unsloth: Your transformers version of {transformers_version} does not support Qwen3.5.\\n\"\n                        f\"The minimum required version is 5.2.0.\\n\"\n                        f'Try `pip install --upgrade \"transformers>=5.2.0\"`\\n'\n                        f\"to obtain the latest transformers build, then restart this session.\"\n                    )\n                raise ValueError(\n                    f\"`{model_name}` is not supported yet in `transformers=={transformers_version}`.\\n\"\n                    f\"Please update transformers via `pip install --upgrade transformers` and try again.\"\n                )\n            is_model = False\n        try:\n            peft_config = PeftConfig.from_pretrained(\n                model_name,\n                token = token,\n                revision = revision,\n                trust_remote_code = trust_remote_code,\n            )\n            is_peft = True\n        except ImportError:\n            raise\n        except Exception as error:\n            peft_error = str(error)\n            if \"architecture\" in peft_error:\n                raise ValueError(\n                    f\"`{model_name}` is not supported yet in `transformers=={transformers_version}`.\\n\"\n                    f\"Please update transformers via `pip install --upgrade transformers` and try again.\"\n                )\n            is_peft = False\n\n        # Old transformers versions check\n        both_exist = (is_model and is_peft) and not SUPPORTS_LLAMA32\n\n        # Error out if both LoRA and normal model config exists.\n        if both_exist:\n            raise RuntimeError(\n                \"Unsloth: Your repo has a LoRA adapter and a base model.\\n\"\n                \"You have 2 files `config.json` and `adapter_config.json`.\\n\"\n                \"We must only allow one config file.\\n\"\n                \"Please separate the LoRA and base models to 2 repos.\"\n            )\n        model_types = get_transformers_model_type(\n            peft_config if peft_config is not None else model_config,\n            trust_remote_code = trust_remote_code,\n        )\n        if len(model_types) == 1:\n            model_type = model_types[0]\n        else:\n            # Leave as tuple if more than one arch\n            model_type = model_types\n\n        # New transformers need to check manually.\n        if SUPPORTS_LLAMA32:\n            # Check if folder exists locally\n            if os.path.isdir(model_name):\n                exist_adapter_config = os.path.exists(\n                    os.path.join(model_name, \"adapter_config.json\")\n                )\n                exist_config = os.path.exists(os.path.join(model_name, \"config.json\"))\n                both_exist = exist_adapter_config and exist_config\n            else:\n                # Because HfFileSystem assumes linux paths, we need to set the path with forward slashes, even on Windows.\n                files = HfFileSystem(token = token).glob(f\"{model_name}/*.json\")\n                files = list(os.path.split(x)[-1] for x in files)\n                if (\n                    sum(x == \"adapter_config.json\" or x == \"config.json\" for x in files)\n                    >= 2\n                ):\n                    both_exist = True\n\n        if not is_model and not is_peft:\n            error = autoconfig_error if autoconfig_error is not None else peft_error\n            # Old transformers version\n            if \"rope_scaling\" in error.lower() and not SUPPORTS_LLAMA31:\n                raise ImportError(\n                    f\"Unsloth: Your transformers version of {transformers_version} does not support new RoPE scaling methods.\\n\"\n                    f\"This includes Llama 3.1. The minimum required version is 4.43.2\\n\"\n                    f'Try `pip install --upgrade \"transformers>=4.43.2\"`\\n'\n                    f\"to obtain the latest transformers build, then restart this session.\"\n                )\n            # Create a combined error message showing both failures\n            combined_error = (\n                \"Unsloth: Failed to load model. Both AutoConfig and PeftConfig loading failed.\\n\\n\"\n                f\"AutoConfig error: {autoconfig_error}\\n\\n\"\n                f\"PeftConfig error: {peft_error}\\n\\n\"\n            )\n            raise RuntimeError(combined_error)\n\n        # Get base model for PEFT:\n        if is_peft:\n            # Check base model again for PEFT\n            model_name = peft_config.base_model_name_or_path\n            if not use_exact_model_name:\n                model_name = get_model_name(\n                    model_name,\n                    load_in_4bit = load_in_4bit,\n                    load_in_fp8 = load_in_fp8,\n                    token = token,\n                    trust_remote_code = trust_remote_code,\n                )\n            # Check if pre-quantized models are allowed\n            # AMD Instinct GPUs need blocksize = 128 on bitsandbytes < 0.49.2 (our pre-quants use blocksize = 64)\n            if not ALLOW_PREQUANTIZED_MODELS and model_name.lower().endswith(\n                (\"-unsloth-bnb-4bit\", \"-bnb-4bit\")\n            ):\n                model_name = model_name.lower().removesuffix(\"-unsloth-bnb-4bit\")\n                model_name = model_name.lower().removesuffix(\"-bnb-4bit\")\n            # Change -BF16 to all False for 4bit, 8bit etc\n            if model_name.lower().endswith(\"-bf16\"):\n                load_in_4bit = False\n                load_in_8bit = False\n                load_in_fp8 = False\n                load_in_16bit = True\n\n            model_config = AutoConfig.from_pretrained(\n                model_name,\n                token = token,\n                trust_remote_code = trust_remote_code,\n            )\n\n        if not was_disabled:\n            enable_progress_bars()\n\n        if model_type == \"llama\":\n            scaling_type = None\n            if getattr(model_config, \"rope_scaling\", None) is not None:\n                scaling_type1 = model_config.rope_scaling.get(\"type\", None)\n                scaling_type2 = model_config.rope_scaling.get(\"rope_type\", None)\n                scaling_type = (\n                    scaling_type1 if scaling_type1 is not None else scaling_type2\n                )\n\n            if scaling_type == \"llama3\" and not SUPPORTS_LLAMA31:\n                raise ImportError(\n                    f\"Unsloth: Your transformers version of {transformers_version} does not support Llama 3.1.\\n\"\n                    f\"The minimum required version is 4.43.2\\n\"\n                    f'Try `pip install --upgrade \"transformers>=4.43.2\"`\\n'\n                    f\"to obtain the latest transformers build, then restart this session.\"\n                )\n\n            dispatch_model = FastLlamaModel\n\n        elif model_type == \"mistral\":\n            dispatch_model = FastMistralModel\n        elif model_type == \"gemma\":\n            if not SUPPORTS_GEMMA:\n                raise ImportError(\n                    f\"Unsloth: Your transformers version of {transformers_version} does not support Gemma.\\n\"\n                    f\"The minimum required version is 4.38.\\n\"\n                    f'Try `pip install --upgrade \"transformers>=4.38\"`\\n'\n                    f\"to obtain the latest transformers build, then restart this session.\"\n                )\n            dispatch_model = FastGemmaModel\n        elif model_type == \"gemma2\":\n            if not SUPPORTS_GEMMA2:\n                raise ImportError(\n                    f\"Unsloth: Your transformers version of {transformers_version} does not support Gemma2.\\n\"\n                    f\"The minimum required version is 4.42.3.\\n\"\n                    f'Try `pip install --upgrade \"transformers>=4.42.3\"`\\n'\n                    f\"to obtain the latest transformers build, then restart this session.\"\n                )\n            # Also check for softcapping support in flash-attn which is faster!\n            if is_bfloat16_supported() and not HAS_FLASH_ATTENTION:\n                print(\n                    \"Unsloth: If you want to finetune Gemma 2, install flash-attn to make it faster!\\n\"\n                    \"To install flash-attn, do the below:\\n\"\n                    '\\npip install --no-deps --upgrade \"flash-attn>=2.6.3\"'\n                )\n            elif HAS_FLASH_ATTENTION and not HAS_FLASH_ATTENTION_SOFTCAPPING:\n                print(\n                    \"Unsloth: If you want to finetune Gemma 2, upgrade flash-attn to version 2.6.3 or higher!\\n\"\n                    \"Newer versions support faster and less memory usage kernels for Gemma 2's attention softcapping!\\n\"\n                    \"To update flash-attn, do the below:\\n\"\n                    '\\npip install --no-deps --upgrade \"flash-attn>=2.6.3\"'\n                )\n\n            dispatch_model = FastGemma2Model\n        elif model_type == \"qwen2\":\n            dispatch_model = FastQwen2Model\n        elif model_type == \"qwen3\":  # or model_type == \"qwen3_moe\":\n            if not SUPPORTS_QWEN3 or not SUPPORTS_QWEN3_MOE:\n                raise ImportError(\n                    f\"Unsloth: Your transformers version of {transformers_version} does not support Qwen3.\\n\"\n                    f\"The minimum required version is 4.50.3.\\n\"\n                    f'Try `pip install --upgrade \"transformers>=4.50.3\"`\\n'\n                    f\"to obtain the latest transformers build, then restart this session.\"\n                )\n            dispatch_model = (\n                FastQwen3Model if model_type == \"qwen3\" else FastQwen3MoeModel\n            )\n        # elif model_type == \"falcon_h1\":\n        #     dispatch_model = FastFalconH1Model\n        #     if not SUPPORTS_FALCON_H1:\n        #         raise ImportError(\n        #             f\"Unsloth: Your transformers version of {transformers_version} does not support FalconH1.\\n\"\\\n        #             f\"The minimum required version is 4.50.3.\\n\"\\\n        #             f'Try `pip install --upgrade \"transformers>=4.50.3\"`\\n'\\\n        #             f\"to obtain the latest transformers build, then restart this session.\"\\\n        #         )\n        # Temporary disable optimized Cohere until errors match\n        # elif model_type == \"cohere\":\n        #     dispatch_model = FastCohereModel\n        # Temporary disable optimized Granite until errors match\n        # elif model_type == \"granite\":\n        #     dispatch_model = FastGraniteModel\n        else:\n            return FastModel.from_pretrained(\n                model_name = old_model_name,\n                max_seq_length = max_seq_length,\n                dtype = dtype,\n                load_in_4bit = load_in_4bit,\n                load_in_8bit = load_in_8bit,\n                load_in_16bit = load_in_16bit,\n                full_finetuning = full_finetuning,\n                token = token,\n                device_map = device_map,\n                rope_scaling = rope_scaling,  # [TODO] No effect\n                fix_tokenizer = fix_tokenizer,  # [TODO] No effect\n                trust_remote_code = trust_remote_code,\n                use_gradient_checkpointing = use_gradient_checkpointing,\n                resize_model_vocab = resize_model_vocab,  # [TODO] No effect\n                revision = revision,\n                return_logits = False,  # Return logits\n                fullgraph = True,  # No graph breaks\n                use_exact_model_name = use_exact_model_name,\n                offload_embedding = offload_embedding,\n                float32_mixed_precision = float32_mixed_precision,\n                # Pass vLLM/inference parameters\n                fast_inference = fast_inference,\n                gpu_memory_utilization = gpu_memory_utilization,\n                float8_kv_cache = float8_kv_cache,\n                random_state = random_state,\n                max_lora_rank = max_lora_rank,\n                disable_log_stats = disable_log_stats,\n                qat_scheme = qat_scheme,\n                load_in_fp8 = load_in_fp8,\n                unsloth_tiled_mlp = unsloth_tiled_mlp,\n                *args,\n                **kwargs,\n            )\n\n        # Apply gradient checkpointing with smart heuristics\n        use_gradient_checkpointing = apply_unsloth_gradient_checkpointing(\n            use_gradient_checkpointing, max_seq_length, dtype\n        )\n\n        # Check if this is local model since the tokenizer gets overwritten\n        if (\n            os.path.exists(os.path.join(old_model_name, \"tokenizer_config.json\"))\n            and os.path.exists(os.path.join(old_model_name, \"tokenizer.json\"))\n            and os.path.exists(os.path.join(old_model_name, \"special_tokens_map.json\"))\n        ):\n            tokenizer_name = old_model_name\n        else:\n            tokenizer_name = kwargs.pop(\"tokenizer_name\", None)\n\n        if fast_inference:\n            fast_inference, model_name = fast_inference_setup(model_name, model_config)\n\n        load_in_4bit_kwargs = load_in_4bit\n        load_in_8bit_kwargs = load_in_8bit\n        if quantization_config is not None and not fast_inference:\n            load_in_4bit_kwargs = False\n            load_in_8bit_kwargs = False\n\n        model, tokenizer = dispatch_model.from_pretrained(\n            model_name = model_name,\n            max_seq_length = max_seq_length,\n            dtype = _get_dtype(dtype),\n            load_in_4bit = load_in_4bit_kwargs,\n            token = token,\n            device_map = device_map,\n            rope_scaling = rope_scaling,\n            fix_tokenizer = fix_tokenizer,\n            model_patcher = dispatch_model,\n            tokenizer_name = tokenizer_name,\n            trust_remote_code = trust_remote_code,\n            revision = revision if not is_peft else None,\n            fast_inference = fast_inference,\n            gpu_memory_utilization = gpu_memory_utilization,\n            float8_kv_cache = float8_kv_cache,\n            random_state = random_state,\n            max_lora_rank = max_lora_rank,\n            disable_log_stats = disable_log_stats,\n            load_in_fp8 = load_in_fp8,\n            *args,\n            **kwargs,\n        )\n\n        if resize_model_vocab is not None:\n            model.resize_token_embeddings(resize_model_vocab)\n\n        # In case the model supports tagging, add the unsloth tag.\n        if hasattr(model, \"add_model_tags\"):\n            model.add_model_tags(\n                [\n                    \"unsloth\",\n                ]\n            )\n        if hasattr(tokenizer, \"add_model_tags\"):\n            tokenizer.add_model_tags(\n                [\n                    \"unsloth\",\n                ]\n            )\n\n        if load_in_4bit:\n            # Fix up bitsandbytes config, but respect user-provided quantization_config\n            if quantization_config is None:\n                compute_dtype = dtype_from_config(model.config)\n                quantization_config = {\n                    # Sometimes compute_dtype is not a string!!\n                    \"bnb_4bit_compute_dtype\": compute_dtype,\n                    \"bnb_4bit_quant_type\": \"nf4\",\n                    \"bnb_4bit_use_double_quant\": True,\n                    \"llm_int8_enable_fp32_cpu_offload\": False,\n                    \"llm_int8_has_fp16_weight\": False,\n                    \"llm_int8_skip_modules\": None,\n                    \"llm_int8_threshold\": 6.0,\n                    \"load_in_4bit\": True,\n                    \"load_in_8bit\": False,\n                    \"quant_method\": \"bitsandbytes\",\n                }\n                model.config.update({\"quantization_config\": quantization_config})\n            else:\n                if hasattr(quantization_config, \"to_dict\"):\n                    model.config.update(\n                        {\"quantization_config\": quantization_config.to_dict()}\n                    )\n                elif isinstance(quantization_config, dict):\n                    model.config.update({\"quantization_config\": quantization_config})\n\n        if load_in_fp8 != False:\n            _tag_model_with_fp8_torchao_config(model, fp8_mode)\n\n        if is_peft:\n            # From https://github.com/huggingface/peft/issues/184\n            # Now add PEFT adapters\n            model = PeftModel.from_pretrained(\n                model,\n                old_model_name,\n                token = token,\n                revision = revision,\n                is_trainable = True,\n                trust_remote_code = trust_remote_code,\n            )\n            # Patch it as well!\n            model = dispatch_model.patch_peft_model(model, use_gradient_checkpointing)\n\n        # Patch Tiled MLP\n        # to turn on set UNSLOTH_TILED_MLP to \"arctic\", \"target\", or \"target:{GB}\"\"\n        patch_tiled_mlp_choice = os.environ.get(\n            \"UNSLOTH_TILED_MLP\", \"arctic\" if unsloth_tiled_mlp else \"0\"\n        )\n        if patch_tiled_mlp_choice != \"0\" or unsloth_tiled_mlp:\n            patch_tiled_mlp(model, patch_options_str = patch_tiled_mlp_choice)\n\n        model = _fix_rope_inv_freq(model)\n        return model, tokenizer\n\n\nfrom ..kernels import (\n    patch_loss_functions,\n    post_patch_loss_function,\n)\nfrom .vision import FastBaseModel\nfrom transformers import (\n    AutoModelForCausalLM,\n)\n\ntry:\n    from transformers import AutoModelForImageTextToText\n\n    AutoModelForVision2Seq = AutoModelForImageTextToText\nexcept:\n    from transformers import AutoModelForVision2Seq\n\n\nclass FastModel(FastBaseModel):\n    @staticmethod\n    def _prepare_for_qat(model, qat_scheme):\n        model = _prepare_model_for_qat(model, qat_scheme)\n        return model\n\n    @staticmethod\n    def from_pretrained(\n        model_name = \"unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit\",\n        max_seq_length = 2048,\n        dtype = None,\n        load_in_4bit = True,  # 4bit QLoRA\n        load_in_8bit = False,  # 8bit  LoRA\n        load_in_16bit = False,  # 16bit LoRA\n        full_finetuning = False,\n        token = None,\n        device_map = \"sequential\",\n        rope_scaling = None,  # [TODO] No effect\n        fix_tokenizer = True,  # [TODO] No effect\n        trust_remote_code = False,\n        use_gradient_checkpointing = \"unsloth\",\n        resize_model_vocab = None,  # [TODO] No effect\n        revision = None,\n        return_logits = False,  # Return logits\n        fullgraph = True,  # No graph breaks\n        use_exact_model_name = False,\n        auto_model = None,\n        whisper_language = None,\n        whisper_task = None,\n        unsloth_force_compile = False,\n        offload_embedding = False,\n        float32_mixed_precision = None,  # Forces float32 mixed precision\n        # Add the missing vLLM/inference parameters\n        fast_inference = False,  # uses vLLM\n        gpu_memory_utilization = 0.5,\n        float8_kv_cache = False,\n        random_state = 3407,\n        max_lora_rank = 64,\n        disable_log_stats = True,\n        qat_scheme = None,\n        load_in_fp8 = False,  # fp8 LoRA (True, False, 'block')\n        unsloth_tiled_mlp = False,\n        target_parameters = None,  # For MoE expert parameters\n        *args,\n        **kwargs,\n    ):\n        # Respect user-provided quantization_config (e.g. BitsAndBytesConfig)\n        quantization_config = kwargs.get(\"quantization_config\", None)\n        if quantization_config is not None:\n            if isinstance(quantization_config, dict):\n                q_load_in_4bit = quantization_config.get(\"load_in_4bit\", False)\n                q_load_in_8bit = quantization_config.get(\"load_in_8bit\", False)\n            else:\n                q_load_in_4bit = getattr(quantization_config, \"load_in_4bit\", False)\n                q_load_in_8bit = getattr(quantization_config, \"load_in_8bit\", False)\n            if q_load_in_4bit:\n                load_in_4bit = True\n                load_in_8bit = False\n            if q_load_in_8bit:\n                load_in_8bit = True\n                load_in_4bit = False\n\n        # Login to allow private models\n        token = hf_login(token)\n        if whisper_language is not None:\n            assert type(whisper_language) is str\n        if whisper_task is not None:\n            assert type(whisper_task) is str\n        # Align dtype with bnb_4bit_compute_dtype if provided and dtype is unset.\n        if dtype is None and quantization_config is not None:\n            bnb_compute_dtype = None\n            if isinstance(quantization_config, dict):\n                if quantization_config.get(\"load_in_4bit\", False):\n                    bnb_compute_dtype = quantization_config.get(\n                        \"bnb_4bit_compute_dtype\", None\n                    )\n            else:\n                if getattr(quantization_config, \"load_in_4bit\", False):\n                    bnb_compute_dtype = getattr(\n                        quantization_config, \"bnb_4bit_compute_dtype\", None\n                    )\n            if isinstance(bnb_compute_dtype, str):\n                bnb_compute_dtype = getattr(torch, bnb_compute_dtype, None)\n            if isinstance(bnb_compute_dtype, torch.dtype):\n                dtype = bnb_compute_dtype\n        SUPPORTS_BFLOAT16 = is_bfloat16_supported()\n        if dtype is None:\n            dtype = torch.float16 if not SUPPORTS_BFLOAT16 else torch.bfloat16\n        elif dtype == torch.bfloat16 and not SUPPORTS_BFLOAT16:\n            logger.warning_once(\n                \"Device does not support bfloat16. Will change to float16.\"\n            )\n            dtype = torch.float16\n        assert dtype in (torch.float16, torch.bfloat16, torch.float32)\n        assert load_in_fp8 in (True, False, \"block\")\n\n        patch_compiled_autograd()\n        patch_compiling_bitsandbytes()\n\n        if full_finetuning and (load_in_4bit or load_in_8bit):\n            print(\n                \"Unsloth: You selected full finetuning support, but 4bit / 8bit is enabled - disabling LoRA / QLoRA.\"\n            )\n            load_in_4bit = False\n            load_in_8bit = False\n            load_in_fp8 = False\n            load_in_16bit = False\n\n        if (\n            int(load_in_4bit)\n            + int(load_in_8bit)\n            + int(load_in_16bit)\n            + int(load_in_fp8 != False)\n            >= 2\n        ):\n            raise RuntimeError(\n                \"Unsloth: Can only load in 4bit or 8bit or 16bit, not a combination!\\n\"\n                \"Also, we by default set `load_in_4bit = True`.\\n\"\n                \"If you want 8bit finetuning, set both `load_in_4bit = False` and `load_in_8bit = True`\\n\"\n                \"If you want 16bit LoRA finetuning, set `load_in_16bit = True`\"\n            )\n\n        if qat_scheme is not None and not full_finetuning:\n            raise ValueError(\n                \"Specifying `qat_scheme` in `FastLanguageModel.from_pretrained(...)` is only \"\n                \"compatible with `full_finetuning=True`. If you wish to use QAT with LoRA, \"\n                \"please pass in `qat_scheme` in `FastLanguageModel.get_peft_model(...)` instead.\"\n            )\n        if qat_scheme == \"phone-deployment\":\n            qat_scheme = \"int8-int4\"\n\n        # Distributed-safe device placement for quantized models.\n        # In multi-GPU (torchrun), each rank must load the model on its own device\n        # to avoid Accelerate device relocation errors with quantized weights.\n        is_quantized = load_in_4bit or load_in_8bit or load_in_fp8\n        if is_quantized and isinstance(device_map, str):\n            distributed_device_map, is_dist = prepare_device_map()\n            if is_dist:\n                device_map = distributed_device_map\n\n        # Check if 4bit is allowed specifically for AMD\n        if not ALLOW_BITSANDBYTES and not use_exact_model_name:\n            if load_in_4bit or load_in_8bit or model_name.lower().endswith(\"-bnb-4bit\"):\n                print(\n                    \"Unsloth: AMD currently is not stable with 4bit bitsandbytes. Disabling for now.\"\n                )\n            load_in_4bit = False\n\n        if fast_inference:\n            if importlib.util.find_spec(\"vllm\") is None:\n                raise ImportError(\n                    \"Unsloth: Please install vLLM before enabling `fast_inference`!\\n\"\n                    \"You can do this in a terminal via `pip install vllm`\"\n                )\n            if DEVICE_TYPE_TORCH == \"cuda\":\n                for i in range(DEVICE_COUNT):\n                    # [TODO] DGX Spark vLLM breaks\n                    if \"NVIDIA GB10\" in str(torch.cuda.get_device_name(i)).upper():\n                        print(\n                            \"Unsloth: DGX Spark detected - `fast_inference=True` is currently broken as of January 2026.\\n\"\n                            \"Defaulting to native Unsloth inference.\"\n                        )\n                        fast_inference = False\n                        break\n\n        # Find FP8, BnB 4bit, other mapped names\n        old_model_name = model_name\n        fp8_mode = None\n        if not use_exact_model_name:\n            new_model_name = get_model_name(\n                model_name, load_in_4bit = load_in_4bit, load_in_fp8 = load_in_fp8\n            )\n            if new_model_name is None and load_in_fp8 != False:\n                fp8_mode = _get_fp8_mode_and_check_settings(\n                    load_in_fp8,\n                    fast_inference,\n                    full_finetuning,\n                    load_in_4bit,\n                    load_in_8bit,\n                    load_in_16bit,\n                )\n                model_name = _offline_quantize_to_fp8(model_name, fp8_mode)\n            else:\n                assert new_model_name is not None\n                model_name = new_model_name\n                # If mapper resolved to a pre-quantized FP8 model, disable\n                # on-the-fly quantization to avoid double quantization\n                if load_in_fp8 != False and new_model_name != old_model_name:\n                    load_in_fp8 = False\n\n        # Check if pre-quantized models are allowed\n        # AMD Instinct GPUs need blocksize = 128 on bitsandbytes < 0.49.2 (our pre-quants use blocksize = 64)\n        if not ALLOW_PREQUANTIZED_MODELS and model_name.lower().endswith(\n            (\"-unsloth-bnb-4bit\", \"-bnb-4bit\")\n        ):\n            model_name = model_name.lower().removesuffix(\"-unsloth-bnb-4bit\")\n            model_name = model_name.lower().removesuffix(\"-bnb-4bit\")\n        # Change -BF16 to all False for 4bit, 8bit etc\n        if model_name.lower().endswith(\"-bf16\"):\n            load_in_4bit = False\n            load_in_8bit = False\n            load_in_fp8 = False\n            load_in_16bit = True\n\n        # Check modelscope\n        if USE_MODELSCOPE and not os.path.exists(model_name):\n            from modelscope import snapshot_download\n\n            model_name = snapshot_download(model_name)\n\n        # First check if it's a normal model via AutoConfig\n        from huggingface_hub.utils import (\n            disable_progress_bars,\n            enable_progress_bars,\n            are_progress_bars_disabled,\n        )\n\n        was_disabled = are_progress_bars_disabled()\n        disable_progress_bars()\n\n        autoconfig_error = None\n        peft_error = None\n        model_config = None\n        peft_config = None\n        try:\n            model_config = AutoConfig.from_pretrained(\n                model_name,\n                token = token,\n                revision = revision,\n                trust_remote_code = trust_remote_code,\n            )\n            is_model = True\n        except ImportError:\n            raise\n        except Exception as error:\n            autoconfig_error = str(error)\n            if \"architecture\" in autoconfig_error:\n                if \"qwen3_5\" in autoconfig_error:\n                    raise ImportError(\n                        f\"Unsloth: Your transformers version of {transformers_version} does not support Qwen3.5.\\n\"\n                        f\"The minimum required version is 5.2.0.\\n\"\n                        f'Try `pip install --upgrade \"transformers>=5.2.0\"`\\n'\n                        f\"to obtain the latest transformers build, then restart this session.\"\n                    )\n                raise ValueError(\n                    f\"`{model_name}` is not supported yet in `transformers=={transformers_version}`.\\n\"\n                    f\"Please update transformers via `pip install --upgrade transformers` and try again.\"\n                )\n            is_model = False\n        try:\n            peft_config = PeftConfig.from_pretrained(\n                model_name,\n                token = token,\n                revision = revision,\n                trust_remote_code = trust_remote_code,\n            )\n            is_peft = True\n        except ImportError:\n            raise\n        except Exception as error:\n            peft_error = str(error)\n            if \"architecture\" in peft_error:\n                raise ValueError(\n                    f\"`{model_name}` is not supported yet in `transformers=={transformers_version}`.\\n\"\n                    f\"Please update transformers via `pip install --upgrade transformers` and try again.\"\n                )\n            is_peft = False\n        # Old transformers versions check\n        both_exist = (is_model and is_peft) and not SUPPORTS_LLAMA32\n        # Error out if both LoRA and normal model config exists.\n        if both_exist:\n            raise RuntimeError(\n                \"Unsloth: Your repo has a LoRA adapter and a base model.\\n\"\n                \"You have 2 files `config.json` and `adapter_config.json`.\\n\"\n                \"We must only allow one config file.\\n\"\n                \"Please separate the LoRA and base models to 2 repos.\"\n            )\n        model_types = get_transformers_model_type(\n            peft_config if peft_config is not None else model_config,\n            trust_remote_code = trust_remote_code,\n        )\n        model_types_all = \",\".join(model_types) + \",\"\n\n        # Save model types and loading method\n        lowered_model_name = model_name.lower()\n        string = os.environ.get(\"UNSLOTH_MODEL_NAME\", \"\") + model_types_all\n        if load_in_4bit:\n            string += \"_load_in_4bit_\"\n        if load_in_8bit:\n            string += \"_load_in_8bit_\"\n        if load_in_16bit:\n            string += \"_load_in_16bit_\"\n        if load_in_fp8:\n            string += \"load_in_fp8\"\n        os.environ[\"UNSLOTH_MODEL_NAME\"] = string\n\n        # Check versions\n        LATEST = \"\\nPlease use transformers via `pip install --no-deps git+https://github.com/huggingface/transformers.git`\"\n        NIGHTLY = '\\nPlease use nightly transformers via pip install --upgrade \"transformers>=4.49.0\"`'\n        # Pixtral\n        if \"pixtral\" in model_types_all and transformers_version < Version(\"4.49.0\"):\n            raise RuntimeError(\n                \"Unsloth: Pixtral only works on transformers >= 4.49.0.\" + LATEST\n            )\n        # Qwen 2.5\n        elif \"qwen2_5\" in model_types_all and transformers_version < Version(\"4.49.0\"):\n            raise RuntimeError(\n                \"Unsloth: Qwen 2.5 only works on transformers >= 4.49.0.\" + LATEST\n            )\n        # Gemma 3N must be before Gemma 3\n        elif \"gemma3n\" in model_types_all:\n            if transformers_version < Version(\"4.53.0\"):\n                raise RuntimeError(\n                    \"Unsloth: Gemma 3N only works on transformers >= 4.53.0\" + LATEST\n                )\n            os.environ[\"UNSLOTH_DISABLE_STATIC_GENERATION\"] = \"1\"\n            os.environ[\"UNSLOTH_FORCE_CUSTOM_DTYPE\"] = (\n                \"float16;torch.float16;torch.float16;\"\n                \"if name.endswith('norm'): \"\n                \"module._pre_set_compute_dtype = torch.float32\\n\"\n                \";\"\n                \"from unsloth_zoo.temporary_patches.gemma3n import patch_Gemma3nConv_Embed_forwards; patch_Gemma3nConv_Embed_forwards()\"\n            )\n            # Set norms to float32 since anyways they get upcasted to float32\n            # common in both gemma-3 and gemma-3n\n            os.environ[\"UNSLOTH_HIGH_PRECISION_LAYERNORM\"] = \"1\"\n        # Gemma 3\n        elif \"gemma3\" in model_types_all:\n            if transformers_version < Version(\"4.50.0.dev0\"):\n                raise RuntimeError(\n                    \"Unsloth: Gemma 3 only works on transformers >= 4.50.0.\" + NIGHTLY\n                )\n            # Set norms to float32 since anyways they get upcasted to float32\n            # common in both gemma-3 and gemma-3n\n            os.environ[\"UNSLOTH_HIGH_PRECISION_LAYERNORM\"] = \"1\"\n            # ROCm/HIP: Gemma3 compiled forward produces NaN on RDNA GPUs\n            # (gfx1100, gfx1101, gfx1102, gfx1150, gfx1151, etc.).\n            # Disable torch.compile for model forward; loss compilation is fine.\n            # See https://github.com/unslothai/unsloth/issues/3385\n            from unsloth.kernels.utils import is_rdna\n\n            if is_rdna():\n                os.environ[\"UNSLOTH_COMPILE_DISABLE\"] = \"partial\"\n        # Cohere\n        elif \"cohere2\" in model_types_all and transformers_version < Version(\n            \"4.50.0.dev0\"\n        ):\n            raise RuntimeError(\n                \"Unsloth: Cohere's Command model only works on transformers >= 4.50.0.\"\n                + NIGHTLY\n            )\n        # Sesame\n        elif \"csm\" in model_types_all:\n            os.environ[\"UNSLOTH_COMPILE_DISABLE\"] = \"partial\"  # Inference is too slow\n            os.environ[\"UNSLOTH_DISABLE_STATIC_GENERATION\"] = \"1\"  # Sesame fails\n            os.environ[\"UNSLOTH_FORCE_CUSTOM_DTYPE\"] = (\n                \"all;torch.float32;torch.float16;\"\n                \"if name.endswith(('_proj', 'fc1', 'fc2', 'codebook', 'head')): module.to(torch.float16)\"\n                \";\"\n            )\n        # Granite 4\n        elif \"granitemoehybrid\" in model_types_all:\n            # Granite-4 rms norms are stored as 16 bit, but we upcast\n            os.environ[\"UNSLOTH_HIGH_PRECISION_LAYERNORM\"] = \"1\"\n            os.environ[\"UNSLOTH_DISABLE_STATIC_GENERATION\"] = \"1\"\n        # Olmo 2\n        elif \"olmo2\" in model_types_all and transformers_version < Version(\n            \"4.50.0.dev0\"\n        ):\n            raise RuntimeError(\n                \"Unsloth: OLMo-2 only works on transformers >= 4.50.0.\" + NIGHTLY\n            )\n        elif \"falcon_h1\" in model_types_all:\n            # Falcon must use float32 Triton ie TRITON_F32_DEFAULT = 'ieee'\n            # since Mamba kernels error out on using lower precision\n            os.environ[\"UNSLOTH_FORCE_CUSTOM_DTYPE\"] = (\n                \"float16;torch.float32;torch.float16;\"\n                \"if name.endswith(('q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj', 'head')): module.to(torch.float16)\"\n                \";\"\n                \"os.environ['TRITON_F32_DEFAULT'] = 'ieee'\"\n            )\n        elif \"nemotron_h\" in model_types_all:\n            # NemotronH (hybrid Mamba-2 + Transformer) uses same Mamba kernels as Falcon-H1\n            # Mamba kernels need float32 Triton precision\n            os.environ[\"UNSLOTH_FORCE_CUSTOM_DTYPE\"] = (\n                \"float16;torch.float32;torch.float16;\"\n                \"if name.endswith(('q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj', 'head')): module.to(torch.float16)\"\n                \";\"\n                \"os.environ['TRITON_F32_DEFAULT'] = 'ieee'\"\n            )\n        elif \"gpt_oss\" in model_types_all:\n            os.environ[\"UNSLOTH_DISABLE_STATIC_GENERATION\"] = \"1\"\n            if not load_in_4bit:\n                # Only upcast MoE biases for MXFP4, not BnB\n                # Set norms to float32 since anyways they get upcasted to float32\n                os.environ[\"UNSLOTH_FORCE_CUSTOM_DTYPE\"] = (\n                    \"all;None;None;\"\n                    \"x = 'gate_up_proj_bias'\\n\"\n                    \"if hasattr(module, x): \"\n                    \"setattr(module, x, torch.nn.Parameter(getattr(module, x).to(torch.float32)) if isinstance(getattr(module, x), torch.nn.Parameter) else getattr(module, x).to(torch.float32))\\n\"\n                    \"\"\n                    \"x = 'down_proj_bias'\\n\"\n                    \"if hasattr(module, x): \"\n                    \"setattr(module, x, torch.nn.Parameter(getattr(module, x).to(torch.float32)) if isinstance(getattr(module, x), torch.nn.Parameter) else getattr(module, x).to(torch.float32))\\n\"\n                    \"\"\n                    \";\"\n                )\n            else:\n                # Set down projection compute dtype to be float32 for float16 machines\n                # Set norms to float32 since anyways they get upcasted to float32\n                os.environ[\"UNSLOTH_FORCE_CUSTOM_DTYPE\"] = (\n                    \"torch.float16;torch.bfloat16;torch.float16;\"\n                    \"if ('down_projs' in name) and hasattr(module, 'weight') and \"\n                    \"torch.amax(dequantize_module_weight(module)) >= 0:\"\n                    \"module._pre_set_compute_dtype = torch.float32\\n\"\n                    \"\"\n                    \"if ('mlp.router' in name) and hasattr(module, 'weight'):\"\n                    \"module._pre_set_compute_dtype = torch.float32\\n\"\n                    \";\"\n                )\n            # Set norms to float32 since anyways they get upcasted to float32\n            os.environ[\"UNSLOTH_HIGH_PRECISION_LAYERNORM\"] = \"1\"\n        else:\n            for check_model_name in DISABLE_COMPILE_MODEL_NAMES:\n                if check_model_name in lowered_model_name:\n                    os.environ[\"UNSLOTH_COMPILE_DISABLE\"] = \"partial\"\n                    os.environ[\"UNSLOTH_DISABLE_STATIC_GENERATION\"] = \"1\"\n                    if transformers_version < Version(\"4.50.0.dev0\"):\n                        raise RuntimeError(\n                            f\"Unsloth: {check_model_name} only works on transformers >= 4.50.0.\"\n                            + NIGHTLY\n                        )\n                    break\n\n        if auto_model is not None:\n            # All other models need to disable static cache\n            os.environ[\"UNSLOTH_DISABLE_STATIC_GENERATION\"] = \"1\"\n\n        # New transformers need to check manually.\n        if SUPPORTS_LLAMA32:\n            # Check if folder exists locally\n            if os.path.isdir(model_name):\n                exist_adapter_config = os.path.exists(\n                    os.path.join(model_name, \"adapter_config.json\")\n                )\n                exist_config = os.path.exists(os.path.join(model_name, \"config.json\"))\n                both_exist = exist_adapter_config and exist_config\n            else:\n                files = HfFileSystem(token = token).glob(f\"{model_name}/*.json\")\n                files = list(os.path.split(x)[-1] for x in files)\n                if (\n                    sum(x == \"adapter_config.json\" or x == \"config.json\" for x in files)\n                    >= 2\n                ):\n                    both_exist = True\n\n        if not is_model and not is_peft:\n            error = autoconfig_error if autoconfig_error is not None else peft_error\n            # Old transformers version\n            if \"rope_scaling\" in error.lower() and not SUPPORTS_LLAMA31:\n                raise ImportError(\n                    f\"Unsloth: Your transformers version of {transformers_version} does not support new RoPE scaling methods.\\n\"\n                    f\"This includes Llama 3.1. The minimum required version is 4.43.2\\n\"\n                    f'Try `pip install --upgrade \"transformers>=4.43.2\"`\\n'\n                    f\"to obtain the latest transformers build, then restart this session.\"\n                )\n            # Create a combined error message showing both failures\n            combined_error = (\n                \"Unsloth: Failed to load model. Both AutoConfig and PeftConfig loading failed.\\n\\n\"\n                f\"AutoConfig error: {autoconfig_error}\\n\\n\"\n                f\"PeftConfig error: {peft_error}\\n\\n\"\n            )\n            raise RuntimeError(combined_error)\n\n        # Get base model for PEFT:\n        if is_peft:\n            # Check base model again for PEFT\n            model_name = peft_config.base_model_name_or_path\n            if not use_exact_model_name:\n                model_name = get_model_name(model_name, load_in_4bit)\n            # Check if pre-quantized models are allowed\n            # AMD Instinct GPUs need blocksize = 128 on bitsandbytes < 0.49.2 (our pre-quants use blocksize = 64)\n            if not ALLOW_PREQUANTIZED_MODELS and model_name.lower().endswith(\n                (\"-unsloth-bnb-4bit\", \"-bnb-4bit\")\n            ):\n                model_name = model_name.lower().removesuffix(\"-unsloth-bnb-4bit\")\n                model_name = model_name.lower().removesuffix(\"-bnb-4bit\")\n            # Change -BF16 to all False for 4bit, 8bit etc\n            if model_name.lower().endswith(\"-bf16\"):\n                load_in_4bit = False\n                load_in_8bit = False\n                load_in_fp8 = False\n                load_in_16bit = True\n\n            model_config = AutoConfig.from_pretrained(\n                model_name,\n                token = token,\n                trust_remote_code = trust_remote_code,\n            )\n\n        if not was_disabled:\n            enable_progress_bars()\n\n        do_logging = os.environ.get(\"UNSLOTH_ENABLE_LOGGING\", \"0\") == \"1\"\n        if do_logging:\n            redirector = contextlib.nullcontext()\n        else:\n            redirector = contextlib.redirect_stdout(open(os.devnull, \"w\"))\n\n        model_types = [\"siglip\"] + model_types\n        # Set forced float32 env flag\n        os.environ[\"UNSLOTH_FORCE_FLOAT32\"] = \"0\"\n        do_forced_float32 = False\n        for model_type_arch in model_types:\n            if model_type_arch != \"siglip\":\n                break\n        global FORCE_FLOAT32\n        for disable_name in FORCE_FLOAT32:\n            # add comma to model_types_all matching in case of exact match for end\n            if (\n                disable_name.lower()\n                == model_type_arch.lower().replace(\"-\", \"\").replace(\"_\", \"\")\n                or disable_name.lower() in model_types_all\n            ) and ((dtype == torch.float16) or not SUPPORTS_BFLOAT16):\n                os.environ[\"UNSLOTH_FORCE_FLOAT32\"] = \"1\"\n                dtype = torch.bfloat16  # Change to bfloat16 loading\n                break\n        # Apply gradient checkpointing with smart heuristics\n        use_gradient_checkpointing = apply_unsloth_gradient_checkpointing(\n            use_gradient_checkpointing, max_seq_length, dtype\n        )\n        with redirector:\n            patch_loss_functions(torch_compile = False)\n            model_types, supports_sdpa = unsloth_compile_transformers(\n                dtype = dtype,\n                model_name = model_name,\n                model_types = model_types,\n                token = token,\n                sdpa_dynamic_mask = True,\n                sdpa_bool_masks = True,\n                sdpa_gqa_replace = True,\n                sdpa_dynamic_compile = True,\n                compile_attention = True,\n                disable_causal_masks = True,\n                compile_torch_modules = True,\n                compile_custom_modules = True,\n                compile_function_calls = True,\n                fuse_lm_head = True,\n                gradient_checkpointing = True,\n                manual_replacements = True,\n                fast_lora_forwards = True,\n                fast_residual_stream = False,\n                accurate_accumulation = True,\n                epilogue_fusion = True,\n                max_autotune = False,\n                shape_padding = True,\n                cudagraphs = False,\n                debug = False,\n                fullgraph = fullgraph,\n                import_from_cache = False,\n                disable = False,\n                return_logits = return_logits,\n                trust_remote_code = trust_remote_code,\n                unsloth_force_compile = unsloth_force_compile,\n            )\n        # Fix SDPA issues\n        for model_type in DISABLE_SDPA_MODEL_NAMES:\n            if model_type in model_types_all:\n                supports_sdpa = False\n\n        # Check if this is local model since the tokenizer gets overwritten\n        if (\n            os.path.exists(os.path.join(old_model_name, \"tokenizer_config.json\"))\n            and os.path.exists(os.path.join(old_model_name, \"tokenizer.json\"))\n            and os.path.exists(os.path.join(old_model_name, \"special_tokens_map.json\"))\n        ):\n            tokenizer_name = old_model_name\n        else:\n            tokenizer_name = kwargs.pop(\"tokenizer_name\", None)\n\n        # Check if VLM\n        architectures = getattr(model_config, \"architectures\", None)\n        if architectures is None:\n            architectures = []\n        is_vlm = any(x.endswith(\"ForConditionalGeneration\") for x in architectures)\n        is_vlm = is_vlm or hasattr(model_config, \"vision_config\")\n        if auto_model is None:\n            if is_vlm:\n                # Check if the model's auto_map supports the VLM auto class.\n                # Some VL models (e.g. Nemotron-VL) only register AutoModelForCausalLM\n                # in their auto_map, not AutoModelForImageTextToText/AutoModelForVision2Seq.\n                _auto_map = getattr(model_config, \"auto_map\", {}) or {}\n                _vlm_class_name = AutoModelForVision2Seq.__name__\n                if (\n                    \"AutoModelForCausalLM\" in _auto_map\n                    and _vlm_class_name not in _auto_map\n                ):\n                    auto_model = AutoModelForCausalLM\n                else:\n                    auto_model = AutoModelForVision2Seq\n            else:\n                auto_model = AutoModelForCausalLM\n\n        load_in_4bit_kwargs = load_in_4bit\n        load_in_8bit_kwargs = load_in_8bit\n        if quantization_config is not None and not fast_inference:\n            load_in_4bit_kwargs = False\n            load_in_8bit_kwargs = False\n\n        model, tokenizer = FastBaseModel.from_pretrained(\n            model_name = model_name,\n            max_seq_length = max_seq_length,\n            dtype = _get_dtype(dtype),\n            load_in_4bit = load_in_4bit_kwargs,\n            load_in_8bit = load_in_8bit_kwargs,\n            load_in_16bit = load_in_16bit,\n            full_finetuning = full_finetuning,\n            token = token,\n            device_map = device_map,\n            trust_remote_code = trust_remote_code,\n            revision = revision if not is_peft else None,\n            model_types = model_types,\n            tokenizer_name = tokenizer_name,\n            auto_model = auto_model,\n            use_gradient_checkpointing = use_gradient_checkpointing,\n            supports_sdpa = supports_sdpa,\n            whisper_language = whisper_language,\n            whisper_task = whisper_task,\n            auto_config = model_config,\n            offload_embedding = offload_embedding,\n            float32_mixed_precision = float32_mixed_precision,\n            # Pass vLLM/inference parameters\n            fast_inference = fast_inference,\n            gpu_memory_utilization = gpu_memory_utilization,\n            float8_kv_cache = float8_kv_cache,\n            random_state = random_state,\n            max_lora_rank = max_lora_rank,\n            disable_log_stats = disable_log_stats,\n            load_in_fp8 = load_in_fp8,\n            *args,\n            **kwargs,\n        )\n\n        if resize_model_vocab is not None:\n            model.resize_token_embeddings(resize_model_vocab)\n\n        # In case the model supports tagging, add the unsloth tag.\n        if hasattr(model, \"add_model_tags\"):\n            model.add_model_tags(\n                [\n                    \"unsloth\",\n                ]\n            )\n        if hasattr(tokenizer, \"add_model_tags\"):\n            tokenizer.add_model_tags(\n                [\n                    \"unsloth\",\n                ]\n            )\n\n        if load_in_4bit:\n            # Fix up bitsandbytes config, but respect user-provided quantization_config\n            if quantization_config is None:\n                compute_dtype = dtype_from_config(model.config)\n                quantization_config = {\n                    # Sometimes compute_dtype is not a string!!\n                    \"bnb_4bit_compute_dtype\": compute_dtype,\n                    \"bnb_4bit_quant_type\": \"nf4\",\n                    \"bnb_4bit_use_double_quant\": True,\n                    \"llm_int8_enable_fp32_cpu_offload\": False,\n                    \"llm_int8_has_fp16_weight\": False,\n                    \"llm_int8_skip_modules\": None,\n                    \"llm_int8_threshold\": 6.0,\n                    \"load_in_4bit\": True,\n                    \"load_in_8bit\": False,\n                    \"quant_method\": \"bitsandbytes\",\n                }\n                model.config.update({\"quantization_config\": quantization_config})\n            else:\n                if hasattr(quantization_config, \"to_dict\"):\n                    model.config.update(\n                        {\"quantization_config\": quantization_config.to_dict()}\n                    )\n                elif isinstance(quantization_config, dict):\n                    model.config.update({\"quantization_config\": quantization_config})\n\n        if load_in_fp8 != False:\n            _tag_model_with_fp8_torchao_config(model, fp8_mode)\n\n        if is_peft:\n            # From https://github.com/huggingface/peft/issues/184\n            # Now add PEFT adapters\n            model = PeftModel.from_pretrained(\n                model,\n                old_model_name,\n                token = token,\n                revision = revision,\n                is_trainable = True,\n                trust_remote_code = trust_remote_code,\n            )\n            # Patch it as well!\n            model = FastBaseModel.post_patch_model(\n                model, use_gradient_checkpointing, trust_remote_code = trust_remote_code\n            )\n\n        # Apply QAT if specified\n        if qat_scheme is not None:\n            print(\"Unsloth: Applying QAT to mitigate quantization degradation\")\n            model = FastModel._prepare_for_qat(model, qat_scheme)\n\n        # Patch Tiled MLP\n        # to turn on set UNSLOTH_TILED_MLP to \"arctic\", \"target\", or \"target:{GB}\"\"\n        patch_tiled_mlp_choice = os.environ.get(\n            \"UNSLOTH_TILED_MLP\", \"arctic\" if unsloth_tiled_mlp else \"0\"\n        )\n        if patch_tiled_mlp_choice != \"0\" or unsloth_tiled_mlp:\n            patch_tiled_mlp(model, patch_options_str = patch_tiled_mlp_choice)\n\n        model = _fix_rope_inv_freq(model)\n        return model, tokenizer\n\n\nclass FastVisionModel(FastModel):\n    pass\n\n\nclass FastTextModel(FastModel):\n    pass\n"
  },
  {
    "path": "unsloth/models/loader_utils.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom ..device_type import DEVICE_TYPE_TORCH\nimport importlib\nimport os\nimport torch\nimport re\nimport tempfile\nfrom typing import Union\nfrom .mapper import (\n    INT_TO_FLOAT_MAPPER,\n    FLOAT_TO_INT_MAPPER,\n    MAP_TO_UNSLOTH_16bit,\n    FLOAT_TO_FP8_BLOCK_MAPPER,\n    FLOAT_TO_FP8_ROW_MAPPER,\n)\n\n# https://github.com/huggingface/transformers/pull/26037 allows 4 bit loading!\nfrom transformers import __version__ as transformers_version\nfrom unsloth.models._utils import TorchAOConfig\nfrom unsloth_zoo.utils import Version\nimport gc\n\ntransformers_version = Version(transformers_version)\nSUPPORTS_FOURBIT = transformers_version >= Version(\"4.37\")\n\nLOCAL_RANK_KEYS = (\"LOCAL_RANK\", \"RANK\")\nWORLD_SIZE_KEYS = (\"WORLD_SIZE\",)\n\nBAD_MAPPINGS = {\n    \"unsloth/Qwen3-32B-unsloth-bnb-4bit\".lower(): \"unsloth/Qwen3-32B-bnb-4bit\".lower(),  # 32B dynamic quant is way too big\n    \"unsloth/Qwen3-30B-A3B-unsloth-bnb-4bit\".lower(): \"unsloth/Qwen3-30B-A3B\".lower(),  # HF loads MoEs too slowly\n    \"unsloth/Qwen3-30B-A3B-bnb-4bit\".lower(): \"unsloth/Qwen3-30B-A3B\".lower(),  # We rather do it on the fly\n    \"unsloth/Qwen3-30B-A3B-Base-unsloth-bnb-4bit\".lower(): \"unsloth/Qwen3-30B-A3B-Base\".lower(),  # HF loads MoEs too slowly\n    \"unsloth/Qwen3-30B-A3B-Base-bnb-4bit\".lower(): \"unsloth/Qwen3-30B-A3B-Base\".lower(),  # We rather do it on the fly\n}\n\n\ndef _get_torchao_fp8_config(fp8_mode):\n    # Import lazily so an optional, broken vLLM install does not break plain `import unsloth`.\n    from unsloth_zoo.vllm_utils import _get_torchao_fp8_config as _impl\n\n    return _impl(fp8_mode)\n\n\ndef _get_env_int(keys):\n    for key in keys:\n        value = os.environ.get(key)\n        if value is None:\n            continue\n        try:\n            return int(value)\n        except ValueError:\n            continue\n    return None\n\n\ndef _infer_distributed_ranks():\n    if torch.distributed.is_available() and torch.distributed.is_initialized():\n        try:\n            return torch.distributed.get_rank(), torch.distributed.get_world_size()\n        except Exception:\n            pass\n    return _get_env_int(LOCAL_RANK_KEYS), _get_env_int(WORLD_SIZE_KEYS)\n\n\ndef is_distributed():\n    rank, world_size = _infer_distributed_ranks()\n    return (world_size or 1) > 1 or (rank is not None and rank > 0)\n\n\ndef prepare_device_map():\n    rank, world_size = _infer_distributed_ranks()\n    distributed = (world_size or 1) > 1 or (rank is not None and rank > 0)\n    if not distributed:\n        return None, False\n\n    local_rank = 0 if rank is None else rank\n    device_map = {\"\": f\"{DEVICE_TYPE_TORCH}:{local_rank}\"}\n    try:\n        if DEVICE_TYPE_TORCH == \"cuda\":\n            torch.cuda.set_device(local_rank)\n        elif DEVICE_TYPE_TORCH == \"xpu\" and hasattr(torch, \"xpu\"):\n            torch.xpu.set_device(local_rank)\n    except Exception:\n        pass\n    return device_map, True\n\n\ndef __get_model_name(\n    model_name,\n    load_in_4bit = True,\n    INT_TO_FLOAT_MAPPER = None,\n    FLOAT_TO_INT_MAPPER = None,\n    MAP_TO_UNSLOTH_16bit = None,\n    load_in_fp8 = False,\n    FLOAT_TO_FP8_BLOCK_MAPPER = None,\n    FLOAT_TO_FP8_ROW_MAPPER = None,\n):\n    model_name = str(model_name)\n    lower_model_name = model_name.lower()\n\n    assert load_in_fp8 in (True, False, \"block\")\n    if load_in_fp8 != False:\n        if load_in_fp8 == True and (os.environ.get(\"UNSLOTH_HAS_FBGEMM\", \"0\") == \"1\"):\n            if lower_model_name in FLOAT_TO_FP8_ROW_MAPPER:\n                # Faster row scaling only works if FBGEMM works!\n                return FLOAT_TO_FP8_ROW_MAPPER[lower_model_name]\n            elif lower_model_name in FLOAT_TO_FP8_BLOCK_MAPPER:\n                # Otherwise we use the slower blockwise type\n                return FLOAT_TO_FP8_BLOCK_MAPPER[lower_model_name]\n        else:\n            if lower_model_name in FLOAT_TO_FP8_BLOCK_MAPPER:\n                return FLOAT_TO_FP8_BLOCK_MAPPER[lower_model_name]\n        # Mapper didn't find a pre-quantized model.\n        # For vllm >= 0.12.0, we can quantize the model to FP8 on the fly,\n        # so just return the original model name. Older vllm versions will\n        # fall through to offline quantization via _offline_quantize_to_fp8.\n        if importlib.util.find_spec(\"vllm\") is not None:\n            import vllm\n\n            if Version(vllm.__version__) >= Version(\"0.12.0\"):\n                return model_name\n        return None\n\n    elif not SUPPORTS_FOURBIT and lower_model_name in INT_TO_FLOAT_MAPPER:\n        model_name = INT_TO_FLOAT_MAPPER[lower_model_name]\n        print(\n            f\"Unsloth: Your transformers version of {transformers_version} does not support native \"\n            f\"4bit loading.\\nThe minimum required version is 4.37.\\n\"\n            f'Try `pip install --upgrade \"transformers>=4.37\"`\\n'\n            f\"to obtain the latest transformers build, then restart this session.\\n\"\n            f\"For now, we shall load `{model_name}` instead (still 4bit, just slower downloading).\"\n        )\n        return model_name\n\n    elif not load_in_4bit and lower_model_name in INT_TO_FLOAT_MAPPER:\n        new_model_name = INT_TO_FLOAT_MAPPER[lower_model_name]\n        # logger.warning_once(\n        #     f\"Unsloth: You passed in `{model_name}` which is a 4bit model, yet you set\\n\"\\\n        #     f\"`load_in_4bit = False`. We shall load `{new_model_name}` instead.\"\n        # )\n        return new_model_name\n\n    elif not load_in_4bit and lower_model_name in MAP_TO_UNSLOTH_16bit:\n        new_model_name = MAP_TO_UNSLOTH_16bit[lower_model_name]\n        return new_model_name\n\n    elif load_in_4bit and SUPPORTS_FOURBIT and lower_model_name in FLOAT_TO_INT_MAPPER:\n        # Support returning original full -bnb-4bit name if specified specifically\n        # since we'll map it to the dynamic version instead\n        if lower_model_name.endswith(\"-bnb-4bit\"):\n            return lower_model_name\n\n        new_model_name = FLOAT_TO_INT_MAPPER[lower_model_name]\n        # logger.warning_once(\n        #     f\"Unsloth: You passed in `{model_name}` and `load_in_4bit = True`.\\n\"\\\n        #     f\"We shall load `{new_model_name}` for 4x faster loading.\"\n        # )\n        return new_model_name\n\n    return None\n\n\ndef _get_new_mapper():\n    try:\n        import requests\n\n        new_mapper = \"https://raw.githubusercontent.com/unslothai/unsloth/main/unsloth/models/mapper.py\"\n        with requests.get(new_mapper, timeout = 3) as new_mapper:\n            new_mapper = new_mapper.text\n        new_mapper = new_mapper[new_mapper.find(\"__INT_TO_FLOAT_MAPPER\") :]\n        new_mapper = (\n            new_mapper.replace(\"INT_TO_FLOAT_MAPPER\", \"NEW_INT_TO_FLOAT_MAPPER\")\n            .replace(\"FLOAT_TO_INT_MAPPER\", \"NEW_FLOAT_TO_INT_MAPPER\")\n            .replace(\"MAP_TO_UNSLOTH_16bit\", \"NEW_MAP_TO_UNSLOTH_16bit\")\n        )\n\n        exec(new_mapper, globals())\n        return (\n            NEW_INT_TO_FLOAT_MAPPER,\n            NEW_FLOAT_TO_INT_MAPPER,\n            NEW_MAP_TO_UNSLOTH_16bit,\n        )\n    except:\n        return {}, {}, {}\n\n\ndef _resolve_with_mappers(\n    model_name,\n    load_in_4bit,\n    load_in_fp8,\n    int_to_float,\n    float_to_int,\n    map_to_unsloth_16bit,\n):\n    return __get_model_name(\n        model_name = model_name,\n        load_in_4bit = load_in_4bit,\n        INT_TO_FLOAT_MAPPER = int_to_float,\n        FLOAT_TO_INT_MAPPER = float_to_int,\n        MAP_TO_UNSLOTH_16bit = map_to_unsloth_16bit,\n        load_in_fp8 = load_in_fp8,\n        FLOAT_TO_FP8_BLOCK_MAPPER = FLOAT_TO_FP8_BLOCK_MAPPER,\n        FLOAT_TO_FP8_ROW_MAPPER = FLOAT_TO_FP8_ROW_MAPPER,\n    )\n\n\ndef get_model_name(\n    model_name,\n    load_in_4bit = True,\n    load_in_fp8 = False,\n    token = None,\n    trust_remote_code = False,\n):\n    assert load_in_fp8 in (True, False, \"block\")\n    new_model_name = _resolve_with_mappers(\n        model_name = model_name,\n        load_in_4bit = load_in_4bit,\n        load_in_fp8 = load_in_fp8,\n        int_to_float = INT_TO_FLOAT_MAPPER,\n        float_to_int = FLOAT_TO_INT_MAPPER,\n        map_to_unsloth_16bit = MAP_TO_UNSLOTH_16bit,\n    )\n    # In the rare case, we convert bad model names to other names\n    # For eg too large dynamic quants or MoEs\n    if (\n        new_model_name is not None\n        and type(new_model_name) is str\n        and new_model_name.lower() in BAD_MAPPINGS\n    ):\n        new_model_name = BAD_MAPPINGS[new_model_name.lower()]\n\n    if (\n        new_model_name is None\n        and model_name.count(\"/\") == 1\n        and model_name[0].isalnum()\n    ):\n        # Try checking if a new Unsloth version allows it!\n        NEW_INT_TO_FLOAT_MAPPER, NEW_FLOAT_TO_INT_MAPPER, NEW_MAP_TO_UNSLOTH_16bit = (\n            _get_new_mapper()\n        )\n        upgraded_model_name = _resolve_with_mappers(\n            model_name = model_name,\n            load_in_4bit = load_in_4bit,\n            load_in_fp8 = load_in_fp8,\n            int_to_float = NEW_INT_TO_FLOAT_MAPPER,\n            float_to_int = NEW_FLOAT_TO_INT_MAPPER,\n            map_to_unsloth_16bit = NEW_MAP_TO_UNSLOTH_16bit,\n        )\n        if upgraded_model_name is not None:\n            raise NotImplementedError(\n                f\"Unsloth: {model_name} is not supported in your current Unsloth version! Please update Unsloth via:\\n\\n\"\n                \"pip uninstall unsloth unsloth_zoo -y\\n\"\n                'pip install --upgrade --no-cache-dir \"unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git\"\\n'\n                'pip install --upgrade --no-cache-dir \"git+https://github.com/unslothai/unsloth-zoo.git\"\\n'\n            )\n\n    if new_model_name is None:\n        new_model_name = model_name\n\n    return new_model_name\n\n\ndef _offline_quantize_to_fp8(model_name: str, fp8_mode: str) -> str:\n    \"\"\"\n    Quantizes the model to fp8 using torchao and saving the quantized model to a\n    temporary location. Return the path to the quantized model.\n\n    Note: For vllm >= 0.12.0, we should dynamically quantize the model in vllm instead:\n\n      llm = LLM(\n        ...\n        hf_overrides={\"quantization_config_file\": \"torchao_config.json\"},\n      )\n    \"\"\"\n    temp_dir = tempfile.gettempdir()\n    new_model_name = model_name.split(\"/\")[-1] + \"-fp8-\" + fp8_mode\n    new_model_name = os.path.join(temp_dir, new_model_name)\n    print(\n        f\"Unsloth: Quantizing '{model_name}' to fp8, using model_name='{new_model_name}' instead\"\n    )\n\n    if not os.path.isdir(new_model_name):\n        from transformers import (\n            AutoModelForCausalLM,\n            AutoModelForImageTextToText,\n            AutoTokenizer,\n            AutoProcessor,\n            TorchAoConfig,\n            AutoConfig,\n        )\n\n        qconfig = _get_torchao_fp8_config(fp8_mode)\n        qconfig = TorchAoConfig(qconfig)\n        config = AutoConfig.from_pretrained(model_name)\n        is_vlm = any(\n            x.endswith((\"ForConditionalGeneration\", \"ForVisionText2Text\"))\n            for x in config.architectures\n        )\n        is_vlm = is_vlm or hasattr(config, \"vision_config\")\n        auto_model = AutoModelForImageTextToText if is_vlm else AutoModelForCausalLM\n        auto_processor = AutoProcessor if is_vlm else AutoTokenizer\n        model = auto_model.from_pretrained(\n            model_name,\n            torch_dtype = \"auto\",\n            device_map = \"auto\",\n            quantization_config = qconfig,\n        )\n        tokenizer = auto_processor.from_pretrained(model_name)\n        model.save_pretrained(new_model_name, safe_serialization = False)\n        del model\n        for _ in range(2):\n            torch.cuda.empty_cache()\n            gc.collect()\n        tokenizer.save_pretrained(new_model_name)\n    return new_model_name\n\n\ndef _tag_model_with_fp8_torchao_config(model: torch.nn.Module, fp8_mode: str):\n    \"\"\"\n    Tag a model with a `TorchAOConfig` so downstream callers will know what to do with it.\n    \"\"\"\n    try:\n        base_config = _get_torchao_fp8_config(fp8_mode)\n        model.torchao_config = TorchAOConfig(\n            qat_scheme = None,\n            base_config_and_filter_fns = [(base_config, None)],\n        )\n    except:\n        pass\n\n\ndef _get_fp8_mode_and_check_settings(\n    load_in_fp8: Union[bool, str],\n    fast_inference: bool,\n    full_finetuning: bool = False,\n    load_in_4bit: bool = False,\n    load_in_8bit: bool = False,\n    load_in_16bit: bool = False,\n) -> str:\n    \"\"\"\n    Assuming `load_in_fp8` is enabled, raise appropriate errors on incompatible settings\n    and environment. Currently this feature requires:\n\n    1. H100 GPUs or after\n    2. torchao 0.15.0+ (or nightly)\n    3. torch 2.9.0+\n    4. If fbgemm_gpu_genai is installed, require 1.4.1+\n\n    Returns the fp8 mode, one of \"row\" or \"block\".\n    \"\"\"\n    assert load_in_fp8 is not False\n    if load_in_fp8 is True:\n        fp8_mode = \"row\"  # default\n    else:\n        fp8_mode = load_in_fp8\n\n    # Check user settings\n    if fp8_mode not in [\"row\", \"block\"]:\n        raise ValueError(\n            f\"Unsloth: `load_in_fp8` can only be 'row' or 'block', got '{fp8_mode}'\"\n        )\n    if full_finetuning:\n        raise ValueError(\n            \"Unsloth: `load_in_fp8` is not compatible with full finetuning\"\n        )\n    if load_in_4bit or load_in_8bit or load_in_16bit:\n        raise ValueError(\n            \"Unsloth: `load_in_fp8` is not compatible with `load_in_4bit`, `load_in_8bit` or `load_in_16bit`\",\n        )\n\n    # Check if this is Hopper or above\n    if not (\n        torch.cuda.is_available()\n        and torch.version.cuda\n        and torch.cuda.get_device_capability() >= (9, 0)\n    ):\n        raise ValueError(\n            \"Unsloth: On the fly `load_in_fp8` requires H100 GPUs or after. Try `unsloth/Qwen3-8B` instead.\"\n        )\n\n    # Check if torch >= 2.9.0\n    if Version(torch.__version__) < Version(\"2.9.0\"):\n        raise ValueError(\n            \"Unsloth: On the fly `load_in_fp8` requires torch 2.9.0+. Try `unsloth/Qwen3-8B` instead.\"\n        )\n\n    # Check if torchao has this PR: https://github.com/pytorch/ao/pull/3158,\n    # which will be released in 0.15.0.\n    if importlib.util.find_spec(\"torchao\") is None:\n        raise ValueError(\n            \"Unsloth: Please install torchao for on the fly float8 to work! Try `unsloth/Qwen3-8B` instead.\"\n        )\n    import torchao\n\n    error_message = (\n        \"Unsloth: `load_in_fp8` requires torchao 0.15.0+ (or nightly).\\n\"\n        f\"You have torchao version={torchao.__version__}\\n\"\n        \"Use `pip install --upgrade --force-reinstall torchao`\"\n    )\n    if Version(torchao.__version__) < Version(\"0.15.0\"):\n        raise ValueError(error_message)\n\n    # If fbgemm_gpu_genai is installed and old, disable FBGEMM and use Triton instead\n    if (\n        importlib.util.find_spec(\"fbgemm_gpu\") is not None\n        and importlib.util.find_spec(\"fbgemm_gpu.experimental\") is not None\n    ):\n        import fbgemm_gpu.experimental.gen_ai\n\n        if Version(fbgemm_gpu.__version__) < Version(\"1.4.1\"):\n            # Old FBGEMM version - disable and use Triton kernels instead\n            os.environ[\"UNSLOTH_HAS_FBGEMM\"] = \"0\"\n            from unsloth_zoo.log import logger\n\n            logger.info(\n                f\"Unsloth: fbgemm_gpu_genai=={fbgemm_gpu.__version__} is old for FP8 loading. \"\n                f\"Using Triton kernels instead.\"\n            )\n    return fp8_mode\n"
  },
  {
    "path": "unsloth/models/mapper.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n__all__ = [\n    \"INT_TO_FLOAT_MAPPER\",\n    \"FLOAT_TO_INT_MAPPER\",\n    \"MAP_TO_UNSLOTH_16bit\",\n    \"FLOAT_TO_FP8_BLOCK_MAPPER\",\n    \"FLOAT_TO_FP8_ROW_MAPPER\",\n]\n\n__INT_TO_FLOAT_MAPPER = \\\n{\n    \"unsloth/mistral-7b-bnb-4bit\" : (\n        \"unsloth/mistral-7b\",\n        \"mistralai/Mistral-7B-v0.1\",\n    ),\n    \"unsloth/llama-2-7b-bnb-4bit\" : (\n        \"unsloth/llama-2-7b\",\n        \"meta-llama/Llama-2-7b-hf\",\n    ),\n    \"unsloth/llama-2-13b-bnb-4bit\" : (\n        \"unsloth/llama-2-13b\",\n        \"meta-llama/Llama-2-13b-hf\",\n    ),\n    \"unsloth/codellama-34b-bnb-4bit\" : (\n        \"codellama/CodeLlama-34b-hf\",\n    ),\n    \"unsloth/zephyr-sft-bnb-4bit\" : (\n        \"unsloth/zephyr-sft\",\n        \"HuggingFaceH4/mistral-7b-sft-beta\",\n    ),\n    \"unsloth/tinyllama-bnb-4bit\" : (\n        \"unsloth/tinyllama\",\n        \"TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T\",\n    ),\n    \"unsloth/tinyllama-chat-bnb-4bit\" : (\n        \"unsloth/tinyllama-chat\",\n        \"TinyLlama/TinyLlama-1.1B-Chat-v1.0\",\n    ),\n    \"unsloth/mistral-7b-instruct-v0.1-bnb-4bit\" : (\n        \"unsloth/mistral-7b-instruct-v0.1\",\n        \"mistralai/Mistral-7B-Instruct-v0.1\",\n    ),\n    \"unsloth/mistral-7b-instruct-v0.2-bnb-4bit\" : (\n        \"unsloth/mistral-7b-instruct-v0.2\",\n        \"mistralai/Mistral-7B-Instruct-v0.2\",\n    ),\n    \"unsloth/llama-2-7b-chat-bnb-4bit\" : (\n        \"unsloth/llama-2-7b-chat\",\n        \"meta-llama/Llama-2-7b-chat-hf\",\n    ),\n    \"unsloth/llama-2-7b-chat-bnb-4bit\" : (\n        \"unsloth/llama-2-7b-chat\",\n        \"meta-llama/Llama-2-7b-chat-hf\",\n    ),\n    \"unsloth/Mixtral-8x7B-v0.1-unsloth-bnb-4bit\" : (\n        \"unsloth/Mixtral-8x7B-v0.1\",\n        \"mistralai/Mixtral-8x7B-v0.1\",\n        \"unsloth/Mixtral-8x7B-v0.1-bnb-4bit\",\n    ),\n    \"unsloth/Mixtral-8x7B-Instruct-v0.1-unsloth-bnb-4bit\" : (\n        \"unsloth/Mixtral-8x7B-Instruct-v0.1\",\n        \"mistralai/Mixtral-8x7B-Instruct-v0.1\",\n        \"unsloth/Mixtral-8x7B-Instruct-v0.1-bnb-4bit\",\n    ),\n    \"unsloth/codellama-7b-bnb-4bit\" : (\n        \"unsloth/codellama-7b\",\n        \"codellama/CodeLlama-7b-hf\",\n    ),\n    \"unsloth/codellama-13b-bnb-4bit\" : (\n        \"codellama/CodeLlama-13b-hf\",\n    ),\n    \"unsloth/yi-6b-bnb-4bit\" : (\n        \"unsloth/yi-6b\",\n        \"01-ai/Yi-6B\",\n    ),\n    \"unsloth/solar-10.7b-bnb-4bit\" : (\n        \"upstage/SOLAR-10.7B-v1.0\",\n    ),\n    \"unsloth/gemma-7b-bnb-4bit\" : (\n        \"unsloth/gemma-7b\",\n        \"google/gemma-7b\",\n    ),\n    \"unsloth/gemma-2b-bnb-4bit\" : (\n        \"unsloth/gemma-2b\",\n        \"google/gemma-2b\",\n    ),\n    \"unsloth/gemma-7b-it-bnb-4bit\" : (\n        \"unsloth/gemma-7b-it\",\n        \"google/gemma-7b-it\",\n    ),\n    \"unsloth/gemma-2b-bnb-4bit\" : (\n        \"unsloth/gemma-2b-it\",\n        \"google/gemma-2b-it\",\n    ),\n    \"unsloth/mistral-7b-v0.2-bnb-4bit\" : (\n        \"unsloth/mistral-7b-v0.2\",\n        \"alpindale/Mistral-7B-v0.2-hf\",\n    ),\n    \"unsloth/gemma-1.1-2b-it-bnb-4bit\" : (\n        \"unsloth/gemma-1.1-2b-it\",\n        \"google/gemma-1.1-2b-it\",\n    ),\n    \"unsloth/gemma-1.1-7b-it-bnb-4bit\" : (\n        \"unsloth/gemma-1.1-7b-it\",\n        \"google/gemma-1.1-7b-it\",\n    ),\n    \"unsloth/Starling-LM-7B-beta\" : (\n        \"unsloth/Starling-LM-7B-beta\",\n        \"Nexusflow/Starling-LM-7B-beta\",\n    ),\n    \"unsloth/Hermes-2-Pro-Mistral-7B-bnb-4bit\" : (\n        \"unsloth/Hermes-2-Pro-Mistral-7B\",\n        \"NousResearch/Hermes-2-Pro-Mistral-7B\",\n    ),\n    \"unsloth/OpenHermes-2.5-Mistral-7B-bnb-4bit\" : (\n        \"unsloth/OpenHermes-2.5-Mistral-7B\",\n        \"teknium/OpenHermes-2.5-Mistral-7B\",\n    ),\n    \"unsloth/codegemma-2b-bnb-4bit\" : (\n        \"unsloth/codegemma-2b\",\n        \"google/codegemma-2b\",\n    ),\n    \"unsloth/codegemma-7b-bnb-4bit\" : (\n        \"unsloth/codegemma-7b\",\n        \"google/codegemma-7b\",\n    ),\n    \"unsloth/codegemma-7b-it-bnb-4bit\" : (\n        \"unsloth/codegemma-7b-it\",\n        \"google/codegemma-7b-it\",\n    ),\n    \"unsloth/llama-3-8b-bnb-4bit\" : (\n        \"unsloth/llama-3-8b\",\n        \"meta-llama/Meta-Llama-3-8B\",\n    ),\n    \"unsloth/llama-3-8b-Instruct-bnb-4bit\" : (\n        \"unsloth/llama-3-8b-Instruct\",\n        \"meta-llama/Meta-Llama-3-8B-Instruct\",\n    ),\n    \"unsloth/llama-3-70b-bnb-4bit\" : (\n        \"meta-llama/Meta-Llama-3-70B\",\n    ),\n    \"unsloth/llama-3-70b-Instruct-bnb-4bit\" : (\n        \"meta-llama/Meta-Llama-3-70B-Instruct\",\n    ),\n    \"unsloth/Phi-3-mini-4k-instruct-bnb-4bit\" : (\n        \"unsloth/Phi-3-mini-4k-instruct\",\n        \"microsoft/Phi-3-mini-4k-instruct\",\n    ),\n    \"unsloth/mistral-7b-v0.3-bnb-4bit\" : (\n        \"unsloth/mistral-7b-v0.3\",\n        \"mistralai/Mistral-7B-v0.3\",\n    ),\n    \"unsloth/mistral-7b-instruct-v0.3-bnb-4bit\" : (\n        \"unsloth/mistral-7b-instruct-v0.3\",\n        \"mistralai/Mistral-7B-Instruct-v0.3\",\n    ),\n    \"unsloth/Phi-3-medium-4k-instruct-bnb-4bit\" : (\n        \"unsloth/Phi-3-medium-4k-instruct\",\n        \"microsoft/Phi-3-medium-4k-instruct\",\n    ),\n    \"unsloth/Qwen2-0.5B-bnb-4bit\" : (\n        \"unsloth/Qwen2-0.5B\",\n        \"Qwen/Qwen2-0.5B\",\n    ),\n    \"unsloth/Qwen2-0.5B-Instruct-bnb-4bit\" : (\n        \"unsloth/Qwen2-0.5B-Instruct\",\n        \"Qwen/Qwen2-0.5B-Instruct\",\n    ),\n    \"unsloth/Qwen2-1.5B-bnb-4bit\" : (\n        \"unsloth/Qwen2-1.5B\",\n        \"Qwen/Qwen2-1.5B\",\n    ),\n    \"unsloth/Qwen2-1.5B-Instruct-bnb-4bit\" : (\n        \"unsloth/Qwen2-1.5B-Instruct\",\n        \"Qwen/Qwen2-1.5B-Instruct\",\n    ),\n    \"unsloth/Qwen2-7B-bnb-4bit\" : (\n        \"unsloth/Qwen2-7B\",\n        \"Qwen/Qwen2-7B\",\n    ),\n    \"unsloth/Qwen2-7B-Instruct-bnb-4bit\" : (\n        \"unsloth/Qwen2-7B-Instruct\",\n        \"Qwen/Qwen2-7B-Instruct\",\n    ),\n    \"unsloth/Qwen2-70B-bnb-4bit\" : (\n        \"Qwen/Qwen2-70B\",\n    ),\n    \"unsloth/Qwen2-70B-Instruct-bnb-4bit\" : (\n        \"Qwen/Qwen2-70B-Instruct\",\n    ),\n    \"mistralai/Codestral-22B-v0.1\" : (\n        \"mistral-community/Codestral-22B-v0.1\",\n    ),\n    \"unsloth/gemma-2-9b-bnb-4bit\" : (\n        \"unsloth/gemma-2-9b\",\n        \"google/gemma-2-9b\",\n    ),\n    \"unsloth/gemma-2-27b-bnb-4bit\" : (\n        \"unsloth/gemma-2-27b\",\n        \"google/gemma-2-27b\",\n    ),\n    \"unsloth/gemma-2-9b-it-bnb-4bit\" : (\n        \"unsloth/gemma-2-9b-it\",\n        \"google/gemma-2-9b-it\",\n    ),\n    \"unsloth/gemma-2-27b-it-bnb-4bit\" : (\n        \"unsloth/gemma-2-27b-it\",\n        \"google/gemma-2-27b-it\",\n    ),\n    \"unsloth/Phi-3-mini-4k-instruct-v0-bnb-4bit\" : ( # Old Phi pre July\n        \"unsloth/Phi-3-mini-4k-instruct-v0\",\n    ),\n    \"unsloth/Mistral-Nemo-Instruct-2407-bnb-4bit\" : ( # New 12b Mistral models\n        \"unsloth/Mistral-Nemo-Instruct-2407\",\n        \"mistralai/Mistral-Nemo-Instruct-2407\",\n    ),\n    \"unsloth/Mistral-Nemo-Base-2407-bnb-4bit\" : ( # New 12b Mistral models\n        \"unsloth/Mistral-Nemo-Base-2407\",\n        \"mistralai/Mistral-Nemo-Base-2407\",\n    ),\n    \"unsloth/Meta-Llama-3.1-8B-unsloth-bnb-4bit\" : (\n        \"unsloth/Meta-Llama-3.1-8B\",\n        \"meta-llama/Meta-Llama-3.1-8B\",\n        \"unsloth/Meta-Llama-3.1-8B-bnb-4bit\",\n    ),\n    \"unsloth/Meta-Llama-3.1-8B-Instruct-unsloth-bnb-4bit\" : {\n        \"8\" : (\n            \"RedHatAI/Llama-3.1-8B-Instruct-FP8\",\n            \"unsloth/Llama-3.1-8B-Instruct-FP8-Block\",\n            \"unsloth/Llama-3.1-8B-Instruct-FP8-Dynamic\",\n        ),\n        \"16\" : (\n            \"unsloth/Meta-Llama-3.1-8B-Instruct\",\n            \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n            \"unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit\",\n        ),\n    },\n    \"unsloth/Llama-3.1-8B-unsloth-bnb-4bit\" : (\n        \"unsloth/Llama-3.1-8B\",\n        \"meta-llama/Llama-3.1-8B\",\n        \"unsloth/Llama-3.1-8B-bnb-4bit\",\n    ),\n    \"unsloth/Llama-3.1-8B-Instruct-unsloth-bnb-4bit\" : {\n        \"8\" : (\n            \"RedHatAI/Llama-3.1-8B-Instruct-FP8\",\n            \"unsloth/Llama-3.1-8B-Instruct-FP8-Block\",\n            \"unsloth/Llama-3.1-8B-Instruct-FP8-Dynamic\",\n        ),\n        \"16\" : (\n            \"unsloth/Llama-3.1-8B-Instruct\",\n            \"meta-llama/Llama-3.1-8B-Instruct\",\n            \"unsloth/Llama-3.1-8B-Instruct-bnb-4bit\",\n        ),\n    },\n    \"unsloth/Meta-Llama-3.1-70B-bnb-4bit\" : (\n        \"unsloth/Meta-Llama-3.1-70B\",\n        \"meta-llama/Meta-Llama-3.1-70B\",\n    ),\n    \"unsloth/Meta-Llama-3.1-405B-bnb-4bit\" : (\n        \"meta-llama/Meta-Llama-3.1-405B\",\n    ),\n    \"unsloth/Meta-Llama-3.1-405B-Instruct-bnb-4bit\" : (\n        \"meta-llama/Meta-Llama-3.1-405B-Instruct\",\n    ),\n    \"unsloth/Meta-Llama-3.1-70B-Instruct-bnb-4bit\" : (\n        \"unsloth/Meta-Llama-3.1-70B-Instruct\",\n        \"meta-llama/Meta-Llama-3.1-70B-Instruct\",\n    ),\n    \"unsloth/Mistral-Large-Instruct-2407-bnb-4bit\" : (\n        \"mistralai/Mistral-Large-Instruct-2407\",\n    ),\n    \"unsloth/gemma-2-2b-bnb-4bit\" : (\n        \"unsloth/gemma-2-2b\",\n        \"google/gemma-2-2b\",\n    ),\n    \"unsloth/gemma-2-2b-it-bnb-4bit\" : (\n        \"unsloth/gemma-2-2b-it\",\n        \"google/gemma-2-2b-it\",\n    ),\n    \"unsloth/Phi-3.5-mini-instruct-bnb-4bit\" : (\n        \"unsloth/Phi-3.5-mini-instruct\",\n        \"microsoft/Phi-3.5-mini-instruct\",\n    ),\n    \"unsloth/c4ai-command-r-08-2024-bnb-4bit\" : (\n        \"CohereForAI/c4ai-command-r-08-2024\",\n    ),\n    \"unsloth/c4ai-command-r-plus-08-2024-bnb-4bit\" : (\n        \"CohereForAI/c4ai-command-r-plus-08-2024\",\n    ),\n    \"unsloth/Llama-3.1-Storm-8B-bnb-4bit\" : (\n        \"unsloth/Llama-3.1-Storm-8B\",\n        \"akjindal53244/Llama-3.1-Storm-8B\",\n    ),\n    \"unsloth/Hermes-3-Llama-3.1-8B-bnb-4bit\" : (\n        \"unsloth/Hermes-3-Llama-3.1-8B\",\n        \"NousResearch/Hermes-3-Llama-3.1-8B\",\n    ),\n    \"unsloth/Hermes-3-Llama-3.1-70B-bnb-4bit\" : (\n        \"unsloth/Hermes-3-Llama-3.1-70B\",\n        \"NousResearch/Hermes-3-Llama-3.1-70B\",\n    ),\n    \"unsloth/Hermes-3-Llama-3.1-405B-bnb-4bit\" : (\n        \"NousResearch/Hermes-3-Llama-3.1-405B\",\n    ),\n    \"unsloth/SmolLM-135M-bnb-4bit\" : (\n        \"unsloth/SmolLM-135M\",\n        \"HuggingFaceTB/SmolLM-135M\",\n    ),\n    \"unsloth/SmolLM-360M-bnb-4bit\" : (\n        \"unsloth/SmolLM-360M\",\n        \"HuggingFaceTB/SmolLM-360M\",\n    ),\n    \"unsloth/SmolLM-1.7B-bnb-4bit\" : (\n        \"unsloth/SmolLM-1.7B\",\n        \"HuggingFaceTB/SmolLM-1.7B\",\n    ),\n    \"unsloth/SmolLM-135M-Instruct-bnb-4bit\" : (\n        \"unsloth/SmolLM-135M-Instruct\",\n        \"HuggingFaceTB/SmolLM-135M-Instruct\",\n    ),\n    \"unsloth/SmolLM-360M-Instruct-bnb-4bit\" : (\n        \"unsloth/SmolLM-360M-Instruct\",\n        \"HuggingFaceTB/SmolLM-360M-Instruct\",\n    ),\n    \"unsloth/SmolLM-1.7B-Instruct-bnb-4bit\" : (\n        \"unsloth/SmolLM-1.7B-Instruct\",\n        \"HuggingFaceTB/SmolLM-1.7B-Instruct\",\n    ),\n    \"unsloth/Mistral-Small-Instruct-2409-bnb-4bit\" : (\n        \"unsloth/Mistral-Small-Instruct-2409\",\n        \"mistralai/Mistral-Small-Instruct-2409\",\n    ),\n    \"unsloth/Qwen2.5-0.5B-Instruct-unsloth-bnb-4bit\" : (\n        \"unsloth/Qwen2.5-0.5B-Instruct\",\n        \"Qwen/Qwen2.5-0.5B-Instruct\",\n        \"unsloth/Qwen2.5-0.5B-Instruct-bnb-4bit\",\n    ),\n    \"unsloth/Qwen2.5-1.5B-Instruct-unsloth-bnb-4bit\" : (\n        \"unsloth/Qwen2.5-1.5B-Instruct\",\n        \"Qwen/Qwen2.5-1.5B-Instruct\",\n        \"unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit\",\n    ),\n    \"unsloth/Qwen2.5-3B-Instruct-unsloth-bnb-4bit\" : (\n        \"unsloth/Qwen2.5-3B-Instruct\",\n        \"Qwen/Qwen2.5-3B-Instruct\",\n        \"unsloth/Qwen2.5-3B-Instruct-bnb-4bit\",\n    ),\n    \"unsloth/Qwen2.5-7B-Instruct-unsloth-bnb-4bit\" : (\n        \"unsloth/Qwen2.5-7B-Instruct\",\n        \"Qwen/Qwen2.5-7B-Instruct\",\n        \"unsloth/Qwen2.5-7B-Instruct-bnb-4bit\",\n    ),\n    \"unsloth/Qwen2.5-14B-Instruct-unsloth-bnb-4bit\" : (\n        \"unsloth/Qwen2.5-14B-Instruct\",\n        \"Qwen/Qwen2.5-14B-Instruct\",\n        \"unsloth/Qwen2.5-14B-Instruct-bnb-4bit\",\n    ),\n    \"unsloth/Qwen2.5-32B-Instruct-bnb-4bit\" : (\n        \"unsloth/Qwen2.5-32B-Instruct\",\n        \"Qwen/Qwen2.5-32B-Instruct\",\n    ),\n    \"unsloth/Qwen2.5-72B-Instruct-bnb-4bit\" : (\n        \"unsloth/Qwen2.5-72B-Instruct\",\n        \"Qwen/Qwen2.5-72B-Instruct\",\n    ),\n    \"unsloth/Qwen2.5-0.5B-unsloth-bnb-4bit\" : (\n        \"unsloth/Qwen2.5-0.5B\",\n        \"Qwen/Qwen2.5-0.5B\",\n        \"unsloth/Qwen2.5-0.5B-bnb-4bit\",\n    ),\n    \"unsloth/Qwen2.5-1.5B-unsloth-bnb-4bit\" : (\n        \"unsloth/Qwen2.5-1.5B\",\n        \"Qwen/Qwen2.5-1.5B\",\n        \"unsloth/Qwen2.5-1.5B-bnb-4bit\",\n    ),\n    \"unsloth/Qwen2.5-3B-unsloth-bnb-4bit\" : (\n        \"unsloth/Qwen2.5-3B\",\n        \"Qwen/Qwen2.5-3B\",\n        \"unsloth/Qwen2.5-3B-bnb-4bit\",\n    ),\n    \"unsloth/Qwen2.5-7B-unsloth-bnb-4bit\" : (\n        \"unsloth/Qwen2.5-7B\",\n        \"Qwen/Qwen2.5-7B\",\n        \"unsloth/Qwen2.5-7B-bnb-4bit\",\n    ),\n    \"unsloth/Qwen2.5-14B-unsloth-bnb-4bit\" : (\n        \"unsloth/Qwen2.5-14B\",\n        \"Qwen/Qwen2.5-14B\",\n        \"unsloth/Qwen2.5-14B-bnb-4bit\",\n    ),\n    \"unsloth/Qwen2.5-32B-bnb-4bit\" : (\n        \"unsloth/Qwen2.5-32B\",\n        \"Qwen/Qwen2.5-32B\",\n    ),\n    \"unsloth/Qwen2.5-72B-bnb-4bit\" : (\n        \"unsloth/Qwen2.5-72B\",\n        \"Qwen/Qwen2.5-72B\",\n    ),\n    \"unsloth/Qwen2.5-Math-1.5B-bnb-4bit\" : (\n        \"unsloth/Qwen2.5-Math-1.5B\",\n        \"Qwen/Qwen2.5-Math-1.5B\",\n    ),\n    \"unsloth/Qwen2.5-Math-7B-bnb-4bit\" : (\n        \"unsloth/Qwen2.5-Math-7B\",\n        \"Qwen/Qwen2.5-Math-7B\",\n    ),\n    \"unsloth/Qwen2.5-Math-72B-bnb-4bit\" : (\n        \"unsloth/Qwen2.5-Math-72B\",\n        \"Qwen/Qwen2.5-Math-72B\",\n    ),\n    \"unsloth/Qwen2.5-Math-1.5B-Instruct-bnb-4bit\" : (\n        \"unsloth/Qwen2.5-Math-1.5B-Instruct\",\n        \"Qwen/Qwen2.5-Math-1.5B-Instruct\",\n    ),\n    \"unsloth/Qwen2.5-Math-7B-Instruct-bnb-4bit\" : (\n        \"unsloth/Qwen2.5-Math-7B-Instruct\",\n        \"Qwen/Qwen2.5-Math-7B-Instruct\",\n    ),\n    \"unsloth/Qwen2.5-Math-72B-Instruct-bnb-4bit\" : (\n        \"unsloth/Qwen2.5-Math-72B-Instruct\",\n        \"Qwen/Qwen2.5-Math-72B-Instruct\",\n    ),\n    \"unsloth/Qwen2.5-Coder-0.5B-bnb-4bit\" : (\n        \"unsloth/Qwen2.5-Coder-0.5B\",\n        \"Qwen/Qwen2.5-Coder-0.5B\",\n    ),\n    \"unsloth/Qwen2.5-Coder-1.5B-bnb-4bit\" : (\n        \"unsloth/Qwen2.5-Coder-1.5B\",\n        \"Qwen/Qwen2.5-Coder-1.5B\",\n    ),\n    \"unsloth/Qwen2.5-Coder-3B-bnb-4bit\" : (\n        \"unsloth/Qwen2.5-Coder-3B\",\n        \"Qwen/Qwen2.5-Coder-3B\",\n    ),\n    \"unsloth/Qwen2.5-Coder-7B-bnb-4bit\" : (\n        \"unsloth/Qwen2.5-Coder-7B\",\n        \"Qwen/Qwen2.5-Coder-7B\",\n    ),\n    \"unsloth/Qwen2.5-Coder-14B-bnb-4bit\" : (\n        \"unsloth/Qwen2.5-Coder-14B\",\n        \"Qwen/Qwen2.5-Coder-14B\",\n    ),\n    \"unsloth/Qwen2.5-Coder-32B-bnb-4bit\" : (\n        \"unsloth/Qwen2.5-Coder-32B\",\n        \"Qwen/Qwen2.5-Coder-32B\",\n    ),\n    \"unsloth/Qwen2.5-Coder-0.5B-Instruct-bnb-4bit\" : (\n        \"unsloth/Qwen2.5-Coder-0.5B-Instruct\",\n        \"Qwen/Qwen2.5-Coder-0.5B-Instruct\",\n    ),\n    \"unsloth/Qwen2.5-Coder-1.5B-Instruct-bnb-4bit\" : (\n        \"unsloth/Qwen2.5-Coder-1.5B-Instruct\",\n        \"Qwen/Qwen2.5-Coder-1.5B-Instruct\",\n    ),\n    \"unsloth/Qwen2.5-Coder-3B-Instruct-bnb-4bit\" : (\n        \"unsloth/Qwen2.5-Coder-3B-Instruct\",\n        \"Qwen/Qwen2.5-Coder-3B-Instruct\",\n    ),\n    \"unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit\" : (\n        \"unsloth/Qwen2.5-Coder-7B-Instruct\",\n        \"Qwen/Qwen2.5-Coder-7B-Instruct\",\n    ),\n    \"unsloth/Qwen2.5-Coder-14B-Instruct-bnb-4bit\" : (\n        \"unsloth/Qwen2.5-Coder-14B-Instruct\",\n        \"Qwen/Qwen2.5-Coder-14B-Instruct\",\n    ),\n    \"unsloth/Qwen2.5-Coder-32B-Instruct-bnb-4bit\" : (\n        \"unsloth/Qwen2.5-Coder-32B-Instruct\",\n        \"Qwen/Qwen2.5-Coder-32B-Instruct\",\n    ),\n    \"unsloth/Llama-3.2-1B-unsloth-bnb-4bit\" : (\n        \"unsloth/Llama-3.2-1B\",\n        \"meta-llama/Llama-3.2-1B\",\n        \"unsloth/Llama-3.2-1B-bnb-4bit\",\n    ),\n    \"unsloth/Llama-3.2-3B-unsloth-bnb-4bit\" : (\n        \"unsloth/Llama-3.2-3B\",\n        \"meta-llama/Llama-3.2-3B\",\n        \"unsloth/Llama-3.2-3B-bnb-4bit\",\n    ),\n    \"unsloth/Llama-3.2-1B-Instruct-unsloth-bnb-4bit\" : {\n        \"8\": (\n            \"RedHatAI/Llama-3.2-1B-Instruct-FP8\",\n            \"unsloth/Llama-3.2-1B-Instruct-FP8-Block\",\n            \"unsloth/Llama-3.2-1B-Instruct-FP8-Dynamic\",\n        ),\n        \"16\" : (\n            \"unsloth/Llama-3.2-1B-Instruct\",\n            \"meta-llama/Llama-3.2-1B-Instruct\",\n            \"unsloth/Llama-3.2-1B-Instruct-bnb-4bit\",\n        ),\n    },\n    \"unsloth/Llama-3.2-3B-Instruct-unsloth-bnb-4bit\" : {\n        \"8\": (\n            \"RedHatAI/Llama-3.2-3B-Instruct-FP8\",\n            \"unsloth/Llama-3.2-3B-Instruct-FP8-Block\",\n            \"unsloth/Llama-3.2-3B-Instruct-FP8-Dynamic\",\n        ),\n        \"16\" : (\n            \"unsloth/Llama-3.2-3B-Instruct\",\n            \"meta-llama/Llama-3.2-3B-Instruct\",\n            \"unsloth/Llama-3.2-3B-Instruct-bnb-4bit\",\n        ),\n    },\n    \"unsloth/Llama-3.1-Nemotron-70B-Instruct-bnb-4bit\" : (\n        \"unsloth/Llama-3.1-Nemotron-70B-Instruct\",\n        \"nvidia/Llama-3.1-Nemotron-70B-Instruct-HF\",\n    ),\n    \"unsloth/Qwen2-VL-2B-Instruct-unsloth-bnb-4bit\" : (\n        \"unsloth/Qwen2-VL-2B-Instruct\",\n        \"Qwen/Qwen2-VL-2B-Instruct\",\n        \"unsloth/Qwen2-VL-2B-Instruct-bnb-4bit\",\n    ),\n    \"unsloth/Qwen2-VL-7B-Instruct-unsloth-bnb-4bit\" : (\n        \"unsloth/Qwen2-VL-7B-Instruct\",\n        \"Qwen/Qwen2-VL-7B-Instruct\",\n        \"unsloth/Qwen2-VL-7B-Instruct-bnb-4bit\",\n    ),\n    \"unsloth/Qwen2-VL-72B-Instruct-bnb-4bit\" : (\n        \"unsloth/Qwen2-VL-72B-Instruct\",\n        \"Qwen/Qwen2-VL-72B-Instruct\",\n    ),\n    \"unsloth/Qwen2-VL-2B-bnb-4bit\" : (\n        \"unsloth/Qwen2-VL-2B\",\n        \"Qwen/Qwen2-VL-2B\",\n    ),\n    \"unsloth/Qwen2-VL-7B-bnb-4bit\" : (\n        \"unsloth/Qwen2-VL-7B\",\n        \"Qwen/Qwen2-VL-7B\",\n    ),\n    \"unsloth/Qwen2-VL-72B-bnb-4bit\" : (\n        \"unsloth/Qwen2-VL-72B\",\n        \"Qwen/Qwen2-VL-72B\",\n    ),\n    \"unsloth/Llama-3.2-11B-Vision-Instruct-unsloth-bnb-4bit\" : (\n        \"unsloth/Llama-3.2-11B-Vision-Instruct\",\n        \"meta-llama/Llama-3.2-11B-Vision-Instruct\",\n        \"unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit\",\n    ),\n    \"unsloth/Llama-3.2-90B-Vision-Instruct-bnb-4bit\" : (\n        \"unsloth/Llama-3.2-90B-Vision-Instruct\",\n        \"meta-llama/Llama-3.2-90B-Vision-Instruct\",\n    ),\n    \"unsloth/Llama-3.2-11B-Vision-unsloth-bnb-4bit\" : (\n        \"unsloth/Llama-3.2-11B-Vision\",\n        \"meta-llama/Llama-3.2-11B-Vision\",\n        \"unsloth/Llama-3.2-11B-Vision-bnb-4bit\",\n    ),\n    \"unsloth/Llama-3.2-90B-Vision-bnb-4bit\" : (\n        \"unsloth/Llama-3.2-90B-Vision\",\n        \"meta-llama/Llama-3.2-90B-Vision\",\n    ),\n    \"unsloth/Pixtral-12B-2409-unsloth-bnb-4bit\" : (\n        \"unsloth/Pixtral-12B-2409\",\n        \"mistralai/Pixtral-12B-2409\",\n        \"unsloth/Pixtral-12B-2409-bnb-4bit\",\n    ),\n    \"unsloth/Pixtral-12B-2409-Base-bnb-4bit\" : (\n        \"unsloth/Pixtral-12B-Base-2409\",\n        \"mistralai/Pixtral-12B-Base-2409\",\n    ),\n    \"unsloth/llava-1.5-7b-hf-bnb-4bit\" : (\n        \"unsloth/llava-1.5-7b-hf\",\n        \"llava-hf/llava-1.5-7b-hf\",\n    ),\n    \"unsloth/llava-v1.6-mistral-7b-hf-bnb-4bit\" : (\n        \"unsloth/llava-v1.6-mistral-7b-hf\",\n        \"llava-hf/llava-v1.6-mistral-7b-hf\",\n    ),\n    \"unsloth/Llama-3.1-Tulu-3-8B-bnb-4bit\" : (\n        \"unsloth/Llama-3.1-Tulu-3-8B\",\n        \"allenai/Llama-3.1-Tulu-3-8B\",\n    ),\n    \"unsloth/Llama-3.1-Tulu-3-70B-bnb-4bit\" : (\n        \"unsloth/Llama-3.1-Tulu-3-70B\",\n        \"allenai/Llama-3.1-Tulu-3-70B\",\n    ),\n    \"unsloth/QwQ-32B-Preview-bnb-4bit\" : (\n        \"unsloth/QwQ-32B-Preview\",\n        \"Qwen/QwQ-32B-Preview\",\n    ),\n    \"unsloth/Llama-3.3-70B-Instruct-unsloth-bnb-4bit\" : {\n        \"8\" : (\n            \"RedHatAI/Llama-3.3-70B-Instruct-FP8\",\n            \"unsloth/Llama-3.3-70B-Instruct-FP8-Block\",\n            \"unsloth/Llama-3.3-70B-Instruct-FP8-Dynamic\",\n        ),\n        \"16\" : (\n            \"unsloth/Llama-3.3-70B-Instruct\",\n            \"meta-llama/Llama-3.3-70B-Instruct\",\n            \"unsloth/Llama-3.3-70B-Instruct-bnb-4bit\",\n        ),\n    },\n    \"unsloth/phi-4-unsloth-bnb-4bit\" : (\n        \"unsloth/phi-4\",\n        \"microsoft/phi-4\",\n        \"unsloth/phi-4-bnb-4bit\",\n    ),\n    \"unsloth/DeepSeek-R1-Distill-Qwen-32B-bnb-4bit\" : (\n        \"unsloth/DeepSeek-R1-Distill-Qwen-32B\",\n        \"deepseek-ai/DeepSeek-R1-Distill-Qwen-32B\",\n    ),\n    \"unsloth/DeepSeek-R1-Distill-Qwen-14B-unsloth-bnb-4bit\" : (\n        \"unsloth/DeepSeek-R1-Distill-Qwen-14B\",\n        \"deepseek-ai/DeepSeek-R1-Distill-Qwen-14B\",\n        \"unsloth/DeepSeek-R1-Distill-Qwen-14B-bnb-4bit\",\n    ),\n    \"unsloth/DeepSeek-R1-Distill-Qwen-7B-unsloth-bnb-4bit\" : (\n        \"unsloth/DeepSeek-R1-Distill-Qwen-7B\",\n        \"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\",\n        \"unsloth/DeepSeek-R1-Distill-Qwen-7B-bnb-4bit\",\n    ),\n    \"unsloth/DeepSeek-R1-Distill-Qwen-1.5B-unsloth-bnb-4bit\" : (\n        \"unsloth/DeepSeek-R1-Distill-Qwen-1.5B\",\n        \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\",\n        \"unsloth/DeepSeek-R1-Distill-Qwen-1.5B-bnb-4bit\",\n    ),\n    \"unsloth/DeepSeek-R1-Distill-Llama-8B-unsloth-bnb-4bit\" : (\n        \"unsloth/DeepSeek-R1-Distill-Llama-8B\",\n        \"deepseek-ai/DeepSeek-R1-Distill-Llama-8B\",\n        \"unsloth/DeepSeek-R1-Distill-Llama-8B-bnb-4bit\",\n    ),\n    \"unsloth/DeepSeek-R1-Distill-Llama-70B-bnb-4bit\" : (\n        \"unsloth/DeepSeek-R1-Distill-Llama-70B\",\n        \"deepseek-ai/DeepSeek-R1-Distill-Llama-70B\",\n    ),\n    \"unsloth/Mistral-Small-24B-Base-2501-unsloth-bnb-4bit\" : (\n        \"unsloth/Mistral-Small-24B-Base-2501\",\n        \"mistralai/Mistral-Small-24B-Base-2501\",\n        \"unsloth/Mistral-Small-24B-Base-2501-bnb-4bit\",\n    ),\n    \"unsloth/Mistral-Small-24B-Instruct-2501-unsloth-bnb-4bit\" : (\n        \"unsloth/Mistral-Small-24B-Instruct-2501\",\n        \"mistralai/Mistral-Small-24B-Instruct-2501\",\n        \"unsloth/Mistral-Small-24B-Instruct-2501-bnb-4bit\",\n    ),\n    \"unsloth/Qwen2.5-VL-3B-Instruct-unsloth-bnb-4bit\" : (\n        \"unsloth/Qwen2.5-VL-3B-Instruct\",\n        \"Qwen/Qwen2.5-VL-3B-Instruct\",\n        \"unsloth/Qwen2.5-VL-3B-Instruct-bnb-4bit\",\n    ),\n    \"unsloth/Qwen2.5-VL-7B-Instruct-unsloth-bnb-4bit\" : (\n        \"unsloth/Qwen2.5-VL-7B-Instruct\",\n        \"Qwen/Qwen2.5-VL-7B-Instruct\",\n        \"unsloth/Qwen2.5-VL-7B-Instruct-bnb-4bit\",\n    ),\n    \"unsloth/Qwen2.5-VL-32B-Instruct-unsloth-bnb-4bit\" : (\n        \"unsloth/Qwen2.5-VL-32B-Instruct\",\n        \"Qwen/Qwen2.5-VL-32B-Instruct\",\n        \"unsloth/Qwen2.5-VL-32B-Instruct-bnb-4bit\",\n    ),\n    \"unsloth/Qwen2.5-VL-72B-Instruct-unsloth-bnb-4bit\" : (\n        \"unsloth/Qwen2.5-VL-72B-Instruct\",\n        \"Qwen/Qwen2.5-VL-72B-Instruct\",\n        \"unsloth/Qwen2.5-VL-72B-Instruct-bnb-4bit\",\n    ),\n    \"unsloth/DeepScaleR-1.5B-Preview-unsloth-bnb-4bit\" : (\n        \"unsloth/DeepHermes-3-Llama-3-8B-Preview\",\n        \"agentica-org/DeepScaleR-1.5B-Preview\",\n        \"unsloth/DeepScaleR-1.5B-Preview-bnb-4bit\",\n    ),\n    \"unsloth/OpenThinker-7B-unsloth-bnb-4bit\" : (\n        \"unsloth/OpenThinker-7B\",\n        \"open-thoughts/OpenThinker-7B\",\n        \"unsloth/OpenThinker-7B-bnb-4bit\",\n    ),\n    \"unsloth/granite-3.2-2b-instruct-unsloth-bnb-4bit\" : (\n        \"unsloth/granite-3.2-2b-instruct\",\n        \"ibm-granite/granite-3.2-2b-instruct\",\n        \"unsloth/granite-3.2-2b-instruct-bnb-4bit\",\n    ),\n    \"unsloth/granite-3.2-8b-instruct-unsloth-bnb-4bit\" : (\n        \"unsloth/granite-3.2-8b-instruct\",\n        \"ibm-granite/granite-3.2-8b-instruct\",\n        \"unsloth/granite-3.2-8b-instruct-bnb-4bit\",\n    ),\n    \"unsloth/QwQ-32B-unsloth-bnb-4bit\" : (\n        \"unsloth/QwQ-32B\",\n        \"Qwen/QwQ-32B\",\n        \"unsloth/QwQ-32B-bnb-4bit\",\n    ),\n    \"unsloth/gemma-3-1b-it-unsloth-bnb-4bit\" : (\n        \"unsloth/gemma-3-1b-it\",\n        \"google/gemma-3-1b-it\",\n        \"unsloth/gemma-3-1b-it-bnb-4bit\",\n    ),\n    \"unsloth/gemma-3-4b-it-unsloth-bnb-4bit\" : (\n        \"unsloth/gemma-3-4b-it\",\n        \"google/gemma-3-4b-it\",\n        \"unsloth/gemma-3-4b-it-bnb-4bit\",\n    ),\n    \"unsloth/gemma-3-12b-it-unsloth-bnb-4bit\" : (\n        \"unsloth/gemma-3-12b-it\",\n        \"google/gemma-3-12b-it\",\n        \"unsloth/gemma-3-12b-it-bnb-4bit\",\n    ),\n    \"unsloth/gemma-3-27b-it-unsloth-bnb-4bit\" : (\n        \"unsloth/gemma-3-27b-it\",\n        \"google/gemma-3-27b-it\",\n        \"unsloth/gemma-3-27b-it-bnb-4bit\",\n    ),\n    \"unsloth/gemma-3-1b-pt-unsloth-bnb-4bit\" : (\n        \"unsloth/gemma-3-1b-pt\",\n        \"google/gemma-3-1b-pt\",\n        \"unsloth/gemma-3-1b-pt-bnb-4bit\",\n    ),\n    \"unsloth/gemma-3-4b-pt-unsloth-bnb-4bit\" : (\n        \"unsloth/gemma-3-4b-pt\",\n        \"google/gemma-3-4b-pt\",\n        \"unsloth/gemma-3-4b-pt-bnb-4bit\",\n    ),\n    \"unsloth/gemma-3-12b-pt-unsloth-bnb-4bit\" : (\n        \"unsloth/gemma-3-12b-pt\",\n        \"google/gemma-3-12b-pt\",\n        \"unsloth/gemma-3-12b-pt-bnb-4bit\",\n    ),\n    \"unsloth/gemma-3-27b-pt-unsloth-bnb-4bit\" : (\n        \"unsloth/gemma-3-27b-pt\",\n        \"google/gemma-3-27b-pt\",\n        \"unsloth/gemma-3-27b-pt-bnb-4bit\",\n    ),\n    \"unsloth/reka-flash-3-unsloth-bnb-4bit\" : (\n        \"unsloth/reka-flash-3\",\n        \"RekaAI/reka-flash-3\",\n        \"unsloth/reka-flash-3-bnb-4bit\",\n    ),\n    \"unsloth/c4ai-command-a-03-2025-unsloth-bnb-4bit\" : (\n        \"unsloth/c4ai-command-a-03-2025\",\n        \"CohereForAI/c4ai-command-a-03-2025\",\n        \"unsloth/c4ai-command-a-03-2025-bnb-4bit\",\n    ),\n    \"unsloth/aya-vision-32b-unsloth-bnb-4bit\" : (\n        \"unsloth/aya-vision-32b\",\n        \"CohereForAI/aya-vision-32b\",\n        \"unsloth/aya-vision-32b-bnb-4bit\",\n    ),\n    \"unsloth/aya-vision-8b-unsloth-bnb-4bit\" : (\n        \"unsloth/aya-vision-8b\",\n        \"CohereForAI/aya-vision-8b\",\n        \"unsloth/aya-vision-8b-bnb-4bit\",\n    ),\n    \"unsloth/granite-vision-3.2-2b-unsloth-bnb-4bit\" : (\n        \"unsloth/granite-vision-3.2-2b\",\n        \"ibm-granite/granite-vision-3.2-2b\",\n        \"unsloth/granite-vision-3.2-2b-bnb-4bit\",\n    ),\n    \"unsloth/OLMo-2-0325-32B-Instruct-unsloth-bnb-4bit\" : (\n        \"unsloth/OLMo-2-0325-32B-Instruct\",\n        \"allenai/OLMo-2-0325-32B-Instruct\",\n        \"unsloth/OLMo-2-0325-32B-Instruct-bnb-4bit\",\n    ),\n    \"unsloth/Mistral-Small-3.1-24B-Instruct-2503-unsloth-bnb-4bit\" : (\n        \"unsloth/Mistral-Small-3.1-24B-Instruct-2503\",\n        \"mistralai/Mistral-Small-3.1-24B-Instruct-2503\",\n        \"unsloth/Mistral-Small-3.1-24B-Instruct-2503-bnb-4bit\",\n    ),\n    \"unsloth/Mistral-Small-3.1-24B-Base-2503-unsloth-bnb-4bit\" : (\n        \"unsloth/Mistral-Small-3.1-24B-Base-2503\",\n        \"mistralai/Mistral-Small-3.1-24B-Base-2503\",\n        \"unsloth/Mistral-Small-3.1-24B-Base-2503-bnb-4bit\",\n    ),\n    \"unsloth/Qwen3-0.6B-unsloth-bnb-4bit\" : {\n        \"8\" : (\n            \"Qwen/Qwen3-0.6B-FP8\",\n            \"unsloth/Qwen3-0.6B-FP8\",\n            \"unsloth/Qwen3-0.6B-FP8\",\n        ),\n        \"16\" : (\n            \"unsloth/Qwen3-0.6B\",\n            \"Qwen/Qwen3-0.6B\",\n            \"unsloth/Qwen3-0.6B-bnb-4bit\",\n        ),\n    },\n    \"unsloth/Qwen3-1.7B-unsloth-bnb-4bit\" : {\n        \"8\" : (\n            \"Qwen/Qwen3-1.7B-FP8\",\n            \"unsloth/Qwen3-1.7B-FP8\",\n            \"unsloth/Qwen3-1.7B-FP8\",\n        ),\n        \"16\" : (\n            \"unsloth/Qwen3-1.7B\",\n            \"Qwen/Qwen3-1.7B\",\n            \"unsloth/Qwen3-1.7B-bnb-4bit\",\n        ),\n    },\n    \"unsloth/Qwen3-4B-unsloth-bnb-4bit\" : {\n        \"8\" : (\n            \"Qwen/Qwen3-4B-FP8\",\n            \"unsloth/Qwen3-4B-FP8\",\n            \"unsloth/Qwen3-4B-FP8\",\n        ),\n        \"16\" : (\n            \"unsloth/Qwen3-4B\",\n            \"Qwen/Qwen3-4B\",\n            \"unsloth/Qwen3-4B-bnb-4bit\",\n        ),\n    },\n    \"unsloth/Qwen3-8B-unsloth-bnb-4bit\" : {\n        \"8\" : (\n            \"Qwen/Qwen3-8B-FP8\",\n            \"unsloth/Qwen3-8B-FP8\",\n            \"unsloth/Qwen3-8B-FP8\",\n        ),\n        \"16\" : (\n            \"unsloth/Qwen3-8B\",\n            \"Qwen/Qwen3-8B\",\n            \"unsloth/Qwen3-8B-bnb-4bit\",\n        ),\n    },\n    \"unsloth/Qwen3-14B-unsloth-bnb-4bit\" : {\n        \"8\" : (\n            \"Qwen/Qwen3-14B-FP8\",\n            \"unsloth/Qwen3-14B-FP8\",\n            \"unsloth/Qwen3-14B-FP8\",\n        ),\n        \"16\" : (\n            \"unsloth/Qwen3-14B\",\n            \"Qwen/Qwen3-14B\",\n            \"unsloth/Qwen3-14B-bnb-4bit\",\n        ),\n    },\n    \"unsloth/Qwen3-32B-unsloth-bnb-4bit\" : {\n        \"8\" : (\n            \"Qwen/Qwen3-32B-FP8\",\n            \"unsloth/Qwen3-32B-FP8\",\n            \"unsloth/Qwen3-32B-FP8\",\n        ),\n        \"16\" : (\n            \"unsloth/Qwen3-32B\",\n            \"Qwen/Qwen3-32B\",\n            \"unsloth/Qwen3-32B-bnb-4bit\",\n        ),\n    },\n    \"unsloth/Qwen3-30B-A3B-unsloth-bnb-4bit\" : (\n        \"unsloth/Qwen3-30B-A3B\",\n        \"Qwen/Qwen3-30B-A3B\",\n        \"unsloth/Qwen3-30B-A3B-bnb-4bit\",\n    ),\n    \"unsloth/Qwen3-0.6B-Base-unsloth-bnb-4bit\" : (\n        \"unsloth/Qwen3-0.6B-Base\",\n        \"Qwen/Qwen3-0.6B-Base\",\n        \"unsloth/Qwen3-0.6B-Base-bnb-4bit\",\n    ),\n    \"unsloth/Qwen3-1.7B-Base-unsloth-bnb-4bit\" : (\n        \"unsloth/Qwen3-1.7B-Base\",\n        \"Qwen/Qwen3-1.7B-Base\",\n        \"unsloth/Qwen3-1.7B-Base-bnb-4bit\",\n    ),\n    \"unsloth/Qwen3-4B-Base-unsloth-bnb-4bit\" : (\n        \"unsloth/Qwen3-4B-Base\",\n        \"Qwen/Qwen3-4B-Base\",\n        \"unsloth/Qwen3-4B-Base-bnb-4bit\",\n    ),\n    \"unsloth/Qwen3-8B-Base-unsloth-bnb-4bit\" : (\n        \"unsloth/Qwen3-8B-Base\",\n        \"Qwen/Qwen3-8B-Base\",\n        \"unsloth/Qwen3-8B-Base-bnb-4bit\",\n    ),\n    \"unsloth/Qwen3-14B-Base-unsloth-bnb-4bit\" : (\n        \"unsloth/Qwen3-14B-Base\",\n        \"Qwen/Qwen3-14B-Base\",\n        \"unsloth/Qwen3-14B-Base-bnb-4bit\",\n    ),\n    \"unsloth/Qwen3-30B-A3B-Base-bnb-4bit\" : (\n        \"unsloth/Qwen3-30B-A3B-Base\",\n        \"Qwen/Qwen3-30B-A3B-Base\",\n    ),\n    \"unsloth/phi-4-reasoning-unsloth-bnb-4bit\" : (\n        \"unsloth/phi-4-reasoning\",\n        \"microsoft/Phi-4-reasoning\",\n        \"unsloth/phi-4-reasoning-bnb-4bit\",\n    ),\n    \"unsloth/phi-4-reasoning-plus-unsloth-bnb-4bit\" : (\n        \"unsloth/phi-4-reasoning-plus\",\n        \"microsoft/Phi-4-reasoning-plus\",\n        \"unsloth/phi-4-reasoning-plus-bnb-4bit\",\n    ),\n    \"unsloth/phi-4-mini-reasoning-unsloth-bnb-4bit\" : (\n        \"unsloth/phi-4-mini-reasoning\",\n        \"microsoft/Phi-4-mini-reasoning\",\n        \"unsloth/phi-4-mini-reasoning-bnb-4bit\",\n    ),\n    \"unsloth/Phi-4-mini-instruct-unsloth-bnb-4bit\" : (\n        \"unsloth/Phi-4-mini-instruct\",\n        \"microsoft/Phi-4-mini-instruct\",\n        \"unsloth/Phi-4-mini-instruct-bnb-4bit\",\n    ),\n    \"unsloth/orpheus-3b-0.1-pretrained-unsloth-bnb-4bit\" : (\n        \"unsloth/orpheus-3b-0.1-pretrained\",\n        \"canopylabs/orpheus-3b-0.1-pretrained\",\n        \"unsloth/orpheus-3b-0.1-pretrained-bnb-4bit\",\n    ),\n    \"unsloth/orpheus-3b-0.1-ft-unsloth-bnb-4bit\" : (\n        \"unsloth/orpheus-3b-0.1-ft\",\n        \"canopylabs/orpheus-3b-0.1-ft\",\n        \"unsloth/orpheus-3b-0.1-ft-bnb-4bit\",\n    ),\n    \"unsloth/csm-1b\" : (\n        \"unsloth/csm-1b\",\n        \"sesame/csm-1b\",\n    ),\n    \"unsloth/whisper-large-v3\" : (\n        \"unsloth/whisper-large-v3\",\n        \"openai/whisper-large-v3\",\n    ),\n    \"unsloth/whisper-large-v3-turbo\" : (\n        \"unsloth/whisper-large-v3-turbo\",\n        \"openai/whisper-large-v3-turbo\",\n    ),\n    \"unsloth/whisper-small\" : (\n        \"unsloth/whisper-small\",\n        \"openai/whisper-small\",\n    ),\n    \"unsloth/CrisperWhisper\" : (\n        \"unsloth/CrisperWhisper\",\n        \"nyrahealth/CrisperWhisper\",\n    ),\n    \"unsloth/Llasa-1B\" : (\n        \"unsloth/Llasa-1B\",\n        \"HKUSTAudio/Llasa-1B\",\n    ),\n    \"unsloth/Spark-TTS-0.5B\" : (\n        \"unsloth/Spark-TTS-0.5B\",\n        \"SparkAudio/Spark-TTS-0.5B\",\n    ),\n    \"unsloth/Llama-OuteTTS-1.0-1B\" : (\n        \"unsloth/Llama-OuteTTS-1.0-1B\",\n        \"OuteAI/Llama-OuteTTS-1.0-1B\",\n    ),\n    \"unsloth/medgemma-4b-it-unsloth-bnb-4bit\" : (\n        \"unsloth/medgemma-4b-it\",\n        \"google/medgemma-4b-it\",\n        \"unsloth/medgemma-4b-it-bnb-4bit\",\n    ),\n    \"unsloth/medgemma-27b-text-it-unsloth-bnb-4bit\" : (\n        \"unsloth/medgemma-27b-text-it\",\n        \"google/medgemma-27b-text-it\",\n        \"unsloth/medgemma-27b-text-it-bnb-4bit\",\n    ),\n    \"unsloth/Devstral-Small-2505-unsloth-bnb-4bit\" : (\n        \"unsloth/Devstral-Small-2505\",\n        \"mistralai/Devstral-Small-2505\",\n        \"unsloth/Devstral-Small-2505-bnb-4bit\",\n    ),\n    \"unsloth/DeepSeek-R1-0528-Qwen3-8B-unsloth-bnb-4bit\" : (\n        \"unsloth/DeepSeek-R1-0528-Qwen3-8B\",\n        \"deepseek-ai/DeepSeek-R1-0528-Qwen3-8B\",\n        \"unsloth/DeepSeek-R1-0528-Qwen3-8B-bnb-4bit\",\n    ),\n    \"unsloth/Magistral-Small-2506-unsloth-bnb-4bit\" : (\n        \"unsloth/Magistral-Small-2506\",\n        \"mistralai/Magistral-Small-2506\",\n        \"unsloth/Magistral-Small-2506-bnb-4bit\",\n    ),\n    \"unsloth/Mistral-Small-3.2-24B-Instruct-2506-unsloth-bnb-4bit\" : {\n        \"8\" : (\n            \"mistralai/Mistral-Small-3.2-24B-Instruct-2506\",\n            \"unsloth/Mistral-Small-3.2-24B-Instruct-2506-FP8\",\n            \"unsloth/Mistral-Small-3.2-24B-Instruct-2506-FP8\",\n        ),\n        \"16\" : (\n            \"unsloth/Mistral-Small-3.2-24B-Instruct-2506\",\n            \"mistralai/Mistral-Small-3.2-24B-Instruct-2506\",\n            \"unsloth/Mistral-Small-3.2-24B-Instruct-2506-bnb-4bit\",\n        ),\n    },\n    \"unsloth/gemma-3n-E4B-it-unsloth-bnb-4bit\" : (\n        \"unsloth/gemma-3n-E4B-it\",\n        \"google/gemma-3n-E4B-it\",\n        \"unsloth/gemma-3n-E4B-it-unsloth-bnb-4bit\",\n    ),\n    \"unsloth/gemma-3n-E2B-it-unsloth-bnb-4bit\" : (\n        \"unsloth/gemma-3n-E2B-it\",\n        \"google/gemma-3n-E2B-it\",\n        \"unsloth/gemma-3n-E2B-it-unsloth-bnb-4bit\",\n    ),\n    \"unsloth/gemma-3n-E4B-unsloth-bnb-4bit\" : (\n        \"unsloth/gemma-3n-E4B\",\n        \"google/gemma-3n-E4B\",\n        \"unsloth/gemma-3n-E4B-unsloth-bnb-4bit\",\n    ),\n    \"unsloth/gemma-3n-E2B-unsloth-bnb-4bit\" : (\n        \"unsloth/gemma-3n-E2B\",\n        \"google/gemma-3n-E2B\",\n        \"unsloth/gemma-3n-E2B-unsloth-bnb-4bit\",\n    ),\n    \"unsloth/Devstral-Small-2507-unsloth-bnb-4bit\" : (\n        \"unsloth/Devstral-Small-2507\",\n        \"mistralai/Devstral-Small-2507\",\n        \"unsloth/Devstral-Small-2507-bnb-4bit\",\n    ),\n    \"unsloth/Qwen3-30B-A3B-Thinking-2507\" : (\n        \"unsloth/Qwen3-30B-A3B-Thinking-2507\",\n        \"Qwen/Qwen3-30B-A3B-Thinking-2507\",\n    ),\n    \"unsloth/Qwen3-30B-A3B-Instruct-2507\" : (\n        \"unsloth/Qwen3-30B-A3B-Instruct-2507\",\n        \"Qwen/Qwen3-30B-A3B-Instruct-2507\",\n    ),\n    \"unsloth/Qwen3-Coder-30B-A3B-Instruct\" : (\n        \"unsloth/Qwen3-Coder-30B-A3B-Instruct\",\n        \"Qwen/Qwen3-Coder-30B-A3B-Instruct\",\n    ),\n    \"unsloth/gpt-oss-20b-unsloth-bnb-4bit\" : (\n        \"unsloth/gpt-oss-20b\",\n        \"openai/gpt-oss-20b\",\n        \"unsloth/gpt-oss-20b-unsloth-bnb-4bit\",\n    ),\n    \"unsloth/gpt-oss-120b-unsloth-bnb-4bit\" : (\n        \"unsloth/gpt-oss-120b\",\n        \"openai/gpt-oss-120b\",\n        \"unsloth/gpt-oss-120b-unsloth-bnb-4bit\",\n    ),\n    \"unsloth/Qwen3-4B-Instruct-2507-unsloth-bnb-4bit\" : {\n        \"8\" : (\n            \"Qwen/Qwen3-4B-Instruct-2507-FP8\",\n            \"unsloth/Qwen3-4B-Instruct-2507-FP8\",\n            \"unsloth/Qwen3-4B-Instruct-2507-FP8\",\n        ),\n        \"16\" : (\n            \"unsloth/Qwen3-4B-Instruct-2507\",\n            \"Qwen/Qwen3-4B-Instruct-2507\",\n            \"unsloth/Qwen3-4B-Instruct-2507-bnb-4bit\",\n        ),\n    },\n    \"unsloth/Qwen3-4B-Thinking-2507-unsloth-bnb-4bit\" : {\n        \"8\" : (\n            \"Qwen/Qwen3-4B-Thinking-2507-FP8\",\n            \"unsloth/Qwen3-4B-Thinking-2507-FP8\",\n            \"unsloth/Qwen3-4B-Thinking-2507-FP8\",\n        ),\n        \"16\" : (\n            \"unsloth/Qwen3-4B-Thinking-2507\",\n            \"Qwen/Qwen3-4B-Thinking-2507\",\n            \"unsloth/Qwen3-4B-Thinking-2507-bnb-4bit\",\n        ),\n    },\n    \"unsloth/gemma-3-270m-it-unsloth-bnb-4bit\" : (\n        \"unsloth/gemma-3-270m-it\",\n        \"google/gemma-3-270m-it\",\n        \"unsloth/gemma-3-270m-it-bnb-4bit\",\n    ),\n    \"unsloth/gemma-3-270m-unsloth-bnb-4bit\" : (\n        \"unsloth/gemma-3-270m\",\n        \"google/gemma-3-270m\",\n        \"unsloth/gemma-3-270m-bnb-4bit\",\n    ),\n    \"unsloth/Magistral-Small-2507-unsloth-bnb-4bit\" : (\n        \"unsloth/Magistral-Small-2507\",\n        \"mistralai/Magistral-Small-2507\",\n        \"unsloth/Magistral-Small-2507-bnb-4bit\",\n    ),\n    \"unsloth/Magistral-Small-2509-unsloth-bnb-4bit\" : {\n        \"8\" : (\n            \"mistralai/Magistral-Small-2509\",\n            \"unsloth/Magistral-Small-2509-FP8-Dynamic\",\n            \"unsloth/Magistral-Small-2509-FP8-Dynamic\",\n        ),\n        \"16\" : (\n            \"unsloth/Magistral-Small-2509\",\n            \"mistralai/Magistral-Small-2509\",\n            \"unsloth/Magistral-Small-2509-bnb-4bit\",\n        ),\n    },\n    \"unsloth/Apertus-70B-Instruct-2509-unsloth-bnb-4bit\" : (\n        \"unsloth/Apertus-70B-Instruct-2509\",\n        \"swiss-ai/Apertus-70B-2509\",\n        \"unsloth/Apertus-70B-Instruct-2509-unsloth-bnb-4bit\",\n    ),\n    \"unsloth/Apertus-8B-Instruct-2509-unsloth-bnb-4bit\" : (\n        \"unsloth/Apertus-8B-Instruct-2509\",\n        \"swiss-ai/Apertus-8B-2509\",\n        \"unsloth/Apertus-8B-Instruct-2509-unsloth-bnb-4bit\",\n    ),\n    \"unsloth/granite-4.0-micro-unsloth-bnb-4bit\" : (\n        \"unsloth/granite-4.0-micro\",\n        \"ibm-granite/granite-4.0-micro\",\n        \"unsloth/granite-4.0-micro-bnb-4bit\",\n    ),\n    \"unsloth/granite-4.0-h-micro-unsloth-bnb-4bit\" : (\n        \"unsloth/granite-4.0-h-micro\",\n        \"ibm-granite/granite-4.0-h-micro\",\n        \"unsloth/granite-4.0-h-micro-bnb-4bit\",\n    ),\n    \"unsloth/granite-4.0-micro-base-unsloth-bnb-4bit\" : (\n        \"unsloth/granite-4.0-micro-base\",\n        \"ibm-granite/granite-4.0-micro-base\",\n        \"unsloth/granite-4.0-micro-base-bnb-4bit\",\n    ),\n    \"unsloth/granite-4.0-h-micro-base-unsloth-bnb-4bit\" : (\n        \"unsloth/granite-4.0-h-micro-base\",\n        \"ibm-granite/granite-4.0-h-micro-base\",\n        \"unsloth/granite-4.0-h-micro-base-bnb-4bit\",\n    ),\n    \"unsloth/granite-4.0-h-tiny\" : (\n        \"unsloth/granite-4.0-h-tiny\",\n        \"ibm-granite/granite-4.0-h-tiny\",\n    ),\n    \"unsloth/granite-4.0-h-small\" : (\n        \"unsloth/granite-4.0-h-small\",\n        \"ibm-granite/granite-4.0-h-small\",\n    ),\n    \"unsloth/granite-4.0-h-tiny-base\" : (\n        \"unsloth/granite-4.0-h-tiny-base\",\n        \"ibm-granite/granite-4.0-h-tiny-base\",\n    ),\n    \"unsloth/granite-4.0-h-small-base\" : (\n        \"unsloth/granite-4.0-h-small-base\",\n        \"ibm-granite/granite-4.0-h-small-base\",\n    ),\n    \"unsloth/Qwen3-VL-4B-Thinking-unsloth-bnb-4bit\" : {\n        \"8\" : (\n            \"Qwen/Qwen3-VL-4B-Thinking-FP8\",\n            \"unsloth/Qwen3-VL-4B-Thinking-FP8\",\n            \"unsloth/Qwen3-VL-4B-Thinking-FP8\",\n        ),\n        \"16\" : (\n            \"unsloth/Qwen3-VL-4B-Thinking\",\n            \"Qwen/Qwen3-VL-4B-Thinking\",\n            \"unsloth/Qwen3-VL-4B-Thinking-bnb-4bit\",\n        ),\n    },\n    \"unsloth/Qwen3-VL-8B-Thinking-unsloth-bnb-4bit\" : {\n        \"8\" : (\n            \"Qwen/Qwen3-VL-8B-Thinking-FP8\",\n            \"unsloth/Qwen3-VL-8B-Thinking-FP8\",\n            \"unsloth/Qwen3-VL-8B-Thinking-FP8\",\n        ),\n        \"16\" : (\n            \"unsloth/Qwen3-VL-8B-Thinking\",\n            \"Qwen/Qwen3-VL-8B-Thinking\",\n            \"unsloth/Qwen3-VL-8B-Thinking-bnb-4bit\",\n        ),\n    },\n    \"unsloth/Qwen3-VL-4B-Instruct-unsloth-bnb-4bit\" : {\n        \"8\" : (\n            \"Qwen/Qwen3-VL-4B-Instruct-FP8\",\n            \"unsloth/Qwen3-VL-4B-Instruct-FP8\",\n            \"unsloth/Qwen3-VL-4B-Instruct-FP8\",\n        ),\n        \"16\" : (\n            \"unsloth/Qwen3-VL-4B-Instruct\",\n            \"Qwen/Qwen3-VL-4B-Instruct\",\n            \"unsloth/Qwen3-VL-4B-Instruct-bnb-4bit\",\n        ),\n    },\n    \"unsloth/Qwen3-VL-8B-Instruct-unsloth-bnb-4bit\" : {\n        \"8\" : (\n            \"Qwen/Qwen3-VL-8B-Instruct-FP8\",\n            \"unsloth/Qwen3-VL-8B-Instruct-FP8\",\n            \"unsloth/Qwen3-VL-8B-Instruct-FP8\",\n        ),\n        \"16\" : (\n            \"unsloth/Qwen3-VL-8B-Instruct\",\n            \"Qwen/Qwen3-VL-8B-Instruct\",\n            \"unsloth/Qwen3-VL-8B-Instruct-bnb-4bit\",\n        ),\n    },\n    \"unsloth/Qwen3-VL-2B-Thinking-unsloth-bnb-4bit\" : {\n        \"8\" : (\n            \"Qwen/Qwen3-VL-2B-Thinking-FP8\",\n            \"unsloth/Qwen3-VL-2B-Thinking-FP8\",\n            \"unsloth/Qwen3-VL-2B-Thinking-FP8\",\n        ),\n        \"16\" : (\n            \"unsloth/Qwen3-VL-2B-Thinking\",\n            \"Qwen/Qwen3-VL-2B-Thinking\",\n            \"unsloth/Qwen3-VL-2B-Thinking-bnb-4bit\",\n        ),\n    },\n    \"unsloth/Qwen3-VL-32B-Thinking-unsloth-bnb-4bit\" : {\n        \"8\" : (\n            \"Qwen/Qwen3-VL-32B-Thinking-FP8\",\n            \"unsloth/Qwen3-VL-32B-Thinking-FP8\",\n            \"unsloth/Qwen3-VL-32B-Thinking-FP8\",\n        ),\n        \"16\" : (\n            \"unsloth/Qwen3-VL-32B-Thinking\",\n            \"Qwen/Qwen3-VL-32B-Thinking\",\n            \"unsloth/Qwen3-VL-32B-Thinking-bnb-4bit\",\n        ),\n    },\n    \"unsloth/Qwen3-VL-2B-Instruct-unsloth-bnb-4bit\" : {\n        \"8\" : (\n            \"Qwen/Qwen3-VL-2B-Instruct-FP8\",\n            \"unsloth/Qwen3-VL-2B-Instruct-FP8\",\n            \"unsloth/Qwen3-VL-2B-Instruct-FP8\",\n        ),\n        \"16\" : (\n            \"unsloth/Qwen3-VL-2B-Instruct\",\n            \"Qwen/Qwen3-VL-2B-Instruct\",\n            \"unsloth/Qwen3-VL-2B-Instruct-bnb-4bit\",\n        ),\n    },\n    \"unsloth/Qwen3-VL-32B-Instruct-unsloth-bnb-4bit\" : {\n        \"8\" : (\n            \"Qwen/Qwen3-VL-32B-Instruct-FP8\",\n            \"unsloth/Qwen3-VL-32B-Instruct-FP8\",\n            \"unsloth/Qwen3-VL-32B-Instruct-FP8\",\n        ),\n        \"16\" : (\n            \"unsloth/Qwen3-VL-32B-Instruct\",\n            \"Qwen/Qwen3-VL-32B-Instruct\",\n            \"unsloth/Qwen3-VL-32B-Instruct-bnb-4bit\",\n        ),\n    },\n    \"unsloth/granite-4.0-350m-base-unsloth-bnb-4bit\" : (\n        \"unsloth/granite-4.0-350m-base\",\n        \"ibm-granite/granite-4.0-350m-base\",\n        \"unsloth/granite-4.0-350m-base-bnb-4bit\",\n    ),\n    \"unsloth/granite-4.0-350m-unsloth-bnb-4bit\" : (\n        \"unsloth/granite-4.0-350m\",\n        \"ibm-granite/granite-4.0-350m\",\n        \"unsloth/granite-4.0-350m-bnb-4bit\",\n    ),\n    \"unsloth/granite-4.0-h-350m-base-unsloth-bnb-4bit\" : (\n        \"unsloth/granite-4.0-h-350m-base\",\n        \"ibm-granite/granite-4.0-h-350m-base\",\n        \"unsloth/granite-4.0-h-350m-base-bnb-4bit\",\n    ),\n    \"unsloth/granite-4.0-h-350m-unsloth-bnb-4bit\" : (\n        \"unsloth/granite-4.0-h-350m\",\n        \"ibm-granite/granite-4.0-h-350m\",\n        \"unsloth/granite-4.0-h-350m-bnb-4bit\",\n    ),\n    \"unsloth/granite-4.0-1b-base-unsloth-bnb-4bit\" : (\n        \"unsloth/granite-4.0-1b-base\",\n        \"ibm-granite/granite-4.0-1b-base\",\n        \"unsloth/granite-4.0-1b-base-bnb-4bit\",\n    ),\n    \"unsloth/granite-4.0-1b-unsloth-bnb-4bit\" : (\n        \"unsloth/granite-4.0-1b\",\n        \"ibm-granite/granite-4.0-1b\",\n        \"unsloth/granite-4.0-1b-bnb-4bit\",\n    ),\n    \"unsloth/granite-4.0-h-1b-base-unsloth-bnb-4bit\" : (\n        \"unsloth/granite-4.0-h-1b-base\",\n        \"ibm-granite/granite-4.0-h-1b-base\",\n        \"unsloth/granite-4.0-h-1b-base-bnb-4bit\",\n    ),\n    \"unsloth/granite-4.0-h-1b-unsloth-bnb-4bit\" : (\n        \"unsloth/granite-4.0-h-1b\",\n        \"ibm-granite/granite-4.0-h-1b\",\n        \"unsloth/granite-4.0-h-1b-bnb-4bit\",\n    ),\n    \"unsloth/gpt-oss-safeguard-20b\" : (\n        \"unsloth/gpt-oss-safeguard-20b\",\n        \"openai/gpt-oss-safeguard-20b\",\n    ),\n    \"unsloth/gpt-oss-safeguard-120b\" : (\n        \"unsloth/gpt-oss-safeguard-120b\",\n        \"openai/gpt-oss-safeguard-120b\",\n    ),\n    \"unsloth/functiongemma-270m-it-unsloth-bnb-4bit\" : (\n        \"unsloth/functiongemma-270m-it\",\n        \"google/functiongemma-270m-it\",\n        \"unsloth/functiongemma-270m-it-unsloth-bnb-4bit\",\n    ),\n    # Ministral 3 models\n    \"unsloth/Ministral-3-3B-Instruct-2512-unsloth-bnb-4bit\" : {\n        \"8\" : (\n            \"mistralai/Ministral-3-3B-Instruct-2512\",\n            \"unsloth/Ministral-3-3B-Instruct-2512-FP8\",\n            \"unsloth/Ministral-3-3B-Instruct-2512-FP8\",\n        ),\n        \"16\" : (\n            \"unsloth/Ministral-3-3B-Instruct-2512\",\n            \"mistralai/Ministral-3-3B-Instruct-2512\",\n            \"unsloth/Ministral-3-3B-Instruct-2512-bnb-4bit\",\n        ),\n    },\n    \"unsloth/Ministral-3-3B-Base-2512-unsloth-bnb-4bit\" : (\n        \"unsloth/Ministral-3-3B-Base-2512\",\n        \"mistralai/Ministral-3-3B-Base-2512\",\n        \"unsloth/Ministral-3-3B-Base-2512-bnb-4bit\",\n    ),\n    \"unsloth/Ministral-3-3B-Reasoning-2512-unsloth-bnb-4bit\" : (\n        \"unsloth/Ministral-3-3B-Reasoning-2512\",\n        \"mistralai/Ministral-3-3B-Reasoning-2512\",\n        \"unsloth/Ministral-3-3B-Reasoning-2512-bnb-4bit\",\n    ),\n    \"unsloth/Ministral-3-8B-Instruct-2512-unsloth-bnb-4bit\" : {\n        \"8\" : (\n            \"mistralai/Ministral-3-8B-Instruct-2512\",\n            \"unsloth/Ministral-3-8B-Instruct-2512-FP8\",\n            \"unsloth/Ministral-3-8B-Instruct-2512-FP8\",\n        ),\n        \"16\" : (\n            \"unsloth/Ministral-3-8B-Instruct-2512\",\n            \"mistralai/Ministral-3-8B-Instruct-2512\",\n            \"unsloth/Ministral-3-8B-Instruct-2512-bnb-4bit\",\n        ),\n    },\n    \"unsloth/Ministral-3-8B-Base-2512-unsloth-bnb-4bit\" : (\n        \"unsloth/Ministral-3-8B-Base-2512\",\n        \"mistralai/Ministral-3-8B-Base-2512\",\n        \"unsloth/Ministral-3-8B-Base-2512-bnb-4bit\",\n    ),\n    \"unsloth/Ministral-3-8B-Reasoning-2512-unsloth-bnb-4bit\" : (\n        \"unsloth/Ministral-3-8B-Reasoning-2512\",\n        \"mistralai/Ministral-3-8B-Reasoning-2512\",\n        \"unsloth/Ministral-3-8B-Reasoning-2512-bnb-4bit\",\n    ),\n    \"unsloth/Ministral-3-14B-Instruct-2512-unsloth-bnb-4bit\" : {\n        \"8\" : (\n            \"mistralai/Ministral-3-14B-Instruct-2512\",\n            \"unsloth/Ministral-3-14B-Instruct-2512-FP8\",\n            \"unsloth/Ministral-3-14B-Instruct-2512-FP8\",\n        ),\n        \"16\" : (\n            \"unsloth/Ministral-3-14B-Instruct-2512\",\n            \"mistralai/Ministral-3-14B-Instruct-2512\",\n            \"unsloth/Ministral-3-14B-Instruct-2512-bnb-4bit\",\n        ),\n    },\n    \"unsloth/Ministral-3-14B-Base-2512-unsloth-bnb-4bit\" : (\n        \"unsloth/Ministral-3-14B-Base-2512\",\n        \"mistralai/Ministral-3-14B-Base-2512\",\n        \"unsloth/Ministral-3-14B-Base-2512-bnb-4bit\",\n    ),\n    \"unsloth/Ministral-3-14B-Reasoning-2512-unsloth-bnb-4bit\" : (\n        \"unsloth/Ministral-3-14B-Reasoning-2512\",\n        \"mistralai/Ministral-3-14B-Reasoning-2512\",\n        \"unsloth/Ministral-3-14B-Reasoning-2512-bnb-4bit\",\n    ),\n    \"unsloth/Kimi-K2-Instruct-BF16\" : (\n        \"unsloth/Kimi-K2-Instruct\",\n    ),\n}\n\nINT_TO_FLOAT_MAPPER  = {}\nFLOAT_TO_INT_MAPPER  = {}\nMAP_TO_UNSLOTH_16bit = {}\nFLOAT_TO_FP8_BLOCK_MAPPER = {}\nFLOAT_TO_FP8_ROW_MAPPER   = {}\n\n\ndef _add_with_lower(mapper, key, value):\n    if key is None:\n        return\n    mapper[key] = value\n    mapper[key.lower()] = value\n\n\ndef _add_lower_only(mapper, key, value):\n    if key is None:\n        return\n    mapper[key.lower()] = value\n\nfor key, values in __INT_TO_FLOAT_MAPPER.items():\n    block, row = None, None\n    if type(values) is dict:\n        assert \"16\" in values\n        float16_values = values[\"16\"]\n        # Float8 and other quantized types\n        if \"8\" in values:\n            float8_values = values[\"8\"]\n            assert len(float8_values) == 3\n            official, block, row = float8_values\n            _add_lower_only(FLOAT_TO_FP8_BLOCK_MAPPER, key, block)\n            _add_lower_only(FLOAT_TO_FP8_ROW_MAPPER, key, row)\n            _add_lower_only(FLOAT_TO_FP8_BLOCK_MAPPER, official + \"-dynamic\", block)\n            _add_lower_only(FLOAT_TO_FP8_ROW_MAPPER, official, row)\n            _add_lower_only(FLOAT_TO_FP8_ROW_MAPPER, official + \"-dynamic\", row)\n            for k in float8_values + float16_values:\n                _add_lower_only(FLOAT_TO_FP8_BLOCK_MAPPER, k, block)\n                _add_lower_only(FLOAT_TO_FP8_ROW_MAPPER, k, row)\n\n            if float8_values[1] is not None and float8_values[1].startswith(\"unsloth\"):\n                for value in float8_values:\n                    if value is not None:\n                        _add_with_lower(MAP_TO_UNSLOTH_16bit, value, float8_values[1])\n\n            for value in float8_values:\n                if value is not None:\n                    FLOAT_TO_INT_MAPPER[value] = key\n                    FLOAT_TO_INT_MAPPER[value.lower()] = key.lower()\n        values = float16_values\n    INT_TO_FLOAT_MAPPER[key] = values[0]\n\n    for value in values:\n        FLOAT_TO_INT_MAPPER[value] = key\n\n    # Map to Unsloth version for 16bit versions\n    if len(values) == 2:\n        if values[0].startswith(\"unsloth\"):\n            _add_with_lower(MAP_TO_UNSLOTH_16bit, values[1], values[0])\n            _add_with_lower(MAP_TO_UNSLOTH_16bit, block, values[0])\n            _add_with_lower(MAP_TO_UNSLOTH_16bit, row, values[0])\n    elif len(values) == 3:\n        # Dynamic Unsloth quantization\n        if values[0].startswith(\"unsloth\"):\n            _add_with_lower(MAP_TO_UNSLOTH_16bit, values[1], values[0])\n            _add_with_lower(MAP_TO_UNSLOTH_16bit, values[2], values[0])\n            _add_with_lower(MAP_TO_UNSLOTH_16bit, block, values[0])\n            _add_with_lower(MAP_TO_UNSLOTH_16bit, row, values[0])\n        pass\n\n    # Get lowercased\n    lowered_key = key.lower()\n    INT_TO_FLOAT_MAPPER[lowered_key] = values[0].lower()\n\n    for value in values:\n        FLOAT_TO_INT_MAPPER[value.lower()] = lowered_key\n"
  },
  {
    "path": "unsloth/models/mistral.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .llama import *\nimport os\nfrom ._utils import __version__\nfrom unsloth_zoo.utils import _get_dtype\nfrom unsloth_zoo.hf_utils import dtype_from_config\nfrom ..utils.packing import (\n    get_packed_info_from_kwargs,\n    mask_packed_sequence_boundaries,\n)\nfrom ..utils.attention_dispatch import (\n    AttentionConfig,\n    AttentionContext,\n    run_attention,\n    SDPA,\n    select_attention_backend,\n)\nfrom .llama import (\n    LlamaRotaryEmbedding,\n    LlamaLinearScalingRotaryEmbedding,\n)\nfrom transformers.models.mistral.modeling_mistral import (\n    MistralAttention,\n    MistralDecoderLayer,\n    MistralModel,\n    MistralForCausalLM,\n)\n\n# For Pytorch 2.1.1\ntry:\n    from transformers.models.mistral.modeling_mistral import (\n        MistralSdpaAttention,\n        MistralFlashAttention2,\n    )\nexcept:\n    MistralSdpaAttention = MistralAttention\n    MistralFlashAttention2 = MistralAttention\nfrom unsloth_zoo.utils import Version, _get_dtype\n\n\ndef MistralAttention_fast_forward(\n    self,\n    hidden_states: torch.Tensor,\n    causal_mask: Optional[BlockDiagonalCausalMask] = None,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_value: Optional[Tuple[torch.Tensor]] = None,\n    output_attentions: bool = False,\n    use_cache: bool = False,\n    padding_mask: Optional[torch.LongTensor] = None,\n    position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n    *args,\n    **kwargs,\n) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n    # Clear inference\n    if hasattr(self, \"paged_attention\"):\n        del self.paged_attention_K\n        del self.paged_attention_V\n        del self.paged_attention\n        del self.temp_QA\n        del self.temp_KV\n        del self.RH_Q\n        del self.attention\n\n    bsz, q_len, _ = hidden_states.size()\n\n    n_heads = self.config.num_attention_heads\n    n_groups = self.num_key_value_groups\n    n_kv_heads = self.config.num_key_value_heads\n    head_dim = self.head_dim\n    assert n_kv_heads * n_groups == n_heads\n\n    Q, K, V = self.apply_qkv(self, hidden_states)\n    Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)\n    K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)\n    V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)\n    seq_info = get_packed_info_from_kwargs(kwargs, Q.device)\n\n    kv_seq_len = K.shape[-2]\n    if past_key_value is not None:\n        kv_seq_len += past_key_value[0].shape[-2]\n\n    # Extend RoPE dynamically to fit in VRAM\n    self.rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len)\n    cos, sin = self.rotary_emb.get_cached(kv_seq_len, Q.device.index)\n\n    rope_position_ids = (\n        position_ids if position_ids is not None else kwargs.get(\"position_ids\")\n    )\n    # Useful for LongRoPE\n    Q, K = fast_rope_embedding(Q, K, cos, sin, rope_position_ids)\n\n    if past_key_value is not None:\n        K = torch.cat([past_key_value[0], K], dim = 2)\n        V = torch.cat([past_key_value[1], V], dim = 2)\n    past_key_value = (K, V) if use_cache else None\n\n    # Attention module\n    sw_cfg = getattr(self.config, \"sliding_window\", None)\n    sw = kv_seq_len if (sw_cfg is None or sw_cfg == \"null\") else sw_cfg\n    window_size = (-1, -1) if (kv_seq_len <= sw) else (sw, sw)\n\n    use_varlen = (\n        seq_info is not None and past_key_value is None and window_size == (-1, -1)\n    )\n    backend = (\n        SDPA if attention_mask is not None else select_attention_backend(use_varlen)\n    )\n    attention_config = AttentionConfig(\n        backend = backend,\n        n_kv_heads = n_kv_heads,\n        n_groups = n_groups,\n        flash_dense_kwargs = {\"causal\": True, \"window_size\": window_size},\n        flash_varlen_kwargs = {\n            \"dropout_p\": 0.0,\n            \"causal\": True,\n            \"softmax_scale\": getattr(self, \"softmax_scale\", None),\n        },\n    )\n    context = AttentionContext(\n        bsz = bsz,\n        q_len = q_len,\n        kv_seq_len = kv_seq_len,\n        n_heads = n_heads,\n        head_dim = head_dim,\n        requires_grad = hidden_states.requires_grad,\n        seq_info = seq_info,\n        attention_mask = attention_mask,\n        causal_mask = causal_mask,\n    )\n\n    A = run_attention(config = attention_config, context = context, Q = Q, K = K, V = V)\n    attn_output = A.reshape(bsz, q_len, n_heads * head_dim)\n    attn_output = self.apply_o(self, attn_output)\n    attn_weights = None\n    return attn_output, attn_weights, past_key_value\n\n\ndef MistralForCausalLM_fast_forward(\n    self,\n    input_ids: torch.LongTensor = None,\n    causal_mask: Optional[BlockDiagonalCausalMask] = None,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_values: Optional[List[torch.FloatTensor]] = None,\n    inputs_embeds: Optional[torch.FloatTensor] = None,\n    labels: Optional[torch.LongTensor] = None,\n    use_cache: Optional[bool] = None,\n    output_attentions: Optional[bool] = None,\n    output_hidden_states: Optional[bool] = None,\n    return_dict: Optional[bool] = None,\n    num_logits_to_keep: Optional[int] = 0,\n    logits_to_keep: Optional[int] = 0,\n    *args,\n    **kwargs,\n) -> Union[Tuple, CausalLMOutputWithPast]:\n    if causal_mask is None and past_key_values is None:\n        bsz, q_len = input_ids.shape\n        sliding_window = getattr(self.config, \"sliding_window\", None)\n\n        if HAS_XFORMERS:\n            # Always create causal mask for xformers\n            if (\n                sliding_window is None\n                or sliding_window == \"null\"\n                or sliding_window <= 0\n            ):\n                causal_mask = xformers.attn_bias.LowerTriangularMask()\n            elif q_len <= sliding_window:\n                causal_mask = xformers.attn_bias.LowerTriangularMask()\n            else:\n                causal_mask = xformers.attn_bias.BlockDiagonalCausalMask.from_seqlens(\n                    [q_len] * bsz\n                ).make_local_attention(window_size = sliding_window)\n\n            # If attention_mask exists, it will be handled in the attention forward\n\n        else:\n            # Not using xformers - need to create attention masks\n            if (\n                sliding_window is None\n                or sliding_window == \"null\"\n                or sliding_window <= 0\n                or q_len <= sliding_window\n            ):\n                # Fully causal mask\n                causal_mask_values = torch.triu(\n                    torch.full((q_len, q_len), -torch.inf, device = input_ids.device),\n                    diagonal = 1,\n                )\n            else:\n                # Sliding window attention\n                q_indices = torch.arange(q_len, device = input_ids.device).view(-1, 1)\n                k_indices = torch.arange(q_len, device = input_ids.device).view(1, -1)\n\n                causal_bool_mask = k_indices <= q_indices\n                window_bool_mask = (q_indices - k_indices) < sliding_window\n\n                causal_mask_values = torch.where(\n                    causal_bool_mask & window_bool_mask, 0.0, -torch.inf\n                )\n\n            # Combine with existing attention_mask if present\n            if attention_mask is None:\n                attention_mask = causal_mask_values[None, None, :, :].expand(\n                    bsz, 1, q_len, q_len\n                )\n            else:\n                if attention_mask.dim() == 2:\n                    # Convert 0/1 padding mask to additive format: 1->0 (keep), 0->-inf (mask)\n                    padding_mask = torch.where(\n                        attention_mask[:, None, None, :].bool(),\n                        0.0,\n                        -torch.inf,\n                    )\n                    attention_mask = causal_mask_values[None, None, :, :] + padding_mask\n                else:\n                    attention_mask = (\n                        attention_mask + causal_mask_values[None, None, :, :]\n                    )\n\n            attention_mask = attention_mask.to(\n                dtype = _get_dtype(dtype_from_config(self.config))\n            )\n\n    output_attentions = (\n        output_attentions\n        if output_attentions is not None\n        else self.config.output_attentions\n    )\n    output_hidden_states = (\n        output_hidden_states\n        if output_hidden_states is not None\n        else self.config.output_hidden_states\n    )\n    return_dict = (\n        return_dict if return_dict is not None else self.config.use_return_dict\n    )\n\n    # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n    self.model._has_no_labels = labels is None\n\n    if past_key_values is not None:\n        outputs = LlamaModel_fast_forward_inference(\n            self,\n            input_ids,\n            past_key_values,\n            position_ids = position_ids,\n            attention_mask = attention_mask,\n        )\n    else:\n        outputs = self.model(\n            input_ids = input_ids,\n            causal_mask = causal_mask,\n            attention_mask = attention_mask,\n            position_ids = position_ids,\n            past_key_values = past_key_values,\n            inputs_embeds = inputs_embeds,\n            use_cache = use_cache,\n            output_attentions = output_attentions,\n            output_hidden_states = output_hidden_states,\n            return_dict = return_dict,\n            **kwargs,\n        )\n\n    hidden_states = outputs[0]\n\n    bsz, q_len, hd = hidden_states.shape\n    lm_head = self.lm_head.weight\n    lm_head_device = lm_head.device\n\n    # Move items to same device as lm_head\n    hidden_states = hidden_states.to(lm_head_device)\n    if labels is not None:\n        labels = labels.to(lm_head_device)\n\n    # If we are in GRPO mode, return raw hidden states\n    if os.environ.get(\"UNSLOTH_RETURN_HIDDEN_STATES\", \"0\") == \"1\":\n        num_logits_to_keep = max(num_logits_to_keep, logits_to_keep)\n        if num_logits_to_keep != 0:\n            hidden_states = hidden_states[:, -num_logits_to_keep:, :]\n        return CausalLMOutputWithPast(\n            loss = None,\n            logits = hidden_states,\n            past_key_values = outputs.past_key_values,\n            hidden_states = outputs.hidden_states,\n            attentions = outputs.attentions,\n        )\n\n    if bsz == 1 and q_len == 1:\n        logits = torch.mv(lm_head, hidden_states.ravel().to(lm_head.dtype))\n        logits = logits.unsqueeze(0).unsqueeze(0)\n    elif num_logits_to_keep != 0:\n        logits = self.lm_head(\n            hidden_states[:, -num_logits_to_keep:, :].to(lm_head.dtype)\n        )\n    else:\n        RETURN_LOGITS = os.environ.get(\"UNSLOTH_RETURN_LOGITS\", \"0\") == \"1\"\n        # < 1024 Normal Unsloth uses less VRAM!\n        if bsz * q_len <= 1024 and not RETURN_LOGITS:\n            # Use unsloth_fused_ce_loss which actually calculates the best chunk size to reduce VRAM usage\n            RETURN_LOGITS = False\n\n        if not RETURN_LOGITS and labels is not None:\n            n_items = kwargs.get(\"num_items_in_batch\", None)\n            if n_items is None:\n                n_items = kwargs.get(\"n_items\", None)\n            logit_softcapping = getattr(self.config, \"final_logit_softcapping\", 0)\n\n            # loss = fused_linear_cross_entropy(\n            #     hidden_states = hidden_states,\n            #     lm_weight = lm_head,\n            #     labels = labels,\n            #     num_items_in_batch = n_items,\n            #     logit_softcapping = logit_softcapping,\n            # )\n            loss = unsloth_fused_ce_loss(\n                trainer = None,\n                hidden_states = hidden_states,\n                lm_head_weight = lm_head,\n                lm_head_bias = None,\n                labels = labels,\n                mask = None,\n                n_items = n_items,\n                scaling = getattr(self, \"accelerator_scaler\", None),\n                target_gb = None,\n                torch_compile = True,\n                logit_softcapping = logit_softcapping,\n            )\n            if not return_dict:\n                output = (logits,) + outputs[1:]\n                return (loss,) + output if loss is not None else output\n\n            output = CausalLMOutputWithPast(\n                loss = loss,\n                logits = EMPTY_LOGITS,\n                past_key_values = outputs.past_key_values,\n                hidden_states = outputs.hidden_states,\n                attentions = outputs.attentions,\n            )\n            return output\n        pass\n        logits = self.lm_head(hidden_states.to(lm_head.dtype))\n    logits = logits.to(_get_dtype(dtype_from_config(self.config)))\n\n    loss = None\n    if labels is not None:\n        shift_logits = logits\n        # if not hasattr(self, \"extra_ignored_labels\"):\n        #     # Fixes https://github.com/unslothai/unsloth/issues/10\n        #     self.extra_ignored_labels = torch.full((self.max_seq_length, 1), -100, device = \"cuda:0\")\n        # pass\n        # shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))\n        shift_labels = torch.empty_like(labels)\n        shift_labels[..., :-1] = labels[..., 1:]\n        shift_labels[..., -1] = -100\n        mask_packed_sequence_boundaries(\n            shift_labels,\n            kwargs.get(\"packed_seq_lengths\"),\n        )\n        n_items = kwargs.get(\"num_items_in_batch\", None)\n        if n_items is None:\n            n_items = kwargs.get(\"n_items\", None)\n        loss = fast_cross_entropy_loss(\n            logits = shift_logits,\n            labels = shift_labels,\n            n_items = n_items,\n        )\n\n    if not return_dict:\n        output = (logits,) + outputs[1:]\n        return (loss,) + output if loss is not None else output\n\n    return CausalLMOutputWithPast(\n        loss = loss,\n        logits = logits,\n        past_key_values = outputs.past_key_values,\n        hidden_states = outputs.hidden_states,\n        attentions = outputs.attentions,\n    )\n\n\n# Transformers had to update for Mistral Nemo 12b since Attention is (5120, 4096) now.\ndef patch_mistral_nemo_attention(function):\n    function = function.replace(\n        \"(self.head_dim * self.config.num_attention_heads) != self.config.hidden_size\",\n        \"False\",\n    )\n    function = function.replace(\n        \"self.head_dim = self.config.hidden_size // self.config.num_attention_heads\",\n        \"self.head_dim = config.head_dim\",\n    )\n    function = function.replace(\n        \"self.o_proj = nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False)\",\n        \"self.o_proj = nn.Linear(self.config.num_attention_heads * self.head_dim, self.config.hidden_size, bias=False)\",\n    )\n    return function\n\n\nclass FastMistralModel(FastLlamaModel):\n    @staticmethod\n    def pre_patch():\n        init_name, function = patch_linear_scaling(\n            model_name = \"mistral\",\n            rope_module = LlamaRotaryEmbedding,\n            scaled_rope_module = LlamaLinearScalingRotaryEmbedding,\n            attention_module = MistralAttention,\n        )\n        # Just for Mistral Nemo models!\n        if function is not None and init_name is not None:\n            function = patch_mistral_nemo_attention(function)\n            # if True:#init_name is not None:\n            exec(function, globals())\n            MistralAttention.__init__ = eval(init_name)\n        MistralAttention.forward = MistralAttention_fast_forward\n        MistralSdpaAttention.forward = MistralAttention_fast_forward\n        MistralFlashAttention2.forward = MistralAttention_fast_forward\n        MistralDecoderLayer.forward = LlamaDecoderLayer_fast_forward\n        MistralModel.forward = LlamaModel_fast_forward\n        MistralForCausalLM.forward = MistralForCausalLM_fast_forward\n        PeftModelForCausalLM.forward = PeftModel_fast_forward\n        fix_prepare_inputs_for_generation(MistralForCausalLM)\n\n        # Solves https://github.com/unslothai/unsloth/issues/168\n        # Static KV Cache was introduced in 4.38.0, causing training to be much slower.\n        # Inference can now be CUDAGraphed, but we shall retain the old rotary embeddings.\n        # https://github.com/huggingface/transformers/pull/27931\n        # https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py\n        import transformers.models.mistral.modeling_mistral\n\n        transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding = (\n            LlamaRotaryEmbedding\n        )\n        return\n\n    @staticmethod\n    def from_pretrained(\n        model_name = \"unsloth/mistral-7b-bnb-4bit\",\n        max_seq_length = None,\n        dtype = None,\n        load_in_4bit = True,\n        token = None,\n        device_map = \"sequential\",\n        rope_scaling = None,  # Mistral does not support RoPE scaling\n        fix_tokenizer = True,\n        model_patcher = None,\n        tokenizer_name = None,\n        trust_remote_code = False,\n        **kwargs,\n    ):\n        return FastLlamaModel.from_pretrained(\n            model_name = model_name,\n            max_seq_length = max_seq_length,\n            dtype = dtype,\n            load_in_4bit = load_in_4bit,\n            token = token,\n            device_map = device_map,\n            rope_scaling = rope_scaling,\n            fix_tokenizer = fix_tokenizer,\n            model_patcher = FastMistralModel,\n            tokenizer_name = tokenizer_name,\n            trust_remote_code = trust_remote_code,\n            **kwargs,\n        )\n"
  },
  {
    "path": "unsloth/models/qwen2.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .llama import *\nfrom .llama import (\n    LlamaRotaryEmbedding,\n    LlamaLinearScalingRotaryEmbedding,\n)\nfrom transformers.models.qwen2.modeling_qwen2 import (\n    Qwen2Attention,\n    Qwen2DecoderLayer,\n    Qwen2Model,\n    Qwen2ForCausalLM,\n)\n\n# For Pytorch 2.1.1\ntry:\n    from transformers.models.qwen2.modeling_qwen2 import (\n        Qwen2SdpaAttention,\n        Qwen2FlashAttention2,\n    )\nexcept:\n    Qwen2SdpaAttention = Qwen2Attention\n    Qwen2FlashAttention2 = Qwen2Attention\n\n\nclass FastQwen2Model(FastLlamaModel):\n    @staticmethod\n    def pre_patch():\n        init_name, function = patch_linear_scaling(\n            model_name = \"qwen2\",\n            rope_module = LlamaRotaryEmbedding,\n            scaled_rope_module = LlamaLinearScalingRotaryEmbedding,\n            attention_module = Qwen2Attention,\n        )\n        if init_name is not None:\n            exec(function, globals())\n            Qwen2Attention.__init__ = eval(init_name)\n        Qwen2Attention.forward = LlamaAttention_fast_forward\n        Qwen2SdpaAttention.forward = LlamaAttention_fast_forward\n        Qwen2FlashAttention2.forward = LlamaAttention_fast_forward\n        Qwen2DecoderLayer.forward = LlamaDecoderLayer_fast_forward\n        Qwen2Model.forward = LlamaModel_fast_forward\n        Qwen2ForCausalLM.forward = CausalLM_fast_forward(\n            LlamaModel_fast_forward_inference\n        )\n        PeftModelForCausalLM.forward = PeftModel_fast_forward\n        fix_prepare_inputs_for_generation(Qwen2ForCausalLM)\n\n        # Solves https://github.com/unslothai/unsloth/issues/168\n        # Static KV Cache was introduced in 4.38.0, causing training to be much slower.\n        # Inference can now be CUDAGraphed, but we shall retain the old rotary embeddings.\n        # https://github.com/huggingface/transformers/pull/27931\n        # https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py\n        import transformers.models.qwen2.modeling_qwen2\n\n        transformers.models.qwen2.modeling_qwen2.Qwen2RotaryEmbedding = (\n            LlamaRotaryEmbedding\n        )\n        return\n\n    @staticmethod\n    def from_pretrained(\n        model_name = \"Qwen/Qwen2-7B\",\n        max_seq_length = 4096,\n        dtype = None,\n        load_in_4bit = True,\n        token = None,\n        device_map = \"sequential\",\n        rope_scaling = None,  # Qwen2 does not support RoPE scaling\n        fix_tokenizer = True,\n        model_patcher = None,\n        tokenizer_name = None,\n        trust_remote_code = False,\n        **kwargs,\n    ):\n        return FastLlamaModel.from_pretrained(\n            model_name = model_name,\n            max_seq_length = max_seq_length,\n            dtype = dtype,\n            load_in_4bit = load_in_4bit,\n            token = token,\n            device_map = device_map,\n            rope_scaling = rope_scaling,\n            fix_tokenizer = fix_tokenizer,\n            model_patcher = FastQwen2Model,\n            tokenizer_name = tokenizer_name,\n            trust_remote_code = trust_remote_code,\n            **kwargs,\n        )\n"
  },
  {
    "path": "unsloth/models/qwen3.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .llama import *\nimport os\nfrom ._utils import __version__\nfrom unsloth_zoo.utils import Version, _get_dtype\nfrom ..utils.packing import get_packed_info_from_kwargs\nfrom ..utils.attention_dispatch import (\n    AttentionConfig,\n    AttentionContext,\n    run_attention,\n    SDPA,\n    select_attention_backend,\n)\nfrom .llama import (\n    LlamaRotaryEmbedding,\n    LlamaLinearScalingRotaryEmbedding,\n    _LlamaModel_fast_forward_inference,\n)\n\ntry:\n    from transformers.models.qwen3.modeling_qwen3 import (\n        Qwen3Attention,\n        Qwen3DecoderLayer,\n        Qwen3Model,\n        Qwen3ForCausalLM,\n    )\nexcept:\n    transformers_version = Version(transformers_version)\n    if not transformers_version >= Version(\n        \"4.50.3\"\n    ):  # TODO: Update when transformers is updated\n        raise ImportError(\n            f\"Unsloth: Your transformers version of {transformers_version} does not support Qwen3 and Qwen3Moe.\\n\"\n            f\"The minimum required version is 4.50.3.\\n\"\n            f'Try `pip install --upgrade \"transformers>=4.50.3\"`\\n'\n            f\"to obtain the latest transformers build, then restart this session.\"\n        )\nfrom transformers.modeling_attn_mask_utils import (\n    _prepare_4d_causal_attention_mask_for_sdpa,\n)\n\n# For Pytorch 2.1.1\ntry:\n    from transformers.models.qwen3.modeling_qwen3 import (\n        Qwen3SdpaAttention,\n        Qwen3FlashAttention2,\n    )\nexcept:\n    Qwen3SdpaAttention = Qwen3Attention\n    Qwen3FlashAttention2 = Qwen3Attention\n\n\ndef Qwen3Attention_fast_forward(\n    self,\n    hidden_states: torch.Tensor,\n    causal_mask: Optional[BlockDiagonalCausalMask] = None,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_value: Optional[Tuple[torch.Tensor]] = None,\n    output_attentions: bool = False,\n    use_cache: bool = False,\n    padding_mask: Optional[torch.LongTensor] = None,\n    position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n    *args,\n    **kwargs,\n) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n    # Clear inference\n    if hasattr(self, \"paged_attention\"):\n        del self.paged_attention_K\n        del self.paged_attention_V\n        del self.paged_attention\n        del self.temp_QA\n        del self.temp_KV\n        del self.RH_Q\n        del self.attention\n\n    bsz, q_len, _ = hidden_states.size()\n\n    n_heads = self.config.num_attention_heads\n    n_groups = self.num_key_value_groups\n    n_kv_heads = self.config.num_key_value_heads\n    head_dim = self.head_dim\n    assert n_kv_heads * n_groups == n_heads\n\n    Q, K, V = self.apply_qkv(self, hidden_states)\n    Q = Q.view(\n        bsz, q_len, n_heads, head_dim\n    )  # .transpose(1, 2) # we will transpose after normalisation\n    K = K.view(\n        bsz, q_len, n_kv_heads, head_dim\n    )  # .transpose(1, 2) # we will transpose after normalisation\n    V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)\n    seq_info = get_packed_info_from_kwargs(kwargs, hidden_states.device)\n\n    # Qwen3 has QKNorm. This seems to be the only difference from Qwen2.\n    # Note that using fast_layernorm_compiled causes issues as the dimensions don't match up.\n    # I tried to add a compiled version of the new norm but the numbers don't match up with Transformers\n    # TODO: Check on the differences here.\n    Q = fast_rms_layernorm(self.q_norm, Q)\n    K = fast_rms_layernorm(self.k_norm, K)\n\n    Q = Q.transpose(1, 2)\n    K = K.transpose(1, 2)\n\n    kv_seq_len = K.shape[-2]\n    if past_key_value is not None:\n        kv_seq_len += past_key_value[0].shape[-2]\n\n    # Extend RoPE dynamically to fit in VRAM\n    if position_embeddings and kv_seq_len <= position_embeddings[0].shape[0]:\n        cos, sin = position_embeddings\n    else:\n        rotary_emb = self.rotary_emb\n        rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len)\n        cos, sin = rotary_emb.get_cached(kv_seq_len, Q.device.index)\n\n    rope_position_ids = (\n        position_ids if position_ids is not None else kwargs.get(\"position_ids\")\n    )\n    # Useful for LongRoPE\n    Q, K = fast_rope_embedding(Q, K, cos, sin, rope_position_ids)\n\n    if past_key_value is not None:\n        K = torch.cat([past_key_value[0], K], dim = 2)\n        V = torch.cat([past_key_value[1], V], dim = 2)\n    past_key_value = (K, V) if use_cache else None\n\n    # Attention module\n    use_varlen = seq_info is not None and past_key_value is None\n    backend = (\n        SDPA if attention_mask is not None else select_attention_backend(use_varlen)\n    )\n    attention_config = AttentionConfig(\n        backend = backend,\n        n_kv_heads = n_kv_heads,\n        n_groups = n_groups,\n        flash_dense_kwargs = {\"causal\": True},\n        flash_varlen_kwargs = {\n            \"dropout_p\": 0.0,\n            \"causal\": True,\n            \"softmax_scale\": getattr(self, \"softmax_scale\", None),\n        },\n    )\n    context = AttentionContext(\n        bsz = bsz,\n        q_len = q_len,\n        kv_seq_len = kv_seq_len,\n        n_heads = n_heads,\n        head_dim = head_dim,\n        requires_grad = hidden_states.requires_grad,\n        seq_info = seq_info,\n        attention_mask = attention_mask,\n        causal_mask = causal_mask,\n    )\n\n    A = run_attention(config = attention_config, context = context, Q = Q, K = K, V = V)\n\n    attn_output = A.reshape(bsz, q_len, n_heads * head_dim)\n    attn_output = self.apply_o(self, attn_output)\n    attn_weights = None\n    return attn_output, attn_weights, past_key_value\n\n\ntorch_matmul = torch.matmul\n\n\ndef Qwen3Attention_fast_forward_inference(\n    self,\n    hidden_states: torch.Tensor,\n    past_key_value: Optional[Tuple[torch.Tensor]],\n    position_ids,\n    do_prefill = False,\n    attention_mask = None,\n    **kwargs,\n):\n    \"\"\"\n    https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L406\n    Fast inference using KV cache.\n    QK^T can be computed in 4 chunks\n\n    [Q, q] @ [K, k].T where q, k are the new tokens.\n    [QK^T, Qk^T]\n    [qK^T, qk^T]\n\n    Since the attention mask wipes Qk^T, we just get\n    [QK^T,    0]\n    [qK^T, qk^T]\n\n    Since softmax is row-wise, we get\n    softmax([QK^T,    0])\n    softmax([qK^T, qk^T])\n\n    We then multiply by   [V]\n                          [v]\n    softmax([QK^T,    0]) [softmax(QK^T)V] *\n    softmax([qK^T, qk^T]) [softmax([qK^T, qk^T]) @ [V, v]]\n\n    But notice * [softmax(QK^T)V] is just the last attention.\n    We just need to compute the last final row.\n\n    This means we can pass in a row of Q, but we need to\n    remember K and V, which are called the KV cache.\n    \"\"\"\n    Xn = hidden_states\n    bsz, _, hd = hidden_states.size()\n    K1, V1 = past_key_value\n    dtype = Xn.dtype\n\n    n_heads = self.config.num_attention_heads\n    n_groups = self.num_key_value_groups\n    n_kv_heads = self.config.num_key_value_heads\n    head_dim = self.head_dim\n    # assert(n_kv_heads * n_groups == n_heads)\n\n    hidden_size = self.config.hidden_size\n    attention_size = n_heads * head_dim\n    seq_len = K1.shape[-2]\n    kv_seq_len = seq_len + 1\n\n    # Prefill phase\n    # if not hasattr(self, \"paged_attention\"):\n    device = hidden_states.device\n    if do_prefill:\n        self.paged_attention = torch.empty(\n            (KV_CACHE_INCREMENT + seq_len + 1, 2, bsz, n_kv_heads, head_dim),\n            dtype = dtype,\n            device = device,\n        )\n        self.paged_attention_K = self.paged_attention[:, 0]\n        self.paged_attention_V = self.paged_attention[:, 1]\n        self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)\n        self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)\n        self.temp_QA = torch.empty(\n            (2, bsz, 1, attention_size), dtype = dtype, device = device\n        )\n        self.temp_KV = torch.empty(\n            (2, bsz, 1, n_kv_heads * head_dim), dtype = dtype, device = device\n        )\n        self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = device)\n\n        # Mistral Nemo 12b has weird dimensions\n        if attention_size != hidden_size:\n            self.temp_O = torch.empty((bsz, 1, hidden_size), dtype = dtype, device = device)\n        else:\n            self.temp_O = self.temp_QA[1][:, :, :hidden_size]\n\n        self.attention = torch.empty(\n            (bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype = dtype, device = device\n        )\n        self.scalar = 1.0 / math_sqrt(self.head_dim)\n        self.half_head_dim = head_dim // 2\n    elif kv_seq_len >= self.paged_attention.shape[0]:\n        self.paged_attention.resize_(\n            (\n                self.paged_attention.shape[0] + KV_CACHE_INCREMENT,\n                2,\n                bsz,\n                n_kv_heads,\n                head_dim,\n            )\n        )\n        self.paged_attention_K = self.paged_attention[:, 0]\n        self.paged_attention_V = self.paged_attention[:, 1]\n        self.attention.resize_(\n            (bsz, n_heads, 1, self.attention.shape[-1] + KV_CACHE_INCREMENT)\n        )\n\n    Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])\n    Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0])\n    Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1])\n    Qn = Qn.view(\n        bsz, 1, n_heads, head_dim\n    )  # .transpose(1, 2) # we will transpose after normalisation\n    Kn = Kn.view(\n        bsz, 1, n_kv_heads, head_dim\n    )  # .transpose(1, 2) # we will transpose after normalisation\n    Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)\n\n    Qn = fast_rms_layernorm_inference(self.q_norm, Qn)\n    Kn = fast_rms_layernorm_inference(self.k_norm, Kn)\n\n    Qn = Qn.transpose(1, 2)\n    Kn = Kn.transpose(1, 2)\n\n    # cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len)\n    # Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids)\n\n    # Need to do it prior 2 steps before hitting full on short KV cache\n    # or else error\n    self.rotary_emb.extend_rope_embedding(Vn, seq_len + 2)\n    cos, sin = self.rotary_emb.get_cached(kv_seq_len, Qn.device.index)\n    cos = cos[position_ids].unsqueeze(1)\n    sin = sin[position_ids].unsqueeze(1)\n    h = self.half_head_dim\n\n    RH_Q = self.RH_Q\n    RH_Q[:, :, :, :h] = Qn[:, :, :, h:]\n    RH_Q[:, :, :, h:] = Qn[:, :, :, :h]\n    RH_Q[:, :, :, :h].neg_()  # torch.neg(RH_Q[:,:,:,:h], out = RH_Q[:,:,:,:h])\n    Qn *= cos\n    Qn.addcmul_(RH_Q, sin)\n\n    RH_K = RH_Q[\n        :, :n_kv_heads, :, :\n    ]  # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = \"cuda:0\")\n    RH_K[:, :, :, :h] = Kn[:, :, :, h:]\n    RH_K[:, :, :, h:] = Kn[:, :, :, :h]\n    RH_K[:, :, :, :h].neg_()  # torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h])\n    Kn *= cos\n    Kn.addcmul_(RH_K, sin)\n\n    # New KV cache\n    # Kn = torch.cat([K1, Kn], dim = 2)\n    # Vn = torch.cat([V1, Vn], dim = 2)\n    self.paged_attention_K[seq_len] = Kn.permute(2, 0, 1, 3)\n    self.paged_attention_V[seq_len] = Vn.permute(2, 0, 1, 3)\n    Kn = self.paged_attention_K[:kv_seq_len].permute(1, 2, 0, 3)\n    Vn = self.paged_attention_V[:kv_seq_len].permute(1, 2, 0, 3)\n\n    # Handle sliding windows\n    sliding_window = getattr(self.config, \"sliding_window\", None)\n    if sliding_window is not None and kv_seq_len > sliding_window:\n        start = kv_seq_len - sliding_window\n        Knn = Kn[:, :, start:, :]  # .contiguous()\n        Vnn = Vn[:, :, start:, :]  # .contiguous()\n        if attention_mask is not None:\n            attention_mask = attention_mask[..., start:]\n    else:\n        Knn, Vnn = Kn, Vn\n\n    # when qlen==vlen and attn_mask is None, we should use causal attention\n    Q_len = Qn.shape[-2]\n    K_len = Knn.shape[-2]\n    if attention_mask is not None and attention_mask.dim() == 2:\n        attention_mask = attention_mask[:, None, None, :].to(torch.bool)\n    elif (\n        attention_mask is not None\n        and attention_mask.dim() == 4\n        and attention_mask.dtype != torch.bool\n    ):\n        attention_mask = attention_mask.eq(0)\n    if attention_mask is None and Q_len == K_len:\n        is_causal = True\n    else:\n        is_causal = False\n    use_sdpa_gqa = SDPA_HAS_GQA\n    if (\n        use_sdpa_gqa\n        and isinstance(attention_mask, torch.Tensor)\n        and attention_mask.dim() >= 3\n        and attention_mask.shape[0] > 1\n    ):\n        # Avoid SDPA GQA drift for batched masked decode.\n        use_sdpa_gqa = False\n\n    # Grouped query attention\n    _, _, cached_len, _ = Knn.shape\n    if bsz == 1 or ((not use_sdpa_gqa) and n_groups != 1):\n        Knn = Knn[:, :, None, :, :].expand(\n            bsz, n_kv_heads, n_groups, cached_len, head_dim\n        )\n        Vnn = Vnn[:, :, None, :, :].expand(\n            bsz, n_kv_heads, n_groups, cached_len, head_dim\n        )\n        Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim)\n        Vnn = Vnn.reshape(bsz, n_heads, cached_len, head_dim)\n\n    # Attention\n    if bsz == 1:\n        Qn *= self.scalar  # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963\n        # It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows\n        A = torch_matmul(\n            Qn, Knn.transpose(2, 3), out = self.attention[:, :, :, :cached_len]\n        )\n        A[:] = torch_nn_functional_softmax(\n            A, dim = -1, dtype = torch.float32\n        )  # .to(A.dtype)\n        A = torch_matmul(A, Vnn, out = Qn)\n    else:\n        if use_sdpa_gqa:\n            A = scaled_dot_product_attention(\n                Qn,\n                Knn,\n                Vnn,\n                attn_mask = attention_mask,\n                is_causal = is_causal,\n                enable_gqa = True,\n            )\n        else:\n            A = scaled_dot_product_attention(\n                Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = is_causal\n            )\n    A = A.transpose(1, 2)\n    A = A.reshape(bsz, 1, attention_size)\n    A = fast_linear_forward(self.o_proj, A, out = self.temp_O)\n    return A, (Kn, Vn)\n\n\nclass FastQwen3Model(FastLlamaModel):\n    @staticmethod\n    def pre_patch():\n        init_name, function = patch_linear_scaling(\n            model_name = \"Qwen3\",\n            rope_module = LlamaRotaryEmbedding,\n            scaled_rope_module = LlamaLinearScalingRotaryEmbedding,\n            attention_module = Qwen3Attention,\n        )\n        if init_name is not None:\n            exec(function, globals())\n            Qwen3Attention.__init__ = eval(init_name)\n        Qwen3Attention.forward = Qwen3Attention_fast_forward\n        Qwen3SdpaAttention.forward = Qwen3Attention_fast_forward\n        Qwen3FlashAttention2.forward = Qwen3Attention_fast_forward\n        Qwen3DecoderLayer.forward = LlamaDecoderLayer_fast_forward\n        Qwen3Model.forward = LlamaModel_fast_forward\n        Qwen3ForCausalLM.forward = CausalLM_fast_forward(\n            _LlamaModel_fast_forward_inference(Qwen3Attention_fast_forward_inference)\n        )\n        PeftModelForCausalLM.forward = PeftModel_fast_forward\n        fix_prepare_inputs_for_generation(Qwen3ForCausalLM)\n\n        # Solves https://github.com/unslothai/unsloth/issues/168\n        # Static KV Cache was introduced in 4.38.0, causing training to be much slower.\n        # Inference can now be CUDAGraphed, but we shall retain the old rotary embeddings.\n        # https://github.com/huggingface/transformers/pull/27931\n        # https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py\n        import transformers.models.qwen3.modeling_qwen3\n\n        transformers.models.qwen3.modeling_qwen3.Qwen3RotaryEmbedding = (\n            LlamaRotaryEmbedding\n        )\n        return\n\n    @staticmethod\n    def from_pretrained(  # TODO: Change after release\n        model_name = \"Qwen/Qwen3-7B\",\n        max_seq_length = 4096,\n        dtype = None,\n        load_in_4bit = True,\n        token = None,\n        device_map = \"sequential\",\n        rope_scaling = None,\n        fix_tokenizer = True,\n        model_patcher = None,\n        tokenizer_name = None,\n        trust_remote_code = False,\n        **kwargs,\n    ):\n        return FastLlamaModel.from_pretrained(\n            model_name = model_name,\n            max_seq_length = max_seq_length,\n            dtype = dtype,\n            load_in_4bit = load_in_4bit,\n            token = token,\n            device_map = device_map,\n            rope_scaling = rope_scaling,\n            fix_tokenizer = fix_tokenizer,\n            model_patcher = FastQwen3Model,\n            tokenizer_name = tokenizer_name,\n            trust_remote_code = trust_remote_code,\n            **kwargs,\n        )\n"
  },
  {
    "path": "unsloth/models/qwen3_moe.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .llama import *\nimport os\nfrom ._utils import __version__\nfrom .llama import (\n    LlamaRotaryEmbedding,\n    LlamaLinearScalingRotaryEmbedding,\n)\nfrom .qwen3 import (\n    Qwen3Attention_fast_forward,\n    FastQwen3Model,\n)\nfrom transformers.models.qwen3_moe.modeling_qwen3_moe import (\n    Qwen3MoeAttention,\n    Qwen3MoeSparseMoeBlock,\n    Qwen3MoeMLP,\n    Qwen3MoeDecoderLayer,\n    Qwen3MoeModel,\n    Qwen3MoeForCausalLM,\n)\n\n# For Pytorch 2.1.1\n# TODO: Transformers moved to `attention_interface`. So we might not need these anymore\n# try:\n#     from transformers.models.qwen3_moe.modeling_qwen3_moe import (\n#         Qwen3SdpaAttention,\n#         Qwen3FlashAttention2,\n#     )\n# except:\n#     Qwen3SdpaAttention   = Qwen3Attention\n#     Qwen3FlashAttention2 = Qwen3Attention\n# pass\nfrom unsloth_zoo.utils import Version, _get_dtype\n\n\ntorch_nn_functional_softmax = torch.nn.functional.softmax\n\n\ndef Qwen3MoeSparseMoeBlock_fast_forward(self, X, temp_gate = None, temp_up = None):\n    # adapted from https://github.com/huggingface/transformers/pull/36878/files#diff-0855b77fc27ad9449158a1c74953f909b011c00de7125f7c8e68d0ff209c092aR356-R370\n\n    bsz, seq_len, hd = X.shape\n    X = X.view(-1, hd)\n\n    router_logits = fast_linear_forward(\n        self.gate_proj, X, out = temp_gate\n    )  # pretty much the only change from transformers implementation.\n\n    routing_weights = torch_nn_functional_softmax(\n        router_logits, dim = -1, dtype = torch.float32\n    )\n    routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim = -1)\n    routing_weights /= routing_weights.sum(dim = -1, keepdim = True)\n    # we cast back to the input dtype\n    routing_weights = routing_weights.to(X.dtype)\n    final_X = torch.zeros((bsz * seq_len, hd), dtype = torch.float32, device = X.device)\n\n    # One hot encode the selected experts to create an expert mask\n    # this will be used to easily index which expert is going to be sollicitated\n    expert_mask = torch.nn.functional.one_hot(\n        selected_experts, num_classes = self.num_experts\n    ).permute(2, 1, 0)\n\n    # Loop over all available experts in the model and perform the computation on each expert\n    for expert_idx in range(self.num_experts):\n        expert_layer = self.experts[expert_idx]\n        idx, top_x = torch.where(expert_mask[expert_idx])\n\n        # Index the correct hidden states and compute the expert hidden state for\n        # the current expert. We need to make sure to multiply the output hidden\n        # states by `routing_weights` on the corresponding tokens (top-1 and top-2)\n        current_state = X[None, top_x].reshape(-1, hd)\n        current_X = (\n            expert_layer(current_state) * routing_weights[top_x, idx, None]\n        )  # Qwen3MoeMLP.forward = fast_swiglu_inference takes care of making this faster. Analogous to Dense models' MLP\n\n        # However `index_add_` only support torch tensors for indexing so we'll use\n        # the `top_x` tensor here.\n        final_X.index_add_(0, top_x, current_X.to(X.dtype))\n    final_X = final_X.reshape(bsz, seq_len, hd)\n    return final_X, router_logits\n\n\ndef Qwen3MoeDecoderLayer_fast_forward(\n    self,\n    hidden_states: torch.Tensor,\n    causal_mask: Optional[BlockDiagonalCausalMask] = None,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_value: Optional[Tuple[torch.Tensor]] = None,\n    output_attentions: Optional[bool] = False,\n    output_router_logits: Optional[bool] = False,\n    use_cache: Optional[bool] = False,\n    padding_mask: Optional[torch.LongTensor] = None,\n    position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n    *args,\n    **kwargs,\n):\n    residual = hidden_states\n\n    if use_cache and hasattr(\n        self, \"_flag_for_generation\"\n    ):  # past_key_value is not None:\n        residual = hidden_states\n        hidden_states = fast_rms_layernorm_inference(\n            self.input_layernorm, hidden_states\n        )\n        hidden_states, self_attn_weights, present_key_value = self.self_attn(\n            hidden_states = hidden_states,\n            causal_mask = causal_mask,\n            attention_mask = attention_mask,\n            position_ids = position_ids,\n            past_key_value = past_key_value,\n            output_attentions = output_attentions,\n            use_cache = use_cache,\n            padding_mask = padding_mask,\n            position_embeddings = position_embeddings,\n            _flag_for_generation = self._flag_for_generation,\n        )\n        hidden_states += residual\n\n        # MoE Router MLP\n        residual = hidden_states\n        hidden_states = fast_rms_layernorm_inference(\n            self.post_attention_layernorm, hidden_states\n        )\n        hidden_states, router_logits = Qwen3MoeSparseMoeBlock_fast_forward(\n            self.mlp, hidden_states\n        )\n        hidden_states += residual\n    else:\n        residual = hidden_states\n        hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)\n        hidden_states, self_attn_weights, present_key_value = self.self_attn(\n            hidden_states = hidden_states,\n            causal_mask = causal_mask,\n            attention_mask = attention_mask,\n            position_ids = position_ids,\n            past_key_value = past_key_value,\n            output_attentions = output_attentions,\n            use_cache = use_cache,\n            padding_mask = padding_mask,\n            position_embeddings = position_embeddings,\n        )\n        hidden_states = residual + hidden_states\n\n        # MoE Router MLP\n        residual = hidden_states\n        hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states)\n        hidden_states, router_logits = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n\n    outputs = (hidden_states,)\n    if output_attentions:\n        outputs += (self_attn_weights,)\n    if output_router_logits:\n        outputs += (router_logits,)\n    if use_cache:\n        outputs += (present_key_value,)\n    return outputs\n\n\nclass FastQwen3MoeModel(FastQwen3Model):\n    @staticmethod\n    def pre_patch():\n        init_name, function = patch_linear_scaling(\n            model_name = \"Qwen3Moe\",\n            rope_module = LlamaRotaryEmbedding,\n            scaled_rope_module = LlamaLinearScalingRotaryEmbedding,\n            attention_module = Qwen3MoeAttention,\n        )\n        if init_name is not None:\n            exec(function, globals())\n            Qwen3MoeAttention.__init__ = eval(init_name)\n        Qwen3MoeAttention.forward = Qwen3Attention_fast_forward\n        # Qwen3SdpaAttention   .forward = Qwen3Attention_fast_forward\n        # Qwen3FlashAttention2 .forward = Qwen3Attention_fast_forward\n        Qwen3MoeSparseMoeBlock.forward = Qwen3MoeSparseMoeBlock_fast_forward\n        Qwen3MoeMLP.forward = (\n            fast_swiglu_inference  # This is analogous to Dense models' MLP\n        )\n        Qwen3MoeDecoderLayer.forward = Qwen3MoeDecoderLayer_fast_forward\n        Qwen3MoeModel.forward = LlamaModel_fast_forward\n        Qwen3MoeForCausalLM.forward = CausalLM_fast_forward(\n            LlamaModel_fast_forward_inference\n        )\n        PeftModelForCausalLM.forward = PeftModel_fast_forward\n        fix_prepare_inputs_for_generation(Qwen3MoeForCausalLM)\n\n        # Solves https://github.com/unslothai/unsloth/issues/168\n        # Static KV Cache was introduced in 4.38.0, causing training to be much slower.\n        # Inference can now be CUDAGraphed, but we shall retain the old rotary embeddings.\n        # https://github.com/huggingface/transformers/pull/27931\n        # https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py\\\n        import transformers.models.qwen3_moe.modeling_qwen3_moe\n\n        transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeRotaryEmbedding = (\n            LlamaRotaryEmbedding\n        )\n        return\n\n    @staticmethod\n    def from_pretrained(  # TODO: Change after release\n        model_name = \"Qwen/Qwen3-7B\",\n        max_seq_length = 4096,\n        dtype = None,\n        load_in_4bit = True,\n        token = None,\n        device_map = \"sequential\",\n        rope_scaling = None,\n        fix_tokenizer = True,\n        model_patcher = None,\n        tokenizer_name = None,\n        trust_remote_code = False,\n        **kwargs,\n    ):\n        return FastLlamaModel.from_pretrained(\n            model_name = model_name,\n            max_seq_length = max_seq_length,\n            dtype = dtype,\n            load_in_4bit = load_in_4bit,\n            token = token,\n            device_map = device_map,\n            rope_scaling = rope_scaling,\n            fix_tokenizer = fix_tokenizer,\n            model_patcher = FastQwen3MoeModel,\n            tokenizer_name = tokenizer_name,\n            trust_remote_code = trust_remote_code,\n            **kwargs,\n        )\n"
  },
  {
    "path": "unsloth/models/rl.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n__all__ = [\n    \"PatchFastRL\",\n    \"vLLMSamplingParams\",\n]\n\nimport torch\nfrom typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union\nimport inspect\nimport os\nimport re\nfrom unsloth_zoo.compiler import create_new_function\nfrom unsloth_zoo.log import logger\nfrom unsloth_zoo.logging_utils import PatchRLStatistics\nfrom unsloth_zoo.rl_replacements import RL_REPLACEMENTS\nfrom ..device_type import DEVICE_TYPE\nfrom .rl_replacements import (\n    RL_EXTRA_ARGS,\n    RL_FUNCTIONS,\n    RL_PRE_ITEMS,\n    RL_CONFIG_CHANGES,\n    RL_METRICS_CHANGES,\n    RL_ADDITIONAL_FUNCTIONS,\n)\n\ntorch_compile_options = {\n    \"epilogue_fusion\": True,\n    \"max_autotune\": False,  # Disable Triton mm kernels\n    \"shape_padding\": True,\n    \"trace.enabled\": False,\n    \"triton.cudagraphs\": False,\n}\n\n# vLLM compatibility shim (TRL expects GuidedDecodingParams even if vLLM doesn't provide it)\ntry:\n    import vllm.sampling_params as _unsloth_vllm_sp\n\n    if not hasattr(_unsloth_vllm_sp, \"GuidedDecodingParams\"):\n\n        class GuidedDecodingParams:\n            def __init__(self, **kwargs):\n                self.kwargs = kwargs\n\n        _unsloth_vllm_sp.GuidedDecodingParams = GuidedDecodingParams\nexcept Exception:\n    pass\n\nfrom trl import __version__ as trl_version_raw\nfrom importlib.metadata import version as importlib_version\nfrom unsloth_zoo.utils import Version\n\ntry:\n    trl_version = Version(trl_version_raw)\nexcept Exception:\n    try:\n        trl_version = Version(importlib_version(\"trl\"))\n    except Exception:\n        trl_version = Version(\"0.0.0\")\n\n# Get PyTorch version for feature detection\ntry:\n    torch_version = Version(torch.__version__.split(\"+\")[0].split(\"a\")[0].split(\"b\")[0])\nexcept Exception:\n    torch_version = Version(\"0.0.0\")\n\n# Get transformers version for feature detection\ntry:\n    from transformers import __version__ as _transformers_version_raw\n\n    transformers_version = Version(_transformers_version_raw)\nexcept Exception:\n    transformers_version = Version(\"0.0.0\")\n\n\ndef vLLMSamplingParams(**kwargs):\n    from vllm import SamplingParams\n\n    sampling_params = SamplingParams(**kwargs)\n    sampling_params._set_kwargs = kwargs\n    return sampling_params\n\n\ndef PatchRL(FastLanguageModel):\n    try:\n        from trl.models.utils import unwrap_model_for_generation\n    except ImportError:\n        try:\n            from trl.models import unwrap_model_for_generation\n        except ImportError:\n            # Local fallback -- TRL removed or moved this symbol\n            from contextlib import contextmanager as _cm\n\n            @_cm\n            def unwrap_model_for_generation(\n                model, accelerator, gather_deepspeed3_params = True\n            ):\n                unwrapped_model = accelerator.unwrap_model(model)\n                is_gc = getattr(unwrapped_model, \"is_gradient_checkpointing\", False)\n                if is_gc:\n                    unwrapped_model.gradient_checkpointing_disable()\n                if (\n                    getattr(accelerator, \"state\", None) is not None\n                    and getattr(accelerator.state, \"deepspeed_plugin\", None) is not None\n                    and accelerator.state.deepspeed_plugin.zero_stage == 3\n                ):\n                    if not gather_deepspeed3_params:\n                        yield accelerator.unwrap_model(model)\n                    else:\n                        import deepspeed\n\n                        with deepspeed.zero.GatheredParameters(model.parameters()):\n                            yield accelerator.unwrap_model(model)\n                else:\n                    yield unwrapped_model\n                if is_gc:\n                    unwrapped_model.gradient_checkpointing_enable()\n\n    from contextlib import contextmanager\n\n    @contextmanager\n    def unsloth_unwrap_model_for_generation(model, *args, **kwargs):\n        with unwrap_model_for_generation(model, *args, **kwargs) as unwrapped_model:\n            # Put the model in inference mode.\n            FastLanguageModel.for_inference(model)\n\n            # We must use .clone for Unsloth since we force inference_mode\n            # Rather we should have used no_grad\n            original_generate = unwrapped_model.generate\n\n            def generate_with_clone(*args, **kwargs):\n                out = original_generate(*args, **kwargs)\n                if isinstance(out, torch.Tensor):\n                    return out.clone()\n                return out\n\n            unwrapped_model.generate = generate_with_clone\n\n            try:\n                yield unwrapped_model\n            finally:\n                # Restore generate and return\n                unwrapped_model.generate = original_generate\n                FastLanguageModel.for_training(model)\n\n    from transformers import Trainer\n    from transformers.trainer_pt_utils import nested_detach\n\n    @torch.no_grad()\n    def unsloth_prediction_step(\n        self,\n        model,\n        inputs,\n        prediction_loss_only,\n        ignore_keys,\n    ):\n        \"\"\"\n        Perform an evaluation step on `model` using `inputs`.\n        Subclass and override to inject custom behavior.\n        Args:\n            model (`nn.Module`):\n                The model to evaluate.\n            inputs (`Dict[str, Union[torch.Tensor, Any]]`):\n                The inputs and targets of the model.\n                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the\n                argument `labels`. Check your model's documentation for all accepted arguments.\n            prediction_loss_only (`bool`):\n                Whether or not to return the loss only.\n            ignore_keys (`List[str]`, *optional*):\n                A list of keys in the output of your model (if it is a dictionary) that should be ignored when\n                gathering predictions.\n        Return:\n            Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,\n            logits and labels (each being optional).\n        \"\"\"\n        has_labels = (\n            False\n            if len(self.label_names) == 0\n            else all(inputs.get(k) is not None for k in self.label_names)\n        )\n        # For CLIP-like models capable of returning loss values.\n        # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`\n        # is `True` in `model.forward`.\n        return_loss = inputs.get(\"return_loss\", None)\n        if return_loss is None:\n            return_loss = self.can_return_loss\n        loss_without_labels = (\n            True if len(self.label_names) == 0 and return_loss else False\n        )\n\n        inputs = self._prepare_inputs(inputs)\n        if ignore_keys is None:\n            if hasattr(self.model, \"config\"):\n                ignore_keys = getattr(\n                    self.model.config, \"keys_to_ignore_at_inference\", []\n                )\n            else:\n                ignore_keys = []\n\n        # labels may be popped when computing the loss (label smoothing for instance) so we grab them first.\n        if has_labels or loss_without_labels:\n            labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))\n            if len(labels) == 1:\n                labels = labels[0]\n        else:\n            labels = None\n\n        os.environ[\"UNSLOTH_RETURN_LOGITS\"] = \"1\"\n        with torch.no_grad():\n            if has_labels or loss_without_labels:\n                with self.compute_loss_context_manager():\n                    loss, outputs = self.compute_loss(\n                        model, inputs, return_outputs = True\n                    )\n                loss = loss.mean().detach()\n\n                if isinstance(outputs, dict):\n                    logits = tuple(\n                        v for k, v in outputs.items() if k not in ignore_keys + [\"loss\"]\n                    )\n                else:\n                    logits = outputs[1:]\n            else:\n                loss = None\n                with self.compute_loss_context_manager():\n                    tokenized_output = self.processing_class(\n                        inputs[\"prompt\"],\n                        padding = True,\n                        truncation = True,\n                        return_tensors = \"pt\",\n                    ).to(model.device)\n                    outputs = model(**tokenized_output)\n                if isinstance(outputs, dict):\n                    logits = tuple(\n                        v for k, v in outputs.items() if k not in ignore_keys\n                    )\n                else:\n                    logits = outputs\n                # TODO: this needs to be fixed and made cleaner later.\n                if self.args.past_index >= 0:\n                    self._past = outputs[self.args.past_index - 1]\n        os.environ[\"UNSLOTH_RETURN_LOGITS\"] = \"0\"\n        if prediction_loss_only:\n            return (loss, None, None)\n\n        logits = nested_detach(logits)\n        if len(logits) == 1:\n            logits = logits[0]\n\n        return (loss, logits, labels)\n\n    import trl.trainer\n\n    trainers = dir(trl.trainer)\n    trainers = [x for x in trainers if x.endswith(\"_trainer\")]\n    unwrap = \"unwrap_model_for_generation\"\n    for trainer in trainers:\n        try:\n            current_trainer = getattr(trl.trainer, trainer)\n        except:\n            continue\n        if hasattr(current_trainer, unwrap):\n            try:\n                setattr(current_trainer, unwrap, unsloth_unwrap_model_for_generation)\n            except:\n                continue\n    Trainer.prediction_step = unsloth_prediction_step\n\n\ngrpo_selective_log_softmax = RL_REPLACEMENTS[\"grpo_selective_log_softmax\"]\nselective_log_softmax = RL_REPLACEMENTS[\"selective_log_softmax\"]\ncalculate_pad_tokens_in_prompt = RL_REPLACEMENTS[\"calculate_pad_tokens_in_prompt\"]\ncreate_completion_attention_mask = RL_REPLACEMENTS[\"create_completion_attention_mask\"]\nleft_pack_padding = RL_REPLACEMENTS[\"left_pack_padding\"]\nalign_logprobs_with_mask = RL_REPLACEMENTS[\"align_logprobs_with_mask\"]\nautotune_batch_and_chunks = RL_REPLACEMENTS[\"grpo_autotune_batch_and_chunks\"]\nsanitize_logprob = RL_REPLACEMENTS[\"sanitize_logprob\"]\n\nRLTrainer_replacement = '''\nimport os\nimport math\nimport logging\nfrom typing import *\nfrom dataclasses import dataclass, field\nfrom packaging.version import Version\nimport torch\nimport numpy as np\nfrom contextlib import nullcontext\nfrom torch.nn import functional as F\nimport inspect\nfrom transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling\nfrom transformers.training_args import ParallelMode\nfrom unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize\n\n# Wrap trainer with padding to right and enable training mode\n# Also patches W&B since multiple runs must use wandb.finish()\nimport functools\nfrom types import MethodType\ntry:\n    from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers\nexcept:\n    def reset_unsloth_gradient_checkpointing_buffers(): pass\ndef prepare_for_training_mode(f):\n    @functools.wraps(f)\n    def wrapper(self, *args, **kwargs):\n        # Enable training mode\n        _was_training = None\n        # Get gradient checkpointing setting from training arguments\n        use_gc = getattr(self.args, 'gradient_checkpointing', True)\n        if hasattr(self, 'model') and hasattr(self.model, \"training\"):\n            _was_training = self.model.training\n        if hasattr(self, 'model') and hasattr(self.model, \"for_training\"):\n            self.model.for_training(use_gradient_checkpointing=use_gc)\n        output = f(self, *args, **kwargs)\n        # Restore previous mode when possible\n        if hasattr(self, 'model') and hasattr(self.model, \"for_inference\"):\n            if _was_training is False:\n                self.model.for_inference()\n            elif _was_training is True and hasattr(self.model, \"for_training\"):\n                self.model.for_training(use_gradient_checkpointing=use_gc)\n        # Reset gradient checkpointing buffers to free memory while staying ready for next run\n        try:\n            reset_unsloth_gradient_checkpointing_buffers()\n        except:\n            pass\n        # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run\n        try:\n            import wandb\n            wandb.finish()\n        except:\n            pass\n        return output\n    return wrapper\npass\n\ntorch_compile_options = {{\n    \"epilogue_fusion\"   : True,\n    \"max_autotune\"      : False,\n    \"shape_padding\"     : True,\n    \"trace.enabled\"     : False,\n    \"triton.cudagraphs\" : False,\n}}\n\n{grpo_selective_log_softmax_code}\n{selective_log_softmax_code}\n{calculate_pad_tokens_in_prompt_code}\n{create_completion_attention_mask_code}\n{left_pack_padding_code}\n{align_logprobs_with_mask_code}\n{autotune_batch_and_chunks_code}\n{sanitize_logprob_code}\n\n{RL_pre}\n\n@dataclass\nclass Unsloth{RLConfig_name}({RLConfig_name}):\n    \"\"\"\n    {__RLConfig_doc__}\n    \"\"\"\n    vllm_sampling_params: Optional[Any] = field(\n        default = None,\n        metadata = {{'help': 'vLLM SamplingParams'}},\n    )\n    unsloth_num_chunks : Optional[int] = field(\n        default = -1,\n        metadata = {{'help': 'Chunk size to reduce memory usage. -1 is most efficient.'}},\n    )\n    unsloth_logit_chunk_multiplier : Optional[int] = field(\n            default = None,\n            metadata = {{'help': 'Multiplier for chunked logit computations.'}},\n        )\n    unsloth_grpo_mini_batch : Optional[int] = field(\n        default = None,\n        metadata = {{'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'}},\n    )\n    {max_seq_length_pre}\n    def __init__({RLConfig_arguments},\n        vllm_sampling_params = None,\n        unsloth_num_chunks = -1,\n        unsloth_logit_chunk_multiplier = None,\n        unsloth_grpo_mini_batch = None,\n        {max_seq_length_call}\n        **kwargs,\n    ):\n{RLConfig_extra_args}\n        super().__init__({RLConfig_call_args}{RLConfig_kwargs})\n        self.vllm_sampling_params = vllm_sampling_params\n        self.unsloth_num_chunks = unsloth_num_chunks\n        if unsloth_grpo_mini_batch is not None:\n            if self.generation_batch_size >= unsloth_grpo_mini_batch:\n                self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch\n            else:\n                raise ValueError(\n                    f\"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, \"\n                    f\"which is self.per_device_train_batch_size * gradient_accumulation_steps.\"\n                )\n        self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier\n        {max_seq_length_post}\n{RLConfig_post}\npass\n\n{RLTrainer_extras}\n\nclass Unsloth{RLTrainer_name}(_Unsloth{RLTrainer_name}):\n    \"\"\"\n    {__RLTrainer_doc__}\n    \"\"\"\n    def __init__({RLTrainer_arguments},\n        **kwargs\n    ):\n        if args is None: args = Unsloth{RLConfig_name}()\n{RLTrainer_extra_args}\n        # [TODO] Fix up DataParallel multiplying batch sizes\n        # [TODO] DDP works, but DP seems to not work? [TODO]\n        if getattr(args, \"parallel_mode\", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1:\n            if getattr(args, \"_n_gpu\", 1) != 1:\n                args._n_gpu = 1\n        if \"model\" in locals() and hasattr(model, \"for_training\"):\n            model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))\n        super().__init__({RLTrainer_call_args}{RLTrainer_kwargs})\n        if \"model\" in locals() and hasattr(model, \"for_inference\"):\n            model.for_inference()\n{RLTrainer_post}\npass\n'''\n\n\ndef _wrap_grpo_generate_and_score(trainer_cls):\n    if not hasattr(trainer_cls, \"_generate_and_score_completions\"):\n        return\n    original = trainer_cls._generate_and_score_completions\n    if getattr(original, \"_unsloth_restore_training_wrapped\", False):\n        return\n\n    def wrapped(self, *args, **kwargs):\n        was_training = getattr(getattr(self, \"model\", None), \"training\", None)\n        try:\n            return original(self, *args, **kwargs)\n        finally:\n            if (\n                was_training is False\n                and hasattr(self, \"model\")\n                and hasattr(self.model, \"for_inference\")\n            ):\n                try:\n                    self.model.for_inference()\n                except Exception:\n                    pass\n\n    wrapped._unsloth_restore_training_wrapped = True\n    trainer_cls._generate_and_score_completions = wrapped\n\n\ndef _patch_trl_rl_trainers(trainer_file = \"grpo_trainer\"):\n    # Patch for vLLM and Unsloth PEFT\n    import trl\n    import trl.trainer\n\n    try:\n        trainer = eval(f\"trl.trainer.{trainer_file}\")\n    except Exception as error:\n        logger.info(f\"Unsloth: Could not import trl.trainer.{trainer_file}: {error}\")\n        return\n\n    # Get SFTTrainer and SFTConfig names\n    name = [\n        x\n        for x in dir(trainer)\n        if x.endswith(\"Trainer\")\n        and x != \"Trainer\"\n        and not x.startswith(\"_\")\n        and trainer_file.split(\"_\")[0] in x.lower()\n    ]\n    config = [\n        x\n        for x in dir(trainer)\n        if x.endswith(\"Config\")\n        and x != \"Config\"\n        and not x.startswith(\"_\")\n        and trainer_file.split(\"_\")[0] in x.lower()\n    ]\n    if len(name) != 1:\n        logger.info(\n            f\"Unsloth: Could not find Trainer class in trl.trainer.{trainer_file}. Found: {name}\"\n        )\n        return\n    if len(config) != 1:\n        # TRL 0.26+: Config may be in a separate *_config.py module\n        config_module_name = trainer_file.replace(\"_trainer\", \"_config\")\n        try:\n            config_mod = eval(f\"trl.trainer.{config_module_name}\")\n            config = [\n                x\n                for x in dir(config_mod)\n                if x.endswith(\"Config\")\n                and x != \"Config\"\n                and not x.startswith(\"_\")\n                and trainer_file.split(\"_\")[0] in x.lower()\n            ]\n        except Exception:\n            pass\n    if len(config) != 1 and len(name) == 1:\n        # Thin wrapper fallback: walk the Trainer's MRO to find Config\n        # in the real implementation module (e.g., trl.experimental.bco)\n        try:\n            _temp_cls = eval(f\"trl.trainer.{trainer_file}.{name[0]}\")\n            for _parent in _temp_cls.__mro__[1:]:\n                if _parent is object:\n                    continue\n                _parent_mod = inspect.getmodule(_parent)\n                if (\n                    _parent_mod is None\n                    or _parent_mod.__name__ == f\"trl.trainer.{trainer_file}\"\n                ):\n                    continue\n                config = [\n                    x\n                    for x in dir(_parent_mod)\n                    if x.endswith(\"Config\")\n                    and x != \"Config\"\n                    and not x.startswith(\"_\")\n                    and trainer_file.split(\"_\")[0] in x.lower()\n                ]\n                if len(config) == 1:\n                    break\n        except Exception:\n            pass\n    if len(config) != 1:\n        logger.info(\n            f\"Unsloth: Could not find Config class in trl.trainer.{trainer_file}. Found: {config}\"\n        )\n        return\n\n    # Get SFTTrainer, SFTConfig\n    RLTrainer_name = name[0]\n    RLConfig_name = config[0]\n    try:\n        RLTrainer = eval(f\"trl.trainer.{trainer_file}.{RLTrainer_name}\")\n    except Exception as e:\n        logger.info(\n            f\"Unsloth: Could not load {RLTrainer_name} from trl.trainer.{trainer_file}: {e}\"\n        )\n        return\n    _config_resolved_module = None\n    try:\n        RLConfig = eval(f\"trl.trainer.{trainer_file}.{RLConfig_name}\")\n    except Exception:\n        # TRL 0.26+: Config may be in a separate *_config.py module\n        try:\n            config_module_name = trainer_file.replace(\"_trainer\", \"_config\")\n            RLConfig = eval(f\"trl.trainer.{config_module_name}.{RLConfig_name}\")\n        except Exception:\n            # Thin wrapper fallback: load Config from parent trainer's module\n            _config_loaded = False\n            try:\n                _temp_cls = eval(f\"trl.trainer.{trainer_file}.{name[0]}\")\n                for _parent in _temp_cls.__mro__[1:]:\n                    if _parent is object:\n                        continue\n                    _parent_mod = inspect.getmodule(_parent)\n                    if (\n                        _parent_mod is None\n                        or _parent_mod.__name__ == f\"trl.trainer.{trainer_file}\"\n                    ):\n                        continue\n                    if hasattr(_parent_mod, RLConfig_name):\n                        RLConfig = getattr(_parent_mod, RLConfig_name)\n                        _config_resolved_module = _parent_mod\n                        _config_loaded = True\n                        break\n            except Exception:\n                pass\n            if not _config_loaded:\n                logger.info(f\"Unsloth: Could not load {RLConfig_name}\")\n                return\n\n    # Check name\n    if RLTrainer.__name__.startswith(\"Unsloth\"):\n        print(f\"Unsloth: {RLTrainer.__name__} is already patched.\")\n        return\n    if RLConfig.__name__.startswith(\"Unsloth\"):\n        print(f\"Unsloth: {RLConfig.__name__} is already patched.\")\n        return\n\n    # TRL 0.26+: Resolve thin wrappers to their experimental parent class.\n    # Thin wrappers are deprecation shims in trl.trainer that just forward\n    # *args/**kwargs to the real implementation in trl.experimental.\n    # Only resolve if a parent class actually lives in a trl.experimental module.\n    _trainer_resolved_module = None\n    try:\n        _trainer_src = inspect.getsource(RLTrainer)\n        _trainer_module = inspect.getmodule(RLTrainer)\n        _trainer_module_src = (\n            inspect.getsource(_trainer_module) if _trainer_module else \"\"\n        )\n        if (\n            \"trl.experimental\" in _trainer_src\n            or \"trl.experimental\" in _trainer_module_src\n        ):\n            for _parent in RLTrainer.__mro__[1:]:\n                if _parent is object:\n                    continue\n                _parent_mod = inspect.getmodule(_parent)\n                if _parent_mod is None:\n                    continue\n                # Only resolve to a parent that lives in trl.experimental\n                if \"trl.experimental\" in _parent_mod.__name__:\n                    RLTrainer = _parent\n                    _trainer_resolved_module = _parent_mod\n                    break\n    except Exception:\n        pass\n\n    try:\n        _config_src = inspect.getsource(RLConfig)\n        _config_module = inspect.getmodule(RLConfig)\n        _config_module_src = inspect.getsource(_config_module) if _config_module else \"\"\n        if (\n            \"trl.experimental\" in _config_src\n            or \"trl.experimental\" in _config_module_src\n        ):\n            for _parent in RLConfig.__mro__[1:]:\n                if _parent is object:\n                    continue\n                _parent_mod = inspect.getmodule(_parent)\n                if _parent_mod is None:\n                    continue\n                # Only resolve to a parent that lives in trl.experimental\n                if \"trl.experimental\" in _parent_mod.__name__:\n                    RLConfig = _parent\n                    break\n    except Exception:\n        pass\n\n    # Get old source\n    old_RLTrainer_source = inspect.getsource(RLTrainer)\n    old_RLConfig_source = inspect.getsource(RLConfig)\n\n    if _trainer_resolved_module is not None:\n        all_imports = dir(_trainer_resolved_module)\n    elif _config_resolved_module is not None:\n        all_imports = dir(_config_resolved_module)\n    else:\n        all_imports = dir(trainer)\n    # Fix _deprecate_arguments not getting imported so stop __ but not _\n    imports = [x for x in all_imports if not x.startswith(\"__\")]\n\n    # Get default arguments\n    EMPTY = inspect.Parameter.empty\n    processed = []\n    for RLobject in [RLTrainer, RLConfig]:\n        parameters = inspect.signature(RLobject.__init__).parameters\n        types = (\n            bool,\n            type(None),\n            int,\n            float,\n            str,\n        )\n        arguments = [\"self\"]\n        call_args = []\n        for k, v in parameters.items():\n            if k == \"self\":\n                continue\n            v = v.default\n            if v == \"\\n\":\n                v = re.escape(\"\\n\")\n            if v is EMPTY:\n                arguments.append(k)\n            elif type(v) is str:\n                arguments.append(f\"{k} = '{v}'\")\n            elif type(v) in types:\n                arguments.append(f\"{k} = {v}\")\n            else:\n                continue\n            call_args.append(f\"{k} = {k}\")\n        arguments = f\"\\n{' '*8}\" + f\",\\n{' '*8}\".join(arguments)\n        call_args = f\"\\n{' '*12}\" + f\",\\n{' '*12}\".join(call_args)\n        processed.append(\n            (\n                arguments,\n                call_args,\n            )\n        )\n\n    # Process RLTrainer first\n    arguments, call_args = processed[0]\n    RLTrainer_post = \"\"\n\n    # Add tokenizer if not seen\n    if \"tokenizer\" not in parameters and \"processing_class\" in parameters:\n        arguments += f\",\\n{' '*8}tokenizer = None\"\n        call_args = call_args.replace(\n            \"processing_class = processing_class\",\n            \"processing_class = tokenizer if tokenizer is not None else processing_class\",\n        )\n\n    # Edit bf16, fp16 by checking model's dtype/torch_dtype directly\n    extra_args = \"\"\n    if \"args\" in call_args and \"model\" in call_args:\n        mixed_precision = (\n            \"use_bf16 = getattr(args, 'bf16', False)\\n\"\n            \"if type(use_bf16) is not bool: use_bf16 = False\\n\"\n            \"use_fp16 = getattr(args, 'fp16', False)\\n\"\n            \"if type(use_fp16) is not bool: use_fp16 = False\\n\"\n            \"force_float32 = False\\n\"\n            \"full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1'\\n\"\n            \"if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'):\\n\"\n            \"    print('Unsloth: Switching to float32 training since model cannot work with float16')\\n\"\n            \"    force_float32 = True\\n\"\n            \"mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')\\n\"\n            \"dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None)\\n\"\n            \"if dtype is None: dtype = model.get_input_embeddings().weight.dtype\\n\"\n            \"from unsloth_zoo.utils import _get_dtype\\n\"\n            \"dtype = _get_dtype(dtype)\\n\"\n            \"float16 = dtype == torch.float16\\n\"\n            \"if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')\\n\"\n            \"if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')\\n\"\n            \"if force_float32:\\n\"\n            \"    # Forced float32 training\\n\"\n            \"    args.fp16 = False\\n\"\n            \"    args.bf16 = False\\n\"\n            \"    os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'\\n\"\n            \"    if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'\\n\"\n            \"    # args.mixed_precision is a new argument which needs to be set now\\n\"\n            \"elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':\\n\"\n            \"    # Mixed precision training\\n\"\n            \"    args.fp16 = float16\\n\"\n            \"    args.bf16 = not float16\\n\"\n            \"    os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'\\n\"\n            \"    if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16'\\n\"\n            \"    # args.mixed_precision is a new argument which needs to be set now\\n\"\n            \"elif mixed_precision_dtype == 'bfloat16':\\n\"\n            \"    # Both False since bfloat16 full finetuning doesn't do any autocasting.\\n\"\n            \"    args.fp16 = False\\n\"\n            \"    args.bf16 = False\\n\"\n            \"    os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'\\n\"\n            \"    if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'\\n\"\n            \"    # args.mixed_precision is a new argument which needs to be set now\\n\"\n            \"\\n\"\n        )\n        extra_args += mixed_precision\n\n    # Check if per_device_eval_batch_size (default 8) bigger than bsz\n    # Also use FP16 / BF16 evaluation\n    if \"args\" in call_args:\n        # Check eval_dataset first\n        if \"eval_dataset\" in call_args:\n            check_eval_dataset = (\n                \"if getattr(args, 'eval_dataset', None) is not None and \"\n                \"getattr(args, 'eval_strategy', 'no') == 'no':\\n\"\n                \"    args.eval_strategy = 'steps'\\n\"\n                \"    if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1\\n\"\n            )\n            extra_args += check_eval_dataset\n\n        # Check if gradient accumulation bug fix is applied\n        check_ga = (\n            \"ga_steps = getattr(args, 'gradient_accumulation_steps', None)\\n\"\n            \"if ga_steps is not None and ga_steps > 1:\\n\"\n            \"    from transformers import __version__ as transformers_version\\n\"\n            \"    if Version(transformers_version) <= Version('4.45.2'):\\n\"\n            \"        print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\\\\n'\\n\"\n            \"              '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')\\n\"\n        )\n        extra_args += check_ga\n\n        eval_changes = (\n            \"if getattr(args, 'eval_strategy', 'no') != 'no':\\n\"\n            \"    eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)\\n\"\n            \"    if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size\\n\"\n            \"    if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps\\n\"\n            \"fp16_full_eval = getattr(args, 'fp16_full_eval', False)\\n\"\n            \"if type(fp16_full_eval) is not bool: fp16_full_eval = False\\n\"\n            \"bf16_full_eval = getattr(args, 'bf16_full_eval', False)\\n\"\n            \"if type(bf16_full_eval) is not bool: bf16_full_eval = False\\n\"\n            \"if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True\\n\"\n            \"if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False\\n\"\n            \"if force_float32:\\n\"\n            \"    args.bf16_full_eval = False\\n\"\n            \"    args.fp16_full_eval = False\\n\"\n            \"elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':\\n\"\n            \"    args.bf16_full_eval = True\\n\"\n            \"    args.fp16_full_eval = False\\n\"\n            \"elif not bf16_full_eval and not fp16_full_eval:\\n\"\n            \"    args.bf16_full_eval = args.bf16\\n\"\n            \"    args.fp16_full_eval = args.fp16\\n\"\n        )\n        extra_args += eval_changes\n\n    # Force logits to be produced if preprocess_logits_for_metrics or compute_metrics is used\n    if \"model\" in call_args:\n        logits_check = (\n            \"_output_logits = False\\n\"\n            \"if locals().get('compute_metrics', None) is not None: _output_logits = True\\n\"\n            \"if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True\\n\"\n            \"if _output_logits:\\n\"\n            \"    os.environ['UNSLOTH_RETURN_LOGITS'] = '1'\\n\"\n        )\n        extra_args += logits_check\n        warnings_issued_check = (\n            \"if model is not None:\\n\"\n            \"    _warnings_issued = getattr(model, 'warnings_issued', None)\\n\"\n            \"    if _warnings_issued is None:\\n\"\n            \"        model.warnings_issued = {}\\n\"\n            \"    elif not isinstance(_warnings_issued, dict):\\n\"\n            \"        try:\\n\"\n            \"            model.warnings_issued = dict(_warnings_issued)\\n\"\n            \"        except Exception:\\n\"\n            \"            model.warnings_issued = {}\\n\"\n        )\n        extra_args += warnings_issued_check\n\n    # Check max_seq_length\n    if \"model\" in call_args:\n        length_check = (\n            \"if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):\\n\"\n            \"    pass\\n\"\n            \"else:\\n\"\n            \"    model_max_seq_length = getattr(model, 'max_seq_length', None)\\n\"\n            \"    args_max_seq_length  = getattr(args,  'max_seq_length', None)\\n\"\n            \"    if args_max_seq_length is None and model_max_seq_length is not None:\\n\"\n            \"        max_seq_length = model.max_seq_length\\n\"\n            \"        if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length\\n\"\n            \"    elif args_max_seq_length is not None and model_max_seq_length is not None:\\n\"\n            \"        if args_max_seq_length > model_max_seq_length:\\n\"\n            \"            print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but '\\n\"\n            \"                   'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.')\\n\"\n            \"            args.max_seq_length = model_max_seq_length\\n\"\n        )\n        extra_args += length_check\n\n        # At this point max_seq_length might be set, but trl is moving to max_length\n        if trainer_file == \"sft_trainer\":\n            max_length_check = (\n                \"if 'max_length' not in locals() and not hasattr(args, 'max_length'):\\n\"\n                \"    pass\\n\"\n                \"else:\\n\"\n                \"    if hasattr(args, 'max_seq_length') and args.max_seq_length is not None and args.max_seq_length > 0:\\n\"\n                \"        if hasattr(args, 'max_length'):\\n\"\n                \"            args.max_length = args.max_seq_length\\n\"\n                \"            max_length = args.max_length\\n\"\n                \"    else:\\n\"\n                \"        model_max_length = getattr(model, 'max_seq_length', None)\\n\"\n                \"        if model_max_length is None: model_max_length = getattr(model, 'max_length', None)\\n\"\n                \"        if model_max_length is not None:\\n\"\n                \"            args.max_length = model_max_length\\n\"\n                \"            max_length = args.max_length\\n\"\n                \"        elif hasattr(args, 'max_length') and args.max_length is not None:\\n\"\n                \"            max_length = args.max_length\\n\"\n                \"            # if we are here, then we are in a weird case where max_length is set but max_seq_length is not set\\n\"\n                \"            setattr(model, 'max_seq_length', max_length)\\n\"\n                \"        else:\\n\"\n                \"            print('Unsloth: We did not find `max_seq_length` or `max_length` in the model or args. We will set it to 1024.')\\n\"\n                \"            args.max_length = 1024\\n\"\n            )\n            extra_args += max_length_check\n\n    # Enable for training and move padding side of tokenizer to right\n    if \"model\" in call_args:\n        training_check = (\n            \"if model is not None and hasattr(model, 'for_training'):\\n\"\n            \"    model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))\\n\"\n            \"if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'\\n\"\n            \"if 'processing_class' in locals():\\n\"\n            \"    if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'\\n\"\n            \"    if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): \"\n            \"processing_class.tokenizer.padding_side = 'right'\\n\"\n        )\n        extra_args += training_check\n\n    # Check data collator if it's correct!\n    if \"data_collator\" in call_args and \"train_dataset\" in call_args:\n        data_collator_check = (\n            \"__tokenizer = processing_class if 'processing_class' in locals() else tokenizer\\n\"\n            \"from unsloth_zoo.vision_utils import UnslothVisionDataCollator\\n\"\n            \"if not isinstance(data_collator, UnslothVisionDataCollator):\\n\"\n            \"    if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:\\n\"\n            \"        data_collator = TransformersDataCollatorForLanguageModeling(\\n\"\n            \"            __tokenizer,\\n\"\n            \"            mlm = False,\\n\"\n            \"            mlm_probability = 0.0,\\n\"\n            \"            pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),\\n\"\n            \"        )\\n\"\n            \"    elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:\\n\"\n            \"        data_collator = DataCollatorForSeq2Seq(\\n\"\n            \"            __tokenizer,\\n\"\n            \"            pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),\\n\"\n            \"        )\\n\"\n            \"else:\\n\"\n            \"    if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False\\n\"\n            \"    if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''\\n\"\n            \"    if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}\\n\"\n        )\n        extra_args += data_collator_check\n\n        # Also check if .pad exists -> if not, and is VLM, then change it!\n        pad_check = (\n            \"if not isinstance(data_collator, UnslothVisionDataCollator):\\n\"\n            \"    if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):\\n\"\n            \"        if isinstance(data_collator, DataCollatorForSeq2Seq):\\n\"\n            \"            data_collator = DataCollatorForSeq2Seq(\\n\"\n            \"                __tokenizer.tokenizer,\\n\"\n            \"                pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),\\n\"\n            \"            )\\n\"\n            \"        else:\\n\"\n            \"            data_collator = TransformersDataCollatorForLanguageModeling(\\n\"\n            \"                __tokenizer.tokenizer,\\n\"\n            \"                mlm = False,\\n\"\n            \"                mlm_probability = 0.0,\\n\"\n            \"                pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),\\n\"\n            \"            )\\n\"\n        )\n        extra_args += pad_check\n\n    # Check NEFTune\n    if \"model\" in call_args:\n        neftune_check = (\n            \"if hasattr(self, 'neftune_hook_handle'):\\n\"\n            \"    self.neftune_hook_handle.remove()\\n\"\n            \"    if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle\\n\"\n            \"if getattr(args, 'neftune_noise_alpha', None) is not None:\\n\"\n            \"    model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha\\n\"\n            \"pass\\n\"\n        )\n        RLTrainer_post += neftune_check\n\n    # Add accelerator scaler to model\n    if \"model\" in call_args:\n        accelerator_check = (\n            \"if hasattr(self, 'accelerator'):\\n\"\n            \"    scaler = self.accelerator.scaler\\n\"\n            \"    current_model = model\\n\"\n            \"    while hasattr(current_model, 'model'):\\n\"\n            \"        current_model.accelerator_scaler = scaler\\n\"\n            \"        current_model = current_model.model\\n\"\n            \"    current_model.accelerator_scaler = scaler\\n\"\n            \"pass\\n\"\n        )\n        RLTrainer_post += accelerator_check\n\n    # Add enabling and disabling training modes\n    if \"model\" in call_args:\n        training_check = (\n            \"if hasattr(self, 'train'):\\n\"\n            \"    self.train = MethodType(prepare_for_training_mode(self.__class__.train), self)\\n\"\n            \"pass\\n\"\n        )\n        RLTrainer_post += training_check\n\n    # Sync chat_template from processing_class to vLLM's tokenizer\n    # This fixes base models that have custom chat templates applied after loading\n    if \"model\" in call_args:\n        vllm_chat_template_sync = (\n            \"if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'):\\n\"\n            \"    _vllm_tok = self.llm.get_tokenizer()\\n\"\n            \"    _pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None)\\n\"\n            \"    if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None:\\n\"\n            \"        _vllm_tok.chat_template = _pc.chat_template\\n\"\n            \"pass\\n\"\n        )\n        RLTrainer_post += vllm_chat_template_sync\n\n    # Edit optional metrics\n    other_metrics_processor = \"\"\n    if trainer_file in RL_METRICS_CHANGES:\n        process_extra_args = RL_METRICS_CHANGES[trainer_file]\n        for process_extra_arg in process_extra_args:\n            other_metrics_processor += process_extra_arg(\n                old_RLTrainer_source, old_RLConfig_source\n            )\n\n    # Add statistics as well!\n    extra_args += (\n        \"other_metrics = []\\n\"\n        f\"{other_metrics_processor}\\n\"\n        \"from unsloth_zoo.logging_utils import PatchRLStatistics\\n\"\n        f\"PatchRLStatistics('{trainer_file}', other_metrics)\\n\"\n    )\n\n    # Patch optional args\n    if trainer_file in RL_EXTRA_ARGS:\n        process_extra_args = RL_EXTRA_ARGS[trainer_file]\n        for process_extra_arg in process_extra_args:\n            extra_args += process_extra_arg(call_args, extra_args)\n\n    # Create RLTrainer args\n    extra_args = extra_args.split(\"\\n\")\n    extra_args = \"\\n\".join(\" \" * 8 + x for x in extra_args)\n    RLTrainer_post = RLTrainer_post.split(\"\\n\")\n    RLTrainer_post = \"\\n\".join(\" \" * 8 + x for x in RLTrainer_post)\n    RLTrainer_arguments = arguments\n    RLTrainer_extra_args = extra_args\n    RLTrainer_call_args = call_args\n\n    # Fix RLConfig next\n    arguments, call_args = processed[1]\n    extra_args = \"\"\n\n    # Edit GA / bsz and weight_decay\n    replacements = {\n        \"output_dir\": None,\n        \"logging_nan_inf_filter\": False,\n        \"per_device_train_batch_size\": 4,\n        \"gradient_accumulation_steps\": 2,\n        \"weight_decay\": 0.01,\n        \"seed\": 3407,\n        \"optim\": \"adamw_8bit\",\n        \"learning_rate\": 5e-05,\n        \"per_device_eval_batch_size\": 4,\n        \"eval_accumulation_steps\": 2,\n        \"torch_empty_cache_steps\": 250,\n        \"logging_steps\": 1,\n        \"max_seq_length\": None,\n        \"num_generations\": 8,\n        # \"steps_per_generation\"          : 1, # Otherwise defaults to ga_steps which is wrong\n        # \"generation_batch_size\"         : None, # Useless. If steps_per_generation set, generation_batch_size clashes\n        \"top_k\": None,\n        \"vllm_mode\": \"colocate\",\n        \"generation_kwargs\": {},\n        \"bf16\": False,\n        \"fp16\": False,\n        \"report_to\": \"none\",\n        \"include_tokens_per_second\": False,\n        \"include_num_input_tokens_seen\": False,\n        \"auto_find_batch_size\": False,  # Auto /2 batch size - too many people complained so removing\n        \"dataloader_pin_memory\": True,\n        \"padding_free\": None,  # None = user didn't set it, allows auto-enable detection\n        # Might fail so disable for now\n        # \"dataloader_persistent_workers\" : True, # Keeps dataloader in RAM\n        # \"dataloader_prefetch_factor\"    : 2,\n        # \"dataloader_num_workers\"        : 2, # Default is 0 means 1\n    }\n    # warmup_ratio deprecated in transformers >= 5.0; warmup_steps accepts float\n    if transformers_version >= Version(\"5.0.0\"):\n        replacements[\"warmup_steps\"] = 0.1\n    else:\n        replacements[\"warmup_ratio\"] = 0.1\n\n    for k, v in replacements.items():\n        x = f\"{k}( = [^,\\n]{{1,}})?,\\n\"\n        y = f\"'{v}'\" if type(v) is str else f\"{v}\"\n        y = f\"{k} = {y},\\n\"\n        arguments = re.sub(x, y, arguments)\n\n    # Fix GRPO beta default as 0.001 TRL used to be 0.04, now 0.00!\n    # https://github.com/huggingface/trl/pull/3516\n    # https://verl.readthedocs.io/en/latest/examples/config.html\n    if trainer_file == \"grpo_trainer\":\n        replacements = {\n            \"loss_type\": \"bnpo\",  # Default GRPO paper\n            \"beta\": 0.001,  # Recommended as seen in verl\n            \"auto_find_batch_size\": False,  # Cannot work on GRPO\n            # [TODO] See https://fengyao.notion.site/off-policy-rl\n            # https://github.com/huggingface/trl/pull/3867 (August 7th)\n            \"vllm_importance_sampling_correction\": False,\n        }\n        for k, v in replacements.items():\n            x = f\"{k}( = [^,\\n]{{1,}})?,\\n\"\n            y = f\"'{v}'\" if type(v) is str else f\"{v}\"\n            y = f\"{k} = {y},\\n\"\n            arguments = re.sub(x, y, arguments)\n\n    # Warn on too large or too small learning rate\n    if \"learning_rate\" in call_args:\n        learning_rate_check = (\n            \"if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! \"\n            \"Consider increasing it, otherwise gradient updates will be close to 0!')\\n\"\n            \"if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! \"\n            \"Consider decreasing it to 1e-1, otherwise gradient updates will explode!')\\n\"\n        )\n        extra_args += learning_rate_check\n\n    # Fix num_train_epochs = None causing TypeError in Trainer.__init__\n    # Trainer does `args.num_train_epochs > 0` which fails when None\n    if \"num_train_epochs\" in call_args:\n        num_train_epochs_check = (\n            \"if num_train_epochs is None:\\n\"\n            \"    num_train_epochs = 3.0  # Default to 3 epochs if None, max_steps will override\\n\"\n        )\n        extra_args += num_train_epochs_check\n\n    # Check if max_seq_length is NOT defined (max_length is now default)\n    if \"max_seq_length\" not in call_args and \"max_length\" in call_args:\n        max_seq_length_pre = \"\"\"max_seq_length : Optional[int] = field(\n        default = None,\n        metadata = {'help': 'Maximum sequence length to truncate to.'},\n    )\"\"\"\n        max_seq_length_call = \"max_seq_length = None,\"\n        max_seq_length_post = \"self.max_seq_length = max_seq_length\"\n    else:\n        max_seq_length_pre = \"\"\n        max_seq_length_call = \"\"\n        max_seq_length_post = \"\"\n\n    # Add output_dir saving\n    if \"output_dir\" in call_args:\n        # Default checks\n        saving_check = (\n            \"if output_dir is None and save_strategy == 'steps' and save_steps == 500:\\n\"\n            \"    output_dir = 'unsloth_training_checkpoints'\\n\"\n            \"    save_strategy = 'no'\\n\"\n        )\n        extra_args += saving_check\n\n    # Edit dataset_num_proc\n    if \"dataset_num_proc\" in call_args:\n        num_proc_check = (\n            \"import multiprocessing as _mp\\n\"\n            \"if dataset_num_proc is None:\\n\"\n            \"    if _mp.get_start_method() != 'fork':\\n\"\n            \"        dataset_num_proc = None\\n\"\n            \"    else:\\n\"\n            \"        import psutil\\n\"\n            \"        dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64)\\n\"\n            \"        memory_gb_left = psutil.virtual_memory().available / (1024**3)\\n\"\n            \"        if memory_gb_left <= 2: dataset_num_proc = 1\\n\"\n            \"        else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left))\\n\"\n        )\n        extra_args += num_proc_check\n\n    # Add padding if flex attention is added\n    if \"pad_to_multiple_of\" in call_args:\n        pad_to_multiple_of = (\n            \"if os.environ.get('UNSLOTH_ENABLE_FLEX_ATTENTION', '0') == '1':\\n\"\n            \"    from unsloth_zoo.flex_attention import HAS_FLEX_ATTENTION\\n\"\n            \"    if HAS_FLEX_ATTENTION and pad_to_multiple_of is None:\\n\"\n            \"        from unsloth_zoo.flex_attention import FLEX_ATTENTION_BLOCK_SIZE\\n\"\n            \"        pad_to_multiple_of = FLEX_ATTENTION_BLOCK_SIZE\\n\"\n            \"\\n\"\n        )\n        extra_args += pad_to_multiple_of\n\n    # Check for loss_type = dr_grpo and scale_rewards for GRPO\n    if \"loss_type\" in call_args and \"scale_rewards\" in call_args:\n        # See https://github.com/huggingface/trl/issues/3130#issuecomment-2746947835\n        # DAPO uses per token loss so BNPO loss used\n        check_dr_grpo = (\n            \"if loss_type.lower() == 'dr_grpo':\\n\"\n            \"    loss_type = 'dr_grpo'\\n\"\n            \"elif loss_type.lower() == 'dapo':\\n\"\n            \"    loss_type = 'dapo'\\n\"\n            \"if loss_type.lower() == 'dr_grpo':\\n\"\n            \"    if scale_rewards == None:\\n\"\n            \"        scale_rewards = True\\n\"\n            \"    elif scale_rewards == True:\\n\"\n            \"        print('Unsloth: The Dr GRPO paper recommends setting `scale_rewards` to False! Will override. Set it to `None` to force False.')\\n\"\n            \"        scale_rewards = False\\n\"\n            \"elif loss_type.lower() == 'dapo':\\n\"\n            \"    if mask_truncated_completions != True:\\n\"\n            \"        print('Unsloth: The DAPO paper recommends `mask_truncated_completions = True` - we will set it.')\\n\"\n            \"    if epsilon_high != 0.28:\\n\"\n            \"        print('Unsloth: The DAPO paper recommends `epsilon_high = 0.28` - we will set it.')\\n\"\n            \"    if beta != 0.0:\\n\"\n            \"        print(f'[WARNING] Unsloth: The DAPO paper recommends setting `beta = 0.0` to remove the KL term - You have set it to {beta}.')\\n\"\n            \"    mask_truncated_completions = True\\n\"\n            \"    epsilon_high = 0.28\\n\"\n            \"\\n\"\n        )\n        extra_args += check_dr_grpo\n\n    # Check GRPO num_generations mismatch\n    if (\n        \"per_device_train_batch_size\" in call_args\n        and \"num_generations\" in call_args\n        and \"steps_per_generation\" in call_args\n        and \"generation_batch_size\" in call_args\n    ):\n        # if world size is not set by accelerate or torchrun at this point it will be 1\n        check_num_generations = (\n            \"if steps_per_generation is None and generation_batch_size is None:\\n\"\n            \"    ga = gradient_accumulation_steps\\n\"\n            \"    world_size = int(os.environ.get('WORLD_SIZE', '1'))\\n\"\n            \"    if (ga * world_size * per_device_train_batch_size) % num_generations != 0:\\n\"\n            \"        print('Unsloth: We now expect `per_device_train_batch_size` * `gradient_accumulation_steps` * `world_size` to be a multiple of `num_generations`.\\\\n\"\n            \"We will change the batch size of ' + str(per_device_train_batch_size) + ' to the `num_generations` of ' + str(num_generations))\\n\"\n            \"        per_device_train_batch_size = num_generations\\n\"\n            \"\\n\"\n        )\n        extra_args += check_num_generations\n    elif \"per_device_train_batch_size\" in call_args and \"num_generations\" in call_args:\n        if \"steps_per_generation\" not in call_args:\n            print(f\"Unsloth: Could not find `steps_per_generation` in {trainer_file}\")\n        if \"generation_batch_size\" not in call_args:\n            print(f\"Unsloth: Could not find `generation_batch_size` in {trainer_file}\")\n\n        check_num_generations = (\n            \"if (per_device_train_batch_size // num_generations) * num_generations != per_device_train_batch_size:\\n\"\n            \"    print('Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.\\\\n\"\n            \"We will change the batch size of ' + str(per_device_train_batch_size) + ' to the `num_generations` of ' + str(num_generations))\\n\"\n            \"    per_device_train_batch_size = num_generations\\n\"\n            \"\\n\"\n        )\n        extra_args += check_num_generations\n\n    # Check temperature must not be <= 0. Also stop if >= 10\n    if \"temperature\" in call_args:\n        check_temperature = (\n            \"if temperature <= 0:\\n\"\n            \"    raise ValueError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')\\n\"\n            \"elif temperature >= 10:\\n\"\n            \"    raise ValueError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.')\\n\"\n            \"\\n\"\n        )\n        extra_args += check_temperature\n\n    # Edit config with anything extra\n    if trainer_file in RL_CONFIG_CHANGES:\n        process_extra_args = RL_CONFIG_CHANGES[trainer_file]\n        for process_extra_arg in process_extra_args:\n            extra_args += process_extra_arg(old_RLTrainer_source, old_RLConfig_source)\n\n    # Create RLConfig args\n    extra_args = extra_args.split(\"\\n\")\n    extra_args = \"\\n\".join(\" \" * 8 + x for x in extra_args)\n    RLConfig_arguments = arguments\n    RLConfig_extra_args = extra_args\n    RLConfig_call_args = call_args\n\n    # TRL 0.27.0+ forces use_reentrant=False in gradient_checkpointing_kwargs.\n    # Unsloth gradient checkpointing requires use_reentrant=True, so we remove\n    # the setting after super().__init__() when it gets auto-applied.\n    RLConfig_post = \"\"\n    if trl_version >= Version(\"0.27.0\"):\n        RLConfig_post = (\n            \"        # Unsloth: Remove use_reentrant=False forced by TRL 0.27.0+\\n\"\n            \"        if getattr(self, 'gradient_checkpointing_kwargs', None) is not None:\\n\"\n            \"            if 'use_reentrant' in self.gradient_checkpointing_kwargs:\\n\"\n            \"                del self.gradient_checkpointing_kwargs['use_reentrant']\\n\"\n        )\n\n    # Patch vLLM and other functions\n    RLTrainer_extras = patch_functions(\n        RLTrainer, trainer_file, RLTrainer_name, all_imports, imports\n    )\n    if RLTrainer_extras is None:\n        RLTrainer_extras = f\"_Unsloth{RLTrainer_name} = {RLTrainer_name}\"\n\n    # Create full module\n    exec(f\"from trl.trainer import ({RLTrainer_name}, {RLConfig_name},)\")\n    __RLTrainer_doc__ = eval(f\"trl.trainer.{RLTrainer_name}\").__doc__\n    if __RLTrainer_doc__ is None:\n        __RLTrainer_doc__ = \"\"\n    __RLConfig_doc__ = eval(f\"trl.trainer.{RLConfig_name}\").__doc__\n    if __RLConfig_doc__ is None:\n        __RLConfig_doc__ = \"\"\n\n    # Get all pre-modules\n    if trainer_file in RL_PRE_ITEMS:\n        RL_pre = \"\\n\".join(RL_PRE_ITEMS[trainer_file])\n    else:\n        RL_pre = \"\"\n\n    # Check if SamplingParams is in there\n    if \"SamplingParams\" in old_RLTrainer_source:\n        RL_pre = RL_pre + \"\\n\" + inspect.getsource(vLLMSamplingParams)\n\n    # Selective log softmax and other functions\n    selective_log_softmax_code = inspect.getsource(selective_log_softmax)\n    grpo_selective_log_softmax_code = inspect.getsource(grpo_selective_log_softmax)\n    calculate_pad_tokens_in_prompt_code = inspect.getsource(\n        calculate_pad_tokens_in_prompt\n    )\n    create_completion_attention_mask_code = inspect.getsource(\n        create_completion_attention_mask\n    )\n    left_pack_padding_code = inspect.getsource(left_pack_padding)\n    align_logprobs_with_mask_code = inspect.getsource(align_logprobs_with_mask)\n    autotune_batch_and_chunks_code = inspect.getsource(autotune_batch_and_chunks)\n    sanitize_logprob_code = inspect.getsource(sanitize_logprob)\n    # Get final source code\n    RLTrainer_source = RLTrainer_replacement.format(\n        RLTrainer_name = RLTrainer_name,\n        __RLTrainer_doc__ = __RLTrainer_doc__,\n        RLTrainer_arguments = RLTrainer_arguments,\n        RLTrainer_extra_args = RLTrainer_extra_args,\n        RLTrainer_call_args = RLTrainer_call_args,\n        RLTrainer_kwargs = \",**kwargs\"[1 if RLTrainer_call_args.endswith(\",\") else 0 :],\n        RLConfig_name = RLConfig_name,\n        __RLConfig_doc__ = __RLConfig_doc__,\n        RLConfig_arguments = RLConfig_arguments,\n        RLConfig_extra_args = RLConfig_extra_args,\n        RLConfig_call_args = RLConfig_call_args,\n        RLConfig_kwargs = \",**kwargs\"[1 if RLConfig_call_args.endswith(\",\") else 0 :],\n        RLConfig_post = RLConfig_post,\n        RLTrainer_extras = RLTrainer_extras,\n        RLTrainer_post = RLTrainer_post,\n        RL_pre = RL_pre,\n        max_seq_length_pre = max_seq_length_pre,\n        max_seq_length_call = max_seq_length_call,\n        max_seq_length_post = max_seq_length_post,\n        selective_log_softmax_code = selective_log_softmax_code,\n        grpo_selective_log_softmax_code = grpo_selective_log_softmax_code,\n        calculate_pad_tokens_in_prompt_code = calculate_pad_tokens_in_prompt_code,\n        create_completion_attention_mask_code = create_completion_attention_mask_code,\n        autotune_batch_and_chunks_code = autotune_batch_and_chunks_code,\n        left_pack_padding_code = left_pack_padding_code,\n        align_logprobs_with_mask_code = align_logprobs_with_mask_code,\n        sanitize_logprob_code = sanitize_logprob_code,\n    )\n\n    if RLTrainer_name == \"GRPOTrainer\":\n        # Base torch_compile_options shared by all device types\n        base_options = \"\"\"torch_compile_options = {\n            \"epilogue_fusion\"   : True,\n            \"max_autotune\"      : False,\n            \"shape_padding\"     : True,\n            \"trace.enabled\"     : False,\"\"\"\n\n        # Generate torch_compile_options based on device type\n        if DEVICE_TYPE == \"cuda\":\n            # CUDA-specific options (added to base options)\n            cuda_options = \"\"\"\n            \"triton.enable_persistent_tma_matmul\": torch.cuda.get_device_capability()[0] >= 9,\"\"\"\n            # cutlass options were added in PyTorch 2.8.0\n            if torch_version >= Version(\"2.8.0\"):\n                cuda_options += \"\"\"\n            \"cuda.cutlass_epilogue_fusion_enabled\": torch.cuda.get_device_capability()[0] >= 9,\n            \"cuda.cutlass_tma_only\": torch.cuda.get_device_capability()[0] >= 9,\"\"\"\n            cuda_options += \"\"\"\n            \"cuda.compile_opt_level\"              : \"-O2\",\n            \"cuda.enable_cuda_lto\"                : True,\n        }\"\"\"\n            new_options = base_options + cuda_options\n        else:\n            # XPU, HIP, and other device types use base options only\n            new_options = (\n                base_options\n                + \"\"\"\n        }\"\"\"\n            )\n\n        pattern = r\"torch_compile_options\\s*=\\s*\\{[^}]*\\}\"\n\n        RLTrainer_source = re.sub(\n            pattern, new_options, RLTrainer_source, flags = re.DOTALL\n        )\n\n        if trl_version >= Version(\"0.27.0\"):\n            peft_pattern = (\n                r\"\\s*if is_peft_available\\(\\) and is_peft_model\\(model\\) and args\\.beta != 0\\.0:\"\n                r\".*?\"\n                r\"param\\.data = param\\.data\\.to\\(torch\\.bfloat16\\)\"\n            )\n\n            replacement_comment = \"\\n        # PEFT initialization logic removed via script for trl >= 0.27.0\\n\"\n\n            RLTrainer_source = re.sub(\n                peft_pattern, replacement_comment, RLTrainer_source, flags = re.DOTALL\n            )\n\n        elif trl_version >= Version(\"0.26.0\"):\n            peft_block_pattern = (\n                r\"\\s*if is_peft_available\\(\\) and isinstance\\(model, PeftModel\\) and peft_config is not None:\"\n                r\".*?\"\n                r\"param\\.data = param\\.data\\.to\\(torch\\.bfloat16\\)\"\n            )\n\n            RLTrainer_source = re.sub(\n                peft_block_pattern,\n                \"\\n        # TRL PEFT 0.26.0 initialization logic removed on unsloth side.\\n\",\n                RLTrainer_source,\n                flags = re.DOTALL,\n            )\n\n    # Remove TRL's unconditional bfloat16 cast of trainable params (added in\n    # TRL 0.26.0). TRL hardcodes bfloat16 for QLoRA per the original paper's\n    # recommendation, but this is wrong: it ignores the user's requested dtype\n    # and breaks GradScaler when training with fp16=True. Unsloth already\n    # handles adapter dtype correctly via patch_model_and_tokenizer, so the\n    # entire block is unnecessary. For GRPOTrainer the enclosing peft init\n    # block is already removed above, making this a no-op for GRPO.\n    RLTrainer_source = RLTrainer_source.replace(\n        'if getattr(model, \"is_loaded_in_4bit\", False) or getattr(model, \"is_loaded_in_8bit\", False):',\n        \"if False:\",\n    )\n\n    if RLTrainer_name == \"SFTTrainer\":\n        original_text = 'self._signature_columns = [\"input_ids\", \"attention_mask\", \"completion_mask\"]'\n        new_text = 'self._signature_columns = [\"input_ids\", \"attention_mask\", \"completion_mask\",\"labels\"]'\n        RLTrainer_source = RLTrainer_source.replace(original_text, new_text)\n\n        # Do NOT override _is_vlm -- let TRL detect VLM models naturally.\n        # In TRL 0.27.1+, forcing _is_vlm=False causes a ValueError when\n        # vision datasets are used with VLM models.\n        #\n        # However, some notebooks pass a bare tokenizer (processor.tokenizer) as\n        # processing_class. TRL then sets _is_vlm=False even for VLM models.\n        # Add a model-architecture-based override before the validation check.\n        _vlm_check_original = (\n            '        self._is_vision_dataset = \"image\" in dataset_sample or \"images\" in dataset_sample\\n'\n            \"        if self._is_vision_dataset and not self._is_vlm:\"\n        )\n        _vlm_check_patched = (\n            '        self._is_vision_dataset = \"image\" in dataset_sample or \"images\" in dataset_sample\\n'\n            \"        # Unsloth: override _is_vlm for VLM models that pass a bare tokenizer\\n\"\n            \"        if not self._is_vlm and self._is_vision_dataset:\\n\"\n            \"            _m = model\\n\"\n            '            if hasattr(_m, \"model\"): _m = _m.model\\n'\n            '            if hasattr(getattr(_m, \"config\", None), \"vision_config\") or \\\\\\n'\n            '               _m.__class__.__name__.endswith(\"ForConditionalGeneration\"):\\n'\n            \"                self._is_vlm = True\\n\"\n            \"        if self._is_vision_dataset and not self._is_vlm:\"\n        )\n        if _vlm_check_original in RLTrainer_source:\n            RLTrainer_source = RLTrainer_source.replace(\n                _vlm_check_original, _vlm_check_patched\n            )\n\n        # Fix TRL 0.22.x: VLM models with text-only datasets.\n        # TRL 0.22.x checks _is_vlm (model type) not _is_vision_dataset (dataset\n        # content, added in 0.25.1+). When _is_vlm=True, signature columns are\n        # vision-only [\"messages\",\"prompt\",\"completion\",\"images\"], which have zero\n        # overlap with tokenized text columns. Fix: merge both column sets into the\n        # VLM branch. Extra columns not in the dataset are harmlessly ignored by\n        # _remove_unused_columns (it only raises when zero columns match).\n        _sig_vlm_old = (\n            'self._signature_columns = [\"messages\", \"prompt\", \"completion\", \"images\"]'\n        )\n        _sig_vlm_new = (\n            'self._signature_columns = [\"messages\", \"prompt\", \"completion\", \"images\",'\n            ' \"input_ids\", \"labels\", \"attention_mask\", \"seq_lengths\", \"completion_mask\", \"assistant_masks\"]'\n        )\n        RLTrainer_source = RLTrainer_source.replace(_sig_vlm_old, _sig_vlm_new)\n\n        # Inject model reference before _prepare_dataset for dynamic\n        # token_type_ids detection in sft_prepare_dataset\n        _prep_pattern = r\"([ \\t]*)train_dataset = self\\._prepare_dataset\\(\"\n        _prep_replacement = r\"\\1self._unsloth_model_ref = model\\n\\1train_dataset = self._prepare_dataset(\"\n        RLTrainer_source = re.sub(\n            _prep_pattern, _prep_replacement, RLTrainer_source, count = 1\n        )\n\n    # Silence TRL's noisy batch_size=1 + padding-free warning (handles both\n    # the original \"anihilate\" typo and the corrected \"annihilate\" spelling)\n    for _typo in (\"anihilate\", \"annihilate\"):\n        _idx = RLTrainer_source.find(_typo)\n        if _idx == -1:\n            continue\n        # Walk backwards to find \"if args.per_device_train_batch_size\"\n        _block_start = RLTrainer_source.rfind(\n            \"if args.per_device_train_batch_size == 1\", 0, _idx\n        )\n        if _block_start == -1:\n            continue\n        # Walk backwards to the newline before the if\n        _line_start = RLTrainer_source.rfind(\"\\n\", 0, _block_start)\n        # Walk forwards past the closing paren to the end of the block\n        _close = RLTrainer_source.find(\")\", _idx)\n        if _close == -1:\n            continue\n        _block_end = RLTrainer_source.find(\"\\n\", _close)\n        if _block_end == -1:\n            continue\n        RLTrainer_source = (\n            RLTrainer_source[:_line_start] + RLTrainer_source[_block_end:]\n        )\n        break\n\n    # Remove multiple doc strings\n    if __RLConfig_doc__ != \"\" and RLTrainer_source.count(__RLTrainer_doc__) == 2:\n        RLTrainer_source = RLTrainer_source.replace(__RLTrainer_doc__, \"\", 1)\n\n    # Remove multiple newlines\n    RLTrainer_source = re.sub(r\"[\\n]{3,}\", \"\\n\", RLTrainer_source)\n\n    # Create new function\n    _resolved_module = _trainer_resolved_module or _config_resolved_module\n    _model_location = (\n        _resolved_module.__name__\n        if _resolved_module is not None\n        else f\"trl.trainer.{trainer_file}\"\n    )\n    created_module = create_new_function(\n        f\"Unsloth{RLTrainer_name}\",\n        RLTrainer_source,\n        _model_location,\n        imports,\n        overwrite = False,\n    )\n\n    # Patch Trainer\n    exec(\n        f\"trl.{RLTrainer_name} = created_module.Unsloth{RLTrainer_name}\",\n        locals(),\n        globals(),\n    )\n    exec(\n        f\"trl.trainer.{RLTrainer_name} = created_module.Unsloth{RLTrainer_name}\",\n        locals(),\n        globals(),\n    )\n    exec(\n        f\"trl.trainer.{trainer_file}.{RLTrainer_name} = created_module.Unsloth{RLTrainer_name}\",\n        locals(),\n        globals(),\n    )\n\n    # Patch Config\n    exec(\n        f\"trl.{RLConfig_name} = created_module.Unsloth{RLConfig_name}\",\n        locals(),\n        globals(),\n    )\n    exec(\n        f\"trl.trainer.{RLConfig_name} = created_module.Unsloth{RLConfig_name}\",\n        locals(),\n        globals(),\n    )\n    exec(\n        f\"trl.trainer.{trainer_file}.{RLConfig_name} = created_module.Unsloth{RLConfig_name}\",\n        locals(),\n        globals(),\n    )\n\n    if trainer_file == \"grpo_trainer\":\n        try:\n            _wrap_grpo_generate_and_score(\n                getattr(created_module, f\"Unsloth{RLTrainer_name}\")\n            )\n        except Exception as e:\n            logger.info(\n                f\"Unsloth: Could not wrap _generate_and_score_completions for {RLTrainer_name}: {e}\"\n            )\n\n\ndef patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, imports):\n    init = inspect.getsource(RLTrainer.__init__)\n    old_init = init\n\n    # Remove brackets in comments since it interferes ie (...)\n    comments = re.findall(r\"\\#[^\\n]{1,}\\n\", init)\n    bracketed_comments = [x for x in comments if \"(\" in x or \")\" in x]\n    # Replace with [...] instead\n    for bracketed_comment in bracketed_comments:\n        init = init.replace(\n            bracketed_comment,\n            bracketed_comment.replace(\"(\", \"[\").replace(\")\", \"]\"),\n        )\n\n    # Remove peft_config\n    init = init.replace(\"elif peft_config is None:\", \"elif False:\")\n    init = init.replace(\"elif peft_config is not None:\", \"elif False:\")\n    init = init.replace(\"if peft_config is None:\", \"if False:\")\n    init = init.replace(\"if peft_config is not None:\", \"if False:\")\n    init = init.replace(\"get_peft_model(model, peft_config)\", \"model\")\n    # New TRL 0.20.0\n    init = init.replace(\n        \"if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)):\",\n        \"if False:\",\n    )\n    # New TRL 0.20.0\n    init = init.replace(\n        \"model = self._prepare_peft_model(model, peft_config, args)\\n\", \"pass\\n\"\n    )\n    # TRL 0.22.0+ uses prepare_peft_model as a standalone function\n    init = init.replace(\"model = prepare_peft_model(model, peft_config, args)\", \"pass\")\n\n    # Skip add_adapter(\"ref\") for reference model computation\n    # Unsloth: We comment out the \"ref\" adapter creation because:\n    # 1. We want to use the original BASE MODEL as the reference model, not the SFT/LoRA model\n    # 2. PEFT doesn't allow multiple adapters when target_parameters is used (MoE models)\n    # When \"ref\" is not in peft_config, GRPO/RLOO fallback uses disable_adapter()\n    # which gives the base model logits - exactly what we want\n    add_adapter_block_pattern = (\n        r\"([ \\t]*)\"  # Capture leading indentation\n        r\"if\\s+is_peft_available\\(\\)\\s+and\\s+is_peft_model\\(model\\)\\s+and\\s+args\\.beta\\s*!=\\s*0\\.0\\s*:\"\n        r\"(.*?)\"  # Match the entire block until ref_param.data.copy_\n        r\"ref_param\\.data\\.copy_\\(param\\.data\\)\"\n    )\n\n    def comment_out_block(match):\n        \"\"\"Comment out each line in the matched block, preserving indentation.\"\"\"\n        full_match = match.group(0)\n        indent = match.group(1)\n        lines = full_match.split(\"\\n\")\n        commented_lines = []\n        # Add explanation comment first\n        commented_lines.append(\n            f\"{indent}# Unsloth: Commented out - use base model as reference, not SFT/LoRA model\"\n        )\n        # Comment out each line - insert # after leading whitespace to preserve indentation\n        for line in lines:\n            if line.strip():\n                stripped = line.lstrip()\n                leading_ws = line[: len(line) - len(stripped)]\n                commented_lines.append(f\"{leading_ws}# {stripped}\")\n            else:\n                commented_lines.append(line)\n        return \"\\n\".join(commented_lines)\n\n    init = re.sub(add_adapter_block_pattern, comment_out_block, init, flags = re.DOTALL)\n\n    # Set use_vllm if not set\n    if \"args.use_vllm\" in init and \"model\" in init and \"args\" in init:\n        # .*? matches first match. .+? matches final match.\n        replacer = re.findall(\n            r\"def __init__\\(.*?\\).*?\\:\\n\",\n            init,\n            flags = re.MULTILINE | re.DOTALL,\n        )\n        if len(replacer) != 0:\n            replacer = replacer[0]\n            vllm_setter = (\n                \"\\n\"\n                + \" \" * 8\n                + \"if hasattr(model, 'vllm_engine') and hasattr(args, 'use_vllm'):\\n\"\n                + \" \" * 12\n                + \"if (getattr(args, 'use_vllm', False) == False):\\n\"\n                + \" \" * 16\n                + \"args.use_vllm = True\\n\"\n            )\n            # \" \" * 16 + \"args.vllm_importance_sampling_correction = True\\n\" + \\\n            # \" \" * 16 + \"args.vllm_importance_sampling_cap = 2.0\\n\"\n\n            if \"grpo\" in trainer_file and trl_version >= Version(\"0.18.0\"):\n                # If model has vllm_engine, then use vllm in colocate mode. Donot wait for server\n                vllm_setter += \" \" * 12 + \"args.vllm_mode='colocate'\\n\"\n                if trl_version >= Version(\"0.23.0\"):\n                    # We need to set this flag for sleep mode auto working with trl update\n                    vllm_setter += (\n                        \" \" * 12\n                        + \"if os.environ.get('UNSLOTH_VLLM_STANDBY', '0') == '1':\\n\"\n                        + \" \" * 16\n                        + \"args.vllm_enable_sleep_mode=True\\n\"\n                    )\n\n            init = init.replace(replacer, replacer + vllm_setter)\n\n    # breakpoint()\n\n    vllm_part = re.findall(\n        r\"(\\n[\\s]{8}\" r\"if (self|args)\\.use_vllm\\:.*?\" r\"\\n[\\s]{8}\" \"else:\\n)\",\n        init,\n        flags = re.MULTILINE | re.DOTALL,\n    )\n\n    if len(vllm_part) == 1:\n        vllm_part, args = vllm_part[0][0], vllm_part[0][1]\n        # Strip all comments\n        new_vllm_part = re.sub(\n            r\"^\\s*\\#[^\\n]*\\n?\", \"\", vllm_part, flags = re.MULTILINE\n        )  # to also remove whole comment line instead of just starting at #\n        new_vllm_part = re.sub(\n            r\"\\s*\\#.*$\", \"\", new_vllm_part, flags = re.MULTILINE\n        )  # remove comments that occur after code\n\n        # Get SamplingParams\n        sampling_params = re.findall(\n            r\"\\n[\\s]{4,}(self\\.[^\\s]{1,}[\\s]{0,}\\=[\\s]{0,}\" r\"SamplingParams\\(.+?\\))\",\n            new_vllm_part,\n            flags = re.MULTILINE | re.DOTALL,\n        )\n\n        if len(sampling_params) == 1:\n            sampling_params = sampling_params[0]\n            # Fix guided_decoding\n            sampling_params = sampling_params.replace(\n                \"guided_decoding=guided_decoding,\",\n                \"guided_decoding=\"\n                'GuidedDecodingParams(backend=\"outlines\", regex=args.vllm_guided_decoding_regex) '\n                'if getattr(args, \"vllm_guided_decoding_regex\", None) is not None else None,',\n            )\n            # Replace with our vLLM engine\n            sampling_params = (\n                \" \" * 12\n                + \"self.llm = model.vllm_engine; self._last_loaded_step = 0; \"\n                + sampling_params\n            )  # Add spaces\n\n            # count the indentation of last line of sampling_params.\n            splitted_sampling_params = sampling_params.split(\"\\n\")\n            if len(splitted_sampling_params) >= 2:\n                last_line = splitted_sampling_params[-1]\n                last_prev_line = splitted_sampling_params[-2]\n                last_prev_indentation = len(last_prev_line) - len(\n                    last_prev_line.lstrip()\n                )\n                last_indentation = len(last_line) - len(last_line.lstrip())\n\n                # Add extra arguments to SamplingParams\n                extra = \"**getattr(getattr(args, 'vllm_sampling_params', vLLMSamplingParams()), '_set_kwargs', {})\"\n                # Backwards replace\n                to_replace = (\n                    \",\\n\"\n                    + \" \" * last_prev_indentation\n                    + extra\n                    + \",\\n\"\n                    + \" \" * last_indentation\n                    + \")\"\n                )\n                sampling_params = to_replace.join(sampling_params.rsplit(\")\", 1))\n                # Strip multiple commas\n                sampling_params = re.sub(r\"[\\,][\\s]{0,}\\,\", \",\", sampling_params)\n\n                new_vllm_part = (\n                    f\"\\n{' '*8}if {args}.use_vllm:\\n{sampling_params}\"\n                    f\"\\n{' '*8}else:\\n\"\n                )\n\n        if trl_version >= Version(\"0.18.0\"):\n            # Replace LLM init with already existing vLLM engine for colocate mode\n            vllm_llm_init_pattern = r\"self\\.llm\\s*=\\s*LLM\\(.*?\\)*\\)\\s*?\\n(?!,)\"\n            vllm_llm_replacement = \"self.llm = model.vllm_engine\\n\"\n            new_vllm_part = re.sub(\n                vllm_llm_init_pattern,\n                vllm_llm_replacement,\n                new_vllm_part,\n                flags = re.DOTALL,  # Ensure . matches newlines [[5]]\n            )\n\n        init = init.replace(vllm_part, new_vllm_part)\n\n    # Search for vLLM calling in all child functions\n    functions = dir(RLTrainer)\n    RLTrainer_source = inspect.getsource(RLTrainer)\n    functions = [x for x in functions if f\"def {x}\" in RLTrainer_source]\n\n    changed = {\n        \"__init__\": (\n            old_init,\n            init,\n        )\n    }\n    edit_functions = RL_FUNCTIONS.get(trainer_file, [])\n\n    for function in functions:\n        if not hasattr(RLTrainer, function):\n            continue\n        if function in changed:\n            original_source, source = changed[function]\n        else:\n            fx = getattr(RLTrainer, function)\n            try:\n                source = inspect.getsource(fx)\n            except:\n                continue\n            original_source = source\n\n        # Check for function\n        for edit_function in edit_functions:\n            source = edit_function(function, source)\n\n        \"\"\"\n        import torch\n        X = torch.ones((2, 2048, 201088), dtype = torch.bfloat16, device = \"cuda\")\n        X[torch.randperm(2, dtype = torch.int64, device = X.device)]\n\n        will error out in torch 2.8 AcceleratorError: CUDA error: invalid configuration argument\n        \"\"\"\n        source = re.sub(\n            r\"(\\n[\\s]{4,})generation_batch = shuffle_sequence_dict\\(generation_batch\\)\\n\",\n            r\"\\n\\1try: generation_batch = shuffle_sequence_dict(generation_batch)\\n\\1except: pass\\n\",\n            source,\n        )\n\n        # llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model\n        source = re.sub(\n            r\"(\\n[\\s]{4,}).+?model_executor\\.driver_worker.+?\\n\",\n            r\"\\n\\1pass\\n\",\n            source,\n        )\n\n        # llm_model.load_weights(model.state_dict().items())\n        source = re.sub(\n            r\"(\\n[\\s]{4,}).+?load_weights\\(.+?\\n\",\n            r\"\\n\\1pass\\n\",\n            source,\n        )\n\n        # .state_dict()\n        source = re.sub(\n            r\"\\.state_dict\\(\\)\",\n            r\"\",\n            source,\n        )\n\n        # Replace self.llm.generate and self.llm.chat\n        if \"CUDA_VISIBLE_DEVICES\" in os.environ:\n            lora_name = (\n                trainer_file\n                + \"_lora_model_' + \"\n                + \"(os.environ.get('CUDA_VISIBLE_DEVICES', '0').replace(',',''))\"\n            )\n        else:\n            lora_name = trainer_file + \"_lora_model'\"\n        source = re.sub(\n            r\"(self\\.llm\\.(?:generate|chat)\\([^\\)]{1,})\\)\",\n            r\"\\1, lora_request = self.model.load_lora('\"\n            + lora_name\n            + r\", load_tensors = True))\",\n            source,\n        )\n        # All these are to fix multiple commas before lora_request (in case the original code ends with something like \",)\")\n        # https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py#L1388 for eg has such an ending\n        source = re.sub(r\"\\,[\\s]{1,}\\,[\\s]{0,}lora_request\", \", lora_request\", source)\n        source = re.sub(r\"[\\s]{1,}\\,[\\s]{0,}lora_request\", \", lora_request\", source)\n        source = re.sub(r\"[\\,]{1,}[\\s]{0,}lora_request\", \", lora_request\", source)\n        # Prefer using unsloth's sampling params and fallback to trl's if not found\n        # We'll enable this later separately when combining both this and GRPOConfig params\n        # source = re.sub(\n        #     r\"sampling_params\\s*=\\s*sampling_params\",\n        #     r\"sampling_params = getattr(self.args, 'vllm_sampling_params', sampling_params)\",\n        #     source\n        # )\n        # Fix later versions of SamplingParams via grpo_update_SamplingParams\n        source = source.replace(\n            \"sampling_params = SamplingParams(**generation_kwargs)\",\n            \"sampling_params = SamplingParams(\"\n            \"**grpo_update_SamplingParams(\"\n            \"SamplingParams, generation_kwargs, \"\n            \"getattr(self.args, 'vllm_sampling_params', None)\"\n            \")\"\n            \")\",\n        )\n\n        # Skip if no changes done\n        if source == original_source:\n            continue\n\n        # Find all imports\n        imports += [x for x in all_imports if not x.startswith(\"_\") and x in source]\n\n        changed[function] = (\n            original_source,\n            source,\n        )\n\n    # Import all functions\n    imports = list(set(imports))\n\n    # Patch all functions\n    for function in changed:\n        old, new = changed[function]\n        RLTrainer_source = RLTrainer_source.replace(old, new)\n\n    RLTrainer_source = RLTrainer_source.replace(\n        f\"class {RLTrainer_name}\", f\"class _Unsloth{RLTrainer_name}\", 1\n    )\n    return RLTrainer_source\n\n\ndef patch_trl_rl_trainers():\n    # Patch all TRL modules if they have vLLM or PEFT\n    import trl.trainer\n\n    all_trainers = dir(trl.trainer)\n    all_trainers = [\n        x\n        for x in all_trainers\n        if x.islower() and x.endswith(\"_trainer\") and x != \"base_trainer\"\n    ]\n    for trainer in all_trainers:\n        try:\n            _patch_trl_rl_trainers(trainer)\n        except Exception as e:\n            logger.warning_once(f\"Unsloth: Could not patch trl.trainer.{trainer}: {e}\")\n    return\n\n\ndef patch_trl_openenv():\n    for function in RL_ADDITIONAL_FUNCTIONS[\"openenv\"]:\n        logger.info(f\"Unsloth: Patching trl openenv with function: {function.__name__}\")\n        function()  # Call the function to apply the patch\n    return\n\n\ndef patch_trl_vllm_generation():\n    # trl moved vllm stuff to trl/generation/vllm_generation.py\n    # We need to min_p patch it to not instantiate another vLLM instance if we already have one with fast_inference\n    # Find the instance of self.llm = LLM(..) (multiline) and wrap it around an if clause\n    for function in RL_ADDITIONAL_FUNCTIONS[\"vllm_generation\"]:\n        logger.info(\n            f\"Unsloth: Patching trl VLLMGeneration with function: {function.__name__}\"\n        )\n        function()\n    return\n\n\ndef patch_trl_vllm_generation():\n    # trl moved vllm stuff to trl/generation/vllm_generation.py\n    # We need to min_p patch it to not instantiate another vLLM instance if we already have one with fast_inference\n    # Find the instance of self.llm = LLM(..) (multiline) and wrap it around an if clause\n    for function in RL_ADDITIONAL_FUNCTIONS[\"vllm_generation\"]:\n        logger.info(\n            f\"Unsloth: Patching trl VLLMGeneration with function: {function.__name__}\"\n        )\n        function()\n    return\n\n\ndef PatchFastRL(algorithm = None, FastLanguageModel = None):\n    if FastLanguageModel is not None:\n        PatchRL(FastLanguageModel)\n    patch_trl_rl_trainers()\n    patch_trl_openenv()\n    patch_trl_vllm_generation()\n    if type(algorithm) is str and algorithm.islower():\n        PatchRLStatistics(algorithm)\n"
  },
  {
    "path": "unsloth/models/rl_replacements.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n__all__ = [\n    \"RL_EXTRA_ARGS\",\n    \"RL_FUNCTIONS\",\n    \"RL_PRE_ITEMS\",\n    \"RL_CONFIG_CHANGES\",\n    \"RL_METRICS_CHANGES\",\n]\n\nimport os\nimport re\nimport torch\nimport inspect\nimport linecache\nfrom collections import defaultdict\nfrom unsloth_zoo.rl_replacements import (\n    RL_REPLACEMENTS,\n    left_pack_padding,\n    chunked_selective_log_softmax,\n)\nfrom unsloth_zoo.utils import Version\nfrom trl import __version__ as trl_version_raw\nfrom importlib.metadata import version as importlib_version\nfrom unsloth_zoo.log import logger\nfrom unsloth_zoo.device_type import device_synchronize\nimport importlib.util\nfrom ..device_type import (\n    is_hip,\n    get_device_type,\n    DEVICE_TYPE,\n    DEVICE_TYPE_TORCH,\n    DEVICE_COUNT,\n    ALLOW_PREQUANTIZED_MODELS,\n)\nimport textwrap\nfrom ._utils import _get_inference_mode_context_manager\n\nRL_EXTRA_ARGS = defaultdict(list)\nRL_FUNCTIONS = defaultdict(list)\nRL_PRE_ITEMS = defaultdict(list)\nRL_CONFIG_CHANGES = defaultdict(list)\nRL_METRICS_CHANGES = defaultdict(list)\nRL_ADDITIONAL_FUNCTIONS = defaultdict(list)\n\ntorch_compile_options = {\n    \"epilogue_fusion\": True,\n    \"max_autotune\": False,  # I saw speedups, but not sure if this has issues in collab\n    \"shape_padding\": True,\n    \"trace.enabled\": False,\n    \"triton.cudagraphs\": False,\n}\n\ntry:\n    trl_version = Version(trl_version_raw)\nexcept Exception:\n    try:\n        trl_version = Version(importlib_version(\"trl\"))\n    except Exception:\n        trl_version = Version(\"0.0.0\")\n\n\n# Check untrained tokens\ndef sft_trainer_fix_untrained_tokens(call_args, extra_args):\n    if \"model\" in call_args and \"train_dataset\" in call_args:\n        fix_tokenizer = (\n            \"IGNORED_TOKENIZER_NAMES = os.environ.get('UNSLOTH_IGNORED_TOKENIZER_NAMES', '').split('\\\\n')\\n\"\n            \"from unsloth_zoo.tokenizer_utils import fix_untrained_tokens\\n\"\n            \"from unsloth_zoo.training_utils  import fix_zero_training_loss\\n\"\n            \"if 'tokenizer' not in locals(): tokenizer = processing_class\\n\"\n            \"fix_untrained_tokens(model, tokenizer, train_dataset, IGNORED_TOKENIZER_NAMES, eps = 1e-16)\\n\"\n            \"fix_zero_training_loss(model, tokenizer, train_dataset)\\n\"\n        )\n        return fix_tokenizer\n    return \"\"\n\n\nRL_EXTRA_ARGS[\"sft_trainer\"].append(sft_trainer_fix_untrained_tokens)\n\n\n# Fix top_k for GRPO vLLM.\n# https://github.com/huggingface/trl/pull/4695 with this change trl added top_k in GRPOConfig and defaults to 0\n# We don't want that since vllm's all include top_k is -1 and 0 returns an error on SamplingParams creation.\ndef grpo_config_fix_vllm_top_k(old_RLTrainer_source, old_RLConfig_source):\n    return \"if use_vllm and (top_k is None or top_k == 0): top_k = -1\\n\"\n\n\nRL_CONFIG_CHANGES[\"grpo_trainer\"].append(grpo_config_fix_vllm_top_k)\n\n\n# Remove DPO columns which might randomnly be tokenized\ndef dpo_trainer_fix_columns(call_args, extra_args):\n    if \"model\" in call_args and \"train_dataset\" in call_args:\n        fix_dpo = (\n            \"if hasattr(train_dataset, 'column_names'):\\n\"\n            \"    column_names = set(train_dataset.column_names)\\n\"\n            \"    check = ['chosen', 'rejected', 'prompt', 'chosen_input_ids', 'chosen_attention_mask',\\n\"\n            \"             'chosen_labels', 'rejected_input_ids', 'rejected_attention_mask', 'rejected_labels',\\n\"\n            \"             'prompt_input_ids', 'prompt_attention_mask']\\n\"\n            \"    if all(x in column_names for x in check):\\n\"\n            \"        train_dataset = train_dataset.remove_columns(['chosen', 'rejected', 'prompt'])\\n\"\n            \"    del check, column_names\\n\"\n        )\n        return fix_dpo\n    return \"\"\n\n\nRL_EXTRA_ARGS[\"dpo_trainer\"].append(dpo_trainer_fix_columns)\n\n\n# Fix tokenizer double BOS\ndef sft_trainer_prepare_dataset(function_name, function):\n    if (\n        function_name != \"_prepare_non_packed_dataloader\"\n        and function_name != \"_prepare_dataset\"\n    ):\n        return function\n\n    fast_sft_prepare_dataset = RL_REPLACEMENTS.get(\"sft_prepare_dataset\", None)\n    if fast_sft_prepare_dataset is not None:\n        params = inspect.signature(fast_sft_prepare_dataset).parameters.keys()\n        params = \".*?\".join(params)\n        matched = re.match(\n            r\"[\\s]{0,}def _prepare_dataset\\(.*?\" + params + r\".*?\\)\",\n            function,\n            flags = re.MULTILINE | re.DOTALL,\n        )\n        if matched:\n            # Use fast version!\n            function = inspect.getsource(fast_sft_prepare_dataset)\n            function = function.split(\"\\n\")\n            function = \"\\n\".join(\" \" * 4 + x for x in function)\n            function = function.replace(\n                \"def sft_prepare_dataset\", \"def _prepare_dataset\"\n            )\n            return function\n\n    check_text = (\n        \"if 'skip_prepare_dataset' in locals() and skip_prepare_dataset:\\n\"\n        \"    return dataset\\n\"\n        \"if 'tokenizer'          not in locals(): tokenizer = processing_class\\n\"\n        \"if 'formatting_func'    not in locals(): raise RuntimeError('Unsloth: Please file a bug report - `formatting_func` does not exist!')\\n\"\n        \"if 'dataset_text_field' not in locals() and 'args' in locals(): dataset_text_field = args.dataset_text_field\\n\"\n        \"if 'dataset_text_field' not in locals(): raise RuntimeError('Unsloth: Please file a bug report - `dataset_text_field` does not exist!')\\n\"\n        \"test_text = dataset[0][dataset_text_field] if (formatting_func is None and dataset_text_field is not None) else formatting_func(dataset[0])[0]\\n\"\n        \"chat_template = getattr(tokenizer, 'chat_template', None)\\n\"\n        \"chat_template = '' if chat_template is None else chat_template\\n\"\n        \"has_bos_token_already = (test_text.startswith(tokenizer.bos_token) or tokenizer.bos_token in chat_template) \"\n        \"if getattr(tokenizer, 'bos_token', None) is not None else False\\n\"\n        \"if 'add_special_tokens' not in locals() and has_bos_token_already:\\n\"\n        \"    from functools import partial\\n\"\n        \"    tokenizer_call = tokenizer.__call__\\n\"\n        \"    tokenizer.__call__ = partial(tokenizer_call, add_special_tokens = False)\\n\"\n        \"    processing_class = tokenizer\\n\"\n        \"else:\\n\"\n        \"    tokenizer_call = None\\n\"\n        \"    add_special_tokens = False if has_bos_token_already else locals().get('add_special_tokens', False)\\n\"\n    )\n\n    check_text = check_text.split(\"\\n\")\n    check_text = \"\\n\".join(\" \" * 8 + x for x in check_text)\n    check_text = check_text.rstrip() + \"\\n\"\n\n    # .*? matches first match. .+? matches final match.\n    replacer = re.findall(\n        r\"def \" + function_name + r\"\\(.*?\\).*?\\:\\n\",\n        function,\n        flags = re.MULTILINE | re.DOTALL,\n    )\n    if len(replacer) != 0:\n        replacer = replacer[0]\n        function = function.replace(replacer, replacer + check_text)\n\n    # Return tokenizer's original state\n    return_state = (\n        \"if tokenizer_call is not None: tokenizer.__call__ = tokenizer_call\\n\"\n    )\n    function = re.sub(\n        r\"\\n([ ]{4,})(return .*?[\\s]{0,})$\",\n        rf\"\\1{return_state}\\1\\2\",\n        function,\n    )\n    return function\n\n\nRL_FUNCTIONS[\"sft_trainer\"].append(sft_trainer_prepare_dataset)\n\n\n# Ignore mean_token_accuracy since it needs logits\n# We override it directly with our version\ndef sft_trainer_compute_loss(function_name, function):\n    if function_name != \"compute_loss\":\n        return function\n\n    def compute_loss(\n        self, model, inputs, return_outputs = False, num_items_in_batch = None\n    ):\n        outputs = super().compute_loss(\n            model,\n            inputs,\n            return_outputs = return_outputs,\n            num_items_in_batch = num_items_in_batch,\n        )\n        return outputs\n\n    function = inspect.getsource(compute_loss)\n    return function\n\n\nRL_FUNCTIONS[\"sft_trainer\"].append(sft_trainer_compute_loss)\n\n\n# Fix bare pop(\"push_to_hub_token\") in compiled SFT/IterativeSFT trainer __init__\n# On transformers 5.0+, to_dict() no longer includes push_to_hub_token, so bare pop KeyErrors\ndef sft_trainer_push_to_hub_token(function_name, function):\n    if function_name != \"__init__\":\n        return function\n    return function.replace(\n        'dict_args.pop(\"push_to_hub_token\")', 'dict_args.pop(\"push_to_hub_token\", None)'\n    )\n\n\nRL_FUNCTIONS[\"sft_trainer\"].append(sft_trainer_push_to_hub_token)\n\n\n# Autocast precision for GRPO\ndef grpo_trainer__prepare_inputs(function_name, function):\n    if function_name != \"_prepare_inputs\":\n        return function\n\n    # Add mixed precision training\n    function = function.replace(\n        \"with torch.inference_mode():\",\n        \"with torch.inference_mode(), \"\n        \"torch.amp.autocast(device_type = 'cuda', \"\n        \"dtype = ((torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) \"\n        \"if not torch.is_autocast_enabled('cuda') else nullcontext())\"\n        \"if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '0' else torch.float16):\",\n    )\n    function = function.replace(\n        \"self.accelerator.unwrap_model(self.model)\",\n        \"self.accelerator.unwrap_model(self.model, keep_fp32_wrapper = False)\",\n    )\n    return function\n\n\nRL_FUNCTIONS[\"grpo_trainer\"].append(grpo_trainer__prepare_inputs)\n\n\n# Remove collective RPC of reload weights from generate\n# trl added reload weights (potentially for quantized models), we don't need it for our use case (LoRA primarily)\n# https://github.com/huggingface/trl/commit/7856d3b1f6518601732f489883b341bb6dd36434#diff-964e6fd373aa93037604064cb2b822d7f8e2735e33f791065acf2c4c3552d393R1168-R1169\ndef grpo_trainer__generate_single_turn(function_name, function):\n    if function_name != \"_generate_single_turn\":\n        return function\n\n    # Remove the reload_weights collective RPC call from the generate function's source\n    # function = function.replace('self.llm.collective_rpc(\"reload_weights\")', \"\")\n    # The regex below does the same thing but is more flexible and can handle single or double quotes\n    # This is for older versions.\n    function = re.sub(\n        r\"self\\.llm\\.collective_rpc\\(\\s*(['\\\"])reload_weights\\1\\s*\\)\",\n        \"\",\n        function,\n    )\n\n    # Current TRL versions call vllm_generation.sync_weights() every step.\n    # When Unsloth fast inference LoRA is active, weights are already shared.\n    sync_weights_block = re.compile(\n        r\"(?P<indent>[ \\t]*)with profiling_context\\(self,\\s*(['\\\"])sync_weights\\2\\s*\\):\\n\"\n        r\"(?P=indent)[ \\t]+self\\.vllm_generation\\.sync_weights\\(\\)\\n\",\n        re.MULTILINE,\n    )\n\n    def remove_sync_weights_block(match):\n        indent = match.group(\"indent\")\n        return (\n            f\"{indent}# Unsloth fast inference LoRA shares weights with vLLM already.\\n\"\n            f\"{indent}# Skipping per-step vLLM sync_weights().\\n\"\n        )\n\n    function = sync_weights_block.sub(remove_sync_weights_block, function)\n\n    # TRL 0.24.0-0.25.1 truncation regression fix\n    #\n    # TRL 0.22.2-0.23.1 used smart truncation via truncate_with_protected_tokens():\n    #   - Tokenizes first without truncation\n    #   - Then truncates keeping the RIGHTMOST tokens (preserves assistant turn)\n    #   - Protects special tokens (image_token, vision_start/end) from removal\n    #\n    # TRL 0.24.0-0.25.1 removed this and passed kwargs directly to the tokenizer:\n    #   max_length=self.max_prompt_length, truncation=True, add_special_tokens=False\n    # This causes issues because tokenizer truncation doesn't protect special tokens\n    # and may not preserve the end of the prompt properly.\n    #\n    # TRL 0.26.2+ removed these kwargs entirely (no tokenizer-level truncation).\n    #\n    # Fix: Remove these kwargs so TRL 0.24.0-0.25.1 behaves like 0.26.2+ (no truncation).\n    # This is a no-op for versions that don't have these kwargs (0.22.2-0.23.1, 0.26.2+).\n    for pattern in [\n        r'[\"\\']?max_length[\"\\']?\\s*[:=]\\s*self\\.max_prompt_length\\s*,\\s*\\n?',\n        r'[\"\\']?truncation[\"\\']?\\s*[:=]\\s*True\\s*,\\s*\\n?',\n        r'[\"\\']?add_special_tokens[\"\\']?\\s*[:=]\\s*False\\s*,\\s*\\n?',\n    ]:\n        function = re.sub(pattern, \"\", function)\n\n    return function\n\n\nRL_FUNCTIONS[\"grpo_trainer\"].append(grpo_trainer__generate_single_turn)\n\n\n# Fix incorrect special tokens handling and truncation in older TRL versions\ndef grpo_trainer__generate_and_score_completions(function_name, function):\n    if function_name != \"_generate_and_score_completions\":\n        return function\n\n    # TRL 0.19.0 did skip_special_tokens = True which should be False\n    function = function.replace(\n        \"prompt_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False\",\n        \"prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False\",\n    )\n\n    # Left pad prompt before calculation old and ref hidden states\n    line_to_replace = 'batch_size = self.args.per_device_train_batch_size if mode == \"train\" else self.args.per_device_eval_batch_size'\n\n    # The new multi-line string that will replace the line above\n    replacement_lines = \"\"\"\n        max_left_pad = None\n        batch_size = self.args.per_device_train_batch_size if mode == \"train\" else self.args.per_device_eval_batch_size\n        try:\n            # TRL 0.23.1 and below path\n            if not has_images:\n                # Left pad prompt before calculation old and ref hidden states\n                left_pad_tokens_per_prompt = calculate_pad_tokens_in_prompt(prompt_completion_ids, logits_to_keep, self.processing_class.pad_token_id)\n                max_left_pad = torch.max(left_pad_tokens_per_prompt).item()\n        except:\n            # TRL 0.24.0 and below path\n            if images is None:\n                # Left pad prompt before calculation old and ref hidden states\n                left_pad_tokens_per_prompt = calculate_pad_tokens_in_prompt(prompt_completion_ids, logits_to_keep, self.processing_class.pad_token_id)\n                max_left_pad = torch.max(left_pad_tokens_per_prompt).item()\n        self.model.for_training()\"\"\"\n\n    function = function.replace(line_to_replace, replacement_lines)\n\n    pattern_to_find = re.compile(\n        r\"^\\s*if self\\.args\\.gradient_accumulation_steps % generate_every != 0 or \\(\\s*\"\n        r\"self\\.use_vllm and self\\.vllm_importance_sampling_correction\\s*\"\n        r\"\\):\",\n        re.MULTILINE,\n    )\n\n    replacement_text = \"\"\"\n            if self.args.gradient_accumulation_steps % generate_every != 0 or (\n                self.use_vllm\n            ):\"\"\"\n    # Use re.sub() to perform the replacement\n    function, num_replacements = pattern_to_find.subn(replacement_text, function)\n\n    pattern_to_find = re.compile(\n        r\"(^\\s*)all_logprobs = \\[\"  # Capture indentation (group 1)\n        r\".*?\"  # Match everything inside non-greedily\n        r\"for output in outputs\\.outputs\\s*\"\n        r\"\\]\",\n        re.DOTALL | re.MULTILINE,\n    )\n\n    # sanitize_logprob is injected as a module-level function via RLTrainer_replacement\n    # template in rl.py (from RL_REPLACEMENTS), so just reference it directly here.\n    replacement_text = (\n        r\"\\1all_logprobs = [\\n\"\n        r\"\\1    [sanitize_logprob(next(iter(logprob.values()))) for logprob in output.logprobs]\\n\"\n        r\"\\1    for outputs in all_outputs\\n\"\n        r\"\\1    for output in outputs.outputs\\n\"\n        r\"\\1]\"\n    )\n\n    function, num_replacements = pattern_to_find.subn(replacement_text, function)\n\n    # Always between max_prompt_length and use_vllm\n    found = re.findall(\n        r\"\\n(([ ]{8,})if self\\.max_prompt_length is not None:.*?\"\n        r\"\\2if self\\.use_vllm:)\",\n        function,\n        flags = re.DOTALL | re.MULTILINE,\n    )\n    if len(found) != 0:\n        replace_part, spacing = found[0]\n        removed_comments = re.sub(r\"\\#[^\\n]{1,}\", \"\", replace_part)\n        splits = removed_comments.split(\"\\n\")\n        if (\n            sum(re.match(rf\"{spacing}[^\\s]\", x) is not None for x in splits) == 2\n            and len(spacing) >= 8\n        ):\n            new_replacement = f\"\"\"\\n{spacing}if self.max_prompt_length is not None:\n            # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.\n            # Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text,\n            # because we can't use `skip_special_tokens=True` (some special tokens are still needed for generation).\n            protected = [self.image_token_id, self.vision_start_token_id, self.vision_end_token_id]\n            protected = [token for token in protected if token is not None]\n            prompt_ids, prompt_mask = truncate_with_protected_tokens(\n                prompt_ids, prompt_mask, self.max_prompt_length, protected\n            )\n\n            prompts_text = [re.sub(rf\"^({{re.escape(self.pad_token)}})+\", \"\", text) for text in prompts_text]\n\n            # The chat template inserts a single image token into the prompt text. However, when this text is later\n            # tokenized, the single image token string is expanded into multiple image token IDs, depending on the\n            # image size. Since we're detokenizing here, we may see repeated image tokens in the decoded text. We\n            # collapse them back into a single token string to match the original template.\n            if self.image_token is not None:\n                prompts_text = [\n                    re.sub(rf\"({{re.escape(self.image_token)}})+\", self.image_token, text) for text in prompts_text\n                ]\n        # Generate completions using either vLLM or regular generation\n        if self.use_vllm:\"\"\"\n            function = function.replace(replace_part, new_replacement)\n\n    # Important note: we disable TRL's importance sampling logic\n    # It is disabled because the LLM path moves left padding to the right.\n    # We must adjust the vLLM sampling_logprob tensor in Unsloth to account for this.\n    string_to_find = \"if self.use_vllm and self.vllm_importance_sampling_correction:\"\n\n    replacement_string = (\n        \"if False and self.use_vllm and self.vllm_importance_sampling_correction:\"\n    )\n\n    function = function.replace(string_to_find, replacement_string)\n\n    string_to_find = \"\"\"        if \"image_sizes\" in prompt_inputs:\n            output[\"image_sizes\"] = prompt_inputs[\"image_sizes\"]\"\"\"\n\n    replacement_string = \"\"\"        if \"image_sizes\" in prompt_inputs:\n            output[\"image_sizes\"] = prompt_inputs[\"image_sizes\"]\n        if max_left_pad is not None:\n            output[\"max_left_pad\"] = torch.tensor(prompt_ids.shape[0] * [max_left_pad]).unsqueeze(-1)\n        try:\n            if self.use_vllm and getattr(self, \"vllm_importance_sampling_correction\", False):\n                output[\"sampling_per_token_logps\"] = sampling_per_token_logps\n        except NameError:\n            output[\"sampling_per_token_logps\"] = None\"\"\"\n\n    function = function.replace(string_to_find, replacement_string)\n\n    # TRL 0.24.0+ extracts prompts = [x[\"prompt\"] for x in inputs], losing metadata\n    # like reasoning_effort. Inject code to store per-sample chat_template_kwargs on self.\n    _metadata_extraction = (\n        \"\\n\"\n        \"        # Unsloth: Extract per-sample chat_template_kwargs before metadata is lost\\n\"\n        \"        _ct_ = getattr(self.processing_class, 'chat_template', None) or ''\\n\"\n        \"        _sk_ = {'prompt', 'chosen', 'rejected', 'completion', 'messages', 'label',\\n\"\n        \"                'images', 'image', 'videos', 'video', 'audios', 'audio'}\\n\"\n        \"        self._unsloth_batch_chat_kwargs = []\\n\"\n        \"        for _inp_ in inputs:\\n\"\n        \"            _kw_ = {}\\n\"\n        \"            if isinstance(_inp_, dict):\\n\"\n        \"                for _k_ in _inp_.keys() - _sk_:\\n\"\n        \"                    if _k_ in _ct_ and isinstance(_inp_[_k_], str):\\n\"\n        \"                        _kw_[_k_] = _inp_[_k_]\\n\"\n        \"            self._unsloth_batch_chat_kwargs.append(_kw_)\\n\"\n    )\n    # Insert after: prompts = [x[\"prompt\"] for x in inputs]\n    _target_line = 'prompts = [x[\"prompt\"] for x in inputs]'\n    if _target_line in function:\n        function = function.replace(\n            _target_line,\n            _target_line + _metadata_extraction,\n        )\n\n    # This path is for TRL 0.24.0 images is a variable exclusive to this version\n    string_to_find = \"\"\"        if images is not None:\n            output[\"num_images\"] = num_images\"\"\"\n\n    replacement_string = \"\"\"        if images is not None:\n            output[\"num_images\"] = num_images\n        if max_left_pad is not None:\n            output[\"max_left_pad\"] = torch.tensor(prompt_ids.shape[0] * [max_left_pad]).unsqueeze(-1)\n        try:\n            if self.use_vllm and getattr(self, \"vllm_importance_sampling_correction\", False):\n                output[\"sampling_per_token_logps\"] = sampling_per_token_logps\n        except NameError:\n            output[\"sampling_per_token_logps\"] = None\"\"\"\n\n    function = function.replace(string_to_find, replacement_string)\n\n    if trl_version >= Version(\"0.24.0\"):\n        # We replace the call using 'completions' with one using 'completions_text'\n        string_to_find = \"        rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list)\"\n        replacement_string = (\n            \"        if images is not None:\\n\"\n            \"            rewards_per_func = self._calculate_rewards(inputs, prompts_text, completions_text, completion_ids_list)\\n\"\n            \"        else:\\n\"\n            \"            rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list)\"\n        )\n        function = function.replace(string_to_find, replacement_string)\n\n    if \"wake_up()\" not in function:\n        # Sleep functionality has been added to trl in v0.23.0. We do not want to redo this.\n        # https://github.com/huggingface/trl/commit/edbe8234bc7e528f72ac76607de9d3e4753e2709\n\n        pattern = re.compile(r\".*self\\.llm\\.generate\\(.*\\).*\", re.MULTILINE)\n        matches = list(pattern.finditer(function))\n        patched = function\n\n        # Generally there's only one match. But this is just to make sure we don't miss any.\n        for match in reversed(matches):\n            line = match.group(0)\n            indent_match = re.match(r\"(\\s*)\", line)\n            indent = indent_match.group(1) if indent_match else \"\"\n\n            wrapped = (\n                f\"{indent}if hasattr(self, 'llm'):\\n\"\n                f\"{indent}    if getattr(self.llm.llm_engine.vllm_config.model_config, 'enable_sleep_mode', False):\\n\"\n                f\"{indent}        self.llm.wake_up()\\n\"\n                f\"{line}\\n\\n\"\n                f\"{indent}if hasattr(self, 'llm'):\\n\"\n                f\"{indent}    if getattr(self.llm.llm_engine.vllm_config.model_config, 'enable_sleep_mode', False):\\n\"\n                f\"{indent}        self.llm.sleep(os.environ.get('VLLM_SLEEP_MODE', 1))\\n\"\n            )\n\n            patched = patched[: match.start()] + wrapped + patched[match.end() :]\n\n        function = patched\n\n    return function\n\n\nRL_FUNCTIONS[\"grpo_trainer\"].append(grpo_trainer__generate_and_score_completions)\n\n\n# Fix {\"reasoning_effort\" : \"high\"} not applied\ndef grpo_trainer_fix_maybe_apply_chat_template(function_name, function):\n    spaces = function.find(\"def \")\n    if spaces % 4 != 0:\n        return function\n    spaces += 4\n    replacement = \"\"\"\n        _chat_template_ = getattr(self.processing_class, \"chat_template\", None)\n        if _chat_template_ is None: _chat_template_ = \"\"\n        _supported_keys_ = set((\"prompt\", \"chosen\", \"rejected\", \"completion\", \"messages\", \"label\"))\n        _batch_chat_kwargs_ = getattr(self, \"_unsloth_batch_chat_kwargs\", None)\n\n        prompts_text = []\n        for _idx_, _example_ in enumerate(__INPUTS__REPLACEMENT__):\n            _tokenizer_kwargs_ = {}\n            if type(_example_) is not dict:\n                _example_ = {\"prompt\": _example_}\n            _left_keys_ = _example_.keys() - _supported_keys_\n            for k in _left_keys_:\n                if k in _chat_template_:\n                    v = _example_[k]\n                    if type(v) is str:\n                        _tokenizer_kwargs_[k] = v\n            if _batch_chat_kwargs_ is not None and _idx_ < len(_batch_chat_kwargs_):\n                for _bk_, _bv_ in _batch_chat_kwargs_[_idx_].items():\n                    if _bk_ not in _tokenizer_kwargs_:\n                        _tokenizer_kwargs_[_bk_] = _bv_\n            _x_ = maybe_apply_chat_template(_example_, self.processing_class, **_tokenizer_kwargs_)[\"prompt\"]\n            prompts_text.append(_x_)\n    \"\"\"\n    replacement = textwrap.dedent(replacement).strip()\n    replacement = textwrap.indent(replacement, spaces * \" \")\n    replacement = f\"\\n{replacement}\\n\"\n    what = 'prompts_text = [maybe_apply_chat_template(example, self.processing_class)[\"prompt\"] for example in inputs]'\n    function = function.replace(\n        what, replacement.replace(\"__INPUTS__REPLACEMENT__\", \"inputs\")\n    )\n\n    \"\"\"prompts_text = [\n        maybe_apply_chat_template({\"prompt\": prompt}, self.processing_class)[\"prompt\"] for prompt in prompts\n    ]\"\"\"\n    function = re.sub(\n        r\"prompts_text = \\[\"\n        r\"[\\s]{0,}\"\n        r\"maybe_apply_chat_template\\(\\{[\\\"\\']prompt[\\\"\\'][\\s]{0,}\\:[\\s]{0,}prompt[\\s]{0,}\\}[\\s]{0,}\\,[\\s]{0,}self\\.processing_class\\)\"\n        r\"\\[[\\\"\\']prompt[\\\"\\']\\] for prompt in prompts\"\n        r\"[\\s]{0,}\"\n        r\"\\]\",\n        replacement.replace(\"__INPUTS__REPLACEMENT__\", \"prompts\"),\n        function,\n    )\n    return function\n\n\nRL_FUNCTIONS[\"grpo_trainer\"].append(grpo_trainer_fix_maybe_apply_chat_template)\n\n\n# Remove _move_model_to_vllm\ndef grpo_trainer__move_model_to_vllm(function_name, function):\n    if function_name != \"_move_model_to_vllm\":\n        return function\n\n    def _move_model_to_vllm(self, *args, **kwargs):\n        return None\n\n    function = inspect.getsource(_move_model_to_vllm)\n    return function\n\n\nRL_FUNCTIONS[\"grpo_trainer\"].append(grpo_trainer__move_model_to_vllm)\n\n\n# Edit _get_per_token_logps to handle mixed precision\ndef grpo_trainer__get_per_token_logps(function_name, function):\n    if function_name != \"_get_per_token_logps\":\n        return function\n\n    def _get_per_token_logps(\n        self, model, input_ids, attention_mask, logits_to_keep, compute_efficient = False\n    ):\n        if True:  # os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0':\n            return None  # Unsloth efficient GRPO\n        # Otherwise, calculate normally:\n        if not hasattr(self, \"_autocast_dtype\"):\n            self._autocast_dtype = (\n                torch.float16\n                if os.environ.get(\"ACCELERATE_MIXED_PRECISION\", \"fp16\") == \"fp16\"\n                else torch.bfloat16\n            )\n            if os.environ.get(\"UNSLOTH_FORCE_FLOAT32\", \"0\") == \"1\":\n                self._autocast_dtype = torch.float16\n\n        os.environ[\"UNSLOTH_RETURN_HIDDEN_STATES\"] = \"1\"\n        with torch.amp.autocast(device_type = DEVICE_TYPE, dtype = self._autocast_dtype):\n            # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded\n            logits = model(\n                input_ids = input_ids,\n                attention_mask = attention_mask,\n                logits_to_keep = logits_to_keep + 1,\n            ).logits\n            # logits = logits[:, :-1, :]  # (B, L-1, V), exclude the last logit: it corresponds to the next token pred\n            return logits\n            # input_ids = input_ids[:, -logits_to_keep:]\n            # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.\n            # See https://github.com/huggingface/trl/issues/2770\n            # logits = logits[:, -logits_to_keep:]\n            # return logits\n            # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details\n            # logits = logits / self.temperature\n            # logps = selective_log_softmax(logits, input_ids)\n\n            # row_indices, col_indices = torch.where(logps < -20)\n\n            # # Method 1: Check if tensors have elements\n            # if len(row_indices) > 0 and len(col_indices) > 0:\n            #     breakpoint()  # Breakpoint triggered here\n            #     print(\"Found high values!\")\n            # return  logps #  compute logprobs for the input tokens\n\n    function = inspect.getsource(_get_per_token_logps)\n    return function\n\n\nRL_FUNCTIONS[\"grpo_trainer\"].append(grpo_trainer__get_per_token_logps)\n\n\ndef grpo_trainer__get_per_token_logps_and_entropies(function_name, function):\n    if function_name != \"_get_per_token_logps_and_entropies\":\n        return function\n\n    # Just copy over from _get_per_token_logps replacement function above. For now this returns None anyway\n    def _get_per_token_logps_and_entropies(\n        self,\n        model,\n        input_ids,\n        attention_mask,\n        logits_to_keep,\n        batch_size = None,\n        compute_entropy = False,\n        compute_efficient = False,\n        *args,\n        **kwargs,\n    ):\n        # All Unsloth code here in this function is licensed under AGPL3\n        # if True: # os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0':\n        #     return None, None  # logps, entropies Unsloth efficient GRPO\n        if compute_efficient:\n            return None, None\n        else:\n            if not hasattr(self, \"_autocast_dtype\"):\n                self._autocast_dtype = (\n                    torch.float16\n                    if os.environ.get(\"ACCELERATE_MIXED_PRECISION\", \"fp16\") == \"fp16\"\n                    else torch.bfloat16\n                )\n                if os.environ.get(\"UNSLOTH_FORCE_FLOAT32\", \"0\") == \"1\":\n                    self._autocast_dtype = torch.float16\n\n            pixel_values, image_grid_thw = (\n                kwargs.get(\"pixel_values\", None),\n                kwargs.get(\"image_grid_thw\", None),\n            )\n            pixel_attention_mask, image_sizes = (\n                kwargs.get(\"pixel_attention_mask\", None),\n                kwargs.get(\"image_sizes\", None),\n            )\n\n            unwrapped_model = self.accelerator.unwrap_model(\n                model, keep_fp32_wrapper = False\n            )\n\n            lm_head = self.model.get_output_embeddings().weight\n\n            dtype_bytes = (\n                16 if self._autocast_dtype in [torch.float16, torch.bfloat16] else 32\n            )\n            total_rows = input_ids.shape[0]\n            seq_len = input_ids.shape[1]\n            hidden_dim = lm_head.shape[1]\n            vocab_dim = lm_head.shape[0]\n\n            if self.args.unsloth_grpo_mini_batch is None:\n                B, multiplier = autotune_batch_and_chunks(\n                    total_rows,\n                    seq_len,\n                    hidden_dim,\n                    vocab_dim,\n                    dtype_bytes,\n                    self.args.unsloth_logit_chunk_multiplier,\n                )\n                B = total_rows // B\n            else:\n                B = self.args.unsloth_grpo_mini_batch\n\n                if self.args.unsloth_logit_chunk_multiplier is None:\n                    multiplier = max(4, seq_len // 4096)\n                else:\n                    multiplier = self.args.unsloth_logit_chunk_multiplier\n\n            all_logprobs_list = []\n            if pixel_values is None:\n                left_pad_tokens_per_prompt = calculate_pad_tokens_in_prompt(\n                    input_ids, logits_to_keep, self.processing_class.pad_token_id\n                )\n                max_left_pad = torch.max(left_pad_tokens_per_prompt).item()\n                input_ids = left_pack_padding(\n                    input_ids, self.processing_class.pad_token_id\n                )\n                attention_mask = input_ids != self.processing_class.pad_token_id\n                attention_mask = attention_mask.to(attention_mask.dtype)\n            else:\n                max_left_pad = 0\n\n            # input_ids_chunks = torch.chunk(input_ids, chunks = B, dim = 0)\n            attention_mask_chunks = torch.chunk(attention_mask, chunks = B, dim = 0)\n\n            def chunk_optional(tensor, chunks):\n                if tensor is None:\n                    return [None] * chunks\n                return torch.chunk(tensor, chunks = chunks, dim = 0)\n\n            import math\n\n            total_samples = input_ids.shape[0]\n            batch_size = math.ceil(total_samples / B)\n\n            input_ids_chunks = []\n            attention_mask_chunks = []\n            pixel_values_chunks = []\n            image_grid_thw_chunks = []\n            pixel_attention_mask_chunks = []\n\n            current_pixel_idx = 0\n            # TRL 0.23.0 batching logic\n            for start in range(0, total_samples, batch_size):\n                end = start + batch_size\n\n                input_ids_chunks.append(input_ids[start:end])\n                attention_mask_chunks.append(attention_mask[start:end])\n\n                if image_grid_thw is not None and pixel_values is not None:\n                    grid_slice = image_grid_thw[start:end]\n                    image_grid_thw_chunks.append(grid_slice)\n\n                    batch_pixel_count = grid_slice.prod(dim = -1).sum().item()\n\n                    start_pixel_idx = current_pixel_idx\n                    end_pixel_idx = current_pixel_idx + batch_pixel_count\n\n                    pixel_values_chunks.append(\n                        pixel_values[start_pixel_idx:end_pixel_idx]\n                    )\n\n                    if pixel_attention_mask is not None:\n                        pixel_attention_mask_chunks.append(\n                            pixel_attention_mask[start_pixel_idx:end_pixel_idx]\n                        )\n                    else:\n                        pixel_attention_mask_chunks.append(None)\n\n                    current_pixel_idx = end_pixel_idx\n\n                else:\n                    pixel_values_chunks.append(None)\n                    image_grid_thw_chunks.append(None)\n                    pixel_attention_mask_chunks.append(None)\n\n            if image_sizes is not None and not isinstance(image_sizes, torch.Tensor):\n                image_sizes_chunks = [[size] for size in image_sizes]\n            else:\n                image_sizes_chunks = chunk_optional(image_sizes, B)\n\n            temperature = self.temperature\n            logit_softcapping = getattr(model.config, \"final_logit_softcapping\", 0)\n            if logit_softcapping is None:\n                logit_softcapping = 0\n            logit_scale_multiply = getattr(model.config, \"logit_scale\", 0)\n            if logit_scale_multiply is None:\n                logit_scale_multiply = 0\n            logit_scale_divide = getattr(model.config, \"logits_scaling\", 0)\n            if logit_scale_divide is None:\n                logit_scale_divide = 0\n\n            zipped_inputs = zip(\n                input_ids_chunks,\n                attention_mask_chunks,\n                pixel_values_chunks,\n                image_grid_thw_chunks,\n                pixel_attention_mask_chunks,\n                image_sizes_chunks,\n            )\n            os.environ[\"UNSLOTH_RETURN_HIDDEN_STATES\"] = \"1\"\n\n            with _get_inference_mode_context_manager(model):\n                for (\n                    input_ids_chunk,\n                    attention_mask_chunk,\n                    pixel_values_chunk,\n                    image_grid_thw_chunk,\n                    pixel_attention_mask_chunk,\n                    image_sizes_chunk,\n                ) in zipped_inputs:\n                    with torch.amp.autocast(\n                        device_type = \"cuda\", dtype = self._autocast_dtype\n                    ):\n                        if pixel_values is None:\n                            logits_chunk = unwrapped_model(\n                                input_ids = input_ids_chunk,\n                                attention_mask = attention_mask_chunk,\n                                pixel_values = pixel_values_chunk,\n                                image_grid_thw = image_grid_thw_chunk,\n                                pixel_attention_mask = pixel_attention_mask_chunk,\n                                image_sizes = image_sizes_chunk,\n                            ).logits\n\n                            completion_input_ids_chunk = input_ids_chunk[\n                                :, -(logits_to_keep + max_left_pad) :\n                            ]\n                            logits_chunk = logits_chunk[\n                                :, -(logits_to_keep + max_left_pad + 1) :, :\n                            ]\n                            logits_chunk = logits_chunk[:, :-1, :]\n                            logprobs_chunk = (\n                                chunked_hidden_states_selective_log_softmax(\n                                    logits_chunk,\n                                    lm_head,\n                                    completion_input_ids_chunk,\n                                    chunks = input_ids_chunk.shape[0] * multiplier,\n                                    logit_scale_multiply = logit_scale_multiply,\n                                    logit_scale_divide = logit_scale_divide,\n                                    logit_softcapping = logit_softcapping,\n                                    temperature = temperature,\n                                )\n                            )\n                        else:\n                            # Essentially, for VLMs we do not go via the optimized path in models/,\n                            # so we don't encounter the Flash Attn left-padding issue.\n                            logits_chunk = unwrapped_model(\n                                input_ids = input_ids_chunk,\n                                attention_mask = attention_mask_chunk,\n                                pixel_values = pixel_values_chunk,\n                                image_grid_thw = image_grid_thw_chunk,\n                                pixel_attention_mask = pixel_attention_mask_chunk,\n                                image_sizes = image_sizes_chunk,\n                                logits_to_keep = logits_to_keep + 1,\n                            ).logits\n\n                            logits_chunk = logits_chunk[:, :-1, :]\n                            completion_input_ids_chunk = input_ids_chunk[\n                                :, -logits_to_keep:\n                            ]\n                            # Guard: check if model returned hidden states or logits\n                            if logits_chunk.shape[-1] == lm_head.shape[1]:\n                                logprobs_chunk = (\n                                    chunked_hidden_states_selective_log_softmax(\n                                        logits_chunk,\n                                        lm_head,\n                                        completion_input_ids_chunk,\n                                        chunks = input_ids_chunk.shape[0] * multiplier,\n                                        logit_scale_multiply = logit_scale_multiply,\n                                        logit_scale_divide = logit_scale_divide,\n                                        logit_softcapping = logit_softcapping,\n                                        temperature = temperature,\n                                    )\n                                )\n                            else:\n                                # Model returned logits directly - scaling/softcapping already applied by model forward\n                                logprobs_chunk = chunked_selective_log_softmax(\n                                    logits_chunk,\n                                    completion_input_ids_chunk,\n                                    temperature,\n                                )\n                    # This is needed to avoid race conditions with GPT OSS offload_embbed=True\n                    # However, it seems that this line does not slow down or disrupt models.\n                    device_synchronize()\n                    all_logprobs_list.append(logprobs_chunk)\n                logprobs = torch.cat(all_logprobs_list, dim = 0)\n                entropies = None\n\n            os.environ[\"UNSLOTH_RETURN_HIDDEN_STATES\"] = \"0\"\n\n            return logprobs.detach(), entropies  # logps, entropies\n            # input_ids = input_ids[:, -logits_to_keep:]\n            # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.\n            # See https://github.com/huggingface/trl/issues/2770\n            # logits = logits[:, -logits_to_keep:]\n            # return logits\n            # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details\n            # logits = logits / self.temperature\n            # logps = selective_log_softmax(logits, input_ids)\n\n            # row_indices, col_indices = torch.where(logps < -20)\n\n            # # Method 1: Check if tensors have elements\n            # if len(row_indices) > 0 and len(col_indices) > 0:\n            #     breakpoint()  # Breakpoint triggered here\n            #     print(\"Found high values!\")\n            # return  logps #  compute logprobs for the input tokens\n\n    function = inspect.getsource(_get_per_token_logps_and_entropies)\n    return function\n\n\nRL_FUNCTIONS[\"grpo_trainer\"].append(grpo_trainer__get_per_token_logps_and_entropies)\n\ngrpo_compute_loss = RL_REPLACEMENTS[\"grpo_compute_loss\"]\ngrpo_compute_loss_slow = RL_REPLACEMENTS[\"grpo_compute_loss_slow\"]\nUnslothEfficientGRPO = RL_REPLACEMENTS[\"UnslothEfficientGRPO\"]\ngrpo_accumulated_loss = RL_REPLACEMENTS[\"grpo_accumulated_loss\"]\ngrpo_update_SamplingParams = RL_REPLACEMENTS[\"grpo_update_SamplingParams\"]\nRL_PRE_ITEMS[\"grpo_trainer\"].append(inspect.getsource(grpo_compute_loss))\nRL_PRE_ITEMS[\"grpo_trainer\"].append(inspect.getsource(UnslothEfficientGRPO))\nRL_PRE_ITEMS[\"grpo_trainer\"].append(inspect.getsource(grpo_accumulated_loss))\nRL_PRE_ITEMS[\"grpo_trainer\"].append(grpo_compute_loss_slow)\nRL_PRE_ITEMS[\"grpo_trainer\"].append(inspect.getsource(grpo_update_SamplingParams))\nRL_PRE_ITEMS[\"grpo_trainer\"].append(\n    inspect.getsource(_get_inference_mode_context_manager)\n)\n\n\n# Edit _get_per_token_logps to handle mixed precision\ndef grpo_trainer_compute_loss(function_name, function):\n    if function_name != \"compute_loss\":\n        return function\n\n    def compute_loss(\n        self, model, inputs, return_outputs = False, num_items_in_batch = None\n    ):\n        if return_outputs:\n            raise ValueError(\"The GRPOTrainer does not support returning outputs\")\n        # Compute the per-token log probabilities for the model\n\n        prompt_ids, prompt_mask = inputs[\"prompt_ids\"], inputs[\"prompt_mask\"]\n        completion_ids, completion_mask = (\n            inputs[\"completion_ids\"],\n            inputs[\"completion_mask\"],\n        )\n        pixel_values, image_grid_thw = (\n            inputs.get(\"pixel_values\", None),\n            inputs.get(\"image_grid_thw\", None),\n        )\n        pixel_attention_mask, image_sizes = (\n            inputs.get(\"pixel_attention_mask\", None),\n            inputs.get(\"image_sizes\", None),\n        )\n        num_items_in_batch = inputs.get(\"num_items_in_batch\", None)\n        sampling_per_token_logps = inputs.get(\"sampling_per_token_logps\", None)\n        current_gradient_accumulation_steps = self.current_gradient_accumulation_steps\n        num_processes = self.accelerator.num_processes\n\n        input_ids = torch.cat([prompt_ids, completion_ids], dim = 1)\n        bsz, qlen = input_ids.shape\n        attention_mask = torch.cat([prompt_mask, completion_mask], dim = 1)\n        # attention_mask = None\n        logits_to_keep = completion_ids.size(\n            1\n        )  # we only need to compute the logits for the completion tokens\n        _input_ids = input_ids\n        _logits_to_keep = logits_to_keep\n\n        get_logps_func = (\n            lambda model,\n            input_ids,\n            attention_mask,\n            logits_to_keep,\n            batch_size = None,\n            compute_entropy = False,\n            compute_efficient = False: self._get_per_token_logps(\n                model, input_ids, attention_mask, logits_to_keep, compute_efficient\n            )\n            if hasattr(self, \"_get_per_token_logps\")\n            else self._get_per_token_logps_and_entropies(\n                model,\n                input_ids,\n                attention_mask,\n                logits_to_keep,\n                batch_size,\n                compute_entropy,\n                compute_efficient,\n            )[0]\n        )  # logps\n\n        per_token_logps = get_logps_func(\n            model, input_ids, attention_mask, logits_to_keep, compute_efficient = True\n        )\n        # Compute the KL divergence between the model and the reference model\n        # _prepare_inputs doesn't return reference log probs anymore. We need to calculate it ourselves.\n        # https://github.com/huggingface/trl/blob/05bc43e960396581e458195b8388efe6b82cae1f/trl/trainer/grpo_trainer.py#L1328\n        # if self.beta != 0.0:\n        #     with torch.inference_mode(), model.disable_adapter():\n        #         ref_per_token_logps = per_token_logps = get_logps_func(model, input_ids, attention_mask, logits_to_keep)\n        # else:\n        #     ref_per_token_logps = None\n        ref_logps = inputs.get(\"ref_per_token_logps\", None)\n        # per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1\n        # x - x.detach() allows for preserving gradients from x\n        advantages = inputs[\"advantages\"]\n        # per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)\n        # per_token_loss = -(per_token_loss - self.beta * per_token_kl)\n        # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()\n        old_logps = inputs.get(\"old_per_token_logps\", None)\n\n        input_ids = input_ids[:, -logits_to_keep:]\n\n        # Get logit softcapping and logit scale\n        logit_softcapping = getattr(model.config, \"final_logit_softcapping\", 0)  # Gemma\n        if logit_softcapping is None:\n            logit_softcapping = 0\n        logit_scale_multiply = getattr(model.config, \"logit_scale\", 0)  # Cohere\n        if logit_scale_multiply is None:\n            logit_scale_multiply = 0\n        logit_scale_divide = getattr(model.config, \"logits_scaling\", 0)  # Granite\n        if logit_scale_divide is None:\n            logit_scale_divide = 0\n\n        max_left_pad = inputs.get(\"max_left_pad\", 0)\n        if per_token_logps is not None:\n            (\n                loss,\n                completion_length,\n                mean_kl,\n                delta,\n                flat_is_ratio,\n                coef_1,\n                completion_mask,\n            ) = grpo_compute_loss_slow(\n                ref_logps,\n                per_token_logps,\n                old_logps,\n                input_ids,\n                completion_mask,\n                self.beta,\n                advantages,\n                pixel_values = pixel_values,\n                image_grid_thw = image_grid_thw,\n                loss_type = self.args.loss_type,\n                importance_sampling_level = self.importance_sampling_level,\n                epsilon_low = self.epsilon_low,\n                epsilon_high = self.epsilon_high,\n                max_completion_length = self.args.max_completion_length,\n                delta = self.args.delta,\n                temperature = self.args.temperature,\n                max_left_pad = max_left_pad,\n                logit_softcapping = logit_softcapping,\n                logit_scale_multiply = logit_scale_multiply,\n                logit_scale_divide = logit_scale_divide,\n                num_items_in_batch = num_items_in_batch,\n                current_gradient_accumulation_steps = current_gradient_accumulation_steps,\n                num_processes = num_processes,\n                sampling_per_token_logps = sampling_per_token_logps,\n            )\n        else:\n            if hasattr(self.args, \"loss_type\"):\n                (\n                    loss,\n                    completion_length,\n                    mean_kl,\n                    delta,\n                    flat_is_ratio,\n                    coef_1,\n                    completion_mask,\n                ) = grpo_accumulated_loss(\n                    trainer = self,\n                    input_ids = _input_ids,\n                    pixel_values = pixel_values,\n                    image_grid_thw = image_grid_thw,\n                    logits_to_keep = logits_to_keep,\n                    completion_mask = completion_mask,\n                    advantages = advantages,\n                    old_logps = old_logps,\n                    ref_logps = ref_logps,\n                    n_chunks = self.args.unsloth_num_chunks,\n                    loss_type = self.args.loss_type,\n                    importance_sampling_level = self.importance_sampling_level,\n                    epsilon_low = self.epsilon_low,\n                    epsilon_high = self.epsilon_high,\n                    max_completion_length = self.args.max_completion_length,\n                    delta = self.args.delta,\n                    temperature = self.args.temperature,\n                    max_left_pad = max_left_pad,\n                    logit_softcapping = logit_softcapping,\n                    logit_scale_multiply = logit_scale_multiply,\n                    logit_scale_divide = logit_scale_divide,\n                    attention_mask = attention_mask,\n                    num_items_in_batch = num_items_in_batch,\n                    current_gradient_accumulation_steps = current_gradient_accumulation_steps,\n                    num_processes = num_processes,\n                    sampling_per_token_logps = sampling_per_token_logps,\n                )\n            else:\n                # to ensure backwards compatibility with trl 0.15.2 and maybe even 0.17\n                loss, completion_length, mean_kl, coef_1, completion_mask = (\n                    grpo_accumulated_loss(\n                        trainer = self,\n                        input_ids = _input_ids,\n                        logits_to_keep = logits_to_keep,\n                        completion_mask = completion_mask,\n                        advantages = advantages,\n                        old_logps = old_logps,\n                        ref_logps = ref_logps,\n                        n_chunks = self.args.unsloth_num_chunks,\n                        temperature = self.args.temperature,\n                        logit_softcapping = logit_softcapping,\n                        logit_scale_multiply = logit_scale_multiply,\n                        logit_scale_divide = logit_scale_divide,\n                        attention_mask = attention_mask,\n                    )\n                )\n        if \"train\" in self._metrics:\n            mode = \"eval\" if self.control.should_evaluate else \"train\"\n            self._metrics[mode][\"completion_length\"].append(completion_length.item())\n            self._metrics[mode][\"kl\"].append(mean_kl.item())\n        else:\n            self._metrics[\"completion_length\"].append(completion_length.item())\n            self._metrics[\"kl\"].append(mean_kl.item())\n\n        if (\n            self.use_vllm\n            and delta is not None\n            and getattr(self, \"vllm_importance_sampling_correction\", False)\n        ):\n            mean_delta = (\n                torch.mean(delta)\n                if delta.numel() > 0\n                else torch.tensor(0.0, device = self.model.device)\n            )\n            max_delta = (\n                torch.max(delta)\n                if delta.numel() > 0\n                else torch.tensor(0.0, device = self.model.device)\n            )\n            self._metrics[mode][\"sampling/sampling_logp_difference/mean\"].append(\n                self.accelerator.gather(mean_delta).mean().item()\n            )\n            self._metrics[mode][\"sampling/sampling_logp_difference/max\"].append(\n                self.accelerator.gather(max_delta).max().item()\n            )\n\n            min_importance_sampling_ratio = (\n                torch.min(flat_is_ratio)\n                if flat_is_ratio.numel() > 0\n                else torch.tensor(0.0, device = self.model.device)\n            )\n            mean_importance_sampling_ratio = (\n                torch.mean(flat_is_ratio)\n                if flat_is_ratio.numel() > 0\n                else torch.tensor(0.0, device = self.model.device)\n            )\n            max_importance_sampling_ratio = (\n                torch.max(flat_is_ratio)\n                if flat_is_ratio.numel() > 0\n                else torch.tensor(0.0, device = self.model.device)\n            )\n            self._metrics[mode][\"sampling/importance_sampling_ratio/min\"].append(\n                self.accelerator.gather(min_importance_sampling_ratio)\n                .nan_to_num(nan = float(\"inf\"))\n                .min()\n                .item()\n            )\n            self._metrics[mode][\"sampling/importance_sampling_ratio/mean\"].append(\n                self.accelerator.gather(mean_importance_sampling_ratio).nanmean().item()\n            )\n            self._metrics[mode][\"sampling/importance_sampling_ratio/max\"].append(\n                self.accelerator.gather(max_importance_sampling_ratio)\n                .nan_to_num(nan = float(\"-inf\"))\n                .max()\n                .item()\n            )\n\n        completion_token_count = completion_mask.sum().clamp(min = 1.0)\n\n        def masked_batch_mean(x):\n            if x.shape[1] == 1:  # when importance_sampling_level == \"sequence\"\n                return x.mean()\n            else:\n                return (x * completion_mask).sum() / completion_token_count\n\n        if advantages.dim() == 1:\n            advantages = advantages.unsqueeze(1)\n\n        if self.loss_type in [\"grpo\", \"bnpo\", \"dr_grpo\", \"dapo\"]:\n            # Compute the clipped probability ratios\n            is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages < 0)\n            is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages > 0)\n            is_region_clipped = is_low_clipped | is_high_clipped\n\n            low_clip = masked_batch_mean(is_low_clipped.float())\n            high_clip = masked_batch_mean(is_high_clipped.float())\n            clip_ratio = masked_batch_mean(is_region_clipped.float())\n\n            gathered_low_clip = self.accelerator.gather(low_clip)\n            self._metrics[mode][\"clip_ratio/low_mean\"].append(\n                gathered_low_clip.nanmean().item()\n            )\n            self._metrics[mode][\"clip_ratio/low_min\"].append(\n                nanmin(gathered_low_clip).item()\n            )\n            gathered_high_clip = self.accelerator.gather(high_clip)\n            self._metrics[mode][\"clip_ratio/high_mean\"].append(\n                gathered_high_clip.nanmean().item()\n            )\n            self._metrics[mode][\"clip_ratio/high_max\"].append(\n                nanmax(gathered_high_clip).item()\n            )\n            gathered_clip_ratio = self.accelerator.gather(clip_ratio)\n            self._metrics[mode][\"clip_ratio/region_mean\"].append(\n                gathered_clip_ratio.nanmean().item()\n            )\n        elif self.loss_type == \"cispo\":\n            is_cispo_clipped = (coef_1 > self.epsilon_high) & (advantages > 0)\n            cispo_clip_ratio = masked_batch_mean(is_cispo_clipped.float())\n            gathered_cispo_clip_ratio = self.accelerator.gather(cispo_clip_ratio)\n            self._metrics[mode][\"cispo_clip_ratio\"].append(\n                gathered_cispo_clip_ratio.nanmean().item()\n            )\n\n        return loss\n\n    function = inspect.getsource(compute_loss)\n    return function\n\n\nRL_FUNCTIONS[\"grpo_trainer\"].append(grpo_trainer_compute_loss)\n\n\n# Fix KTO shape mismatch when Unsloth model forward truncates input_ids\n# but labels aren't truncated. TRL 0.27.2+ _process_tokens only truncates\n# completions, not prompts -- so prompts exceeding max_seq_length cause the\n# model to produce shorter logits than the labels expect.\ndef kto_trainer_get_batch_logps(function_name, function):\n    if function_name != \"get_batch_logps\":\n        return function\n    # The raise is inside an if block inside the method, so we need\n    # to preserve the exact indentation of the raise statement.\n    old = 'raise ValueError(\"Logits (batch and sequence length dim) and labels must have the same shape.\")'\n    new = (\n        \"# Unsloth: auto-truncate to shorter sequence length (model may have truncated input_ids)\\n\"\n        \"            _min_len = min(logits.shape[1], labels.shape[1])\\n\"\n        \"            logits = logits[:, :_min_len, :]\\n\"\n        \"            labels = labels[:, :_min_len]\"\n    )\n    function = function.replace(old, new)\n    return function\n\n\nRL_FUNCTIONS[\"kto_trainer\"].append(kto_trainer_get_batch_logps)\n\n\n# https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py#L356\n# TRL warns if batch size is not a multiple of num_generations -> fix this.\ndef grpo_trainer_fix_batch_size(RLTrainer_source, RLConfig_source):\n    if \"divisible by the number of generations\" not in RLTrainer_source:\n        # in later trl versions this doesn't exist anymore\n        return \"\"\n    if \"num_generations\" not in RLConfig_source:\n        return \"\"\n\n    check_batch_size = (\n        \"div = per_device_train_batch_size // num_generations\\n\"\n        \"if div * num_generations != per_device_train_batch_size:\\n\"\n        \"    print('Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.\\\\n\"\n        \"We will change the batch size of ' + str(per_device_train_batch_size) + ' to the `num_generations` of ' + str(num_generations))\\n\"\n        \"    per_device_train_batch_size = num_generations\\n\"\n    )\n    return check_batch_size\n\n\nRL_CONFIG_CHANGES[\"grpo_trainer\"].append(grpo_trainer_fix_batch_size)\n\n\n# Add other reward function names\ndef grpo_trainer_metrics(RLTrainer_source, RLConfig_source):\n    if \"reward_funcs\" not in RLTrainer_source:\n        return \"\"\n\n    # For new TRL we have /mean and /std\n    use_mean = \"rewards/{reward_func_name}/mean\" in RLTrainer_source\n    use_std = \"rewards/{reward_func_name}/std\" in RLTrainer_source\n    if not use_mean:\n        use_normal = \"rewards/{reward_func_name}\" in RLTrainer_source\n    else:\n        use_normal = False\n\n    log_metrics = (\n        \"if not isinstance(reward_funcs, list): _reward_funcs = [reward_funcs]\\n\"\n        \"else: _reward_funcs = reward_funcs\\n\"\n        \"for reward_func in _reward_funcs:\\n\"\n        \"    try:\\n\"\n        \"        reward_func_name = reward_func.__name__\\n\"\n        f\"        if {use_mean}:\\n\"\n        \"            other_metrics.append(f'rewards/{reward_func_name}/mean')\\n\"\n        f\"        if {use_std}:\\n\"\n        \"            other_metrics.append(f'rewards/{reward_func_name}/std')\\n\"\n        f\"        if {use_normal}:\\n\"\n        \"            other_metrics.append(f'rewards/{reward_func_name}')\\n\"\n        \"    except: pass\\n\"\n    )\n    return log_metrics\n\n\nRL_METRICS_CHANGES[\"grpo_trainer\"].append(grpo_trainer_metrics)\n\n\ndef openenv_vllm_reload_weights():\n    # This function patches the trl openenv generate_rollout_completions function to:\n    # 1. Remove the reload_weights call (unsloth handles weight reloading)\n    # 2. Fix wake_up call to be compatible with unsloth (remove tags to wake everything)\n    #\n    # The issue: TRL's wake_up(tags=[\"kv_cache\"]) only wakes kv_cache, leaving is_sleeping=True\n    # at the executor level. This causes unsloth's patched generate to try waking up again,\n    # resulting in double create_and_map on already-mapped handles.\n    #\n    # The fix: Use wake_up() with no tags, which wakes everything. Unsloth's patched\n    # CuMemAllocator.wake_up skips weights anyway, so this is safe.\n    if importlib.util.find_spec(\"trl\") is None:\n        return\n    if Version(importlib_version(\"trl\")) < Version(\"0.26.0\"):\n        return\n\n    try:\n        import trl.experimental.openenv.utils as openenv_utils\n        import trl.experimental.openenv as openenv\n    except (ImportError, NameError, Exception) as e:\n        logger.info(f\"Unsloth: Failed to import trl openenv: {e}\")\n        logger.info(\n            \"Unsloth: trl.experimental.openenv not available — skipping RL openenv patches.\"\n        )\n        return\n\n    # trl 0.28 changed the function name yet again! Thanks trl :)\n    patch_target_name = \"_generate_rollout_completions_colocate\"\n    if hasattr(openenv_utils, patch_target_name):\n        patch_target = getattr(openenv_utils, patch_target_name)\n    else:\n        # Older TRL versions may keep sleep/wake logic in the public dispatcher.\n        patch_target_name = \"generate_rollout_completions\"\n        patch_target = getattr(openenv_utils, patch_target_name)\n\n    src = inspect.getsource(patch_target)\n    src = textwrap.dedent(src)\n    original_src = src\n\n    # Remove the reload_weights call - unsloth handles this differently\n    src = re.sub(r'.*\\.collective_rpc\\(\\s*([\\'\"])reload_weights\\1\\s*\\).*\\n?', \"\", src)\n\n    # Change wake_up(tags=[\"kv_cache\"]) to wake_up() - wake everything to set is_sleeping=False\n    # This prevents double wake_up issues. Unsloth's allocator skips weights anyway.\n    src = re.sub(r\"\\.wake_up\\(tags=\\[.*?\\]\\)\", \".wake_up()\", src)\n\n    if original_src == src:\n        logger.warning(\"Unsloth: Warning - regex did not match, patch may have failed\")\n        return\n\n    # Execute and explicitly assign to module\n    local_ns = {}\n    exec(compile(src, \"<unsloth>\", \"exec\"), openenv_utils.__dict__, local_ns)\n    patched_func = local_ns[patch_target_name]\n\n    # Patch the target function in utils; if dispatcher was patched also update parent module alias.\n    setattr(openenv_utils, patch_target_name, patched_func)\n    if patch_target_name == \"generate_rollout_completions\":\n        openenv.generate_rollout_completions = patched_func\n    logger.info(f\"Unsloth: Patched trl openenv {patch_target_name}\")\n\n\nRL_ADDITIONAL_FUNCTIONS[\"openenv\"].append(openenv_vllm_reload_weights)\n\n\ndef vllm_generation_init_patch():\n    # trl moved vllm stuff to trl/generation/vllm_generation.py\n    # We need to patch it to not instantiate another vLLM instance if we already have one with fast_inference\n    # Edit the TRL source directly and install the patched function in the TRL module.\n    # https://github.com/huggingface/trl/commit/0eb66d8f2fc63b3d00d8dbc18f99c3f48750bd16\n    # This exists in trl versions 0.28.0 and above\n\n    if importlib.util.find_spec(\"trl\") is None:\n        return\n    if Version(importlib_version(\"trl\")) < Version(\"0.28.0\"):\n        return\n\n    try:\n        import trl.generation.vllm_generation as vllm_generation\n    except (ImportError, NameError, Exception) as e:\n        logger.info(f\"Unsloth: Failed to import trl.generation.vllm_generation: {e}\")\n        return\n\n    def patch_vllm_generation_method(method_name, transform, marker, filename_suffix):\n        method = getattr(vllm_generation.VLLMGeneration, method_name, None)\n        if method is None:\n            logger.info(f\"Unsloth: Could not find VLLMGeneration.{method_name}\")\n            return False\n\n        try:\n            src = inspect.getsource(method)\n        except Exception as e:\n            logger.info(\n                f\"Unsloth: Could not get source of VLLMGeneration.{method_name}: {e}\"\n            )\n            return False\n\n        src = textwrap.dedent(src)\n        if marker in src:\n            return True\n\n        src = transform(src)\n        filename = f\"<unsloth_trl_vllm_generation_{filename_suffix}_patch>\"\n        source_lines = [line + \"\\n\" for line in src.splitlines()]\n        linecache.cache[filename] = (\n            len(src),\n            None,\n            source_lines,\n            filename,\n        )\n\n        local_ns = {}\n        exec(compile(src, filename, \"exec\"), vllm_generation.__dict__, local_ns)\n        setattr(vllm_generation.VLLMGeneration, method_name, local_ns[method_name])\n        return True\n\n    # Patch init to remove vLLM.LLM instantiation\n    def patch_init_vllm(src):\n        pattern = re.compile(\n            r\"(?P<llm_block>^(?P<indent>[ \\t]*)self\\.llm\\s*=\\s*LLM\\s*\\(\\n(?:.*\\n)*?^(?P=indent)\\))\",\n            re.MULTILINE,\n        )\n\n        def replace_llm_block(match):\n            indent = match.group(\"indent\")\n            llm_block = textwrap.dedent(match.group(\"llm_block\"))\n            return (\n                f\"{indent}if hasattr(model, 'vllm_engine'):\\n\"\n                f\"{indent}    # Unsloth already inits vLLM in fast inference mode. Do not redo :)\\n\"\n                f\"{indent}    self.llm = model.vllm_engine\\n\"\n                f\"{indent}    self.unsloth_fast_inference_lora = True\\n\"\n                f\"{indent}else:\\n\" + textwrap.indent(llm_block, indent + \"    \")\n            )\n\n        patched_src, num_replacements = pattern.subn(replace_llm_block, src, count = 1)\n        if num_replacements == 0:\n            raise RuntimeError(\n                \"Unsloth: Warning - regex did not match, VLLMGeneration._init_vllm patch may have failed\"\n            )\n        return patched_src\n\n    # has some sync_weights or reload rpc calls.\n    # we patched the grpo_trainer to strip them for prev versions\n    # Ref: grpo_trainer__generate_single_turn above around L270-280\n    def patch_sync_weights(src):\n        pattern = re.compile(\n            r\"^(?P<def_line>def sync_weights\\(self\\):\\n)(?P<body>(?:.*\\n)*)\",\n            re.MULTILINE,\n        )\n\n        def replace_sync_weights(match):\n            body = match.group(\"body\")\n            guard = (\n                \"    if getattr(self, 'unsloth_fast_inference_lora', False):\\n\"\n                \"        # Unsloth fast inference LoRA shares weights with vLLM already.\\n\"\n                \"        return\\n\\n\"\n            )\n            return match.group(\"def_line\") + guard + body\n\n        patched_src, num_replacements = pattern.subn(replace_sync_weights, src, count = 1)\n        if num_replacements == 0:\n            raise RuntimeError(\n                \"Unsloth: Warning - regex did not match, VLLMGeneration.sync_weights patch may have failed\"\n            )\n        return patched_src\n\n    def patch_generate(src):\n        pattern = re.compile(\n            r\"^(?P<indent>[ \\t]*)self\\.llm\\.collective_rpc\\(\\s*(['\\\"])reload_weights\\2\\s*\\)\\s*$\",\n            re.MULTILINE,\n        )\n\n        def replace_reload_weights(match):\n            indent = match.group(\"indent\")\n            return f'{indent}pass  # self.llm.collective_rpc(\"reload_weights\")'\n\n        patched_src, num_replacements = pattern.subn(\n            replace_reload_weights, src, count = 1\n        )\n        if num_replacements == 0:\n            raise RuntimeError(\n                \"Unsloth: Warning - regex did not match, VLLMGeneration.generate patch may have failed\"\n            )\n        return patched_src\n\n    try:\n        init_patched = patch_vllm_generation_method(\n            \"_init_vllm\",\n            patch_init_vllm,\n            \"self.unsloth_fast_inference_lora = True\",\n            \"init_vllm\",\n        )\n        sync_patched = patch_vllm_generation_method(\n            \"sync_weights\",\n            patch_sync_weights,\n            \"if getattr(self, 'unsloth_fast_inference_lora', False):\",\n            \"sync_weights\",\n        )\n        generate_patched = patch_vllm_generation_method(\n            \"generate\",\n            patch_generate,\n            'pass  # self.llm.collective_rpc(\"reload_weights\")',\n            \"generate\",\n        )\n    except RuntimeError as e:\n        logger.warning(str(e))\n        return\n\n    if init_patched:\n        logger.info(\"Unsloth: Patched trl VLLMGeneration._init_vllm\")\n    if sync_patched:\n        logger.info(\"Unsloth: Patched trl VLLMGeneration.sync_weights\")\n    if generate_patched:\n        logger.info(\"Unsloth: Patched trl VLLMGeneration.generate\")\n\n\nRL_ADDITIONAL_FUNCTIONS[\"vllm_generation\"].append(vllm_generation_init_patch)\n"
  },
  {
    "path": "unsloth/models/sentence_transformer.py",
    "content": "# Copyright 2025 electroglyph. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\n\nfrom .loader import FastModel, DISABLE_SDPA_MODEL_NAMES\nfrom ._utils import SUPPORTS_BFLOAT16\nimport inspect\nimport json\nimport os\nimport types\nfrom huggingface_hub import hf_hub_download\nfrom typing import Optional\nimport torch\nfrom transformers.modeling_outputs import BaseModelOutput\nfrom collections import OrderedDict\nfrom transformers.models.distilbert import modeling_distilbert\nfrom transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa\nimport transformers\nfrom packaging.version import Version\nimport re\nfrom transformers import AutoModel, AutoConfig\nfrom transformers.models.auto.auto_factory import _get_model_class\nimport tempfile\nfrom huggingface_hub import HfApi, get_token\nfrom ..save import unsloth_save_pretrained_torchao, unsloth_save_pretrained_gguf\nimport contextlib\nimport shutil\n\n\ndef _save_pretrained_torchao(\n    self,\n    save_directory,\n    tokenizer = None,\n    torchao_config = None,\n    push_to_hub = False,\n    token = None,\n):\n    self.save_pretrained(save_directory)\n\n    # grab inner model\n    inner_model = self[0].auto_model\n    if hasattr(inner_model, \"_orig_mod\"):\n        inner_model = inner_model._orig_mod\n\n    # merge LoRA first\n    if hasattr(inner_model, \"merge_and_unload\"):\n        inner_model = inner_model.merge_and_unload()\n\n    # confirm Transformer path\n    transformer_path = \"0_Transformer\"\n    modules_path = os.path.join(save_directory, \"modules.json\")\n    if os.path.exists(modules_path):\n        try:\n            with open(modules_path, \"r\") as f:\n                modules = json.load(f)\n            for m in modules:\n                if m.get(\"type\", \"\").endswith(\"Transformer\"):\n                    transformer_path = m.get(\"path\", \"\")\n                    break\n        except:\n            pass\n\n    transformer_dir = os.path.join(save_directory, transformer_path)\n    transformer_dir = os.path.abspath(transformer_dir)\n\n    if tokenizer is None:\n        tokenizer = self.tokenizer\n\n    @contextlib.contextmanager\n    def patch_unsloth_save():\n        original_causal = transformers.AutoModelForCausalLM\n        original_rmtree = shutil.rmtree\n        # unsloth_save_pretrained_torchao expects AutoModelForCausalLM\n        transformers.AutoModelForCausalLM = transformers.AutoModel\n        # prevent unsloth from deleting the unquantized model directory\n        shutil.rmtree = lambda *args, **kwargs: None\n        try:\n            yield\n        finally:\n            # unpatch\n            transformers.AutoModelForCausalLM = original_causal\n            shutil.rmtree = original_rmtree\n\n    with patch_unsloth_save():\n        unsloth_save_pretrained_torchao(\n            inner_model,\n            transformer_dir,\n            tokenizer = tokenizer,\n            torchao_config = torchao_config,\n            push_to_hub = push_to_hub,\n            token = token,\n        )\n\n    # avoid `0_Transformer-torchao`, it was either this or fix modules.json\n    torchao_dir = transformer_dir + \"-torchao\"\n    if os.path.exists(torchao_dir):\n        if not os.path.exists(transformer_dir):\n            os.makedirs(transformer_dir, exist_ok = True)\n\n        # move contents\n        for item in os.listdir(torchao_dir):\n            s = os.path.join(torchao_dir, item)\n            d = os.path.join(transformer_dir, item)\n            if os.path.isdir(s):\n                shutil.copytree(s, d, dirs_exist_ok = True)\n            else:\n                shutil.copy2(s, d)\n\n        # remove torchao dir\n        shutil.rmtree(torchao_dir)\n\n        # remove conflicting safetensors if we brought in bin\n        if os.path.exists(os.path.join(transformer_dir, \"pytorch_model.bin\")):\n            safetensors_path = os.path.join(transformer_dir, \"model.safetensors\")\n            if os.path.exists(safetensors_path):\n                try:\n                    os.remove(safetensors_path)\n                except:\n                    pass\n\n    try:\n        FastSentenceTransformer._add_unsloth_branding(save_directory)\n    except:\n        pass\n\n\n# Thanks Etherl:\ndef _save_pretrained_gguf(\n    self,\n    save_directory,\n    tokenizer = None,\n    quantization_method = \"fast_quantized\",\n    first_conversion = None,\n    push_to_hub = False,\n    token = None,\n    max_shard_size = \"5GB\",\n    temporary_location = \"_unsloth_temporary_saved_buffers\",\n    maximum_memory_usage = 0.85,\n    **kwargs,\n):\n    \"\"\"\n    Saves the SentenceTransformer model to GGUF format by saving the inner transformer model,\n    converting it, and placing the resulting GGUF files in the save directory.\n    \"\"\"\n    # 1. Save standard SentenceTransformer structure (configs, modules.json, etc.)\n    self.save_pretrained(save_directory)\n\n    # 2. Extract inner transformer model\n    inner_model = self[0].auto_model\n    if hasattr(inner_model, \"_orig_mod\"):\n        inner_model = inner_model._orig_mod\n\n    # If it's a PEFT model, unsloth_save_pretrained_gguf handles merging,\n    # but we pass the inner model wrapper.\n\n    # 3. Identify where the transformer weights are stored\n    transformer_path = \"0_Transformer\"\n    modules_path = os.path.join(save_directory, \"modules.json\")\n    if os.path.exists(modules_path):\n        try:\n            with open(modules_path, \"r\") as f:\n                modules = json.load(f)\n            for m in modules:\n                if m.get(\"type\", \"\").endswith(\"Transformer\"):\n                    transformer_path = m.get(\"path\", \"\")\n                    break\n        except:\n            pass\n\n    # This is where Unsloth will perform the save + conversion operations\n    transformer_dir = os.path.join(save_directory, transformer_path)\n    # Ensure this path is absolute for consistent comparison later\n    transformer_dir = os.path.abspath(transformer_dir)\n\n    if tokenizer is None:\n        tokenizer = self.tokenizer\n\n    # 4. Patch environment to ensure Unsloth treats this embedding model correctly\n    @contextlib.contextmanager\n    def patch_unsloth_gguf_save():\n        # Prevent deletion of the directory we just created via self.save_pretrained\n        original_rmtree = shutil.rmtree\n        try:\n            yield\n        finally:\n            shutil.rmtree = original_rmtree\n\n    # 5. Call Unsloth's GGUF saver on the inner model targeting the transformer subdirectory\n    with patch_unsloth_gguf_save():\n        result = unsloth_save_pretrained_gguf(\n            inner_model,\n            save_directory = transformer_dir,\n            tokenizer = tokenizer,\n            quantization_method = quantization_method,\n            first_conversion = first_conversion,\n            push_to_hub = False,  # Force local first to move files\n            token = token,\n            max_shard_size = max_shard_size,\n            temporary_location = temporary_location,\n            maximum_memory_usage = maximum_memory_usage,\n        )\n\n    # 6. Move GGUF files from the subdirectory (0_Transformer) to the root save_directory\n    gguf_files = result.get(\"gguf_files\", [])\n\n    new_gguf_locations = []\n\n    for gguf_file in gguf_files:\n        if os.path.exists(gguf_file):\n            filename = os.path.basename(gguf_file)\n            dest_path = os.path.join(save_directory, filename)\n\n            # Convert to absolute path to avoid mixing relative/absolute in commonpath\n            abs_gguf_file = os.path.abspath(gguf_file)\n\n            # Check if file is inside transformer_dir (subpath)\n            try:\n                is_subpath = (\n                    os.path.commonpath([abs_gguf_file, transformer_dir])\n                    == transformer_dir\n                )\n            except ValueError:\n                # Can happen on Windows with different drives, or mix of absolute/relative (handled by abspath above)\n                is_subpath = False\n\n            if is_subpath:\n                # If the GGUF file is inside the transformer_dir, move it out to root\n                shutil.move(gguf_file, dest_path)\n                new_gguf_locations.append(dest_path)\n            else:\n                # If it's elsewhere, move it to root if not already there\n                if os.path.abspath(dest_path) != abs_gguf_file:\n                    shutil.move(gguf_file, dest_path)\n                new_gguf_locations.append(dest_path)\n\n    # Update result with new locations\n    result[\"gguf_files\"] = new_gguf_locations\n\n    # 7. Add branding\n    try:\n        FastSentenceTransformer._add_unsloth_branding(save_directory)\n\n        # Add GGUF details to README\n        readme_path = os.path.join(save_directory, \"README.md\")\n        if os.path.exists(readme_path):\n            with open(readme_path, \"a\", encoding = \"utf-8\") as f:\n                f.write(\"\\n## GGUF Quantization\\n\")\n                f.write(\n                    f\"This model contains GGUF quantized versions in: {', '.join([os.path.basename(f) for f in new_gguf_locations])}\\n\"\n                )\n    except:\n        pass\n\n    # 8. Handle Push to Hub if requested\n    if push_to_hub:\n        if token is None:\n            token = get_token()\n\n        api = HfApi(token = token)\n        repo_id = save_directory  # Assuming save_directory is the repo name if pushing\n\n        print(f\"Unsloth: Uploading to {repo_id}...\")\n        try:\n            api.create_repo(\n                repo_id = repo_id, exist_ok = True, private = kwargs.get(\"private\", False)\n            )\n            api.upload_folder(\n                folder_path = save_directory,\n                repo_id = repo_id,\n                commit_message = \"Upload GGUF and SentenceTransformer model\",\n            )\n            print(f\"Unsloth: Uploaded to https://huggingface.co/{repo_id}\")\n        except Exception as e:\n            print(f\"Unsloth: Upload failed: {e}\")\n\n    return result\n\n\ndef _push_to_hub_gguf(\n    self,\n    repo_id,\n    tokenizer = None,\n    quantization_method = \"fast_quantized\",\n    first_conversion = None,\n    token = None,\n    private = None,\n    commit_message = \"Upload GGUF SentenceTransformer model trained with Unsloth\",\n    commit_description = \"Upload GGUF model trained with Unsloth 2x faster\",\n    max_shard_size = \"5GB\",\n    temporary_location = \"_unsloth_temporary_saved_buffers\",\n    maximum_memory_usage = 0.85,\n    create_pr = False,\n    revision = None,\n    tags = None,\n    **kwargs,\n):\n    \"\"\"\n    Converts the SentenceTransformer model to GGUF format and pushes to the Hugging Face Hub.\n\n    This method:\n    1. Saves the model locally to a temporary directory in GGUF format.\n    2. Uploads the GGUF files, config, Ollama Modelfile, and README to the Hub.\n    3. Cleans up the temporary directory.\n\n    Args:\n        repo_id (str): The Hugging Face Hub repo ID (e.g., \"username/model-name\").\n        tokenizer: The tokenizer to save. Defaults to `self.tokenizer`.\n        quantization_method (str or list): GGUF quantization method(s). Can be a string or list of strings.\n            Choose from the following options:\n            * \"not_quantized\"  : Recommended. Fast conversion. Slow inference, big files.\n            * \"fast_quantized\" : Recommended. Fast conversion. OK inference, OK file size.\n            * \"quantized\"      : Recommended. Slow conversion. Fast inference, small files.\n            * \"f32\"     : Not recommended. Retains 100% accuracy, but super slow and memory hungry.\n            * \"f16\"     : Fastest conversion + retains 100% accuracy. Slow and memory hungry.\n            * \"q8_0\"    : Fast conversion. High resource use, but generally acceptable.\n            * \"q4_k_m\"  : Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K\n            * \"q5_k_m\"  : Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K\n            * \"q2_k\"    : Uses Q4_K for the attention.vw and feed_forward.w2 tensors, Q2_K for the other tensors.\n            * \"q3_k_l\"  : Uses Q5_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K\n            * \"q3_k_m\"  : Uses Q4_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K\n            * \"q3_k_s\"  : Uses Q3_K for all tensors\n            * \"q4_0\"    : Original quant method, 4-bit.\n            * \"q4_1\"    : Higher accuracy than q4_0 but not as high as q5_0. However has quicker inference than q5 models.\n            * \"q4_k_s\"  : Uses Q4_K for all tensors\n            * \"q5_0\"    : Higher accuracy, higher resource usage and slower inference.\n            * \"q5_1\"    : Even higher accuracy, resource usage and slower inference.\n            * \"q5_k_s\"  : Uses Q5_K for all tensors\n            * \"q6_k\"    : Uses Q8_K for all tensors\n        first_conversion (str, optional): The initial conversion format before quantization.\n        token (str, optional): Hugging Face token. Uses cached token if not provided.\n        private (bool, optional): Whether the repo should be private.\n        commit_message (str): Commit message for the upload.\n        commit_description (str): Commit description for the upload.\n        max_shard_size (str): Maximum shard size for saving.\n        temporary_location (str): Temp directory for intermediate files.\n        maximum_memory_usage (float): Max fraction of memory to use.\n        create_pr (bool): Whether to create a pull request instead of pushing directly.\n        revision (str, optional): Branch/revision to push to.\n        tags (list, optional): Additional tags for the repo.\n\n    Returns:\n        str: The full repo ID on Hugging Face Hub.\n    \"\"\"\n    if token is None:\n        token = get_token()\n    if token is None:\n        raise ValueError(\n            \"No HF token provided. Please provide a token or login with `huggingface-cli login`\"\n        )\n\n    api = HfApi(token = token)\n\n    # Determine full repo_id\n    if \"/\" not in repo_id:\n        username = api.whoami()[\"name\"]\n        full_repo_id = f\"{username}/{repo_id}\"\n    else:\n        full_repo_id = repo_id\n\n    model_name = full_repo_id.split(\"/\")[-1]\n\n    # Create repo\n    try:\n        api.create_repo(\n            repo_id = full_repo_id,\n            private = private,\n            exist_ok = True,\n            repo_type = \"model\",\n        )\n    except Exception as e:\n        print(f\"Unsloth Warning: Could not create repo: {e}\")\n\n    # Save to temporary directory first\n    with tempfile.TemporaryDirectory(prefix = \"unsloth_st_gguf_\") as temp_dir:\n        print(f\"Unsloth: Converting SentenceTransformer to GGUF format...\")\n\n        # Call save_pretrained_gguf to do the local conversion\n        result = _save_pretrained_gguf(\n            self,\n            save_directory = temp_dir,\n            tokenizer = tokenizer,\n            quantization_method = quantization_method,\n            first_conversion = first_conversion,\n            push_to_hub = False,  # We handle upload ourselves\n            token = token,\n            max_shard_size = max_shard_size,\n            temporary_location = temporary_location,\n            maximum_memory_usage = maximum_memory_usage,\n        )\n\n        gguf_files = result.get(\"gguf_files\", [])\n        modelfile_location = result.get(\"modelfile_location\", None)\n        is_vlm = result.get(\"is_vlm\", False)\n        fix_bos_token = result.get(\"fix_bos_token\", False)\n\n        print(f\"Unsloth: Uploading GGUF to https://huggingface.co/{full_repo_id}...\")\n\n        # Upload GGUF files\n        for file_location in gguf_files:\n            if os.path.exists(file_location):\n                filename = os.path.basename(file_location)\n                print(f\"  Uploading {filename}...\")\n                api.upload_file(\n                    path_or_fileobj = file_location,\n                    path_in_repo = filename,\n                    repo_id = full_repo_id,\n                    repo_type = \"model\",\n                    commit_message = commit_message,\n                    commit_description = commit_description,\n                    create_pr = create_pr,\n                    revision = revision,\n                )\n\n        # Upload Modelfile if exists\n        if modelfile_location and os.path.exists(modelfile_location):\n            print(\"  Uploading Ollama Modelfile...\")\n            api.upload_file(\n                path_or_fileobj = modelfile_location,\n                path_in_repo = \"Modelfile\",\n                repo_id = full_repo_id,\n                repo_type = \"model\",\n                commit_message = f\"{commit_message} - Ollama Modelfile\",\n                create_pr = create_pr,\n                revision = revision,\n            )\n\n        # Upload config.json if exists\n        config_path = os.path.join(temp_dir, \"config.json\")\n        if os.path.exists(config_path):\n            print(\"  Uploading config.json...\")\n            api.upload_file(\n                path_or_fileobj = config_path,\n                path_in_repo = \"config.json\",\n                repo_id = full_repo_id,\n                repo_type = \"model\",\n                commit_message = f\"{commit_message} - config\",\n                create_pr = create_pr,\n                revision = revision,\n            )\n\n        # Create and upload README\n        gguf_basenames = [os.path.basename(f) for f in gguf_files if os.path.exists(f)]\n        readme_content = f\"\"\"---\ntags:\n- gguf\n- llama.cpp\n- unsloth\n- sentence-transformers\n{\"- vision-language-model\" if is_vlm else \"\"}\n---\n\n# {model_name} - GGUF\n\nThis sentence-transformers model was finetuned and converted to GGUF format using [Unsloth](https://github.com/unslothai/unsloth).\n\n## Available Model files:\n\"\"\"\n        for fname in gguf_basenames:\n            readme_content += f\"- `{fname}`\\n\"\n\n        if modelfile_location and os.path.exists(modelfile_location):\n            readme_content += \"\\n## Ollama\\n\"\n            readme_content += \"An Ollama Modelfile is included for easy deployment.\\n\"\n\n        if fix_bos_token:\n            readme_content += \"\\n## Note\\n\"\n            readme_content += (\n                \"The model's BOS token behavior was adjusted for GGUF compatibility.\\n\"\n            )\n\n        readme_content += (\n            \"\\nThis was trained 2x faster with [Unsloth](https://github.com/unslothai/unsloth)\\n\"\n            '[<img src=\"https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20made%20with%20love.png\" width=\"200\"/>](https://github.com/unslothai/unsloth)\\n'\n        )\n\n        readme_path = os.path.join(temp_dir, \"README.md\")\n        with open(readme_path, \"w\", encoding = \"utf-8\") as f:\n            f.write(readme_content)\n\n        api.upload_file(\n            path_or_fileobj = readme_path,\n            path_in_repo = \"README.md\",\n            repo_id = full_repo_id,\n            repo_type = \"model\",\n            commit_message = \"Add README\",\n            create_pr = create_pr,\n            revision = revision,\n        )\n\n    # Add tags\n    all_tags = [\"gguf\", \"llama-cpp\", \"unsloth\", \"sentence-transformers\"]\n    if is_vlm:\n        all_tags.append(\"vision-language-model\")\n    if tags is not None:\n        if isinstance(tags, (list, tuple)):\n            all_tags.extend(tags)\n        else:\n            all_tags.append(tags)\n    try:\n        api.add_tags(repo_id = full_repo_id, tags = all_tags, repo_type = \"model\")\n    except:\n        pass\n\n    print(\n        f\"Unsloth: Successfully uploaded GGUF to https://huggingface.co/{full_repo_id}\"\n    )\n    return full_repo_id\n\n\nclass FastSentenceTransformer(FastModel):\n    @staticmethod\n    def _read_pooling_mode(model_name, token):\n        \"\"\"\n        Read the pooling mode from the modules.json file if it exists, otherwise return \"mean\".\n        \"\"\"\n        try:\n            if os.path.exists(model_name) and os.path.exists(\n                os.path.join(model_name, \"modules.json\")\n            ):\n                modules_json_path = os.path.join(model_name, \"modules.json\")\n            else:\n                modules_json_path = hf_hub_download(\n                    model_name, \"modules.json\", token = token\n                )\n\n            with open(modules_json_path, \"r\") as f:\n                modules_config = json.load(f)\n\n            pooling_config_path = None\n            for module in modules_config:\n                if module.get(\"type\", \"\") == \"sentence_transformers.models.Pooling\":\n                    pooling_path = module.get(\"path\", \"\")\n                    if pooling_path:\n                        # try to find config.json for pooling module\n                        if os.path.exists(model_name) and os.path.exists(\n                            os.path.join(model_name, pooling_path, \"config.json\")\n                        ):\n                            pooling_config_path = os.path.join(\n                                model_name, pooling_path, \"config.json\"\n                            )\n                        else:\n                            pooling_config_path = hf_hub_download(\n                                model_name,\n                                os.path.join(pooling_path, \"config.json\"),\n                                token = token,\n                            )\n                        break\n\n            if pooling_config_path:\n                with open(pooling_config_path, \"r\") as f:\n                    pooling_config = json.load(f)\n                    # from here:\n                    # https://github.com/huggingface/sentence-transformers/blob/main/sentence_transformers/models/Pooling.py#L43\n                    pooling_map = {\n                        \"pooling_mode_cls_token\": \"cls\",\n                        \"pooling_mode_mean_tokens\": \"mean\",\n                        \"pooling_mode_max_tokens\": \"max\",\n                        \"pooling_mode_mean_sqrt_len_tokens\": \"mean_sqrt_len\",\n                        \"pooling_mode_weightedmean_tokens\": \"weightedmean\",\n                        \"pooling_mode_lasttoken\": \"lasttoken\",\n                    }\n                    for config_key, mode in pooling_map.items():\n                        if pooling_config.get(config_key):\n                            if mode != \"mean\":\n                                print(f\"Pooling mode detected as {mode}, updating...\")\n                            return mode\n\n        except Exception as e:\n            print(\n                f\"Failed to detect pooling mode, not a sentence-transformers model. Using default pooling mode 'mean', this may or may not work.\"\n            )\n            return \"mean\"\n\n    # should prolly be done upstream instead of this hackfest here\n    @staticmethod\n    def _patch_mpnet_v4():\n        \"\"\"\n        Patch the MPNetModel to support gradient checkpointing.\n        Supports transformers 4.\n        \"\"\"\n        from transformers.models.mpnet import modeling_mpnet\n\n        # add supports_gradient_checkpointing flag\n        modeling_mpnet.MPNetModel.supports_gradient_checkpointing = True\n\n        # add _set_gradient_checkpointing method\n        def _set_gradient_checkpointing(self, module = None, value = True):\n            if module is None:\n                module = self.encoder\n            if isinstance(module, modeling_mpnet.MPNetEncoder):\n                module.gradient_checkpointing = value\n\n        modeling_mpnet.MPNetModel._set_gradient_checkpointing = (\n            _set_gradient_checkpointing\n        )\n\n        # patch MPNetEncoder.forward to support checkpointing\n        # based on:\n        # https://github.com/huggingface/transformers/blob/v4.57.3/src/transformers/models/mpnet/modeling_mpnet.py#L321\n        def forward(\n            self,\n            hidden_states: torch.Tensor,\n            attention_mask: Optional[torch.Tensor] = None,\n            head_mask: Optional[torch.Tensor] = None,\n            output_attentions: bool = False,\n            output_hidden_states: bool = False,\n            return_dict: bool = False,\n            **kwargs,\n        ):\n            position_bias = self.compute_position_bias(hidden_states)\n            all_hidden_states = () if output_hidden_states else None\n            all_attentions = () if output_attentions else None\n\n            for i, layer_module in enumerate(self.layer):\n                if output_hidden_states:\n                    all_hidden_states = all_hidden_states + (hidden_states,)\n\n                # do gradient checkpointing if enabled and training\n                if getattr(self, \"gradient_checkpointing\", False) and self.training:\n\n                    def create_custom_forward(module):\n                        # bog standard checkpoint\n                        def custom_forward(*inputs):\n                            return module(*inputs, output_attentions = output_attentions)\n\n                        return custom_forward\n\n                    layer_outputs = torch.utils.checkpoint.checkpoint(\n                        create_custom_forward(layer_module),\n                        hidden_states,\n                        attention_mask,\n                        head_mask[i] if head_mask is not None else None,\n                        position_bias,\n                        use_reentrant = True,  # fix for torch 2.9\n                    )\n                else:\n                    # original code from here on\n                    layer_outputs = layer_module(\n                        hidden_states,\n                        attention_mask,\n                        head_mask[i] if head_mask is not None else None,\n                        position_bias,\n                        output_attentions = output_attentions,\n                        **kwargs,\n                    )\n\n                hidden_states = layer_outputs[0]\n\n                if output_attentions:\n                    all_attentions = all_attentions + (layer_outputs[1],)\n\n            if output_hidden_states:\n                all_hidden_states = all_hidden_states + (hidden_states,)\n\n            if not return_dict:\n                return tuple(\n                    v\n                    for v in [hidden_states, all_hidden_states, all_attentions]\n                    if v is not None\n                )\n            return BaseModelOutput(\n                last_hidden_state = hidden_states,\n                hidden_states = all_hidden_states,\n                attentions = all_attentions,\n            )\n\n        # assign the patched forward\n        modeling_mpnet.MPNetEncoder.forward = forward\n\n    @staticmethod\n    def _patch_mpnet_v5():\n        \"\"\"\n        Patch the MPNetModel to support gradient checkpointing.\n        Supports transformers 5.\n        \"\"\"\n        from transformers.models.mpnet import modeling_mpnet\n\n        # add supports_gradient_checkpointing flag\n        modeling_mpnet.MPNetModel.supports_gradient_checkpointing = True\n\n        # add _set_gradient_checkpointing method\n        def _set_gradient_checkpointing(self, module = None, value = True):\n            if module is None:\n                module = self.encoder\n            if isinstance(module, modeling_mpnet.MPNetEncoder):\n                module.gradient_checkpointing = value\n\n        modeling_mpnet.MPNetModel._set_gradient_checkpointing = (\n            _set_gradient_checkpointing\n        )\n\n        # patch MPNetEncoder.forward to support checkpointing\n        # based on:\n        # https://github.com/huggingface/transformers/blob/v5.0.0rc1/src/transformers/models/mpnet/modeling_mpnet.py#L284\n        def forward(\n            self,\n            hidden_states: torch.Tensor,\n            attention_mask: Optional[torch.Tensor] = None,\n            output_attentions: bool = False,\n            output_hidden_states: bool = False,\n            return_dict: bool = False,\n            **kwargs,\n        ):\n            position_bias = self.compute_position_bias(hidden_states)\n            all_hidden_states = () if output_hidden_states else None\n            all_attentions = () if output_attentions else None\n\n            for i, layer_module in enumerate(self.layer):\n                if output_hidden_states:\n                    all_hidden_states = all_hidden_states + (hidden_states,)\n\n                # do gradient checkpointing if enabled and training\n                if getattr(self, \"gradient_checkpointing\", False) and self.training:\n\n                    def create_custom_forward(module):\n                        # checkpoint\n                        def custom_forward(*inputs):\n                            return module(*inputs, output_attentions = output_attentions)\n\n                        return custom_forward\n\n                    layer_outputs = torch.utils.checkpoint.checkpoint(\n                        create_custom_forward(layer_module),\n                        hidden_states,\n                        attention_mask,\n                        position_bias,\n                        use_reentrant = True,  # required for torch >= 2.9\n                    )\n                else:\n                    # original code from here on\n                    layer_outputs = layer_module(\n                        hidden_states,\n                        attention_mask,\n                        position_bias,\n                        output_attentions,\n                        **kwargs,\n                    )\n\n                hidden_states = layer_outputs[0]\n\n                if output_attentions:\n                    all_attentions = all_attentions + (layer_outputs[1],)\n\n            if output_hidden_states:\n                all_hidden_states = all_hidden_states + (hidden_states,)\n\n            if not return_dict:\n                return tuple(\n                    v\n                    for v in [hidden_states, all_hidden_states, all_attentions]\n                    if v is not None\n                )\n            return BaseModelOutput(\n                last_hidden_state = hidden_states,\n                hidden_states = all_hidden_states,\n                attentions = all_attentions,\n            )\n\n        modeling_mpnet.MPNetEncoder.forward = forward\n\n    @staticmethod\n    def _patch_distilbert_v4():\n        # change kwargs to positional args to be compatible with peft_utils\n        \"\"\"\n        Patch the forward method of the DistilBertModel to use positional arguments instead of keyword arguments.\n        Transformers 4 version.\n        \"\"\"\n\n        # based on:\n        # https://github.com/huggingface/transformers/blob/v4.57.3/src/transformers/models/distilbert/modeling_distilbert.py#L666\n        # original code from here on:\n        def forward(\n            self,\n            input_ids: Optional[torch.Tensor] = None,\n            attention_mask: Optional[torch.Tensor] = None,\n            head_mask: Optional[torch.Tensor] = None,\n            inputs_embeds: Optional[torch.Tensor] = None,\n            output_attentions: Optional[bool] = None,\n            output_hidden_states: Optional[bool] = None,\n            return_dict: Optional[bool] = None,\n        ):\n            output_attentions = (\n                output_attentions\n                if output_attentions is not None\n                else self.config.output_attentions\n            )\n            output_hidden_states = (\n                output_hidden_states\n                if output_hidden_states is not None\n                else self.config.output_hidden_states\n            )\n            return_dict = (\n                return_dict if return_dict is not None else self.config.use_return_dict\n            )\n\n            if input_ids is not None and inputs_embeds is not None:\n                raise ValueError(\n                    \"You cannot specify both input_ids and inputs_embeds at the same time\"\n                )\n            elif input_ids is not None:\n                self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)\n                input_shape = input_ids.size()\n            elif inputs_embeds is not None:\n                input_shape = inputs_embeds.size()[:-1]\n            else:\n                raise ValueError(\n                    \"You have to specify either input_ids or inputs_embeds\"\n                )\n\n            device = input_ids.device if input_ids is not None else inputs_embeds.device\n\n            head_mask_is_none = head_mask is None\n            # Prepare head mask if needed\n            head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)\n\n            embeddings = self.embeddings(\n                input_ids, inputs_embeds\n            )  # (bs, seq_length, dim)\n\n            if self.config._attn_implementation == \"flash_attention_2\":\n                attention_mask = (\n                    attention_mask\n                    if (attention_mask is not None and 0 in attention_mask)\n                    else None\n                )\n            else:\n                if attention_mask is None:\n                    attention_mask = torch.ones(\n                        input_shape, device = device\n                    )  # (bs, seq_length)\n\n                if (\n                    self.config._attn_implementation == \"sdpa\"\n                    and head_mask_is_none\n                    and not output_attentions\n                ):\n                    attention_mask = _prepare_4d_attention_mask_for_sdpa(\n                        attention_mask, embeddings.dtype, tgt_len = input_shape[1]\n                    )\n            # patch here, change kwargs to positional args:\n            return self.transformer(\n                embeddings,\n                attention_mask,\n                head_mask,\n                output_attentions,\n                output_hidden_states,\n                return_dict,\n            )\n\n        modeling_distilbert.DistilBertModel.forward = forward\n\n    @staticmethod\n    def _has_add_pooling_layer(config, auto_model_class = None):\n        \"\"\"\n        Checks if the model class supports the `add_pooling_layer` argument\n        \"\"\"\n        try:\n            if auto_model_class is None:\n                auto_model_class = AutoModel\n            # try to resolve the class\n            model_class = _get_model_class(config, auto_model_class._model_mapping)\n\n            if model_class:\n                sig = inspect.signature(model_class.__init__)\n                return \"add_pooling_layer\" in sig.parameters\n        except:\n            pass\n\n        return False\n\n    @staticmethod\n    def _patch_distilbert_v5():\n        \"\"\"\n        Patch the forward method of the DistilBertModel to use positional arguments instead of keyword arguments.\n        Transformers 5 version.\n        \"\"\"\n        # based on:\n        # https://github.com/huggingface/transformers/blob/v5.0.0rc1/src/transformers/models/distilbert/modeling_distilbert.py#L386\n        # original code from here on:\n        from transformers.masking_utils import create_bidirectional_mask\n\n        def forward(\n            self,\n            input_ids: Optional[torch.Tensor] = None,\n            attention_mask: Optional[torch.Tensor] = None,\n            inputs_embeds: Optional[torch.Tensor] = None,\n            position_ids: Optional[torch.Tensor] = None,\n            **kwargs,\n        ):\n            if (input_ids is None) ^ (inputs_embeds is not None):\n                raise ValueError(\n                    \"You must specify exactly one of input_ids or inputs_embeds\"\n                )\n\n            embeddings = self.embeddings(input_ids, inputs_embeds, position_ids)\n\n            attention_mask = create_bidirectional_mask(\n                config = self.config,\n                input_embeds = embeddings,\n                attention_mask = attention_mask,\n            )\n\n            # patch here: unsloth gradient checkpointing hook needs positional arguments\n            return self.transformer(\n                embeddings,\n                attention_mask,\n                **kwargs,\n            )\n\n        modeling_distilbert.DistilBertModel.forward = forward\n\n    @staticmethod\n    def _add_unsloth_tags(repo_id, token, tags = None):\n        \"\"\"\n        Add Unsloth and sentence-transformers tags to the Hugging Face Hub repository.\n        \"\"\"\n        from huggingface_hub import HfApi\n\n        api = HfApi(token = token)\n        if tags is None:\n            tags = []\n        tags.extend([\"unsloth\", \"sentence-transformers\"])\n        try:\n            api.add_tags(\n                repo_id = repo_id,\n                tags = tags,\n                repo_type = \"model\",\n            )\n        except:\n            pass\n\n    @staticmethod\n    def _add_unsloth_branding(save_directory):\n        \"\"\"\n        Add Unsloth branding to the README.md file generated by sentence-transformers.\n        \"\"\"\n        readme_path = os.path.join(save_directory, \"README.md\")\n        if not os.path.exists(readme_path):\n            return\n\n        with open(readme_path, \"r\", encoding = \"utf-8\") as f:\n            content = f.read()\n\n        # add unsloth tag to frontmatter\n        if \"---\\ntags:\\n\" in content:\n            content = content.replace(\"---\\ntags:\\n\", \"---\\ntags:\\n- unsloth\\n\")\n        else:\n            # if tags exist but not right at start, use regex to append\n            pattern = r\"(^tags:\\s*\\n)\"\n            if re.search(pattern, content, re.MULTILINE):\n                content = re.sub(\n                    pattern, r\"\\1- unsloth\\n\", content, count = 1, flags = re.MULTILINE\n                )\n\n        # add branding badge and text\n        branding = (\n            \"\\n\\nThis model was finetuned with [Unsloth](https://github.com/unslothai/unsloth).\\n\\n\"\n            '[<img src=\"https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20made%20with%20love.png\" width=\"200\"/>](https://github.com/unslothai/unsloth)\\n'\n        )\n\n        # add to description\n        if \"# SentenceTransformer\" in content:\n            parts = content.split(\"# SentenceTransformer\", 1)\n            content = parts[0] + \"# SentenceTransformer\" + branding + parts[1]\n        else:\n            content += branding\n\n        with open(readme_path, \"w\", encoding = \"utf-8\") as f:\n            f.write(content)\n\n    @staticmethod\n    def _module_path(model_name, token = None):\n        \"\"\"\n        Returns the path to the modules.json file or None\n        \"\"\"\n        try:\n            if os.path.exists(model_name) and os.path.isdir(model_name):\n                path = os.path.join(model_name, \"modules.json\")\n                return path if os.path.exists(path) else None\n            else:\n                try:\n                    return hf_hub_download(model_name, \"modules.json\", token = token)\n                except:\n                    return None\n        except:\n            return None\n\n    @staticmethod\n    def _create_transformer_module(\n        model_name,\n        model,\n        tokenizer,\n        max_seq_length,\n        trust_remote_code,\n    ):\n        \"\"\"Helper to create and configure a Transformer module.\"\"\"\n        from sentence_transformers.models import Transformer\n\n        # prevents sentence-transformers from loading the model a second time, thanks Etherl\n        original_from_pretrained = AutoModel.from_pretrained\n\n        def return_existing_model(*args, **kwargs):\n            return model\n\n        try:\n            # Temporarily redirect AutoModel loading to return our pre-loaded model\n            AutoModel.from_pretrained = return_existing_model\n\n            # Initialize Transformer\n            transformer_module = Transformer(\n                model_name,\n                max_seq_length = max_seq_length,\n                model_args = {\"trust_remote_code\": trust_remote_code},\n                config_args = {\"trust_remote_code\": trust_remote_code},\n            )\n        finally:\n            # Restore original functionality immediately\n            AutoModel.from_pretrained = original_from_pretrained\n\n        transformer_module.tokenizer = tokenizer\n        transformer_module.do_lower_case = getattr(tokenizer, \"do_lower_case\", False)\n\n        # sentence-transformers only passes along known keys to model.forward\n        model_forward_params = list(inspect.signature(model.forward).parameters)\n        transformer_module.model_forward_params = set(model_forward_params) | {\n            \"input_ids\",\n            \"attention_mask\",\n            \"token_type_ids\",\n            \"inputs_embeds\",\n        }\n\n        # determine max_seq_length if not provided\n        if max_seq_length is None:\n            if hasattr(model, \"config\") and hasattr(\n                model.config, \"max_position_embeddings\"\n            ):\n                max_seq_length = model.config.max_position_embeddings\n            elif hasattr(tokenizer, \"model_max_length\"):\n                max_seq_length = tokenizer.model_max_length\n            else:\n                max_seq_length = 512\n\n        transformer_module.max_seq_length = max_seq_length\n        transformer_module.config_keys = [\"max_seq_length\", \"do_lower_case\"]\n        transformer_module.save_in_root = True\n\n        if hasattr(model, \"config\"):\n            model.config.tokenizer_class = tokenizer.__class__.__name__\n\n        return transformer_module\n\n    @staticmethod\n    def _load_modules(\n        model_name,\n        token,\n        model,\n        tokenizer,\n        max_seq_length,\n        pooling_mode,\n        trust_remote_code = False,\n    ) -> tuple[OrderedDict, bool]:\n        \"\"\"\n        Load modules from modules.json if available, otherwise fallback to hard-coded modules.\n\n        Returns:\n            tuple[OrderedDict, bool]: (modules, no_modules_json)\n        \"\"\"\n        from sentence_transformers.util import import_from_string, load_dir_path\n        from sentence_transformers.models import Pooling, Normalize\n\n        modules = OrderedDict()\n        modules_json_path = FastSentenceTransformer._module_path(model_name, token)\n\n        if modules_json_path:\n            with open(modules_json_path, encoding = \"utf8\") as f:\n                modules_config = json.load(f)\n\n            for module_config in modules_config:\n                class_ref = module_config[\"type\"]\n                name = module_config.get(\n                    \"name\", str(module_config.get(\"idx\", len(modules)))\n                )\n\n                if class_ref == \"sentence_transformers.models.Transformer\":\n                    transformer_module = (\n                        FastSentenceTransformer._create_transformer_module(\n                            model_name,\n                            model,\n                            tokenizer,\n                            max_seq_length,\n                            trust_remote_code,\n                        )\n                    )\n                    modules[name] = transformer_module\n                else:\n                    # load other modules (Pooling, Normalize, etc.)\n                    module_path = module_config[\"path\"]\n                    if os.path.isdir(model_name):\n                        load_path = os.path.join(model_name, module_path)\n                    else:\n                        try:\n                            load_path = load_dir_path(\n                                model_name, module_path, token = token\n                            )\n                        except Exception as e:\n                            print(\n                                f\"Unsloth Warning: Could not download module {module_path}: {e}\"\n                            )\n                            continue\n\n                    module_class = import_from_string(class_ref)\n                    try:\n                        module = module_class.load(load_path)\n                        modules[name] = module\n                    except Exception as e:\n                        print(\n                            f\"Unsloth Warning: Failed to load module {name} ({class_ref}): {e}\"\n                        )\n\n            return modules, False\n\n        # fallback if no modules.json (non sentence-transformers models)\n        print(\n            \"Unsloth: No modules.json found, falling back to [Transformer, Pooling, Normalize]. This may or may not work.\"\n        )\n\n        transformer_module = FastSentenceTransformer._create_transformer_module(\n            model_name, model, tokenizer, max_seq_length, trust_remote_code\n        )\n        modules[\"0\"] = transformer_module\n\n        hidden_size = getattr(model.config, \"hidden_size\", 768)\n\n        if pooling_mode == \"mean\":\n            pooling_mode = FastSentenceTransformer._read_pooling_mode(model_name, token)\n\n        modules[\"1\"] = Pooling(\n            word_embedding_dimension = hidden_size, pooling_mode = pooling_mode\n        )\n        modules[\"2\"] = Normalize()\n\n        return modules, True\n\n    # Encoder model types that benefit from native torch.compile instead of Unsloth patching\n    ENCODER_MODEL_TYPES = {\n        \"mpnet\",\n        \"bert\",\n        \"distilbert\",\n        \"modernbert\",\n        \"roberta\",\n        \"xlm-roberta\",\n        \"albert\",\n        \"electra\",\n    }\n\n    @staticmethod\n    def _estimate_compile_threshold(\n        model,\n        batch_size = None,\n        grad_accum = None,\n        max_seq_length = None,\n    ):\n        \"\"\"\n        Estimate the minimum training steps needed for torch.compile to be beneficial.\n        Returns the threshold with a 1.2x safety margin built in.\n\n        Based on empirical benchmarks:\n        - Larger models have lower breakeven (more time saved per step)\n        - Warmup time scales with model size but speedup also increases\n\n        Optional inputs (batch_size, grad_accum, max_seq_length) allow\n        a coarse pre-run adjustment. These are intentionally conservative\n        and avoid any runtime measurements.\n        \"\"\"\n        # Get parameter count from inner model\n        if hasattr(model, \"__getitem__\"):\n            try:\n                inner = model[0].auto_model\n                params = sum(p.numel() for p in inner.parameters())\n            except:\n                params = 100_000_000  # Default to 100M if can't determine\n        else:\n            params = sum(p.numel() for p in model.parameters())\n\n        model_type = None\n        try:\n            if \"inner\" in locals():\n                model_type = getattr(getattr(inner, \"config\", None), \"model_type\", None)\n        except Exception:\n            model_type = None\n        if isinstance(model_type, str):\n            model_type = model_type.lower()\n\n        params_m = params / 1e6\n\n        # Empirical formula based on benchmarks with batch_size=2, grad_accum=4\n        # Small models: high fixed overhead, lower speedup\n        # Large models: warmup scales but speedup is significant\n        if params_m < 50:\n            estimated_warmup = 35 + params_m * 0.3\n            base_speedup = 1.35\n        elif params_m < 200:\n            estimated_warmup = 12 + params_m * 0.03\n            base_speedup = 1.75\n        else:\n            estimated_warmup = 15 + params_m * 0.04\n            base_speedup = 1.60\n\n        # Estimate time per step (ms) and time saved\n        naive_ms = 50 + params_m * 1.0\n        compiled_ms = naive_ms / base_speedup\n        time_saved_per_step_s = (naive_ms - compiled_ms) / 1000\n\n        if time_saved_per_step_s > 0:\n            breakeven = estimated_warmup / time_saved_per_step_s\n        else:\n            breakeven = float(\"inf\")\n\n        # Return threshold with 1.2x safety margin\n        threshold = breakeven * 1.2\n\n        # Optional adjustment based on expected work per step.\n        # This uses only pre-run information (batch size, grad accum, seq length).\n        generic_scale = 1.0\n        fast_scale = 1.0\n        if (\n            batch_size is not None\n            or grad_accum is not None\n            or max_seq_length is not None\n        ):\n            try:\n                bs = int(batch_size) if batch_size is not None else 2\n                ga = int(grad_accum) if grad_accum is not None else 4\n                seq = int(max_seq_length) if max_seq_length is not None else 512\n            except Exception:\n                bs, ga, seq = 2, 4, 512\n\n            bs = max(1, bs)\n            ga = max(1, ga)\n            # Guard against unbounded tokenizer.model_max_length\n            seq = max(64, min(seq, 8192))\n\n            ref_bs, ref_ga, ref_seq = 2, 4, 512\n\n            # Generic path: lighter scaling, less conservative than params-only.\n            ga_scale = (ref_ga / ga) ** 1.0\n            bs_seq_scale = ((ref_bs * ref_seq) / (bs * seq)) ** 0.15\n            generic_scale = 0.35 * ga_scale * bs_seq_scale\n            generic_scale = max(0.05, min(generic_scale, 5.0))\n\n            # Fast encoder path: stronger scaling based on observed behavior.\n            fast_ga_scale = (ref_ga / ga) ** 1.5\n            fast_bs_seq_scale = ((ref_bs * ref_seq) / (bs * seq)) ** 0.25\n            fast_scale = 0.2 * fast_ga_scale * fast_bs_seq_scale\n            fast_scale = max(0.05, min(fast_scale, 5.0))\n\n        # Conservative safety factors: generic is less conservative than fast.\n        generic_threshold = threshold * generic_scale * 1.25\n\n        is_fast_type = (\n            isinstance(model_type, str)\n            and model_type in FastSentenceTransformer.ENCODER_MODEL_TYPES\n        )\n        if is_fast_type:\n            fast_threshold = threshold * fast_scale * 1.5\n            # Prefer the smaller (less conservative) of the two estimates.\n            final_threshold = min(generic_threshold, fast_threshold)\n        else:\n            final_threshold = generic_threshold\n\n        # Reduce mpnet overestimation slightly.\n        if model_type == \"mpnet\":\n            final_threshold *= 0.7\n\n        # Lower bound to avoid compiling on extremely short runs.\n        return int(max(20, final_threshold))\n\n    @staticmethod\n    def _apply_torch_compile(model, mode = \"default\"):\n        \"\"\"\n        Apply torch.compile to a SentenceTransformer model.\n        Includes workaround for accelerate's unwrap_model bug.\n        \"\"\"\n        if hasattr(model, \"__getitem__\"):\n            inner_model = model[0].auto_model\n            compiled = torch.compile(inner_model, mode = mode)\n            model[0].auto_model = compiled\n            # Fix for accelerate unwrap_model bug:\n            # When SentenceTransformer contains a compiled inner model,\n            # accelerate checks has_compiled_regions() which returns True,\n            # then tries to access model.__dict__[\"_orig_mod\"] which fails.\n            # This workaround sets _orig_mod to satisfy accelerate.\n            model.__dict__[\"_orig_mod\"] = model\n        else:\n            model = torch.compile(model, mode = mode)\n        return model\n\n    @staticmethod\n    def from_pretrained(\n        model_name,\n        max_seq_length = None,\n        dtype = None,\n        load_in_4bit = False,  # Changed default: 4-bit is slow for encoders\n        load_in_8bit = False,\n        load_in_16bit = True,  # Changed default: 16-bit is optimal for encoders\n        full_finetuning = False,\n        token = None,\n        device_map = \"sequential\",\n        rope_scaling = None,\n        fix_tokenizer = True,\n        trust_remote_code = False,\n        use_gradient_checkpointing = False,  # Changed default: conflicts with torch.compile\n        resize_model_vocab = None,\n        revision = None,\n        use_exact_model_name = False,\n        offload_embedding = False,\n        random_state = 3407,\n        max_lora_rank = 64,\n        disable_log_stats = True,\n        qat_scheme = None,\n        unsloth_tiled_mlp = False,\n        pooling_mode = \"mean\",\n        for_inference = False,\n        **kwargs,\n    ):\n        try:\n            from sentence_transformers import SentenceTransformer\n            from sentence_transformers.models import Transformer, Pooling, Normalize\n        except ImportError:\n            raise ImportError(\n                \"Unsloth: To use `FastSentenceTransformer`, you must install `sentence-transformers`.\\n\"\n                \"Run `pip install sentence-transformers` to install it.\"\n            )\n\n        # if for_inference == True, skip Unsloth optimizations to avoid torch compile issues\n        if for_inference:\n            st_device = device_map\n            if isinstance(st_device, dict) or (\n                isinstance(st_device, str) and st_device in [\"auto\", \"sequential\"]\n            ):\n                st_device = None\n\n            # this was added because when loading for inference it was defaulting to float32\n            # propagate dtype to model_kwargs, default to \"auto\"\n            model_kwargs = kwargs.get(\"model_kwargs\", {})\n            model_kwargs[\"dtype\"] = dtype if dtype is not None else \"auto\"\n\n            # filter kwargs for SentenceTransformer\n            st_kwargs = {\n                \"device\": st_device,\n                \"trust_remote_code\": trust_remote_code,\n                \"token\": token,\n                \"revision\": revision,\n                \"model_kwargs\": model_kwargs,\n            }\n\n            # add other known kwargs if present\n            known_keys = [\n                \"cache_folder\",\n                \"truncate_dim\",\n                \"tokenizer_kwargs\",\n                \"config_kwargs\",\n            ]\n            for k in known_keys:\n                if k in kwargs:\n                    st_kwargs[k] = kwargs[k]\n\n            st_model = SentenceTransformer(model_name, **st_kwargs)\n            return st_model\n\n        # sanity check, thanks Etherl:\n        if full_finetuning and (load_in_4bit or load_in_8bit):\n            print(\n                \"Unsloth: You selected full finetuning support, but 4bit / 8bit is enabled - disabling LoRA / QLoRA.\"\n            )\n            load_in_4bit = False\n            load_in_8bit = False\n            load_in_fp8 = False\n            load_in_16bit = False\n\n        if int(load_in_4bit) + int(load_in_8bit) + int(load_in_16bit) >= 2:\n            raise RuntimeError(\n                \"Unsloth: Can only load in 4bit or 8bit or 16bit, not a combination!\\n\"\n                \"Also, we by default set `load_in_16bit = True`.\\n\"\n                \"If you want 4bit LoRA finetuning, set `load_in_16bit = False` and `load_in_4bit = True`\\n\"\n                \"If you want 8bit finetuning, set both `load_in_16bit = False` and `load_in_8bit = True`\"\n            )\n\n        if \"auto_model\" not in kwargs:\n            kwargs[\"auto_model\"] = AutoModel\n\n        transformers4 = Version(transformers.__version__).major < 5\n        model_type = \"\"\n        config = None\n        try:\n            config = AutoConfig.from_pretrained(\n                model_name, token = token, trust_remote_code = trust_remote_code\n            )\n            model_type = getattr(config, \"model_type\", \"\")\n        except:\n            pass\n\n        # Fast encoder path: Use native torch.compile for encoder models (6x speedup)\n        # This bypasses Unsloth's auto-compiler which adds @torch.compiler.disable decorators\n        # that interfere with torch.compile and cause runtime errors for encoder models.\n        # NOTE: The old Unsloth path is BROKEN for encoder models with torch 2.9+ due to\n        # conflicting @torch.compile and @torch.compiler.disable decorators.\n        # Set UNSLOTH_COMPILE_DISABLE=1 to disable torch.compile and use the old path.\n        is_encoder_model = (\n            model_type.lower() in FastSentenceTransformer.ENCODER_MODEL_TYPES\n        )\n        use_fast_encoder = os.environ.get(\"UNSLOTH_COMPILE_DISABLE\", \"0\") != \"1\"\n        if use_fast_encoder and is_encoder_model:\n            # torch.compile mode: \"default\" is safest for PEFT/LoRA training\n            # Note: \"reduce-overhead\" uses CUDA Graphs which is incompatible with PEFT\n            compile_mode = \"default\"\n\n            # Determine dtype - handle float16 machines that don't support bfloat16\n            if dtype is None:\n                if load_in_16bit:\n                    dtype = torch.float16 if not SUPPORTS_BFLOAT16 else torch.bfloat16\n                else:\n                    dtype = torch.float32\n            elif dtype == torch.bfloat16 and not SUPPORTS_BFLOAT16:\n                print(\n                    \"Unsloth: Device does not support bfloat16. Using float16 instead.\"\n                )\n                dtype = torch.float16\n\n            # Determine device\n            st_device = device_map\n            if isinstance(st_device, dict) or (\n                isinstance(st_device, str) and st_device in [\"auto\", \"sequential\"]\n            ):\n                st_device = \"cuda\"\n\n            # Check if model supports SDPA (Scaled Dot Product Attention) for extra speedup\n            supports_sdpa = False\n            if config is not None:\n                try:\n                    model_class = _get_model_class(\n                        config, kwargs.get(\"auto_model\", AutoModel)._model_mapping\n                    )\n                    supports_sdpa = getattr(model_class, \"_supports_sdpa\", False)\n                except:\n                    pass\n\n            # Build model_kwargs for SentenceTransformer\n            model_kwargs = {\"torch_dtype\": dtype}\n\n            # Enable SDPA if supported (1.2x extra speedup on top of torch.compile)\n            # But disable for models with known SDPA + torch.compile backward issues\n            _force_eager = False\n            for _sdpa_model in DISABLE_SDPA_MODEL_NAMES:\n                if _sdpa_model in model_type.lower():\n                    supports_sdpa = False\n                    _force_eager = True\n                    break\n            if supports_sdpa:\n                model_kwargs[\"attn_implementation\"] = \"sdpa\"\n            elif _force_eager:\n                model_kwargs[\"attn_implementation\"] = \"eager\"\n\n            # Print optimization status\n            sdpa_str = \" + SDPA\" if supports_sdpa else \"\"\n            if load_in_4bit:\n                print(\n                    f\"Unsloth: Using fast encoder path for {model_type} with 4-bit quantization{sdpa_str}\"\n                )\n            else:\n                print(\n                    f\"Unsloth: Using fast encoder path for {model_type} (torch.compile{sdpa_str})\"\n                )\n\n            # Handle 4-bit quantization via BitsAndBytesConfig\n            if load_in_4bit:\n                from transformers import BitsAndBytesConfig\n\n                bnb_config = BitsAndBytesConfig(\n                    load_in_4bit = True,\n                    bnb_4bit_compute_dtype = dtype,\n                    bnb_4bit_quant_type = \"nf4\",\n                    bnb_4bit_use_double_quant = True,\n                )\n                model_kwargs[\"quantization_config\"] = bnb_config\n                # When using quantization, device must be handled by accelerate\n                st_device = None\n\n            # Handle gradient checkpointing - warn user it conflicts with torch.compile\n            _use_gc = use_gradient_checkpointing\n            if _use_gc and _use_gc != False:\n                print(\n                    \"Unsloth Warning: Gradient checkpointing is incompatible with torch.compile.\"\n                )\n                print(\"Disabling torch.compile to enable gradient checkpointing.\")\n                compile_mode = None  # Disable compilation\n\n                is_mpnet = \"mpnet\" == model_type.lower()\n\n                if is_mpnet and transformers4:\n                    FastSentenceTransformer._patch_mpnet_v4()\n                elif is_mpnet:\n                    FastSentenceTransformer._patch_mpnet_v5()\n\n            # Load via native SentenceTransformer (bypasses Unsloth patching)\n            st_model = SentenceTransformer(\n                model_name,\n                device = st_device,\n                trust_remote_code = trust_remote_code,\n                token = token,\n                revision = revision,\n                model_kwargs = model_kwargs,\n            )\n\n            # Store metadata for get_peft_model\n            st_model._unsloth_fast_encoder = True\n            st_model._compile_mode = compile_mode\n            st_model._dtype = dtype\n            st_model._load_in_4bit = load_in_4bit\n            st_model.no_modules = False\n\n            # Add save methods\n            def _save_pretrained_merged(self, save_directory, **save_kwargs):\n                self.save_pretrained(save_directory)\n                tokenizer = save_kwargs.pop(\"tokenizer\", self.tokenizer)\n                if hasattr(self[0], \"auto_model\"):\n                    inner = self[0].auto_model\n                    # Handle compiled model\n                    if hasattr(inner, \"_orig_mod\"):\n                        inner = inner._orig_mod\n                    if hasattr(inner, \"merge_and_unload\"):\n                        merged = inner.merge_and_unload()\n                        merged.save_pretrained(save_directory)\n                    elif hasattr(inner, \"save_pretrained\"):\n                        inner.save_pretrained(save_directory)\n                if tokenizer is not None:\n                    tokenizer.save_pretrained(save_directory)\n                FastSentenceTransformer._add_unsloth_branding(save_directory)\n\n            st_model.save_pretrained_merged = types.MethodType(\n                _save_pretrained_merged, st_model\n            )\n\n            st_model.save_pretrained_torchao = types.MethodType(\n                _save_pretrained_torchao, st_model\n            )\n\n            st_model.save_pretrained_gguf = types.MethodType(\n                _save_pretrained_gguf, st_model\n            )\n\n            st_model.push_to_hub_gguf = types.MethodType(_push_to_hub_gguf, st_model)\n\n            def _push_to_hub_merged(self, repo_id, **push_kwargs):\n                hub_token = push_kwargs.get(\"token\", None) or get_token()\n                if hub_token is None:\n                    raise ValueError(\"No HF token provided\")\n                api = HfApi(token = hub_token)\n                try:\n                    api.create_repo(\n                        repo_id = repo_id,\n                        private = push_kwargs.get(\"private\"),\n                        exist_ok = True,\n                        repo_type = \"model\",\n                    )\n                except:\n                    pass\n                FastSentenceTransformer._add_unsloth_tags(repo_id, hub_token)\n                with tempfile.TemporaryDirectory() as temp_dir:\n                    self.save_pretrained_merged(temp_dir, **push_kwargs)\n                    api.upload_folder(\n                        folder_path = temp_dir,\n                        repo_id = repo_id,\n                        commit_message = push_kwargs.get(\n                            \"commit_message\", \"Upload model\"\n                        ),\n                    )\n                print(f\"Unsloth: Pushed to https://huggingface.co/{repo_id}\")\n\n            st_model.push_to_hub_merged = types.MethodType(\n                _push_to_hub_merged, st_model\n            )\n\n            return st_model\n\n        # Warn if using 4-bit with encoder (slow due to dequantization overhead)\n        if is_encoder_model and load_in_4bit:\n            print(\n                \"Unsloth Warning: 4-bit quantization adds ~2.3x overhead for encoder models.\"\n            )\n            print(\"Consider using load_in_16bit=True for better performance.\")\n\n        # check if the model supports add_pooling_layer\n        if \"add_pooling_layer\" not in kwargs:\n            supported = FastSentenceTransformer._has_add_pooling_layer(\n                config, kwargs.get(\"auto_model\", AutoModel)\n            )\n            if supported:\n                kwargs[\"add_pooling_layer\"] = False\n\n        # forces fp8 to be False since it's not supported\n        fp8 = kwargs.pop(\"load_in_fp8\", None)\n        if fp8:\n            logging.info(\"Unsloth: Disabling fp8 for model\")\n        load_in_fp8 = False\n\n        # this is a fix for Snowflake/snowflake-arctic-embed-l-v2.0\n        # it has pooler weights which we don't care about for training,\n        # however unsloth throws an exception if \"UNSLOTH_WARN_UNINITIALIZED\" == 1 and it sees unused weights\n        old_environ = os.environ.get(\"UNSLOTH_WARN_UNINITIALIZED\", \"1\")\n        os.environ[\"UNSLOTH_WARN_UNINITIALIZED\"] = \"0\"\n\n        is_distilbert = \"distilbert\" == model_type.lower()\n        is_mpnet = \"mpnet\" == model_type.lower()\n\n        if is_distilbert and transformers4:\n            FastSentenceTransformer._patch_distilbert_v4()\n        elif is_distilbert:\n            FastSentenceTransformer._patch_distilbert_v5()\n        elif is_mpnet and transformers4:\n            FastSentenceTransformer._patch_mpnet_v4()\n        elif is_mpnet:\n            FastSentenceTransformer._patch_mpnet_v5()\n\n        # check if modules.json exists - if not, force 16-bit training\n        # why? because i have to implement saving myself for these models, and i don't feel like adding dequantization\n        # to the save_pretrained_merged for a model that really should be trained in 16-bit anyway\n        has_modules_json = (\n            FastSentenceTransformer._module_path(model_name, token) is not None\n        )\n\n        if not has_modules_json and load_in_4bit:\n            print(\n                \"Unsloth: No modules.json found. This is not a sentence-transformers model.\\n\"\n                \"Forcing 16-bit loading to simplify merged model saving.\"\n            )\n            load_in_4bit = False\n            load_in_16bit = True\n\n        try:\n            model, tokenizer = FastModel.from_pretrained(\n                model_name = model_name,\n                max_seq_length = max_seq_length,\n                dtype = dtype,\n                load_in_4bit = load_in_4bit,\n                load_in_8bit = load_in_8bit,\n                load_in_16bit = load_in_16bit,\n                full_finetuning = full_finetuning,\n                token = token,\n                device_map = device_map,\n                rope_scaling = rope_scaling,\n                fix_tokenizer = fix_tokenizer,\n                trust_remote_code = trust_remote_code,\n                use_gradient_checkpointing = use_gradient_checkpointing,\n                resize_model_vocab = resize_model_vocab,\n                revision = revision,\n                return_logits = False,\n                use_exact_model_name = use_exact_model_name,\n                offload_embedding = offload_embedding,\n                random_state = random_state,\n                max_lora_rank = max_lora_rank,\n                disable_log_stats = disable_log_stats,\n                qat_scheme = qat_scheme,\n                load_in_fp8 = load_in_fp8,\n                unsloth_tiled_mlp = unsloth_tiled_mlp,\n                **kwargs,\n            )\n        finally:\n            os.environ[\"UNSLOTH_WARN_UNINITIALIZED\"] = old_environ\n\n        # try to load modules, otherwise fallback to old hard-coded modules\n        from sentence_transformers import SentenceTransformer\n\n        modules, no_modules = FastSentenceTransformer._load_modules(\n            model_name,\n            token,\n            model,\n            tokenizer,\n            max_seq_length,\n            pooling_mode,\n            trust_remote_code = trust_remote_code,\n        )\n\n        st_device = device_map\n        if isinstance(st_device, dict) or (\n            isinstance(st_device, str) and st_device in [\"auto\", \"sequential\"]\n        ):\n            st_device = None\n\n        st_model = SentenceTransformer(modules = modules, device = st_device)\n        st_model.no_modules = no_modules\n\n        def _save_pretrained_merged(self, save_directory, **kwargs):\n            # check which adapter files exist before save_pretrained\n            adapter_files = [\"adapter_model.safetensors\", \"adapter_config.json\"]\n            existing_before = {\n                f\n                for f in adapter_files\n                if os.path.exists(os.path.join(save_directory, f))\n            }\n\n            # sentence-transformers config and modules only get saved if we call save_pretrained\n            self.save_pretrained(save_directory)\n\n            # remove LoRA adapters only if they were created by save_pretrained (not pre-existing)\n            for file in adapter_files:\n                if file not in existing_before:\n                    try:\n                        os.remove(os.path.join(save_directory, file))\n                    except:\n                        pass\n\n            tokenizer = kwargs.pop(\"tokenizer\", self.tokenizer)\n            if self.no_modules:\n                # fallback for non-sentence-transformers models\n                print(\n                    \"Unsloth: No modules detected. Using standard merge_and_unload for saving...\"\n                )\n                safe_kwargs = kwargs.copy()\n                # filter out Unsloth-specific args that are not in huggingface's save_pretrained\n                unsloth_args = [\n                    \"save_method\",\n                    \"temporary_location\",\n                    \"maximum_memory_usage\",\n                ]\n                for k in unsloth_args:\n                    safe_kwargs.pop(k, None)\n\n                merged_model = self[0].auto_model.merge_and_unload()\n                merged_model.save_pretrained(save_directory, **safe_kwargs)\n                if tokenizer is not None:\n                    tokenizer.save_pretrained(save_directory)\n            else:\n                self[0].auto_model.save_pretrained_merged(\n                    save_directory, tokenizer = tokenizer, **kwargs\n                )\n\n            # add Unsloth branding to the generated README\n            try:\n                FastSentenceTransformer._add_unsloth_branding(save_directory)\n            except Exception as e:\n                print(f\"Unsloth Warning: Failed to add branding to README: {e}\")\n\n        st_model.save_pretrained_merged = types.MethodType(\n            _save_pretrained_merged, st_model\n        )\n\n        st_model.save_pretrained_torchao = types.MethodType(\n            _save_pretrained_torchao, st_model\n        )\n\n        st_model.save_pretrained_gguf = types.MethodType(\n            _save_pretrained_gguf, st_model\n        )\n\n        st_model.push_to_hub_gguf = types.MethodType(_push_to_hub_gguf, st_model)\n\n        def _push_to_hub_merged(self, repo_id, **kwargs):\n            token = kwargs.get(\"token\", None) or get_token()\n            if token is None:\n                raise ValueError(\n                    \"No HF token provided. Please provide a token or login with `hf auth login`\"\n                )\n            private = kwargs.get(\"private\", None)\n            commit_message = kwargs.get(\"commit_message\", \"Upload model\")\n\n            from huggingface_hub import HfApi\n\n            api = HfApi(token = token)\n            try:\n                api.create_repo(\n                    repo_id = repo_id,\n                    private = private,\n                    exist_ok = True,\n                    repo_type = \"model\",\n                )\n            except:\n                pass\n\n            # order doesn't seem to matter for this after repo creation...\n            FastSentenceTransformer._add_unsloth_tags(repo_id, token)\n\n            with tempfile.TemporaryDirectory() as temp_dir:\n                self.save_pretrained_merged(temp_dir, **kwargs)\n                api.upload_folder(\n                    folder_path = temp_dir,\n                    repo_id = repo_id,\n                    commit_message = commit_message,\n                )\n            print(\n                f\"Unsloth: Successfully pushed merged model to https://huggingface.co/{repo_id}\"\n            )\n\n        st_model.push_to_hub_merged = types.MethodType(_push_to_hub_merged, st_model)\n        return st_model\n\n    @staticmethod\n    def get_peft_model(\n        model,\n        r = 16,\n        target_modules = [\n            \"query\",\n            \"key\",\n            \"value\",\n            \"dense\",\n        ],\n        lora_alpha = 16,\n        lora_dropout = 0.0,\n        bias = \"none\",\n        layers_to_transform = None,\n        layers_pattern = None,\n        use_gradient_checkpointing = False,  # Changed default: conflicts with torch.compile\n        random_state = 3407,\n        max_seq_length = 2048,\n        use_rslora = False,\n        modules_to_save = None,\n        init_lora_weights = True,\n        loftq_config = {},\n        **kwargs,\n    ):\n        from sentence_transformers import SentenceTransformer\n        from peft import LoraConfig, get_peft_model as peft_get_peft_model\n\n        if \"task_type\" not in kwargs:\n            kwargs[\"task_type\"] = \"FEATURE_EXTRACTION\"\n            print(\"Setting task_type to FEATURE_EXTRACTION\")\n\n        if isinstance(model, SentenceTransformer):\n            # Check if this is a fast encoder model (uses torch.compile instead of Unsloth patching)\n            is_fast_encoder = getattr(model, \"_unsloth_fast_encoder\", False)\n\n            if is_fast_encoder:\n                # Fast encoder path: Use native PEFT + torch.compile (6x speedup)\n                transformer_module = model[0]\n                inner_model = transformer_module.auto_model\n\n                # Check if model is quantized (4-bit/8-bit)\n                is_quantized = (\n                    getattr(inner_model, \"is_quantized\", False)\n                    or getattr(inner_model.config, \"quantization_config\", None)\n                    is not None\n                )\n\n                # Track if gradient checkpointing was actually enabled\n                gc_enabled = False\n\n                # this is needed when from_pretrained was called without gradient\n                # checkpointing but get_peft_model requests it\n                if use_gradient_checkpointing and use_gradient_checkpointing != False:\n                    import transformers\n                    from packaging.version import Version\n\n                    transformers4 = Version(transformers.__version__).major < 5\n                    model_type = getattr(inner_model.config, \"model_type\", \"\").lower()\n\n                    if model_type == \"mpnet\" and transformers4:\n                        FastSentenceTransformer._patch_mpnet_v4()\n                    elif model_type == \"mpnet\":\n                        FastSentenceTransformer._patch_mpnet_v5()\n\n                # Prepare for k-bit training if quantized\n                if is_quantized:\n                    from ._utils import prepare_model_for_kbit_training\n\n                    _gc_for_kbit = (\n                        use_gradient_checkpointing\n                        if use_gradient_checkpointing\n                        else False\n                    )\n                    try:\n                        inner_model = prepare_model_for_kbit_training(\n                            inner_model,\n                            use_gradient_checkpointing = _gc_for_kbit,\n                        )\n                        print(\"Unsloth: Prepared quantized model for k-bit training\")\n                        gc_enabled = bool(_gc_for_kbit)\n                    except ValueError as e:\n                        if \"does not support gradient checkpointing\" in str(e):\n                            # Model doesn't support gradient checkpointing, disable it\n                            print(\n                                f\"Unsloth Warning: {inner_model.__class__.__name__} does not support gradient checkpointing. Skipping.\"\n                            )\n                            inner_model = prepare_model_for_kbit_training(\n                                inner_model,\n                                use_gradient_checkpointing = False,\n                            )\n                            print(\n                                \"Unsloth: Prepared quantized model for k-bit training (without gradient checkpointing)\"\n                            )\n                        else:\n                            raise\n\n                # Enable gradient checkpointing if requested (only for non-quantized, since prepare_model handles it)\n                elif use_gradient_checkpointing and use_gradient_checkpointing != False:\n                    if hasattr(inner_model, \"gradient_checkpointing_enable\"):\n                        try:\n                            inner_model.gradient_checkpointing_enable()\n                            print(\"Unsloth: Enabled gradient checkpointing\")\n                            gc_enabled = True\n                        except ValueError as e:\n                            if \"does not support gradient checkpointing\" in str(e):\n                                print(\n                                    f\"Unsloth Warning: {inner_model.__class__.__name__} does not support gradient checkpointing. Skipping.\"\n                                )\n\n                # Create LoRA config\n                lora_config = LoraConfig(\n                    r = r,\n                    lora_alpha = lora_alpha,\n                    target_modules = target_modules,\n                    lora_dropout = lora_dropout,\n                    bias = bias,\n                    task_type = kwargs.get(\"task_type\", \"FEATURE_EXTRACTION\"),\n                )\n\n                # Apply PEFT directly (not through FastModel)\n                peft_model = peft_get_peft_model(inner_model, lora_config)\n\n                # Apply QAT if specified\n                qat_scheme = kwargs.get(\"qat_scheme\", None)\n                if qat_scheme is not None:\n                    from ._utils import _prepare_model_for_qat\n\n                    peft_model = _prepare_model_for_qat(peft_model, qat_scheme)\n\n                # Determine compile mode (only if not using gradient checkpointing)\n                compile_mode = getattr(model, \"_compile_mode\", \"default\")\n                # Re-enable torch.compile if gradient checkpointing was requested but couldn't be enabled\n                if compile_mode is None and not gc_enabled:\n                    compile_mode = \"default\"\n                    print(\n                        \"Unsloth: Re-enabling torch.compile since gradient checkpointing is not supported\"\n                    )\n\n                # Re-assign the peft model back to the transformer module\n                transformer_module.auto_model = peft_model\n\n                # Store compile info for auto-compile at trainer time\n                # torch.compile is deferred until training starts so we can check max_steps\n                if compile_mode is not None:\n                    model._compile_mode = compile_mode\n                    model._compile_threshold = (\n                        FastSentenceTransformer._estimate_compile_threshold(model)\n                    )\n                    # Flag to indicate compile has not been applied yet\n                    model._compile_pending = True\n                    print(\n                        f\"Unsloth: torch.compile will be applied automatically if max_steps > {model._compile_threshold}\"\n                    )\n                else:\n                    model._compile_mode = None\n                    model._compile_pending = False\n                    print(\n                        \"Unsloth: torch.compile disabled (gradient checkpointing enabled)\"\n                    )\n\n                return model\n\n            # Original path for non-fast-encoder models\n            transformer_module = model[0]\n            inner_model = transformer_module.auto_model\n\n            peft_model = FastModel.get_peft_model(\n                model = inner_model,\n                r = r,\n                target_modules = target_modules,\n                lora_alpha = lora_alpha,\n                lora_dropout = lora_dropout,\n                bias = bias,\n                layers_to_transform = layers_to_transform,\n                layers_pattern = layers_pattern,\n                use_gradient_checkpointing = use_gradient_checkpointing,\n                random_state = random_state,\n                max_seq_length = max_seq_length,\n                use_rslora = use_rslora,\n                modules_to_save = modules_to_save,\n                init_lora_weights = init_lora_weights,\n                loftq_config = loftq_config,\n                **kwargs,\n            )\n\n            # re-assign the peft model back to the transformer module\n            transformer_module.auto_model = peft_model\n            return model\n        else:\n            return FastModel.get_peft_model(\n                model = model,\n                r = r,\n                target_modules = target_modules,\n                lora_alpha = lora_alpha,\n                lora_dropout = lora_dropout,\n                bias = bias,\n                layers_to_transform = layers_to_transform,\n                layers_pattern = layers_pattern,\n                use_gradient_checkpointing = use_gradient_checkpointing,\n                random_state = random_state,\n                max_seq_length = max_seq_length,\n                use_rslora = use_rslora,\n                modules_to_save = modules_to_save,\n                init_lora_weights = init_lora_weights,\n                loftq_config = loftq_config,\n                **kwargs,\n            )\n\n\ndef _patch_sentence_transformer_trainer():\n    \"\"\"\n    Patch SentenceTransformerTrainer to automatically apply torch.compile\n    when training steps exceed the breakeven threshold.\n\n    This is called automatically when this module is imported.\n    \"\"\"\n    try:\n        from sentence_transformers import SentenceTransformerTrainer\n    except ImportError:\n        return  # sentence_transformers not installed\n\n    if getattr(SentenceTransformerTrainer, \"_unsloth_auto_compile_patched\", False):\n        return  # Already patched\n\n    from functools import wraps\n\n    _original_init = SentenceTransformerTrainer.__init__\n\n    @wraps(_original_init)\n    def _patched_init(self, *args, **kwargs):\n        # Extract model and training_args\n        model = kwargs.get(\"model\") or (args[0] if args else None)\n        training_args = kwargs.get(\"args\") or (args[1] if len(args) > 1 else None)\n\n        # Check if model has pending compile\n        if (\n            model is not None\n            and training_args is not None\n            and getattr(model, \"_compile_pending\", False)\n        ):\n            max_steps = getattr(training_args, \"max_steps\", -1)\n            compile_mode = getattr(model, \"_compile_mode\", \"default\")\n\n            # Re-estimate threshold now that training args are available\n            batch_size = getattr(training_args, \"per_device_train_batch_size\", None)\n            grad_accum = getattr(training_args, \"gradient_accumulation_steps\", None)\n            max_seq_length = getattr(model, \"max_seq_length\", None)\n            if max_seq_length is None and hasattr(model, \"__getitem__\"):\n                try:\n                    max_seq_length = getattr(model[0], \"max_seq_length\", None)\n                except Exception:\n                    max_seq_length = None\n            if max_seq_length is None:\n                tokenizer = getattr(model, \"tokenizer\", None)\n                max_seq_length = (\n                    getattr(tokenizer, \"model_max_length\", None)\n                    if tokenizer is not None\n                    else None\n                )\n\n            threshold = FastSentenceTransformer._estimate_compile_threshold(\n                model,\n                batch_size = batch_size,\n                grad_accum = grad_accum,\n                max_seq_length = max_seq_length,\n            )\n            model._compile_threshold = threshold\n\n            if max_steps > 0 and max_steps >= threshold:\n                print(\n                    f\"Unsloth: Auto-compiling model ({max_steps} steps >= {threshold} threshold)\"\n                )\n                FastSentenceTransformer._apply_torch_compile(model, mode = compile_mode)\n                model._compile_pending = False\n            elif max_steps > 0:\n                print(\n                    f\"Unsloth: Skipping torch.compile ({max_steps} steps < {threshold} threshold)\"\n                )\n                model._compile_pending = False\n\n        # Call original __init__\n        _original_init(self, *args, **kwargs)\n\n        # Disable mixed precision when FORCE_FLOAT32 is active (matches rl.py behavior)\n        if os.environ.get(\"UNSLOTH_FORCE_FLOAT32\", \"0\") == \"1\":\n            if hasattr(self, \"args\") and self.args is not None:\n                if self.args.fp16 or self.args.bf16:\n                    print(\n                        \"Unsloth: Switching to float32 training since model cannot work with float16\"\n                    )\n                    self.args.fp16 = False\n                    self.args.bf16 = False\n                    if hasattr(self.args, \"bf16_full_eval\"):\n                        self.args.bf16_full_eval = False\n                    if hasattr(self.args, \"fp16_full_eval\"):\n                        self.args.fp16_full_eval = False\n\n    SentenceTransformerTrainer.__init__ = _patched_init\n    SentenceTransformerTrainer._unsloth_auto_compile_patched = True\n\n\n# Auto-patch trainer on module import\n_patch_sentence_transformer_trainer()\n"
  },
  {
    "path": "unsloth/models/vision.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nfrom transformers import (\n    BitsAndBytesConfig,\n    AutoProcessor,\n    AutoTokenizer,\n    AutoModelForCausalLM,\n)\n\ntry:\n    from transformers import AutoModelForImageTextToText\n\n    AutoModelForVision2Seq = AutoModelForImageTextToText\nexcept:\n    from transformers import AutoModelForVision2Seq\nfrom ..kernels import (\n    post_patch_loss_function,\n)\nfrom ._utils import __version__, importlib_version, _prepare_model_for_qat\nfrom ._utils import *\nfrom .loader_utils import _get_fp8_mode_and_check_settings\nfrom ..save import patch_saving_functions\nfrom ..models.loader_utils import is_distributed\nfrom unsloth_zoo.gradient_checkpointing import (\n    unpatch_unsloth_gradient_checkpointing,\n    unpatch_unsloth_smart_gradient_checkpointing,\n)\nimport torch.utils.checkpoint as torch_checkpoint\nimport transformers.modeling_utils as hf_modeling_utils\nfrom peft import LoraConfig, TaskType, get_peft_model as _get_peft_model\nfrom peft import PeftModelForCausalLM\nfrom transformers import set_seed as transformers_set_seed\nfrom unsloth_zoo.peft_utils import (\n    get_peft_regex,\n    SKIP_QUANTIZATION_MODULES,\n    requires_grad_for_gradient_checkpointing,\n)\nfrom transformers.models.llama.modeling_llama import logger\nfrom transformers import __version__ as transformers_version\nfrom triton import __version__ as triton_version\nfrom unsloth_zoo.utils import _get_dtype\nfrom unsloth_zoo.hf_utils import (\n    dtype_from_config,\n    add_dtype_kwargs,\n    fix_lora_auto_mapping,\n    get_auto_processor,\n)\nfrom unsloth_zoo.patching_utils import patch_model_and_tokenizer\nfrom unsloth_zoo.training_utils import prepare_model_for_training\n\nfrom unsloth_zoo.utils import Version\nfrom transformers import __version__ as transformers_version\n\nimport types\nimport functools\nimport os\nimport gc\nimport math\nfrom typing import Optional, Tuple, List, Union\nimport re, inspect, sys\nimport contextlib\n\ntry:\n    from huggingface_hub.utils import get_token\nexcept:\n    # Old HF Hub versions <= 0.0.25\n    from huggingface_hub.utils._token import get_token\nfrom ..device_type import (\n    is_hip,\n    get_device_type,\n    DEVICE_TYPE,\n    DEVICE_TYPE_TORCH,\n    DEVICE_COUNT,\n    ALLOW_PREQUANTIZED_MODELS,\n)\n\n__all__ = [\n    \"FastBaseModel\",\n]\n\nglobal NUM_LOGITS_TO_KEEP\nNUM_LOGITS_TO_KEEP = dict()\n\nVLLM_SUPPORTED_VLM = [\n    \"qwen2_5_vl\",\n    \"gemma3\",\n    \"mistral3\",\n    \"qwen3_vl\",\n    \"qwen3_vl_moe\",\n]\nVLLM_NON_LORA_VLM = [\n    \"mllama\",\n]\nPRE_COMPILE_INFERENCE = [\n    \"gpt_oss\",\n]\n\nfrom transformers import GenerationConfig, CompileConfig, AutoConfig\n\ntry:\n    from transformers import PreTrainedConfig\n\n    PretrainedConfig = PreTrainedConfig\nexcept:\n    from transformers import PretrainedConfig\n\nHAS_TORCH_DTYPE = \"torch_dtype\" in PretrainedConfig.__doc__\n\n_compile_config = CompileConfig(\n    fullgraph = False,\n    dynamic = None,\n    mode = \"reduce-overhead\",\n)\n_compile_config.disable = True  # Must set manually\n\ntry:\n    torch_compiler_set_stance = torch.compiler.set_stance\nexcept:\n    torch_compiler_set_stance = None\n\n\ndef unsloth_base_fast_generate(\n    self,\n    *args,\n    **kwargs,\n):\n    if len(args) != 0:\n        input_ids = args[0]\n    elif \"input_ids\" in kwargs:\n        input_ids = kwargs[\"input_ids\"]\n    elif \"input\" in kwargs:\n        input_ids = kwargs[\"input\"]\n    elif \"input_features\" in kwargs:\n        input_ids = kwargs[\"input_features\"]\n    elif \"input_embeds\" in kwargs:\n        input_ids = kwargs[\"input_embeds\"]\n    elif \"inputs\" in kwargs:\n        input_ids = kwargs[\"inputs\"]\n    else:\n        key = next(iter(kwargs.keys()))\n        if type(kwargs[key]) is not torch.Tensor:\n            raise TypeError(\"Unsloth: You need to pass in input_ids to .generate!\")\n        input_ids = kwargs[key]\n    assert type(input_ids) is torch.Tensor\n    bsz = input_ids.shape[0]\n\n    FastBaseModel.for_inference(self)\n    dtype = _get_dtype(dtype_from_config(self.config))\n    # Handle full float32 cases as config.dtype == torch.float32!\n    do_bfloat16_mixed_precision = (\n        os.environ.get(\"UNSLOTH_BFLOAT16_MIXED_PRECISION\", \"0\") == \"1\"\n    )\n    if do_bfloat16_mixed_precision:\n        dtype = torch.bfloat16\n\n    # Check if VLM\n    is_vlm = any(\n        x.endswith((\"ForConditionalGeneration\", \"ForVisionText2Text\"))\n        for x in self.config.architectures\n    )\n    is_vlm = is_vlm or hasattr(self.config, \"vision_config\")\n    arch = self.config.architectures[0]\n\n    # Remove token_type_ids - WRONG for Gemma 3 since bidirectional attention\n    if hasattr(self, \"generate\") and hasattr(self, \"forward\"):\n        # did not combine with below since self might not have model\n        keys = inspect.signature(self.forward).parameters.keys()\n        if \"token_type_ids\" not in keys:\n            kwargs.pop(\"token_type_ids\", None)\n    # kwargs.pop(\"token_type_ids\", None)\n\n    # VLMs do not allow logits_to_keep\n    global NUM_LOGITS_TO_KEEP\n    if arch not in NUM_LOGITS_TO_KEEP:\n        m = self\n        # Find which is needed ie\n        # num_logits_to_keep or logits_to_keep\n        while hasattr(m, \"model\"):\n            if hasattr(m, \"forward\"):\n                keys = inspect.signature(m.forward).parameters.keys()\n                if \"num_logits_to_keep\" in keys:\n                    NUM_LOGITS_TO_KEEP[arch] = \"num_logits_to_keep\"\n                    break\n                elif \"logits_to_keep\" in keys:\n                    NUM_LOGITS_TO_KEEP[arch] = \"logits_to_keep\"\n                    break\n            m = m.model\n        if arch not in NUM_LOGITS_TO_KEEP:\n            NUM_LOGITS_TO_KEEP[arch] = None\n    key = NUM_LOGITS_TO_KEEP[arch]\n    if key is not None and key not in kwargs:\n        kwargs[key] = 1\n\n    # Check pad_token\n    model_eos_token_id = getattr(self.config, \"eos_token_id\", None)\n    if model_eos_token_id is not None and hasattr(model_eos_token_id, \"__iter__\"):\n        model_eos_token_id = model_eos_token_id[0]\n\n    kwargs[\"pad_token_id\"] = kwargs.pop(\"pad_token_id\", model_eos_token_id)\n\n    # Get pixel values for VLMs\n    try:\n        kwargs[\"pixel_values\"] = kwargs[\"pixel_values\"].to(dtype)\n    except:\n        pass\n\n    # Mixed precision autocast\n    if os.environ.get(\"UNSLOTH_FORCE_FLOAT32\", \"0\") == \"1\":\n        autocaster = torch.autocast(device_type = DEVICE_TYPE_TORCH, dtype = torch.float16)\n        dtype = torch.float16\n    else:\n        autocaster = torch.autocast(device_type = DEVICE_TYPE_TORCH, dtype = dtype)\n    # Prepare LoRA\n    # state_dict = convert_lora_modules(self, dtype = dtype)\n\n    # Set compile dynamic shapes\n    torch._dynamo.mark_static(input_ids, 0)\n    torch._dynamo.mark_dynamic(input_ids, 1)\n    if \"attention_mask\" in kwargs:\n        torch._dynamo.mark_static(kwargs[\"attention_mask\"], 0)\n        torch._dynamo.mark_dynamic(kwargs[\"attention_mask\"], 1)\n    if \"token_type_ids\" in kwargs:\n        torch._dynamo.mark_static(kwargs[\"token_type_ids\"], 0)\n        torch._dynamo.mark_dynamic(kwargs[\"token_type_ids\"], 1)\n\n    # Fix generation_config\n    # Use hybrid if sliding window seen, otherwise try static\n    cache_implementation = getattr(self.config, \"cache_implementation\", None)\n    if getattr(\n        self, \"_supports_static_cache\", getattr(self, \"_can_compile_fullgraph\", True)\n    ):\n        if os.environ.get(\"UNSLOTH_DISABLE_STATIC_GENERATION\", \"0\") == \"0\":\n            cache_implementation = \"static\"\n        elif Version(transformers_version) < Version(\"4.56.0.dev0\"):\n            cache_implementation = None\n        else:\n            # Should work in latest transformers!\n            cache_implementation = \"static\"\n    else:\n        cache_implementation = None\n    if cache_implementation is not None:\n        swa = getattr(\n            getattr(self.config, \"text_config\", self.config), \"sliding_window\", None\n        )\n        if (swa == 0 or type(swa) is not int) and (\n            getattr(self, \"_can_compile_fullgraph\", True) is True\n        ):\n            cache_implementation = \"static\"\n        else:\n            if Version(transformers_version) < Version(\"4.56.0.dev0\"):\n                cache_implementation = \"hybrid\"\n            else:\n                cache_implementation = \"static\"\n    # [TODO] Unsure why static fails\n    if do_bfloat16_mixed_precision:\n        cache_implementation = None\n\n    if \"generation_config\" in kwargs:\n        kwargs[\"generation_config\"].cache_implementation = cache_implementation\n        if cache_implementation is not None:\n            kwargs[\"generation_config\"].compile_config = _compile_config\n    else:\n        kwargs[\"cache_implementation\"] = cache_implementation\n        if cache_implementation is not None:\n            kwargs[\"compile_config\"] = _compile_config\n\n    # Delete cached Flex Attention masks to reset inference\n    for name, module in self.named_modules():\n        if hasattr(module, \"_flex_attention_cache\"):\n            try:\n                del module._flex_attention_cache\n            except:\n                pass\n        # Solves AttributeError: 'SlidingWindowLayer' object has no attribute 'max_batch_size'\n        if hasattr(module, \"_cache\") and \"cache_utils\" in str(module._cache.__class__):\n            try:\n                del module._cache\n            except:\n                pass\n\n    # DO INFERENCE\n    with torch.inference_mode(), autocaster:\n        output = self._old_generate(*args, **kwargs)\n\n    # Delete cached Flex Attention masks to reset inference\n    for name, module in self.named_modules():\n        if hasattr(module, \"_flex_attention_cache\"):\n            try:\n                del module._flex_attention_cache\n            except:\n                pass\n        # Solves AttributeError: 'SlidingWindowLayer' object has no attribute 'max_batch_size'\n        if hasattr(module, \"_cache\") and \"cache_utils\" in str(module._cache.__class__):\n            try:\n                del module._cache\n            except:\n                pass\n\n    # FastBaseModel.for_training(self)\n    return output\n\n\ndef _construct_vlm_processor_fallback(\n    tokenizer_name, model_type, token, trust_remote_code\n):\n    \"\"\"Construct a VLM processor manually when AutoProcessor.from_pretrained fails.\n\n    Some VLMs (e.g., LFM2.5-VL) have tokenizer_class entries that AutoTokenizer\n    cannot resolve. This function loads the image processor and tokenizer separately,\n    sets required special token attributes, and constructs the processor.\n    \"\"\"\n    try:\n        from transformers import AutoImageProcessor, PreTrainedTokenizerFast, AutoConfig\n        from transformers.models.auto.processing_auto import PROCESSOR_MAPPING_NAMES\n        import json\n\n        # Load image processor\n        image_processor = AutoImageProcessor.from_pretrained(\n            tokenizer_name,\n            token = token,\n            trust_remote_code = trust_remote_code,\n        )\n        # Load tokenizer via PreTrainedTokenizerFast (bypasses tokenizer_class check)\n        tok = PreTrainedTokenizerFast.from_pretrained(\n            tokenizer_name,\n            padding_side = \"left\",\n            token = token,\n            trust_remote_code = trust_remote_code,\n        )\n        # Read tokenizer_config.json for model-specific special tokens\n        try:\n            from huggingface_hub import hf_hub_download\n\n            config_path = hf_hub_download(\n                tokenizer_name, \"tokenizer_config.json\", token = token\n            )\n            with open(config_path, \"r\", encoding = \"utf-8\") as f:\n                tok_config = json.load(f)\n            # Set model-specific special tokens and their IDs\n            for key in (\n                \"image_token\",\n                \"image_start_token\",\n                \"image_end_token\",\n                \"image_thumbnail\",\n                \"video_token\",\n            ):\n                if key in tok_config and not hasattr(tok, key):\n                    setattr(tok, key, tok_config[key])\n                    id_key = key + \"_id\" if not key.endswith(\"_id\") else key\n                    token_id = tok.convert_tokens_to_ids(tok_config[key])\n                    if not hasattr(tok, id_key):\n                        setattr(tok, id_key, token_id)\n        except Exception:\n            pass\n\n        # Find the processor class - try model_type first, then top-level config model_type\n        proc_class_name = PROCESSOR_MAPPING_NAMES.get(model_type)\n        if proc_class_name is None:\n            # model_type might be a sub-model type (e.g. \"lfm2\" instead of \"lfm2_vl\").\n            # Try the top-level config.model_type which often has the processor mapping.\n            try:\n                config = AutoConfig.from_pretrained(\n                    tokenizer_name,\n                    token = token,\n                    trust_remote_code = trust_remote_code,\n                )\n                proc_class_name = PROCESSOR_MAPPING_NAMES.get(config.model_type)\n            except Exception:\n                pass\n\n        if proc_class_name is not None:\n            import transformers\n\n            proc_class = getattr(transformers, proc_class_name, None)\n            if proc_class is not None:\n                processor = proc_class(image_processor = image_processor, tokenizer = tok)\n                # Copy chat_template from tokenizer to processor if needed\n                if not getattr(processor, \"chat_template\", None) and getattr(\n                    tok, \"chat_template\", None\n                ):\n                    processor.chat_template = tok.chat_template\n                return processor\n    except Exception:\n        pass\n    return None\n\n\nclass FastBaseModel:\n    @staticmethod\n    def from_pretrained(\n        model_name = \"unsloth/Llama-3.2-1B-Instruct\",\n        max_seq_length = 2048,\n        dtype = None,\n        load_in_4bit = True,\n        load_in_8bit = False,\n        load_in_16bit = False,\n        full_finetuning = False,\n        token = None,\n        device_map = \"sequential\",\n        trust_remote_code = False,\n        model_types = None,\n        tokenizer_name = None,\n        auto_model = AutoModelForVision2Seq,\n        use_gradient_checkpointing = \"unsloth\",\n        supports_sdpa = True,\n        whisper_language = None,\n        whisper_task = None,\n        auto_config = None,\n        offload_embedding = False,\n        float32_mixed_precision = None,  # Forces float32 mixed precision\n        # vLLM parameters\n        fast_inference = False,\n        gpu_memory_utilization = 0.5,\n        float8_kv_cache = False,\n        random_state = 3407,\n        max_lora_rank = 64,\n        disable_log_stats = False,\n        unsloth_vllm_standby = False,\n        load_in_fp8 = False,  # fp8 LoRA (True, False, 'block')\n        **kwargs,\n    ):\n        if unsloth_vllm_standby and os.environ.get(\"UNSLOTH_VLLM_STANDBY\", \"0\") != \"1\":\n            raise RuntimeError(\n                \"Unsloth: UNSLOTH_VLLM_STANDBY is True, but UNSLOTH_VLLM_STANDBY is not set to 1!\"\n            )\n\n        if model_types is None:\n            raise RuntimeError(\n                \"Unsloth: Please use FastModel or FastVisionModel and not use FastBaseModel directly!\"\n            )\n        if os.environ.get(\"UNSLOTH_MODEL_NAME\", \"\") == \"\":\n            os.environ[\"UNSLOTH_MODEL_NAME\"] = model_name.lower()\n\n        is_vlm = auto_model in [AutoModelForVision2Seq, AutoModelForImageTextToText]\n        is_whisper = whisper_language is not None and whisper_task is not None\n        auto_processor = AutoProcessor if (is_vlm or is_whisper) else AutoTokenizer\n\n        model_type_arch = model_types[0]\n        if model_type_arch == \"siglip\":\n            for model_type_arch in model_types:\n                if model_type_arch != \"siglip\":\n                    break\n\n        vllm_enable_lora = True\n\n        if is_vlm and fast_inference:\n            if not any(arch in VLLM_SUPPORTED_VLM for arch in model_types):\n                raise RuntimeError(\n                    f\"Unsloth: Fast inference is only supported for Language models and Qwen2.5-VL, Gemma3 among vision models. \"\n                    f\"Found architectures: {', '.join(model_types)}!\"\n                )\n\n        if any(arch in VLLM_NON_LORA_VLM for arch in model_types):\n            # mllama is still only in vllm v0 https://arc.net/l/quote/llwkfgmu\n            # https://docs.vllm.ai/en/stable/models/supported_models.html#text-generation_1\n            # vLLM V0 does not support LoRA on multi modal models.\n            # TODO: Update this once vLLM V1 supports Llama 3.2 aka mllama\n            vllm_enable_lora = False\n\n        os.environ[\"UNSLOTH_USE_NEW_MODEL\"] = \"1\"\n        if trust_remote_code:\n            print(\n                \"Unsloth: WARNING `trust_remote_code` is True.\\n\"\n                \"Are you certain you want to do remote code execution?\"\n            )\n        token = hf_login(token)\n        SUPPORTS_BFLOAT16 = is_bfloat16_supported()\n\n        if DEVICE_TYPE == \"cuda\":\n            gpu_stats = torch.cuda.get_device_properties(0)\n            gpu_stats_name = (\n                gpu_stats.name + \". \" if gpu_stats.name != \"\" else \"NVIDIA GPU Device. \"\n            )\n            gpu_version = torch.version.cuda\n            gpu_stats_snippet = f\"CUDA: {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit: {gpu_version}.\"\n            try:\n                vllm_version = f\" vLLM: {importlib_version('vllm')}.\"\n            except:\n                vllm_version = \"\"\n        elif DEVICE_TYPE == \"hip\":\n            gpu_stats = torch.cuda.get_device_properties(0)\n            gpu_stats_name = resolve_hip_gpu_stats_name(gpu_stats)\n            gpu_version = torch.version.hip\n            gpu_stats_snippet = f\"ROCm Toolkit: {gpu_version}.\"\n            try:\n                vllm_version = f\" vLLM: {importlib_version('vllm')}.\"\n            except:\n                vllm_version = \"\"\n        elif DEVICE_TYPE == \"xpu\":\n            gpu_stats = torch.xpu.get_device_properties(0)\n            gpu_stats_name = (\n                gpu_stats.name + \". \" if gpu_stats.name != \"\" else \"Intel XPU Device. \"\n            )\n            gpu_version = torch.version.xpu\n            gpu_stats_snippet = f\"Intel Toolkit: {gpu_version}.\"\n            # [TODO] After adding vLLM support for XPU, change this\n            vllm_version = \"\"\n        else:\n            raise ValueError(f\"Unsloth: Unsupported device type: {DEVICE_TYPE}\")\n\n        max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n\n        arch_name = model_type_arch.title()\n        arch_name = arch_name.replace(\"_Vl_\", \"_VL_\").replace(\"_Moe\", \"_MoE\")\n        statistics = (\n            f\"==((====))==  Unsloth {__version__}: Fast {arch_name} patching. Transformers: {transformers_version}.{vllm_version}\\n\"\n            f\"   {chr(92)}{chr(92)}   /|    {gpu_stats_name}Num GPUs = {DEVICE_COUNT}. Max memory: {max_memory} GB. Platform: {platform_system}.\\n\"\n            f\"O^O/ {chr(92)}_/ {chr(92)}    Torch: {torch.__version__}. {gpu_stats_snippet} Triton: {triton_version}\\n\"\n            f\"{chr(92)}        /    Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\\n\"\n            f' \"-____-\"     Free license: http://github.com/unslothai/unsloth'\n        )\n\n        print(statistics)\n\n        # Warn about fast transfers\n        if \"HF_HUB_ENABLE_HF_TRANSFER\" in os.environ:\n            old_hf_transfer = os.environ[\"HF_HUB_ENABLE_HF_TRANSFER\"]\n            if old_hf_transfer in (\"False\", \"false\"):\n                old_hf_transfer = \"0\"\n            if old_hf_transfer in (\"True\", \"true\"):\n                old_hf_transfer = \"1\"\n        else:\n            old_hf_transfer = \"0\"\n        if old_hf_transfer == \"1\":\n            print(\n                \"Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!\"\n            )\n        if old_hf_transfer != \"0\":\n            os.environ[\"HF_HUB_ENABLE_HF_TRANSFER\"] = \"1\"\n\n        # For debugging - we use a download counter to see if environments are not breaking or if HF is down\n        get_statistics(kwargs.get(\"local_files_only\", False))\n\n        if dtype is None:\n            dtype = torch.float16 if not SUPPORTS_BFLOAT16 else torch.bfloat16\n        elif os.environ.get(\"UNSLOTH_FORCE_FLOAT32\", \"0\") == \"1\":\n            if dtype == torch.float16:\n                dtype = torch.bfloat16\n        elif dtype == torch.bfloat16 and not SUPPORTS_BFLOAT16:\n            logger.warning_once(\n                \"Device does not support bfloat16. Will change to float16.\"\n            )\n            dtype = torch.float16\n        assert dtype in (torch.float16, torch.bfloat16, torch.float32)\n\n        bnb_compute_dtype = dtype\n        do_forced_float32 = False\n        if os.environ.get(\"UNSLOTH_FORCE_FLOAT32\", \"0\") == \"1\":\n            print(\n                f\"Unsloth: Using float16 precision for {model_type_arch} won't work! Using float32.\"\n            )\n            bnb_compute_dtype = torch.float16\n            do_forced_float32 = True\n\n        # Check for custom data-types\n        custom_datatype = None\n        correct_dtype = None\n        if os.environ.get(\"UNSLOTH_FORCE_CUSTOM_DTYPE\", \"\") != \"\":\n            custom_datatype = os.environ[\"UNSLOTH_FORCE_CUSTOM_DTYPE\"]\n            assert custom_datatype.count(\";\") >= 4\n            checker, _dtype, _bnb_compute_dtype, _custom_datatype, execute_code = (\n                custom_datatype.split(\";\", 4)\n            )\n            # Allow custom dtypes on all runs\n            allow_all_runs = checker == \"all\"\n            # Allow only on float16 datatypes\n            allow_float16_runs = (\n                checker == \"float16\" or checker == \"torch.float16\"\n            ) and (\n                dtype == torch.float16\n                or os.environ.get(\"UNSLOTH_FORCE_FLOAT32\", \"0\") == \"1\"\n            )\n            if allow_all_runs or allow_float16_runs:\n                if eval(_dtype) is not None:\n                    dtype = eval(_dtype)\n                if eval(_bnb_compute_dtype) is not None:\n                    bnb_compute_dtype = eval(_bnb_compute_dtype)\n                correct_dtype = bnb_compute_dtype\n                custom_datatype = _custom_datatype\n                # Execute code as well\n                if len(execute_code.strip()) != 0:\n                    exec(execute_code)\n            else:\n                custom_datatype = None\n                correct_dtype = None\n\n        # Stop SDPA for some archs like Pixtral / Mistral3\n        flex_attn_impl = None\n        if auto_config is None:\n            auto_config = AutoConfig.from_pretrained(\n                model_name,\n                token = token,\n                trust_remote_code = trust_remote_code,\n            )\n        try:\n            model_class = auto_model._model_mapping[auto_config.__class__]\n        except Exception:\n            model_class = None\n        flex_attn_impl = prefer_flex_attn_if_supported(model_class, auto_config)\n\n        # Handle FP8 models: get_model_name has already redirected this to BF16 sibling if the model ships with\n        # FP8 weights. We just need to update it here for sanity.\n        auto_config.model_name = model_name\n        # Re-resolve model_class after potential config change\n        try:\n            model_class = auto_model._model_mapping[auto_config.__class__]\n        except Exception:\n            model_class = None\n\n        model_type = str(getattr(auto_config, \"model_type\", \"\")).lower()\n        if model_type.startswith(\"gemma3n\"):\n            # Gemma3N variants initialize timm-based vision towers which do\n            # not support flex_attention, so default to eager unless overridden.\n            default_attn_impl = \"eager\"\n        else:\n            default_attn_impl = \"flex_attention\" if flex_attn_impl else \"sdpa\"\n        if not (\"attn_implementation\" in kwargs):\n            kwargs[\"attn_implementation\"] = default_attn_impl\n        if not supports_sdpa and kwargs.get(\"attn_implementation\") == \"sdpa\":\n            if os.environ.get(\"UNSLOTH_ENABLE_FLEX_ATTENTION\", \"0\") == \"0\":\n                print(\n                    f\"Unsloth: {model_type_arch.title()} does not support SDPA - switching to fast eager.\"\n                )\n            del kwargs[\"attn_implementation\"]\n\n        bnb_config = None\n        user_quantization_config = kwargs.get(\"quantization_config\", None)\n        if full_finetuning and (load_in_4bit or load_in_8bit):\n            print(\n                \"Unsloth: You selected full finetuning support, but 4bit / 8bit is enabled - disabling LoRA / QLoRA.\"\n            )\n            load_in_4bit = False\n            load_in_8bit = False\n            load_in_16bit = False\n\n        if int(load_in_4bit) + int(load_in_8bit) + int(load_in_16bit) >= 2:\n            raise RuntimeError(\n                \"Unsloth: Can only load in 4bit or 8bit or 16bit, not a combination!\"\n            )\n        _skip_modules = SKIP_QUANTIZATION_MODULES.copy()\n        # Nemotron-H uses 'mixer' (not 'mamba') for Mamba layers.\n        # Mamba fused kernels pass out_proj.weight directly to F.linear,\n        # which fails with quantized Params4bit. Skip out_proj from quantization.\n        if any(mt == \"nemotron_h\" for mt in (model_types or [])):\n            _skip_modules.append(\"out_proj\")\n\n        if load_in_4bit:\n            bnb_config = BitsAndBytesConfig(\n                load_in_4bit = True,\n                bnb_4bit_use_double_quant = True,\n                bnb_4bit_quant_type = \"nf4\",\n                bnb_4bit_compute_dtype = bnb_compute_dtype,\n                llm_int8_skip_modules = _skip_modules,\n            )\n        elif load_in_8bit:\n            bnb_config = BitsAndBytesConfig(\n                load_in_8bit = True,\n                llm_int8_skip_modules = _skip_modules,\n            )\n        elif load_in_16bit:\n            bnb_config = None\n        elif not load_in_4bit and not load_in_8bit and not full_finetuning:\n            print(\n                \"Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA.\"\n            )\n\n        if full_finetuning:\n            os.environ[\"UNSLOTH_ENABLE_FULL_FINETUNING\"] = \"1\"\n            if dtype == torch.bfloat16:\n                if float32_mixed_precision != True:\n                    print(\n                        f\"Unsloth: Using bfloat16 full finetuning which cuts memory usage by 50%.\\n\"\n                        f\"To enable float32 training, use `float32_mixed_precision = True` during FastLanguageModel.from_pretrained\"\n                    )\n                else:\n                    print(\n                        f\"Unsloth: Using full float32 full finetuning. \"\n                        f\"To enable bfloat16 training to reduce VRAM usage by 50% albeit with a slightly higher loss, do:\\n\"\n                        \"use `float32_mixed_precision = False` during FastLanguageModel.from_pretrained\"\n                    )\n                    os.environ[\"UNSLOTH_BFLOAT16_MIXED_PRECISION\"] = \"1\"\n            else:\n                print(\n                    \"Unsloth: Float16 full finetuning uses more memory since we upcast weights to float32.\"\n                )\n        else:\n            os.environ[\"UNSLOTH_ENABLE_FULL_FINETUNING\"] = \"0\"\n\n        # Fix AttributeError: 'BitsAndBytesConfig' object has no attribute 'get_loading_attributes'\n        if bnb_config is not None and not hasattr(bnb_config, \"get_loading_attributes\"):\n            bnb_config.get_loading_attributes = lambda *args, **kwargs: {}\n\n        # Cannot be None, since HF now checks for the config\n        if load_in_4bit or load_in_8bit:\n            # Ignore load_in_4bit / load_in_8bit for MXFP4 - best to get config file\n            if (\n                \"gpt-oss-20b\" in model_name.lower()\n                or \"gpt-oss-120b\" in model_name.lower()\n            ):\n                pass\n            else:\n                if user_quantization_config is None:\n                    kwargs[\"quantization_config\"] = bnb_config\n        else:\n            if auto_config is None:\n                auto_config = AutoConfig.from_pretrained(\n                    model_name,\n                    token = token,\n                    trust_remote_code = trust_remote_code,\n                )\n            if hasattr(auto_config, \"quantization_config\"):\n                from transformers.quantizers.auto import (\n                    AUTO_QUANTIZATION_CONFIG_MAPPING,\n                )\n\n                quantization_config = auto_config.quantization_config\n                quant_method = quantization_config[\"quant_method\"]\n                # Sometimes bitsandbytes_4bit + bitsandbytes_8bit is provided\n                if (\n                    quant_method == \"bitsandbytes\"\n                    and \"bitsandbytes\" not in AUTO_QUANTIZATION_CONFIG_MAPPING\n                ):\n                    if \"bitsandbytes_4bit\" not in AUTO_QUANTIZATION_CONFIG_MAPPING:\n                        raise KeyError(\n                            \"Unsloth: AUTO_QUANTIZATION_CONFIG_MAPPING does not have `bitsandbytes_4bit`\"\n                        )\n                    quantizer = AUTO_QUANTIZATION_CONFIG_MAPPING[\"bitsandbytes_4bit\"]\n                else:\n                    quantizer = AUTO_QUANTIZATION_CONFIG_MAPPING[quant_method]\n                quantizer_kwargs = {}\n                if quant_method == \"compressed-tensors\":\n                    # Ignore these\n                    pass\n                else:\n                    # We cannot dequantize since gpt-oss-20b MXFP4 will now be gpt-oss-20b-BF16\n                    if (\n                        load_in_16bit\n                        and \"dequantize\" in inspect.signature(quantizer).parameters\n                    ):\n                        quantizer_kwargs[\"dequantize\"] = True\n                    try:\n                        # Sometimes this fails so we wrap it in a try except\n                        quantization_config = quantizer.from_dict(\n                            quantization_config, **quantizer_kwargs\n                        )\n                    except:\n                        pass\n                    if user_quantization_config is None:\n                        kwargs[\"quantization_config\"] = quantization_config\n\n        # Check if using forced float32 - we load it in bfloat16, then cast to float16!\n        torch_dtype = dtype\n        if do_forced_float32:\n            torch_dtype = torch.bfloat16\n\n        kwargs = add_dtype_kwargs(torch_dtype, kwargs)\n\n        config_attn_impl = kwargs.get(\"attn_implementation\", None)\n        if config_attn_impl is None:\n            config_attn_impl = \"sdpa\" if supports_sdpa else \"eager\"\n        if auto_config is None:\n            auto_config = AutoConfig.from_pretrained(\n                model_name,\n                token = token,\n                trust_remote_code = trust_remote_code,\n            )\n        setattr(auto_config, \"_attn_implementation\", config_attn_impl)\n        if hasattr(auto_config, \"attn_implementation\"):\n            setattr(auto_config, \"attn_implementation\", config_attn_impl)\n        model_config = auto_config\n\n        verify_fp8_support_if_applicable(model_config)\n\n        raise_handler = RaiseUninitialized()\n        if not fast_inference:\n            # Prevent load_in_fp8 from being forwarded into HF internal model loading\n            load_in_fp8 = kwargs.pop(\"load_in_fp8\", None)\n            model = auto_model.from_pretrained(\n                model_name,\n                config = model_config,\n                device_map = device_map,\n                # torch_dtype           = torch_dtype, # Transformers removed torch_dtype\n                # quantization_config   = bnb_config,\n                token = token,\n                trust_remote_code = trust_remote_code,\n                # attn_implementation   = attn_implementation,\n                **kwargs,\n            )\n            if hasattr(model, \"generate\"):\n                model.fast_generate = make_fast_generate_wrapper(model.generate)\n                model.fast_generate_batches = error_out_no_vllm\n            if offload_embedding:\n                if bool(\n                    os.environ.get(\"WSL_DISTRO_NAME\") or os.environ.get(\"WSL_INTEROP\")\n                ):\n                    # WSL doesn't work with offloaded embeddings\n                    pass\n                elif os.name == \"nt\":\n                    # Windows doesn't work with offloaded embeddings\n                    pass\n                else:\n                    embed_tokens = model.get_input_embeddings()\n                    nbytes = embed_tokens.weight.numel() * embed_tokens.weight.itemsize\n                    ngb = round(nbytes / 1024 / 1024 / 1024, 2)\n                    print(f\"Unsloth: Offloading embeddings to RAM to save {ngb} GB.\")\n                    embed_tokens.to(\"cpu\")\n\n                    # Add hooks to move inputs to CPU and back to CUDA\n                    # [TODO] Doesn't seem to work!\n                    # def pre_hook(module, args):\n                    #     args[0]._old_device = args[0].device\n                    #     return (args[0].to(\"cpu\", non_blocking = True))\n                    # def post_hook(module, args, output):\n                    #     old_device = getattr(args[0], \"_old_device\", \"cuda\")\n                    #     return output.to(old_device, non_blocking = True)\n                    # embed_tokens.register_forward_pre_hook(pre_hook,  prepend = True)\n                    # embed_tokens.register_forward_hook    (post_hook, prepend = True)\n                    # Must free GPU memory otherwise will not free!\n                    torch.cuda.empty_cache()\n                    gc.collect()\n        else:\n            from unsloth_zoo.vllm_utils import (\n                load_vllm,\n                get_vllm_state_dict,\n                convert_vllm_to_huggingface,\n                generate_batches,\n                get_lora_supported_ranks,\n            )\n\n            if full_finetuning:\n                max_lora_rank = max(get_lora_supported_ranks())\n                raise NotImplementedError(\n                    \"Unsloth: `fast_inference=True` cannot be used together with `full_finetuning=True`.\\n\"\n                    \"Reason: fast_inference is optimized for inference-only workflows and \"\n                    \"does not currently support full fine-tuning.\\n\"\n                    \"Workaround: disable fast_inference, or use parameter-efficient fine-tuning \"\n                    f\"(e.g. LoRA with rank r={max_lora_rank}).\"\n                )\n\n            model_config.model_name = model_name\n\n            if fast_inference:\n                fast_inference, model_name = fast_inference_setup(\n                    model_name, model_config\n                )\n\n            fp8_mode = None\n            if load_in_fp8 != False:\n                fp8_mode = _get_fp8_mode_and_check_settings(\n                    load_in_fp8,\n                    fast_inference,\n                    full_finetuning,\n                    load_in_4bit,\n                    load_in_8bit,\n                    load_in_16bit,\n                )\n\n            allowed_args = inspect.getfullargspec(load_vllm).args\n            load_vllm_kwargs = dict(\n                model_name = model_name,\n                config = model_config,\n                gpu_memory_utilization = gpu_memory_utilization,\n                max_seq_length = max_seq_length,\n                dtype = dtype,\n                float8_kv_cache = float8_kv_cache,\n                enable_lora = vllm_enable_lora,\n                max_lora_rank = max_lora_rank,\n                disable_log_stats = disable_log_stats,\n                use_bitsandbytes = load_in_4bit,\n                unsloth_vllm_standby = unsloth_vllm_standby,\n                is_vision_model = is_vlm,\n                fp8_mode = fp8_mode,\n            )\n            for allowed_arg in allowed_args:\n                if allowed_arg not in load_vllm_kwargs and allowed_arg in kwargs:\n                    load_vllm_kwargs[allowed_arg] = kwargs[allowed_arg]\n\n            # Load vLLM first\n            llm = load_vllm(**load_vllm_kwargs)\n\n            # Convert to HF format\n            _, quant_state_dict = get_vllm_state_dict(\n                llm,\n                config = model_config,\n                is_vision_model = is_vlm,\n                load_in_fp8 = load_in_fp8,\n            )\n            model = convert_vllm_to_huggingface(\n                quant_state_dict,\n                model_config,\n                dtype,\n                bnb_config,\n                is_vision_model = is_vlm,\n            )\n            model.vllm_engine = llm\n            model.fast_generate = model.vllm_engine.generate\n            model.fast_generate_batches = functools.partial(\n                generate_batches, model.vllm_engine\n            )\n\n        raise_handler.remove()\n\n        # Return old flag\n        os.environ[\"HF_HUB_ENABLE_HF_TRANSFER\"] = old_hf_transfer\n\n        # Check float32 norm weights\n        if os.environ.get(\"UNSLOTH_HIGH_PRECISION_LAYERNORM\", \"0\") == \"1\":\n            for jj, (name, module) in enumerate(model.named_modules()):\n                if (\n                    name.endswith((\"norm\", \"norm1\", \"norm2\", \"norm3\", \"norm4\"))\n                    or \"layernorm\" in name\n                    or \"layer_norm\" in name\n                ) and hasattr(module, \"weight\"):\n                    module._pre_set_compute_dtype = torch.float32\n        # Edit data-types\n        if custom_datatype is not None:\n            with torch.no_grad():\n                for jj, (name, module) in enumerate(model.named_modules()):\n                    exec(custom_datatype)\n        # Clear deleted GPU items\n        for _ in range(3):\n            gc.collect()\n            if DEVICE_TYPE in (\"cuda\", \"hip\"):\n                torch.cuda.empty_cache()\n            elif DEVICE_TYPE == \"xpu\":\n                torch.xpu.empty_cache()\n\n        # Counteract saved tokenizers\n        tokenizer_name = model_name if tokenizer_name is None else tokenizer_name\n\n        # Fix _Unsloth_Patched_ prefix in local config files from old saves (issue #4085)\n        if os.path.isdir(tokenizer_name):\n            import json as _json\n\n            for _cfg_name in (\n                \"processor_config.json\",\n                \"preprocessor_config.json\",\n                \"tokenizer_config.json\",\n            ):\n                _cfg_path = os.path.join(tokenizer_name, _cfg_name)\n                if os.path.exists(_cfg_path):\n                    try:\n                        with open(_cfg_path, \"r\", encoding = \"utf-8\") as _f:\n                            _cfg = _json.load(_f)\n                        if _cfg.get(\"processor_class\", \"\").startswith(\n                            \"_Unsloth_Patched_\"\n                        ):\n                            _cfg[\"processor_class\"] = _cfg[\"processor_class\"][\n                                len(\"_Unsloth_Patched_\") :\n                            ]\n                            with open(_cfg_path, \"w\", encoding = \"utf-8\") as _f:\n                                _json.dump(_cfg, _f, indent = 2, ensure_ascii = False)\n                    except Exception:\n                        pass\n\n        if (whisper_language and whisper_task) or auto_model.__name__.endswith(\n            \"ForConditionalGeneration\"\n        ):\n            try:\n                tokenizer = auto_processor.from_pretrained(\n                    tokenizer_name,\n                    padding_side = \"left\",\n                    token = token,\n                    language = whisper_language,\n                    task = whisper_task,\n                    trust_remote_code = trust_remote_code,\n                )\n            except Exception:\n                tokenizer = None\n        else:\n            try:\n                tokenizer = auto_processor.from_pretrained(\n                    tokenizer_name,\n                    padding_side = \"left\",\n                    token = token,\n                    trust_remote_code = trust_remote_code,\n                )\n            except:\n                tokenizer = get_auto_processor(\n                    tokenizer_name,\n                    padding_side = \"left\",\n                    token = token,\n                    trust_remote_code = trust_remote_code,\n                )\n\n        # If processor loading failed (e.g., tokenizer class not found),\n        # or if AutoProcessor silently degraded to a text-only tokenizer\n        # instead of returning a full VLM processor (issue #4085),\n        # try constructing the processor manually from separate components.\n        _processor_is_degraded = (\n            is_vlm\n            and tokenizer is not None\n            and not hasattr(tokenizer, \"image_processor\")\n        )\n        if (tokenizer is None or _processor_is_degraded) and is_vlm:\n            _fallback = _construct_vlm_processor_fallback(\n                tokenizer_name,\n                model_type_arch,\n                token,\n                trust_remote_code,\n            )\n            if _fallback is not None:\n                tokenizer = _fallback\n            if tokenizer is None:\n                import sys\n\n                print(\n                    f\"Unsloth: Warning - VLM processor fallback returned None for model_type={model_type_arch}\",\n                    file = sys.stderr,\n                )\n        if hasattr(tokenizer, \"tokenizer\"):\n            __tokenizer = tokenizer.tokenizer\n            # Add padding side as well\n            __tokenizer.padding_side = \"left\"\n            # Check bos, eos, pad tokens\n            if hasattr(__tokenizer, \"bos_token\"):\n                tokenizer.bos_token = __tokenizer.bos_token\n                tokenizer.bos_token_id = __tokenizer.bos_token_id\n            if hasattr(__tokenizer, \"eos_token\"):\n                tokenizer.eos_token = __tokenizer.eos_token\n                tokenizer.eos_token_id = __tokenizer.eos_token_id\n            if hasattr(__tokenizer, \"pad_token\"):\n                tokenizer.pad_token = __tokenizer.pad_token\n                tokenizer.pad_token_id = __tokenizer.pad_token_id\n        # Fix other stuff like BnB compute data types\n        model, tokenizer = patch_model_and_tokenizer(\n            model,\n            tokenizer,\n            downcast_rope = False,\n            fix_embeddings = False,\n            do_forced_float32 = do_forced_float32,\n            correct_dtype = correct_dtype,\n        )\n\n        try:\n            model, tokenizer = patch_tokenizer(model, tokenizer)\n        except Exception as _patch_err:\n            # Some VLM processors (e.g., ERNIE VL) may fail during tokenizer patching.\n            # Try loading tokenizer separately via AutoTokenizer as fallback.\n            try:\n                from transformers import AutoTokenizer as _AutoTokenizer\n\n                _fallback_tok = _AutoTokenizer.from_pretrained(\n                    tokenizer_name,\n                    padding_side = \"left\",\n                    token = token,\n                    trust_remote_code = trust_remote_code,\n                )\n                model, _fallback_tok = patch_tokenizer(model, _fallback_tok)\n                # Re-attach as processor wrapper if original was a processor\n                if hasattr(tokenizer, \"image_processor\"):\n                    tokenizer.tokenizer = _fallback_tok\n                else:\n                    tokenizer = _fallback_tok\n            except Exception:\n                # If fallback also fails, raise the original error\n                raise _patch_err\n        model = post_patch_loss_function(model)\n\n        # Log Unsloth version for future fastpaths for inference\n        if hasattr(model, \"config\"):\n            model.config.update({\"unsloth_version\": __version__})\n        patch_saving_functions(model, vision = True)\n        if tokenizer is None:\n            # Last resort: try loading tokenizer via AutoTokenizer, then PreTrainedTokenizerFast\n            try:\n                from transformers import AutoTokenizer as _AutoTokenizer\n\n                tokenizer = _AutoTokenizer.from_pretrained(\n                    tokenizer_name,\n                    padding_side = \"left\",\n                    token = token,\n                    trust_remote_code = trust_remote_code,\n                )\n            except Exception:\n                try:\n                    from transformers import PreTrainedTokenizerFast\n\n                    tokenizer = PreTrainedTokenizerFast.from_pretrained(\n                        tokenizer_name,\n                        padding_side = \"left\",\n                        token = token,\n                        trust_remote_code = trust_remote_code,\n                    )\n                except Exception:\n                    del model\n                    raise RuntimeError(\n                        \"Unsloth: The tokenizer is weirdly not loaded? Please check if there is one.\"\n                    )\n        patch_saving_functions(tokenizer, vision = True)\n\n        # Fix gradient accumulation\n        from transformers.trainer import Trainer\n\n        patch_gradient_accumulation_fix(Trainer)\n\n        # Save tokenizer for inference purposes\n        tokenizer.padding_side = \"left\"  # Force inference\n        if hasattr(tokenizer, \"tokenizer\"):\n            tokenizer.tokenizer.padding_side = \"left\"  # Force inference\n        m = model\n        while hasattr(m, \"model\"):\n            m.max_seq_length = max_seq_length\n            m._saved_temp_tokenizer = tokenizer\n            # Also set is_loaded_in_8bit to disable incorrect DDP\n            m.is_loaded_in_8bit = True if not full_finetuning else False\n            m = m.model\n        m.max_seq_length = max_seq_length\n        # Save to modules as well\n        for module in model.modules():\n            module.max_seq_length = max_seq_length\n        m._saved_temp_tokenizer = tokenizer\n        # Also set is_loaded_in_8bit to disable incorrect DDP\n        m.is_loaded_in_8bit = True if not full_finetuning else False\n\n        # Patch generate\n        if os.environ.get(\"UNSLOTH_DISABLE_FAST_GENERATION\", \"0\") == \"0\" and hasattr(\n            model, \"generate\"\n        ):\n            if model.generate.__name__ != \"unsloth_base_fast_generate\":\n                model._old_generate = model.generate\n                unsloth_base_fast_generate.__doc__ = model._old_generate.__doc__\n                model.generate = types.MethodType(unsloth_base_fast_generate, model)\n        model._unsloth_trust_remote_code = trust_remote_code\n        # Post patches\n        model = FastBaseModel.post_patch_model(\n            model,\n            use_gradient_checkpointing = use_gradient_checkpointing,\n            trust_remote_code = trust_remote_code,\n            model_type = model_type_arch,\n            tokenizer = tokenizer,\n            float32_mixed_precision = float32_mixed_precision,\n        )\n        # Clear deleted GPU items\n        for _ in range(3):\n            gc.collect()\n            if DEVICE_TYPE in (\"cuda\", \"hip\"):\n                torch.cuda.empty_cache()\n            elif DEVICE_TYPE == \"xpu\":\n                torch.xpu.empty_cache()\n        return model, tokenizer\n\n    @staticmethod\n    def get_peft_model(\n        model,\n        r = 16,\n        target_modules = None,\n        lora_alpha = 16,\n        lora_dropout = 0.0,\n        bias = \"none\",\n        finetune_vision_layers = True,\n        finetune_language_layers = True,\n        finetune_attention_modules = True,\n        finetune_mlp_modules = True,\n        layers_to_transform = None,\n        layers_pattern = None,\n        use_gradient_checkpointing = \"unsloth\",\n        random_state = 3407,\n        max_seq_length = 2048,  # not used anymore\n        use_rslora = False,\n        modules_to_save = None,\n        init_lora_weights = True,\n        loftq_config = {},\n        task_type = TaskType.CAUSAL_LM,\n        temporary_location = \"_unsloth_temporary_saved_buffers\",\n        qat_scheme = None,\n        target_parameters = None,  # For MoE expert layers (nn.Parameter)\n        ensure_weight_tying = False,  # [TODO] Add `ensure_weight_tying` for `modules_to_save` for vision models\n        **kwargs,\n    ):\n        if os.environ.get(\"UNSLOTH_ENABLE_FULL_FINETUNING\", \"0\") == \"1\":\n            print(\n                \"Unsloth: Full finetuning is enabled, so .get_peft_model has no effect\"\n            )\n            return model\n        transformers_set_seed(random_state)\n\n        if type(r) is not int:\n            raise TypeError(f\"Unsloth: Rank of {str(r)} must be an integer.\")\n        if r <= 0:\n            raise TypeError(f\"Unsloth: Rank of {str(r)} must be larger than 0.\")\n\n        if isinstance(model, PeftModelForCausalLM):\n            raise RuntimeError(\n                \"Unsloth: You already added LoRA adapters to your model!\"\n            )\n\n        if target_modules == \"all-linear\":\n            finetune_vision_layers = True\n            finetune_language_layers = True\n            finetune_attention_modules = True\n            finetune_mlp_modules = True\n        if target_modules is None or target_modules == \"all-linear\":\n            target_modules = get_peft_regex(\n                model,\n                finetune_vision_layers = finetune_vision_layers,\n                finetune_language_layers = finetune_language_layers,\n                finetune_attention_modules = finetune_attention_modules,\n                finetune_mlp_modules = finetune_mlp_modules,\n            )\n        else:\n            assert type(target_modules) in (\n                list,\n                tuple,\n                str,\n            )\n\n        if hasattr(model, \"vllm_engine\"):\n            if (\n                hasattr(model.vllm_engine, \"llm_engine\")\n                and hasattr(model.vllm_engine.llm_engine, \"vllm_config\")\n                and getattr(\n                    model.vllm_engine.llm_engine.vllm_config, \"lora_config\", None\n                )\n                is None\n            ):\n                # If vLLM is being used but lora is not enabled, throw an error\n                # Ref https://github.com/vllm-project/vllm/blob/51ba839555a5d122eadd91e9c16463ac288f5fa1/vllm/v1/engine/processor.py#L148-L151\n                raise RuntimeError(\"Unsloth: LoRA is not enabled for this model!\")\n            if finetune_vision_layers:\n                # vLLM does not support LoRA on vision layers\n                # https://github.com/vllm-project/vllm/blob/main/vllm/lora/models.py#L471-L477\n                # TODO: Update this once vLLM V1 supports LoRA on vision layers (possibly not happening)\n                raise RuntimeError(\n                    \"Unsloth: Finetuning vision layers is not supported for fast_inference. Only text layers are supported!\"\n                )\n            if model.config.model_type in VLLM_NON_LORA_VLM:\n                # mllama is still only in vllm v0 https://arc.net/l/quote/llwkfgmu\n                # https://docs.vllm.ai/en/stable/models/supported_models.html#text-generation_1\n                # vLLM V0 does not support LoRA on multi modal models.\n                # TODO: Update this once vLLM V1 supports Llama 3.2 aka mllama\n                raise RuntimeError(\n                    \"Unsloth: LoRA finetuning for Llama 3.2 aka mllama models is not supported with fast_inference!\"\n                )\n\n        # Clear deleted GPU items\n        for _ in range(3):\n            gc.collect()\n            if DEVICE_TYPE in (\"cuda\", \"hip\"):\n                torch.cuda.empty_cache()\n            elif DEVICE_TYPE == \"xpu\":\n                torch.xpu.empty_cache()\n        max_seq_length = model.max_seq_length\n        # If we pass loftq_config = None we will get an error\n        loftq_config = validate_loftq_config(\n            loftq_config, lora_dropout, bias, init_lora_weights, model\n        )\n\n        # Auto-detect MoE models and populate target_parameters for expert layers\n        if target_parameters is None:\n            target_parameters = get_moe_target_parameters(model, target_modules)\n\n        # Get only allowed parameters for LoraConfig\n        local_variables = {\n            **locals(),\n            **kwargs,\n        }\n        del local_variables[\"kwargs\"]\n        allowed_parameters = inspect.signature(LoraConfig).parameters.keys()\n        lora_config = LoraConfig(\n            **{k: v for k, v in local_variables.items() if k in allowed_parameters},\n        )\n        model = prepare_model_for_kbit_training(\n            model,\n            use_gradient_checkpointing = use_gradient_checkpointing,\n        )\n        model = _get_peft_model(model, lora_config)\n        # Apply QAT + LoRA if specified\n        if qat_scheme is not None:\n            print(\"Unsloth: Applying QAT to mitigate quantization degradation\")\n            model = _prepare_model_for_qat(model, qat_scheme)\n        # Fix LoraConfig.auto_mapping is None\n        fix_lora_auto_mapping(model)\n        # Enable gradients on modules which are trainable\n        requires_grad_for_gradient_checkpointing(model)\n        trust_remote_code = getattr(model, \"_unsloth_trust_remote_code\", False)\n        model = FastBaseModel.post_patch_model(\n            model,\n            use_gradient_checkpointing = use_gradient_checkpointing,\n            trust_remote_code = trust_remote_code,\n        )\n        model.max_seq_length = max_seq_length\n        # Save to modules as well\n        for module in model.modules():\n            module.max_seq_length = max_seq_length\n        # Clear deleted GPU items\n        for _ in range(3):\n            gc.collect()\n            if DEVICE_TYPE in (\"cuda\", \"hip\"):\n                torch.cuda.empty_cache()\n            elif DEVICE_TYPE == \"xpu\":\n                torch.xpu.empty_cache()\n        patch_saving_functions(model, vision = True)\n        patch_peft_fast_inference(model)\n\n        # Add for_inference and for_training\n        model.for_training = functools.partial(FastBaseModel.for_training, model)\n        model.for_inference = functools.partial(FastBaseModel.for_inference, model)\n        m = model\n        while hasattr(m, \"model\"):\n            m.for_training = functools.partial(FastBaseModel.for_training, m)\n            m.for_inference = functools.partial(FastBaseModel.for_inference, m)\n            m = m.model\n        return model\n\n    @staticmethod\n    def post_patch_model(\n        model,\n        use_gradient_checkpointing = True,\n        trust_remote_code = False,\n        model_type = None,\n        tokenizer = None,\n        float32_mixed_precision = None,\n    ):\n        full_finetuning = os.environ.get(\"UNSLOTH_ENABLE_FULL_FINETUNING\", \"0\") == \"1\"\n\n        if type(float32_mixed_precision) is bool:\n            # Respect whatever it was set before\n            pass\n        else:\n            float32_mixed_precision = True\n            if (\n                _get_dtype(dtype_from_config(model.config)) == torch.bfloat16\n                and full_finetuning\n            ):\n                # Use bfloat16 precision for full finetuning\n                float32_mixed_precision = False\n\n        # VLMs can hit DDP \"marked ready twice\" with re-entrant checkpointing.\n        # See: https://github.com/unslothai/unsloth/issues/3713.\n        use_reentrant = not is_distributed()\n        if not use_reentrant:\n            # Under DDP, avoid the offloaded/re-entrant checkpoint patch.\n            unpatch_unsloth_gradient_checkpointing()\n            unpatch_unsloth_smart_gradient_checkpointing()\n            # Force native checkpoint to default to non-reentrant for downstream calls.\n            _orig_checkpoint = torch_checkpoint.checkpoint\n\n            def _nonre_checkpoint(function, *args, **kwargs):\n                kwargs[\"use_reentrant\"] = False\n                return _orig_checkpoint(function, *args, **kwargs)\n\n            torch_checkpoint.checkpoint = _nonre_checkpoint\n            hf_modeling_utils.checkpoint = _nonre_checkpoint\n\n        model = prepare_model_for_training(\n            model,\n            use_gradient_checkpointing = use_gradient_checkpointing,\n            use_reentrant = use_reentrant,\n            full_finetuning = full_finetuning,\n            train_layernorms = full_finetuning,\n            train_embedding = full_finetuning,\n            train_lm_head = full_finetuning,\n            float32_mixed_precision = float32_mixed_precision,\n            patch_modules_to_save = True,\n        )\n\n        from transformers.trainer import Trainer\n\n        if (\n            Trainer._inner_training_loop.__name__ != \"_fast_inner_training_loop\"\n            and trust_remote_code == False\n        ):\n            raise RuntimeError(\"Unsloth: Unsuccessfully patched inner_training_loop\")\n        patch_saving_functions(model, vision = True)\n\n        # Patch tokenizer to pad to the left\n        m = model\n        while hasattr(m, \"model\"):\n            if hasattr(m, \"_saved_temp_tokenizer\"):\n                if hasattr(m._saved_temp_tokenizer, \"tokenizer\"):\n                    m._saved_temp_tokenizer.tokenizer.padding_side = \"left\"\n            # Also set is_loaded_in_8bit to disable incorrect DDP\n            m.is_loaded_in_8bit = True if not full_finetuning else False\n            m = m.model\n        if hasattr(m, \"_saved_temp_tokenizer\"):\n            if hasattr(m._saved_temp_tokenizer, \"tokenizer\"):\n                m._saved_temp_tokenizer.tokenizer.padding_side = \"left\"\n        # Also set is_loaded_in_8bit to disable incorrect DDP\n        m.is_loaded_in_8bit = True if not full_finetuning else False\n\n        # Clear deleted GPU items\n        for _ in range(3):\n            gc.collect()\n            if DEVICE_TYPE in (\"cuda\", \"hip\"):\n                torch.cuda.empty_cache()\n            elif DEVICE_TYPE == \"xpu\":\n                torch.xpu.empty_cache()\n        # Add for_inference and for_training\n        model.for_training = functools.partial(FastBaseModel.for_training, model)\n        model.for_inference = functools.partial(FastBaseModel.for_inference, model)\n        m = model\n        while hasattr(m, \"model\"):\n            m.for_training = functools.partial(FastBaseModel.for_training, m)\n            m.for_inference = functools.partial(FastBaseModel.for_inference, m)\n            m = m.model\n        # Set weight[padding_idx] = 0 for embeddings that are NOT tied with the\n        # lm_head. When weights are tied, zeroing the padding row also zeros\n        # the corresponding lm_head row, forcing logit = 0 for the pad token.\n        # Only do this if tokenizer is defined since eos_token == pad_token sometimes!\n        pad_token_id = getattr(tokenizer, \"pad_token_id\", None)\n        lm_head = getattr(model, \"lm_head\", None)\n        lm_head_weight = (\n            getattr(lm_head, \"weight\", None) if lm_head is not None else None\n        )\n        if (\n            tokenizer is not None\n            and getattr(tokenizer, \"eos_token_id\", None) != pad_token_id\n        ):\n            with torch.no_grad():\n                for name, module in model.named_modules():\n                    if type(module) is torch.nn.Embedding:\n                        if (\n                            getattr(module, \"weight\", None) is not None\n                            and getattr(module, \"padding_idx\", None) is not None\n                        ):\n                            if (\n                                module.padding_idx == pad_token_id\n                                and module.padding_idx < module.weight.shape[0]\n                            ):\n                                # Skip if tied to lm_head\n                                if (\n                                    lm_head_weight is not None\n                                    and module.weight.data_ptr()\n                                    == lm_head_weight.data_ptr()\n                                ):\n                                    continue\n                                module.weight[module.padding_idx] = 0\n        return model\n\n    @staticmethod\n    def for_inference(model):\n        if not hasattr(model, \"parameters\"):\n            raise TypeError(\n                \"Unsloth: I think you're passing a tokenizer, not the model to for_inference!\"\n            )\n\n        def _for_inference(m):\n            if hasattr(m, \"gradient_checkpointing\"):\n                m.gradient_checkpointing = False\n            if hasattr(m, \"training\"):\n                m.training = False\n            # Pad tokenizer to the left\n            if hasattr(m, \"_saved_temp_tokenizer\"):\n                m._saved_temp_tokenizer.padding_side = \"left\"\n            # Set a flag for generation!\n            m._flag_for_generation = True\n\n        m = model\n        while hasattr(m, \"model\"):\n            _for_inference(m)\n            m = m.model\n        _for_inference(m)\n        model.eval()  # to turn off training on modules deeper in\n\n        # Since transformers 4.53, must turn off explicitly\n        for module in model.modules():\n            if hasattr(module, \"gradient_checkpointing\"):\n                module.gradient_checkpointing = False\n\n        # Also disable training for embeddings for NEFTune\n        if hasattr(model, \"get_input_embeddings\"):\n            embeddings = model.get_input_embeddings()\n            if hasattr(embeddings, \"training\"):\n                embeddings.training = False\n        if hasattr(model, \"get_output_embeddings\"):\n            embeddings = model.get_output_embeddings()\n            if hasattr(embeddings, \"training\"):\n                embeddings.training = False\n        # Must disable returning hidden states in the case for GRPO\n        os.environ[\"UNSLOTH_RETURN_HIDDEN_STATES\"] = \"0\"\n        # Must enable returning logits\n        os.environ[\"UNSLOTH_RETURN_LOGITS\"] = \"1\"\n        # Turn off skip guards and set stance to default\n        if torch_compiler_set_stance is not None:\n            torch_compiler_set_stance(stance = \"default\", skip_guard_eval_unsafe = False)\n        return model\n\n    @staticmethod\n    def for_training(model, use_gradient_checkpointing = True):\n        if not hasattr(model, \"parameters\"):\n            raise TypeError(\n                \"Unsloth: I think you're passing a tokenizer, not the model to for_training!\"\n            )\n\n        # Delete all fast inference loras\n        for param in model.parameters():\n            if hasattr(param, \"_fast_lora\"):\n                del param._fast_lora\n\n        def _for_training(m):\n            if hasattr(m, \"gradient_checkpointing\"):\n                m.gradient_checkpointing = use_gradient_checkpointing\n            if hasattr(m, \"training\"):\n                m.training = True\n            # Pad tokenizer to the left\n            if hasattr(m, \"_saved_temp_tokenizer\"):\n                m._saved_temp_tokenizer.padding_side = \"right\"\n            # Set a flag for generation!\n            if hasattr(m, \"_flag_for_generation\"):\n                try:\n                    # Weirdly sometimes cannot succeed so do a try except\n                    del m._flag_for_generation\n                except:\n                    pass\n\n        m = model\n        while hasattr(m, \"model\"):\n            _for_training(m)\n            m = m.model\n        _for_training(m)\n        model.train()  # to turn on training on modules deeper in\n\n        # Since transformers 4.53, must turn on explicitly\n        for module in model.modules():\n            if hasattr(module, \"gradient_checkpointing\"):\n                module.gradient_checkpointing = use_gradient_checkpointing\n\n        # Also re-enable training for embeddings for NEFTune\n        if hasattr(model, \"get_input_embeddings\"):\n            embeddings = model.get_input_embeddings()\n            if hasattr(embeddings, \"training\"):\n                embeddings.training = True\n        if hasattr(model, \"get_output_embeddings\"):\n            embeddings = model.get_output_embeddings()\n            if hasattr(embeddings, \"training\"):\n                embeddings.training = True\n        # Can re-enable not returning logits\n        os.environ[\"UNSLOTH_RETURN_LOGITS\"] = \"0\"\n        # Turn off skip guards and set stance to default\n        if torch_compiler_set_stance is not None:\n            torch_compiler_set_stance(stance = \"default\", skip_guard_eval_unsafe = False)\n        return model\n"
  },
  {
    "path": "unsloth/ollama_template_mappers.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n__all__ = [\n    \"OLLAMA_TEMPLATES\",\n    \"OLLAMA_TEMPLATE_TO_MODEL_MAPPER\",\n    \"MODEL_TO_OLLAMA_TEMPLATE_MAPPER\",\n]\n\nOLLAMA_TEMPLATES = {}\n\n# =========================================== Unsloth\n\nunsloth_ollama = '''\nFROM {__FILE_LOCATION__}\nTEMPLATE \"\"\"{{ if .System }}{{ .System }}\n{{ end }}{{ if .Prompt }}>>> User: {{ .Prompt }}\n{{ end }}>>> Assistant: {{ .Response }}{__EOS_TOKEN__}\n\"\"\"\nPARAMETER stop \"{__EOS_TOKEN__}\"\nPARAMETER temperature 1.5\nPARAMETER min_p 0.1\nSYSTEM \"\"\"You are a helpful assistant to the user\"\"\"\n'''\n\nOLLAMA_TEMPLATES[\"unsloth\"] = unsloth_ollama\n\n# =========================================== Zephyr\n\nzephyr_ollama = '''\nFROM {__FILE_LOCATION__}\nTEMPLATE \"\"\"{{ if .System }}<|system|>\n{{ .System }}{__EOS_TOKEN__}\n{{ end }}{{ if .Prompt }}<|user|>\n{{ .Prompt }}{__EOS_TOKEN__}\n{{ end }}<|assistant|>\n{{ .Response }}{__EOS_TOKEN__}\n\"\"\"\nPARAMETER stop \"{__EOS_TOKEN__}\"\nPARAMETER temperature 1.5\nPARAMETER min_p 0.1\n'''\n\nOLLAMA_TEMPLATES[\"zephyr\"] = zephyr_ollama\n\n# =========================================== ChatML\nchatml_ollama = '''\nFROM {__FILE_LOCATION__}\nTEMPLATE \"\"\"{{ if .System }}<|im_start|>system\n{{ .System }}<|im_end|>\n{{ end }}{{ if .Prompt }}<|im_start|>user\n{{ .Prompt }}<|im_end|>\n{{ end }}<|im_start|>assistant\n{{ .Response }}<|im_end|>\n\"\"\"\nPARAMETER stop \"<|im_start|>\"\nPARAMETER stop \"<|im_end|>\"\nPARAMETER temperature 1.5\nPARAMETER min_p 0.1\n'''\n\nOLLAMA_TEMPLATES[\"chatml\"] = chatml_ollama\n\n# =========================================== Mistral-1\n# Ollama from https://www.ollama.com/library/mistral\n# Mistral v0.1 https://ollama.com/library/mistral:v0.1/blobs/22e1b2e8dc2f\n# Mistral v0.2 https://ollama.com/library/mistral:v0.2/blobs/e6836092461f\nmistral_ollama = '''\nFROM {__FILE_LOCATION__}\nTEMPLATE \"\"\"[INST] {{ if .System }}{{ .System }} {{ end }}{{ .Prompt }} [/INST]\"\"\"\nPARAMETER stop \"[INST]\"\nPARAMETER stop \"[/INST]\"\n'''\n\n# mistral:v0.3 https://ollama.com/library/mistral:v0.3/blobs/1ff5b64b61b9\n# mistral-large https://ollama.com/library/mistral-large:latest/blobs/96adabcf2c08\nmistral_v03_ollama = '''\nFROM {__FILE_LOCATION__}\nTEMPLATE \"\"\"{{- if .Messages }}\n{{- range $index, $_ := .Messages }}\n{{- if eq .Role \"user\" }}\n{{- if and (eq (len (slice $.Messages $index)) 1) $.Tools }}[AVAILABLE_TOOLS] {{ $.Tools }}[/AVAILABLE_TOOLS]\n{{- end }}[INST] {{ if and $.System (eq (len (slice $.Messages $index)) 1) }}{{ $.System }}\n\n{{ end }}{{ .Content }}[/INST]\n{{- else if eq .Role \"assistant\" }}\n{{- if .Content }}{{ .Content }}\n{{- else if .ToolCalls }}[TOOL_CALLS] [\n{{- range .ToolCalls }}{\"name\": \"{{ .Function.Name }}\", \"arguments\": {{ .Function.Arguments }}}\n{{- end }}]\n{{- end }}</s>\n{{- else if eq .Role \"tool\" }}[TOOL_RESULTS] {\"content\": {{ .Content }}} [/TOOL_RESULTS]\n{{- end }}\n{{- end }}\n{{- else }}[INST] {{ if .System }}{{ .System }}\n\n{{ end }}{{ .Prompt }}[/INST]\n{{- end }}{{ .Response }}\n{{- if .Response }}</s>\n{{- end }}\"\"\"\nPARAMETER stop \"[INST]\"\nPARAMETER stop \"[/INST]\"\nPARAMETER stop \"</s>\"\n'''\n\n# Mistral-small https://ollama.com/library/mistral-small:latest/blobs/6db27cd4e277\nmistral_small_ollama = '''\nFROM {__FILE_LOCATION__}\nTEMPLATE \"\"\"{{- range $index, $_ := .Messages }}\n{{- if eq .Role \"system\" }}[SYSTEM_PROMPT]{{ .Content }}[/SYSTEM_PROMPT]\n{{- else if eq .Role \"user\" }}\n{{- if and (le (len (slice $.Messages $index)) 2) $.Tools }}[AVAILABLE_TOOLS]{{ $.Tools }}[/AVAILABLE_TOOLS]\n{{- end }}[INST]{{ .Content }}[/INST]\n{{- else if eq .Role \"assistant\" }}\n{{- if .Content }}{{ .Content }}\n{{- if not (eq (len (slice $.Messages $index)) 1) }}</s>\n{{- end }}\n{{- else if .ToolCalls }}[TOOL_CALLS][\n{{- range .ToolCalls }}{\"name\": \"{{ .Function.Name }}\", \"arguments\": {{ .Function.Arguments }}}\n{{- end }}]</s>\n{{- end }}\n{{- else if eq .Role \"tool\" }}[TOOL_RESULTS]{\"content\": {{ .Content }}}[/TOOL_RESULTS]\n{{- end }}\n{{- end }}\"\"\"\nPARAMETER temperature 0.15\nSYSTEM \"\"\"You are Mistral Small 3, a Large Language Model (LLM) created by Mistral AI, a French startup headquartered in Paris. Your knowledge base was last updated on 2023-10-01. When you're not sure about some information, you say that you don't have the information and don't make up anything. If the user's question is not clear, ambiguous, or does not provide enough context for you to accurately answer the question, you do not try to answer it right away and you rather ask the user to clarify their request (e.g. \"What are some good restaurants around me?\" => \"Where are you?\" or \"When is the next flight to Tokyo\" => \"Where do you travel from?\")\"\"\"\n'''\n\n# mistral-small-3.1 https://ollama.com/library/mistral-small3.1:latest/blobs/6db27cd4e277\nmistral_small_31_ollama = '''\nFROM {__FILE_LOCATION__}\nTEMPLATE \"\"\"{{- range $index, $_ := .Messages }}\n{{- if eq .Role \"system\" }}[SYSTEM_PROMPT]{{ .Content }}[/SYSTEM_PROMPT]\n{{- else if eq .Role \"user\" }}\n{{- if and (le (len (slice $.Messages $index)) 2) $.Tools }}[AVAILABLE_TOOLS]{{ $.Tools }}[/AVAILABLE_TOOLS]\n{{- end }}[INST]{{ .Content }}[/INST]\n{{- else if eq .Role \"assistant\" }}\n{{- if .Content }}{{ .Content }}\n{{- if not (eq (len (slice $.Messages $index)) 1) }}</s>\n{{- end }}\n{{- else if .ToolCalls }}[TOOL_CALLS][\n{{- range .ToolCalls }}{\"name\": \"{{ .Function.Name }}\", \"arguments\": {{ .Function.Arguments }}}\n{{- end }}]</s>\n{{- end }}\n{{- else if eq .Role \"tool\" }}[TOOL_RESULTS]{\"content\": {{ .Content }}}[/TOOL_RESULTS]\n{{- end }}\n{{- end }}\"\"\"\nPARAMETER num_ctx 4096\nSYSTEM \"\"\"You are Mistral Small 3.1, a Large Language Model (LLM) created by Mistral AI, a French startup headquartered in Paris.\nYou power an AI assistant called Le Chat.\nYour knowledge base was last updated on 2023-10-01.\n\nWhen you're not sure about some information, you say that you don't have the information and don't make up anything.\nIf the user's question is not clear, ambiguous, or does not provide enough context for you to accurately answer the question, you do not try to answer it right away and you rather ask the user to clarify their request (e.g. \"What are some good restaurants around me?\" => \"Where are you?\" or \"When is the next flight to Tokyo\" => \"Where do you travel from?\").\nYou are always very attentive to dates, in particular you try to resolve dates (e.g. \"yesterday\" is {yesterday}) and when asked about information at specific dates, you discard information that is at another date.\nYou follow these instructions in all languages, and always respond to the user in the language they use or request.\nNext sections describe the capabilities that you have.\n\n# WEB BROWSING INSTRUCTIONS\n\nYou cannot perform any web search or access internet to open URLs, links etc. If it seems like the user is expecting you to do so, you clarify the situation and ask the user to copy paste the text directly in the chat.\n\n# MULTI-MODAL INSTRUCTIONS\n\nYou have the ability to read images, but you cannot generate images. You also cannot transcribe audio files or videos.\nYou cannot read nor transcribe audio files or videos.\"\"\"\n'''\n\n# mistral-small-3.2 https://ollama.com/library/mistral-small3.2:latest/blobs/706c4d1164f7\nmistral_small_32_ollama = '''\nFROM {__FILE_LOCATION__}\nTEMPLATE \"\"\"{{- range $index, $_ := .Messages }}\n{{- if eq .Role \"system\" }}[SYSTEM_PROMPT]{{ .Content }}[/SYSTEM_PROMPT]\n{{- else if eq .Role \"user\" }}\n{{- if and (le (len (slice $.Messages $index)) 2) $.Tools }}[AVAILABLE_TOOLS]{{ $.Tools }}[/AVAILABLE_TOOLS]\n{{- end }}[INST]{{ .Content }}[/INST]\n{{- else if eq .Role \"assistant\" }}\n{{- if .Content }}{{ .Content }}\n{{- if not (eq (len (slice $.Messages $index)) 1) }}</s>\n{{- end }}\n{{- else if .ToolCalls }}\n{{- range $i, $_ := .ToolCalls }}[TOOL_CALLS]{{ .Function.Name }}[CALL_ID]{{ $i }}[ARGS]{{ .Function.Arguments }}\n{{- end }}</s>\n{{- end }}\n{{- else if eq .Role \"tool\" }}[TOOL_RESULTS]{\"content\": {{ .Content }}}[/TOOL_RESULTS]\n{{- end }}\n{{- end }}\"\"\"\nPARAMETER temperature 0.15\nSYSTEM \"\"\"You are Mistral Small 3.2, a Large Language Model (LLM) created by Mistral AI, a French startup headquartered in Paris.\nYou power an AI assistant called Le Chat.\nYour knowledge base was last updated on 2023-10-01.\n\nWhen you're not sure about some information or when the user's request requires up-to-date or specific data, you must use the available tools to fetch the information. Do not hesitate to use tools whenever they can provide a more accurate or complete response. If no relevant tools are available, then clearly state that you don't have the information and avoid making up anything.\nIf the user's question is not clear, ambiguous, or does not provide enough context for you to accurately answer the question, you do not try to answer it right away and you rather ask the user to clarify their request (e.g. \"What are some good restaurants around me?\" => \"Where are you?\" or \"When is the next flight to Tokyo\" => \"Where do you travel from?\").\nYou are always very attentive to dates, in particular you try to resolve dates and when asked about information at specific dates, you discard information that is at another date.\nYou follow these instructions in all languages, and always respond to the user in the language they use or request.\nNext sections describe the capabilities that you have.\n\n# WEB BROWSING INSTRUCTIONS\n\nYou cannot perform any web search or access internet to open URLs, links etc. If it seems like the user is expecting you to do so, you clarify the situation and ask the user to copy paste the text directly in the chat.\n\n# MULTI-MODAL INSTRUCTIONS\n\nYou have the ability to read images, but you cannot generate images. You also cannot transcribe audio files or videos.\nYou cannot read nor transcribe audio files or videos.\n\nTOOL CALLING INSTRUCTIONS\n\nYou may have access to tools that you can use to fetch information or perform actions. You must use these tools in the following situations:\n\n1. When the request requires up-to-date information.\n2. When the request requires specific data that you do not have in your knowledge base.\n3. When the request involves actions that you cannot perform without tools.\n\nAlways prioritize using tools to provide the most accurate and helpful response. If tools are not available, inform the user that you cannot perform the requested action at the moment.\"\"\"\n'''\n\n\n# https://ollama.com/library/mixtral:latest/blobs/53d74de0d84c\nmixtral_ollama = '''\nFROM {__FILE_LOCATION__}\nTEMPLATE \"\"\"[INST] {{ if .System }}{{ .System }} {{ end }}{{ .Prompt }} [/INST] {{ .Response }}\"\"\"\nPARAMETER stop \"[INST]\"\nPARAMETER stop \"[/INST]\"\n'''\n\n# https://registry.ollama.ai/library/mistral-nemo:latest/blobs/438402ddac75\nmistral_nemo_ollama = '''\nFROM {__FILE_LOCATION__}\nTEMPLATE \"\"\"\n{{- range $i, $_ := .Messages }}\n{{- if eq .Role \"user\" }}\n{{- if and $.Tools (le (len (slice $.Messages $i)) 2) }}[AVAILABLE_TOOLS]{{ $.Tools }}[/AVAILABLE_TOOLS]\n{{- end }}[INST]{{ if and $.System (eq (len (slice $.Messages $i)) 1) }}{{ $.System }}\n\n{{ end }}{{ .Content }}[/INST]\n{{- else if eq .Role \"assistant\" }}\n{{- if .Content }} {{ .Content }}{{ if not (eq (len (slice $.Messages $i)) 1) }}</s>{{ end }}\n{{- else if .ToolCalls }}[TOOL_CALLS][\n{{- range .ToolCalls }}{\"name\": \"{{ .Function.Name }}\", \"arguments\": {{ .Function.Arguments }}}\n{{- end }}]</s>\n{{- end }}\n{{- else if eq .Role \"tool\" }}[TOOL_RESULTS]{\"content\": {{ .Content }}}[/TOOL_RESULTS]\n{{- end }}\n{{- end }}\"\"\"\nPARAMETER stop \"[INST]\"\nPARAMETER stop \"[/INST]\"\n'''\n\n# https://ollama.com/library/codestral:latest/blobs/51707752a87c\ncodestral_ollama = '''\nFROM {__FILE_LOCATION__}\nTEMPLATE \"\"\"\n{{- if .Suffix }}[SUFFIX]{{ .Suffix }}[PREFIX] {{ .Prompt }}\n{{- else if .Messages }}\n{{- range $index, $_ := .Messages }}\n{{- if eq .Role \"user\" }}[INST] {{ if and $.System (eq (len (slice $.Messages $index)) 1) }}{{ $.System }}\n\n{{ end }}{{ .Content }}[/INST]\n{{- else if eq .Role \"assistant\" }} {{ .Content }}</s>\n{{- end }}\n{{- end }}\n{{- else }}[INST] {{ if .System }}{{ .System }}\n\n{{ end }}{{ .Prompt }} [/INST]\n{{- end }} {{ .Response }}\n{{- if .Response }}</s>\n{{- end }}\n\"\"\"\nPARAMETER stop \"[INST]\"\nPARAMETER stop \"[/INST]\"\nPARAMETER stop \"[PREFIX]\"\nPARAMETER stop \"[MIDDLE]\"\nPARAMETER stop \"[SUFFIX]\"\n'''\n\n# https://ollama.com/library/devstral:latest/blobs/ea9ec42474e0\ndevstral_ollama = '''\nFROM {__FILE_LOCATION__}\nTEMPLATE \"\"\"{{- $lastUserIndex := -1 }}\n{{- range $index, $_ := .Messages }}\n{{- if eq .Role \"user\" }}{{ $lastUserIndex = $index }}{{ end }}\n{{- end }}\n{{- range $index, $_ := .Messages }}\n{{- if eq .Role \"system\" }}[SYSTEM_PROMPT]{{ .Content }}[/SYSTEM_PROMPT]\n{{- else if eq .Role \"user\" }}\n{{- if and (eq $lastUserIndex $index) $.Tools }}[AVAILABLE_TOOLS]{{ $.Tools }}[/AVAILABLE_TOOLS]\n{{- end }}[INST]{{ .Content }}[/INST]\n{{- else if eq .Role \"assistant\" }}\n{{- if .Content }}{{ .Content }}\n{{- if not (eq (len (slice $.Messages $index)) 1) }}</s>\n{{- end }}\n{{- else if .ToolCalls }}[TOOL_CALLS][\n{{- range .ToolCalls }}{\"name\": \"{{ .Function.Name }}\", \"arguments\": {{ .Function.Arguments }}}\n{{- end }}]</s>\n{{- end }}\n{{- else if eq .Role \"tool\" }}[TOOL_RESULTS]{\"content\": {{ .Content }}}[/TOOL_RESULTS]\n{{- end }}\n{{- end }}\"\"\"\nSYSTEM \"\"\"You are Devstral, a helpful agentic model trained by Mistral AI and using the OpenHands scaffold. You can interact with a computer to solve tasks.\n\n<ROLE>\nYour primary role is to assist users by executing commands, modifying code, and solving technical problems effectively. You should be thorough, methodical, and prioritize quality over speed.\n* If the user asks a question, like \"why is X happening\", don't try to fix the problem. Just give an answer to the question.\n</ROLE>\n\n<EFFICIENCY>\n* Each action you take is somewhat expensive. Wherever possible, combine multiple actions into a single action, e.g. combine multiple bash commands into one, using sed and grep to edit/view multiple files at once.\n* When exploring the codebase, use efficient tools like find, grep, and git commands with appropriate filters to minimize unnecessary operations.\n</EFFICIENCY>\n\n<FILE_SYSTEM_GUIDELINES>\n* When a user provides a file path, do NOT assume it's relative to the current working directory. First explore the file system to locate the file before working on it.\n* If asked to edit a file, edit the file directly, rather than creating a new file with a different filename.\n* For global search-and-replace operations, consider using `sed` instead of opening file editors multiple times.\n</FILE_SYSTEM_GUIDELINES>\n\n<CODE_QUALITY>\n* Write clean, efficient code with minimal comments. Avoid redundancy in comments: Do not repeat information that can be easily inferred from the code itself.\n* When implementing solutions, focus on making the minimal changes needed to solve the problem.\n* Before implementing any changes, first thoroughly understand the codebase through exploration.\n* If you are adding a lot of code to a function or file, consider splitting the function or file into smaller pieces when appropriate.\n</CODE_QUALITY>\n\n<VERSION_CONTROL>\n* When configuring git credentials, use \"openhands\" as the user.name and \"openhands@all-hands.dev\" as the user.email by default, unless explicitly instructed otherwise.\n* Exercise caution with git operations. Do NOT make potentially dangerous changes (e.g., pushing to main, deleting repositories) unless explicitly asked to do so.\n* When committing changes, use `git status` to see all modified files, and stage all files necessary for the commit. Use `git commit -a` whenever possible.\n* Do NOT commit files that typically shouldn't go into version control (e.g., node_modules/, .env files, build directories, cache files, large binaries) unless explicitly instructed by the user.\n* If unsure about committing certain files, check for the presence of .gitignore files or ask the user for clarification.\n</VERSION_CONTROL>\n\n<PULL_REQUESTS>\n* When creating pull requests, create only ONE per session/issue unless explicitly instructed otherwise.\n* When working with an existing PR, update it with new commits rather than creating additional PRs for the same issue.\n* When updating a PR, preserve the original PR title and purpose, updating description only when necessary.\n</PULL_REQUESTS>\n\n<PROBLEM_SOLVING_WORKFLOW>\n1. EXPLORATION: Thoroughly explore relevant files and understand the context before proposing solutions\n2. ANALYSIS: Consider multiple approaches and select the most promising one\n3. TESTING:\n   * For bug fixes: Create tests to verify issues before implementing fixes\n   * For new features: Consider test-driven development when appropriate\n   * If the repository lacks testing infrastructure and implementing tests would require extensive setup, consult with the user before investing time in building testing infrastructure\n   * If the environment is not set up to run tests, consult with the user first before investing time to install all dependencies\n4. IMPLEMENTATION: Make focused, minimal changes to address the problem\n5. VERIFICATION: If the environment is set up to run tests, test your implementation thoroughly, including edge cases. If the environment is not set up to run tests, consult with the user first before investing time to run tests.\n</PROBLEM_SOLVING_WORKFLOW>\n\n<SECURITY>\n* Only use GITHUB_TOKEN and other credentials in ways the user has explicitly requested and would expect.\n* Use APIs to work with GitHub or other platforms, unless the user asks otherwise or your task requires browsing.\n</SECURITY>\n\n<ENVIRONMENT_SETUP>\n* When user asks you to run an application, don't stop if the application is not installed. Instead, please install the application and run the command again.\n* If you encounter missing dependencies:\n  1. First, look around in the repository for existing dependency files (requirements.txt, pyproject.toml, package.json, Gemfile, etc.)\n  2. If dependency files exist, use them to install all dependencies at once (e.g., `pip install -r requirements.txt`, `npm install`, etc.)\n  3. Only install individual packages directly if no dependency files are found or if only specific packages are needed\n* Similarly, if you encounter missing dependencies for essential tools requested by the user, install them when possible.\n</ENVIRONMENT_SETUP>\n\n<TROUBLESHOOTING>\n* If you've made repeated attempts to solve a problem but tests still fail or the user reports it's still broken:\n  1. Step back and reflect on 5-7 different possible sources of the problem\n  2. Assess the likelihood of each possible cause\n  3. Methodically address the most likely causes, starting with the highest probability\n  4. Document your reasoning process\n* When you run into any major issue while executing a plan from the user, please don't try to directly work around it. Instead, propose a new plan and confirm with the user before proceeding.\n</TROUBLESHOOTING>\"\"\"\n'''\n\n# https://ollama.com/library/magistral:latest/blobs/35f7a1efc383\nmagistral_ollama = '''\nFROM {__FILE_LOCATION__}\nTEMPLATE \"\"\"\n{{- range $i, $_ := .Messages }}\n{{- $last := eq (len (slice $.Messages $i)) 1}}\n{{- if eq .Role \"system\" }}[SYSTEM_PROMPT]{{ .Content }}[/SYSTEM_PROMPT]\n{{- else if eq .Role \"user\" }}\n{{- if and (le (len (slice $.Messages $i)) 2) $.Tools }}[AVAILABLE_TOOLS]{{ $.Tools }}[/AVAILABLE_TOOLS]\n{{- end }}[INST]{{ .Content }}[/INST]\n{{- else if eq .Role \"assistant\" }}\n{{- if and $.IsThinkSet (and $last .Thinking) -}}\n<think>\n{{ .Thinking }}\n</think>\n{{ end }}\n{{- if .Content }}{{ .Content }}\n{{- end }}\n{{- if .ToolCalls }}{{ range $i, $_ := .ToolCalls }}[TOOL_CALLS]{{ .Function.Name }}[CALL_ID]{{ $i }}[ARGS]{{ .Function.Arguments }}{{ end }}\n{{- end }}\n{{- if not (eq (len (slice $.Messages $i)) 1) }}</s>\n{{- end }}\n{{- else if eq .Role \"tool\" }}[TOOL_RESULTS]0[TOOL_CONTENT]{{ .Content }}[/TOOL_RESULTS]\n{{- end }}\n{{- if and $last (ne .Role \"assistant\") }}{{ if and $.IsThinkSet (not $.Think) -}}<think>\n</think>\n{{ end }}\n{{- end }}\n{{- end }}\"\"\"\nPARAMETER temperature 0.7\nPARAMETER top_p 0.95\nSYSTEM \"\"\"A user will ask you to solve a task. You should first draft your thinking process (inner monologue) until you have derived the final answer. Afterwards, write a self-contained summary of your thoughts (i.e. your summary should be succinct but contain all the critical steps you needed to reach the conclusion). You should use Markdown and Latex to format your response. Write both your thoughts and summary in the same language as the task posed by the user.\n\nYour thinking process must follow the template below:\n<think>\nYour thoughts or/and draft, like working through an exercise on scratch paper. Be as casual and as long as you want until you are confident to generate a correct answer.\n</think>\n\nHere, provide a concise summary that reflects your reasoning and presents a clear final answer to the user.\n\nProblem:\"\"\"\n'''\n\nOLLAMA_TEMPLATES[\"mistral\"] = mistral_ollama\nOLLAMA_TEMPLATES[\"mistral-v03\"] = mistral_v03_ollama\nOLLAMA_TEMPLATES[\"mistral-small\"] = mistral_small_ollama\nOLLAMA_TEMPLATES[\"mistral-small-31\"] = mistral_small_31_ollama\nOLLAMA_TEMPLATES[\"mistral-small-32\"] = mistral_small_32_ollama\nOLLAMA_TEMPLATES[\"mixtral\"] = mixtral_ollama\nOLLAMA_TEMPLATES[\"mistral-nemo\"] = mistral_nemo_ollama\nOLLAMA_TEMPLATES[\"devstral\"] = devstral_ollama\nOLLAMA_TEMPLATES[\"magistral\"] = magistral_ollama\nOLLAMA_TEMPLATES[\"codestral\"] = codestral_ollama\n\n\n# =========================================== Llama-2\n# Ollama from https://www.ollama.com/library/llama3\nllama_ollama = '''\nFROM {__FILE_LOCATION__}\nTEMPLATE \"\"\"[INST] <<SYS>>{{ .System }}<</SYS>>\n\n{{ .Prompt }} [/INST]\"\"\"\nPARAMETER stop \"{__EOS_TOKEN__}\"\nPARAMETER temperature 1.5\nPARAMETER min_p 0.1\n'''\n\nOLLAMA_TEMPLATES[\"llama\"] = llama_ollama\n\n# ===========================================  Vicuna\n# Ollama from https://www.ollama.com/library/vicuna\nvicuna_ollama = '''\nFROM {__FILE_LOCATION__}\nTEMPLATE \"\"\"{{ if .System }}{{ .System }} {{ end }}{{ if .Prompt }}USER: {{ .Prompt }} {{ end }}ASSISTANT: {{ .Response }} {__EOS_TOKEN__}\"\"\"\nPARAMETER stop \"{__EOS_TOKEN__}\"\nPARAMETER temperature 1.5\nPARAMETER min_p 0.1\n'''\n\nOLLAMA_TEMPLATES[\"vicuna\"] = vicuna_ollama\n\n# =========================================== Vicuna Old\nvicuna_old_ollama = '''\nFROM {__FILE_LOCATION__}\nTEMPLATE \"\"\"{{ if .System }}{{ .System }}\n{{ end }}{{ if .Prompt }}### Human: {{ .Prompt }}\n{{ end }}### Assistant: {{ .Response }}{__EOS_TOKEN__}\n\"\"\"\nPARAMETER stop \"{__EOS_TOKEN__}\"\nPARAMETER temperature 1.5\nPARAMETER min_p 0.1\nSYSTEM \"\"\"A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\"\"\"\n'''\n\nOLLAMA_TEMPLATES[\"vicuna_old\"] = vicuna_old_ollama\nOLLAMA_TEMPLATES[\"vicuna old\"] = OLLAMA_TEMPLATES[\"vicuna_old\"]\n\n# =========================================== Alpaca multi turn\nalpaca_ollama = '''\nFROM {__FILE_LOCATION__}\nTEMPLATE \"\"\"{{ if .System }}{{ .System }}\n\n{{ end }}{{ if .Prompt }}### Instruction:\n{{ .Prompt }}{{ end }}\n\n### Response:\n{{ .Response }}{__EOS_TOKEN__}\n\n\"\"\"\nPARAMETER stop \"{__EOS_TOKEN__}\"\nPARAMETER temperature 1.5\nPARAMETER min_p 0.1\nSYSTEM \"\"\"Below are some instructions that describe some tasks. Write responses that appropriately complete each request.\"\"\"\n'''\n\nOLLAMA_TEMPLATES[\"alpaca\"] = alpaca_ollama\n\n# =========================================== Gemma\n# Ollama from https://www.ollama.com/library/gemma\ngemma_ollama = '''\nFROM {__FILE_LOCATION__}\nTEMPLATE \"\"\"<start_of_turn>user\n{{ if .System }}{{ .System }} {{ end }}{{ .Prompt }}<end_of_turn>\n<start_of_turn>model\n{{ .Response }}<end_of_turn>\n\"\"\"\nPARAMETER repeat_penalty 1\nPARAMETER stop \"<start_of_turn>\"\nPARAMETER stop \"<end_of_turn>\"\nPARAMETER penalize_newline false\nPARAMETER temperature 1.5\nPARAMETER min_p 0.1\n'''\n\nOLLAMA_TEMPLATES[\"gemma\"] = gemma_ollama\n\n# =========================================== Gemma with ChatML instead\ngemma_chatml_ollama = '''\nFROM {__FILE_LOCATION__}\nTEMPLATE \"\"\"{{ if .System }}<|im_start|>system\n{{ .System }}<|im_end|>\n{{ end }}{{ if .Prompt }}<|im_start|>user\n{{ .Prompt }}<|im_end|>\n{{ end }}<|im_start|>assistant\n{{ .Response }}<|im_end|>\n\"\"\"\nPARAMETER repeat_penalty 1\nPARAMETER stop \"<|im_start|>\"\nPARAMETER stop \"<|im_end|>\"\nPARAMETER penalize_newline false\nPARAMETER temperature 1.5\nPARAMETER min_p 0.1\n'''\n\nOLLAMA_TEMPLATES[\"gemma_chatml\"] = gemma_chatml_ollama\n\n# =========================================== Gemma 2\n# Same as Gemma 1, but with sliding window attention!\n# https://ollama.com/library/gemma2/blobs/6522ca797f47\ngemma2_ollama = gemma_ollama + \"PARAMETER num_ctx 4096\\n\"\nOLLAMA_TEMPLATES[\"gemma2\"] = gemma2_ollama\n\n# =========================================== Gemma 2 with ChatML instead\ngemma2_chatml_ollama = gemma_chatml_ollama + \"PARAMETER num_ctx 4096\\n\"\nOLLAMA_TEMPLATES[\"gemma2_chatml\"] = gemma2_chatml_ollama\n\n# =========================================== Llama-3\n# Ollama from https://www.ollama.com/library/llama3\nllama3_ollama = '''\nFROM {__FILE_LOCATION__}\nTEMPLATE \"\"\"{{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ .Response }}<|eot_id|>\"\"\"\nPARAMETER num_keep 24\nPARAMETER stop \"<|start_header_id|>\"\nPARAMETER stop \"<|end_header_id|>\"\nPARAMETER stop \"<|eot_id|>\"\nPARAMETER temperature 1.5\nPARAMETER min_p 0.1\n'''\n\nOLLAMA_TEMPLATES[\"llama-3\"] = llama3_ollama\nOLLAMA_TEMPLATES[\"llama3\"] = llama3_ollama\n\n\n# =========================================== Phi-3\n# Ollama from https://www.ollama.com/library/phi3\nphi3_ollama = '''\nFROM {__FILE_LOCATION__}\nTEMPLATE \"\"\"{{ if .System }}<|system|>\n{{ .System }}<|end|>\n{{ end }}{{ if .Prompt }}<|user|>\n{{ .Prompt }}<|end|>\n{{ end }}<|assistant|>\n{{ .Response }}<|end|>\n\"\"\"\nPARAMETER stop \"<|end|>\"\nPARAMETER stop \"<|user|>\"\nPARAMETER stop \"<|assistant|>\"\nPARAMETER temperature 1.5\nPARAMETER min_p 0.1\n'''\n\nOLLAMA_TEMPLATES[\"phi-3\"] = phi3_ollama\nOLLAMA_TEMPLATES[\"phi-35\"] = OLLAMA_TEMPLATES[\"phi-3\"]\nOLLAMA_TEMPLATES[\"phi-3.5\"] = OLLAMA_TEMPLATES[\"phi-3\"]\n\n# =========================================== Llama-3.1\n\"\"\"\nNo trimming in Llama 3.1 Instruct!\nAlso an extra newline for Cutting Knowledge Date\nSee https://colab.research.google.com/drive/1Xpqq5xpIgO-B00MQ-UccYMwN2J8QFgBM?usp=sharing\n\nAlso should be\n\nimport datetime\ntokenizer.apply_chat_template(\n    messages,\n    add_generation_prompt = True,\n    tokenize = False,\n    date_string = datetime.today().strftime(\"%d %B %Y\")),\n)\n\"\"\"\n\n# Ollama from https://ollama.com/library/llama3.1 (needs updating!)\nllama31_ollama = '''\nFROM {__FILE_LOCATION__}\nTEMPLATE \"\"\"{{ if .Messages }}\n{{- if or .System .Tools }}<|start_header_id|>system<|end_header_id|>\n{{- if .System }}\n\n{{ .System }}\n{{- end }}\n{{- if .Tools }}\n\nYou are a helpful assistant with tool calling capabilities. When you receive a tool call response, use the output to format an answer to the original use question.\n{{- end }}\n{{- end }}<|eot_id|>\n{{- range $i, $_ := .Messages }}\n{{- $last := eq (len (slice $.Messages $i)) 1 }}\n{{- if eq .Role \"user\" }}<|start_header_id|>user<|end_header_id|>\n{{- if and $.Tools $last }}\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}. Do not use variables.\n\n{{ $.Tools }}\n{{- end }}\n\n{{ .Content }}<|eot_id|>{{ if $last }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ end }}\n{{- else if eq .Role \"assistant\" }}<|start_header_id|>assistant<|end_header_id|>\n{{- if .ToolCalls }}\n\n{{- range .ToolCalls }}{\"name\": \"{{ .Function.Name }}\", \"parameters\": {{ .Function.Arguments }}}{{ end }}\n{{- else }}\n\n{{ .Content }}{{ if not $last }}<|eot_id|>{{ end }}\n{{- end }}\n{{- else if eq .Role \"tool\" }}<|start_header_id|>ipython<|end_header_id|>\n\n{{ .Content }}<|eot_id|>{{ if $last }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ end }}\n{{- end }}\n{{- end }}\n{{- else }}\n{{- if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ end }}{{ .Response }}{{ if .Response }}<|eot_id|>{{ end }}\"\"\"\nPARAMETER stop \"<|start_header_id|>\"\nPARAMETER stop \"<|end_header_id|>\"\nPARAMETER stop \"<|eot_id|>\"\nPARAMETER stop \"<|eom_id|>\"\nPARAMETER temperature 1.5\nPARAMETER min_p 0.1\n'''\n\n# https://ollama.com/ajindal/llama3.1-storm:8b/blobs/1970553b62f4\nllama_31_storm_ollama = '''\nFROM {__FILE_LOCATION__}\nTEMPLATE \"\"\"\n{{ if .Messages }}\n{{- if or .System .Tools }}<|start_header_id|>system<|end_header_id|>\n{{- if .System }}\n\n{{ .System }}\n{{- end }}\n{{- if .Tools }}\n\nYou are a function calling AI model. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into function. The user may use the terms function calling or tool use interchangeably.\n\nHere are the available functions:\n<tools>{{ json .Tools }}</tools>\n\nFor each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags in the format:\n<tool_call>{\"tool_name\": <function-name>, \"tool_arguments\": <args-dict>}</tool_call>\n{{- end }}\n{{- end }}<|eot_id|>\n{{- range $i, $_ := .Messages }}\n{{- $last := eq (len (slice $.Messages $i)) 1 }}\n{{- if eq .Role \"user\" }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Content }}<|eot_id|>{{ if $last }}<|start_header_id|>assistant<|end_header_id|>\n{{ end }}\n{{- else if eq .Role \"assistant\" }}<|start_header_id|>assistant<|end_header_id|>\n{{- if .ToolCalls }}\n\n{{- range .ToolCalls }}{\"name\": \"{{ .Function.Name }}\", \"parameters\": {{ .Function.Arguments }}}{{ end }}\n{{- else }}\n\n{{ .Content }}{{ if not $last }}<|eot_id|>{{ end }}\n{{- end }}\n{{- else if eq .Role \"tool\" }}<|start_header_id|>ipython<|end_header_id|>\n\n{{ .Content }}<|eot_id|>{{ if $last }}<|start_header_id|>assistant<|end_header_id|>\n{{ end }}\n{{- end }}\n{{- end }}\n{{- else }}\n{{- if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ end }}{{ .Response }}{{ if .Response }}<|eot_id|>{{ end }}\n\"\"\"\nPARAMETER stop \"<|start_header_id|>\"\nPARAMETER stop \"<|end_header_id|>\"\nPARAMETER stop \"<|eot_id|>\"\n'''\n\n# https://ollama.com/library/nemotron:latest/blobs/4863fe3335f3\nllama_31_nemotron_ollama = '''\nFROM {__FILE_LOCATION__}\nTEMPLATE \"\"\"<|start_header_id|>system<|end_header_id|>\n\n{{ if .Tools }}You have access to the following functions. To call a function, please respond with JSON for a function call. Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}. Do not use variables.\n\n{{ range .Tools }}{{ . }}\n\n{{ end }}\n{{- end }}{{ .System }}<|eot_id|>\n{{- range $i, $_ := .Messages }}\n{{- $isLastMessage := eq (len (slice $.Messages $i)) 1 -}}\n{{- if eq .Role \"system\" }}\n{{- else if eq .Role \"assistant\" }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ if .Content }}{{ .Content }}\n{{- else if .ToolCalls }}\n{{- range .ToolCalls }}{\"name\": \"{{ .Function.Name }}\", \"parameters\": {{ .Function.Arguments }} }\n{{- end }}\n{{- end }}\n{{- if not $isLastMessage }}<|eot_id|>\n{{- end }}\n{{- else if eq .Role \"tool\" }}<|start_header_id|>ipython<|end_header_id|>\n\n{{ .Content }}<|eot_id|>\n{{- if $isLastMessage }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ end }}\n{{- else }}<|start_header_id|>{{ .Role }}<|end_header_id|>\n\n{{ .Content }}<|eot_id|>\n{{- if $isLastMessage }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ end }}\n{{- end }}\n{{- end }}\n\"\"\"\nPARAMETER stop \"<|start_header_id|>\"\nPARAMETER stop \"<|end_header_id|>\"\nPARAMETER stop \"<|eot_id|>\"\n'''\n\n# https://ollama.com/library/llama3.2-vision:latest/blobs/715415638c895a1f8e8c6\nllama_32_vision_ollama = '''\nFROM {__FILE_LOCATION__}\nTEMPLATE \"\"\"{{- range $index, $_ := .Messages }}<|start_header_id|>{{ .Role }}<|end_header_id|>\n\n{{ .Content }}\n{{- if gt (len (slice $.Messages $index)) 1 }}<|eot_id|>\n{{- else if ne .Role \"assistant\" }}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{{ end }}\n{{- end }}\"\"\"\nPARAMETER temperature 0.6\nPARAMETER top_p 0.9\n'''\n\nOLLAMA_TEMPLATES[\"llama-3.1\"] = llama31_ollama\nOLLAMA_TEMPLATES[\"llama-31\"] = llama31_ollama\nOLLAMA_TEMPLATES[\"llama-31-nemotron\"] = llama_31_nemotron_ollama\nOLLAMA_TEMPLATES[\"llama-31-storm\"] = llama_31_storm_ollama\nOLLAMA_TEMPLATES[\"llama-32-vision\"] = llama_32_vision_ollama\n\nfor version in (\"llama-3.2\", \"llama-3.3\", \"llama-32\", \"llama-33\"):\n    OLLAMA_TEMPLATES[version] = OLLAMA_TEMPLATES[\"llama-3.1\"]\n\n# =========================================== tinyllama\n# tinyllama-chat https://ollama.com/library/tinyllama:latest/blobs/af0ddbdaaa26\ntinyllama_ollama = '''\nFROM {__FILE_LOCATION__}\nTEMPLATE \"\"\"<|system|>\n{{ .System }}</s>\n<|user|>\n{{ .Prompt }}</s>\n<|assistant|>\"\"\"\nPARAMETER stop \"<|system|>\"\nPARAMETER stop \"<|user|>\"\nPARAMETER stop \"<|assistant|>\"\nPARAMETER stop \"</s>\"\nSYSTEM \"\"\"You are a helpful AI assistant.\"\"\"\n'''\n\nOLLAMA_TEMPLATES[\"tinyllama\"] = tinyllama_ollama\n\n\n# =========================================== Qwen 2/2.5\n# Qwen2 https://ollama.com/library/qwen2:latest/blobs/77c91b422cc9\n# Qwen2.5 from https://ollama.com/library/qwen2.5/blobs/eb4402837c78\nqwen25_ollama = '''\nFROM {__FILE_LOCATION__}\nTEMPLATE \"\"\"{{- if .Messages }}\n{{- if or .System .Tools }}<|im_start|>system\n{{- if .System }}\n{{ .System }}\n{{- end }}\n{{- if .Tools }}\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{{- range .Tools }}\n{\"type\": \"function\", \"function\": {{ .Function }}}\n{{- end }}\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call>\n{{- end }}<|im_end|>\n{{ end }}\n{{- range $i, $_ := .Messages }}\n{{- $last := eq (len (slice $.Messages $i)) 1 -}}\n{{- if eq .Role \"user\" }}<|im_start|>user\n{{ .Content }}<|im_end|>\n{{ else if eq .Role \"assistant\" }}<|im_start|>assistant\n{{ if .Content }}{{ .Content }}\n{{- else if .ToolCalls }}<tool_call>\n{{ range .ToolCalls }}{\"name\": \"{{ .Function.Name }}\", \"arguments\": {{ .Function.Arguments }}}\n{{ end }}</tool_call>\n{{- end }}{{ if not $last }}<|im_end|>\n{{ end }}\n{{- else if eq .Role \"tool\" }}<|im_start|>user\n<tool_response>\n{{ .Content }}\n</tool_response><|im_end|>\n{{ end }}\n{{- if and (ne .Role \"assistant\") $last }}<|im_start|>assistant\n{{ end }}\n{{- end }}\n{{- else }}\n{{- if .System }}<|im_start|>system\n{{ .System }}<|im_end|>\n{{ end }}{{ if .Prompt }}<|im_start|>user\n{{ .Prompt }}<|im_end|>\n{{ end }}<|im_start|>assistant\n{{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }}\"\"\"\nPARAMETER stop \"<|im_end|>\"\nPARAMETER stop \"<|endoftext|>\"\nPARAMETER temperature 1.5\nPARAMETER min_p 0.1\nSYSTEM \"\"\"You are Qwen, created by Alibaba Cloud. You are a helpful assistant.\"\"\"\n'''\n\n# https://ollama.com/library/qwen2.5-coder:latest/blobs/1e65450c3067\nqwen_25_coder_ollama = '''\nFROM {__FILE_LOCATION__}\nTEMPLATE \"\"\"{{- if .Suffix }}<|fim_prefix|>{{ .Prompt }}<|fim_suffix|>{{ .Suffix }}<|fim_middle|>\n{{- else if .Messages }}\n{{- if or .System .Tools }}<|im_start|>system\n{{- if .System }}\n{{ .System }}\n{{- end }}\n{{- if .Tools }}\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools>:\n<tools>\n{{- range .Tools }}\n{\"type\": \"function\", \"function\": {{ .Function }}}\n{{- end }}\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> with NO other text. Do not include any backticks or ```json.\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call>\n{{- end }}<|im_end|>\n{{ end }}\n{{- range $i, $_ := .Messages }}\n{{- $last := eq (len (slice $.Messages $i)) 1 -}}\n{{- if eq .Role \"user\" }}<|im_start|>user\n{{ .Content }}<|im_end|>\n{{ else if eq .Role \"assistant\" }}<|im_start|>assistant\n{{ if .Content }}{{ .Content }}\n{{- else if .ToolCalls }}<tool_call>\n{{ range .ToolCalls }}{\"name\": \"{{ .Function.Name }}\", \"arguments\": {{ .Function.Arguments }}}\n{{ end }}</tool_call>\n{{- end }}{{ if not $last }}<|im_end|>\n{{ end }}\n{{- else if eq .Role \"tool\" }}<|im_start|>user\n<tool_response>\n{{ .Content }}\n</tool_response><|im_end|>\n{{ end }}\n{{- if and (ne .Role \"assistant\") $last }}<|im_start|>assistant\n{{ end }}\n{{- end }}\n{{- else }}\n{{- if .System }}<|im_start|>system\n{{ .System }}<|im_end|>\n{{ end }}{{ if .Prompt }}<|im_start|>user\n{{ .Prompt }}<|im_end|>\n{{ end }}<|im_start|>assistant\n{{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }}\"\"\"\nSYSTEM \"\"\"You are Qwen, created by Alibaba Cloud. You are a helpful assistant.\"\"\"\n'''\n\n# https://ollama.com/library/qwen2.5vl:latest/blobs/a242d8dfdc8f\nqwen_25_vl_ollama = '''\nFROM {__FILE_LOCATION__}\nTEMPLATE \"\"\"{{- if .System -}}\n<|im_start|>system\n{{ .System }}<|im_end|>\n{{- end -}}\n{{- range $i, $_ := .Messages }}\n{{- $last := eq (len (slice $.Messages $i)) 1 -}}\n{{- if eq .Role \"user\" }}\n<|im_start|>user\n{{ .Content }}<|im_end|>\n{{- else if eq .Role \"assistant\" }}\n<|im_start|>assistant\n{{ if .Content }}{{ .Content }}{{ if not $last }}<|im_end|>\n{{- else -}}<|im_end|>{{- end -}}\n{{- end -}}\n{{- end -}}\n{{- if and (ne .Role \"assistant\") $last }}\n<|im_start|>assistant\n{{ end -}}\n{{- end }}\"\"\"\nPARAMETER temperature 0.0001\nSYSTEM \"\"\"You are a helpful assistant.\"\"\"\n'''\n\n# https://ollama.com/library/openthinker:latest/blobs/32695b892af8\nopenthinker_ollama = '''\nFROM {__FILE_LOCATION__}\nTEMPLATE \"\"\"{{- range $i, $_ := .Messages }}\n{{- $last := eq (len (slice $.Messages $i)) 1 -}}\n<|im_start|>{{ .Role }}<|im_sep|>\n{{ .Content }}{{ if not $last }}<|im_end|>\n{{ end }}\n{{- if and (ne .Role \"assistant\") $last }}<|im_end|>\n<|im_start|>assistant<|im_sep|>\n{{ end }}\n{{- end }}\"\"\"\n'''\n\n\nOLLAMA_TEMPLATES[\"qwen-25\"] = qwen25_ollama\nOLLAMA_TEMPLATES[\"qwen-2.5\"] = qwen25_ollama\nOLLAMA_TEMPLATES[\"qwen-25-coder\"] = qwen_25_coder_ollama\nOLLAMA_TEMPLATES[\"qwen-25-vl\"] = qwen_25_vl_ollama\nOLLAMA_TEMPLATES[\"openthinker\"] = openthinker_ollama\nOLLAMA_TEMPLATES[\"qwen-2\"] = qwen25_ollama\n\n# =========================================== Phi-4\n_phi4_ollama_template = (\n    \"{{ if .System }}<|im_start|><|system|><|im_sep|>{{ .System }}<|im_end|>{{ end }}\"\n    \"{{ if .Prompt }}<|im_start|><|user|><|im_sep|>{{ .Prompt }}<|im_end|>{{ end }}\"\n    \"<|im_start|><|assistant|><|im_sep|>{{ .Response }}<|im_end|>\"\n)\n\n# Ollama from https://www.ollama.com/library/phi4 is different\nphi_4_ollama = f'''\nFROM {{__FILE_LOCATION__}}\nTEMPLATE \"\"\"{_phi4_ollama_template}\"\"\"\nPARAMETER stop \"<|im_end|>\"\nPARAMETER stop \"<|im_start|>\"\nPARAMETER stop \"<|im_sep|>\"\nPARAMETER temperature 1.5\nPARAMETER min_p 0.1\n'''\n\n# https://ollama.com/library/phi4-reasoning:latest/blobs/32695b892af8\nphi_4_reasoning_ollama = '''\nFROM {__FILE_LOCATION__}\nTEMPLATE \"\"\"\n{{- range $i, $_ := .Messages }}\n{{- $last := eq (len (slice $.Messages $i)) 1 -}}\n<|im_start|>{{ .Role }}<|im_sep|>\n{{ .Content }}{{ if not $last }}<|im_end|>\n{{ end }}\n{{- if and (ne .Role \"assistant\") $last }}<|im_end|>\n<|im_start|>assistant<|im_sep|>\n{{ end }}\n{{- end }}\"\"\"\nPARAMETER stop \"<|im_start|>\"\nPARAMETER stop \"<|im_end|>\"\nPARAMETER stop \"<|im_sep|>\"\n'''\n\n# https://ollama.com/library/phi4-mini:latest/blobs/813f53fdc6e5\nphi_4_mini_ollama = '''\nFROM {__FILE_LOCATION__}\nTEMPLATE \"\"\"{{- if or .System .Tools }}<|system|>{{ if .System }}{{ .System }}{{ end }}\n{{- if .Tools }}{{ if not .System }}You are a helpful assistant with some tools.{{ end }}<|tool|>{{ .Tools }}<|/tool|><|end|>\n{{- end }}\n{{- end }}\n{{- range $i, $_ := .Messages }}\n{{- $last := eq (len (slice $.Messages $i)) 1 -}}\n{{- if ne .Role \"system\" }}<|{{ .Role }}|>{{ .Content }}\n{{- if .ToolCalls }}<|tool_call|>[{{ range .ToolCalls }}{\"name\":\"{{ .Function.Name }}\",\"arguments\":{{ .Function.Arguments }}{{ end }}]<|/tool_call|>\n{{- end }}\n{{- if not $last }}<|end|>\n{{- end }}\n{{- if and (ne .Role \"assistant\") $last }}<|end|><|assistant|>{{ end }}\n{{- end }}\n{{- end }}\"\"\"\n'''\n\n# https://ollama.com/library/phi4-mini-reasoning:latest/blobs/c895a1f8e8c6\nphi_4_mini_reasoning_ollama = '''\nFROM {__FILE_LOCATION__}\nTEMPLATE \"\"\"\n{{- if .System }}<|system|>{{ .System }}\n{{- end }}\n{{- range $i, $_ := .Messages }}\n{{- $last := eq (len (slice $.Messages $i)) 1 -}}\n{{- if ne .Role \"system\" }}<|{{ .Role }}|>{{ .Content }}\n{{- if not $last }}<|end|>\n{{- end }}\n{{- if and (ne .Role \"assistant\") $last }}<|end|><|assistant|>{{ end }}\n{{- end }}\n{{- end }}\"\"\"\nSYSTEM \"\"\"Your name is Phi, an AI math expert developed by Microsoft.\"\"\"\n'''\nOLLAMA_TEMPLATES[\"phi-4\"] = phi_4_ollama\nOLLAMA_TEMPLATES[\"phi-4-reasoning\"] = phi_4_reasoning_ollama\nOLLAMA_TEMPLATES[\"phi-4-mini\"] = phi_4_mini_ollama\nOLLAMA_TEMPLATES[\"phi-4-mini-reasoning\"] = phi_4_mini_reasoning_ollama\n\n\n# =========================================== Gemma-3\n# Ollama from https://ollama.com/library/gemma3/blobs/e0a42594d802\ngemma3_ollama = '''\nFROM {__FILE_LOCATION__}\nTEMPLATE \"\"\"{{- range $i, $_ := .Messages }}\n{{- $last := eq (len (slice $.Messages $i)) 1 }}\n{{- if or (eq .Role \"user\") (eq .Role \"system\") }}<start_of_turn>user\n{{ .Content }}<end_of_turn>\n{{ if $last }}<start_of_turn>model\n{{ end }}\n{{- else if eq .Role \"assistant\" }}<start_of_turn>model\n{{ .Content }}{{ if not $last }}<end_of_turn>\n{{ end }}\n{{- end }}\n{{- end }}\"\"\"\nPARAMETER stop \"<end_of_turn>\"\nPARAMETER stop \"<eos>\"\nPARAMETER temperature 1.0\nPARAMETER min_p 0.0\nPARAMETER top_k 64\nPARAMETER top_p 0.95\nPARAMETER num_predict 32768\n'''\n\n# https://ollama.com/library/gemma3:270m/blobs/4b19ac7dd2fb\ngemma3_270m_ollama = '''\nFROM {__FILE_LOCATION__}\nTEMPLATE \"\"\"{{- $systemPromptAdded := false }}\n{{- range $i, $_ := .Messages }}\n{{- $last := eq (len (slice $.Messages $i)) 1 }}\n{{- if eq .Role \"user\" }}<start_of_turn>user\n{{- if (and (not $systemPromptAdded) $.System) }}\n{{- $systemPromptAdded = true }}\n{{ $.System }}\n{{ end }}\n{{ .Content }}<end_of_turn>\n{{ if $last }}<start_of_turn>model\n{{ end }}\n{{- else if eq .Role \"assistant\" }}<start_of_turn>model\n{{ .Content }}{{ if not $last }}<end_of_turn>\n{{ end }}\n{{- end }}\n{{- end }}\n\"\"\"\nPARAMETER stop \"<end_of_turn>\"\nPARAMETER top_k 64\nPARAMETER top_p 0.95\n'''\n\nOLLAMA_TEMPLATES[\"gemma-3\"] = gemma3_ollama\nOLLAMA_TEMPLATES[\"gemma3\"] = gemma3_ollama\nOLLAMA_TEMPLATES[\"gemma3-270m\"] = gemma3_270m_ollama\n\n\n# =========================================== Qwen-3\n# Ollama template for Qwen-3 (see https://ollama.com/library/qwen3/blobs/eb4402837c78)\nqwen3_ollama = '''\nFROM {__FILE_LOCATION__}\nTEMPLATE \"\"\"{{- if .Messages }}\n{{- if or .System .Tools }}<|im_start|>system\n{{- if .System }}\n{{ .System }}\n{{- end }}\n{{- if .Tools }}\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{{- range .Tools }}\n{\"type\": \"function\", \"function\": {{ .Function }}}\n{{- end }}\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call>\n{{- end }}<|im_end|>\n{{ end }}\n{{- range $i, $_ := .Messages }}\n{{- $last := eq (len (slice $.Messages $i)) 1 -}}\n{{- if eq .Role \"user\" }}<|im_start|>user\n{{ .Content }}<|im_end|>\n{{ else if eq .Role \"assistant\" }}<|im_start|>assistant\n{{ if .Content }}{{ .Content }}\n{{- else if .ToolCalls }}<tool_call>\n{{ range .ToolCalls }}{\"name\": \"{{ .Function.Name }}\", \"arguments\": {{ .Function.Arguments }}}\n{{ end }}</tool_call>\n{{- end }}{{ if not $last }}<|im_end|>\n{{ end }}\n{{- else if eq .Role \"tool\" }}<|im_start|>user\n<tool_response>\n{{ .Content }}\n</tool_response><|im_end|>\n{{ end }}\n{{- if and (ne .Role \"assistant\") $last }}<|im_start|>assistant\n{{ end }}\n{{- end }}\n{{- else }}\n{{- if .System }}<|im_start|>system\n{{ .System }}<|im_end|>\n{{ end }}{{ if .Prompt }}<|im_start|>user\n{{ .Prompt }}<|im_end|>\n{{ end }}<|im_start|>assistant\n{{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }}\"\"\"\nPARAMETER stop \"<|im_end|>\"\nPARAMETER stop \"<|im_start|>\"\nPARAMETER temperature 0.6\nPARAMETER min_p 0.0\nPARAMETER top_k 20\nPARAMETER top_p 0.95\nPARAMETER repeat_penalty 1\n'''\n\nqwen3_template_eos_token = \"<|im_end|>\"\nOLLAMA_TEMPLATES[\"qwen-3\"] = qwen3_ollama\nOLLAMA_TEMPLATES[\"qwen3\"] = qwen3_ollama\n\n\n# =========================================== Gemma-3n\n# Ollama from https://ollama.com/library/gemma3n/blobs/e0a42594d802\ngemma3n_ollama = '''\nFROM {__FILE_LOCATION__}\nTEMPLATE \"\"\"{{- range $i, $_ := .Messages }}\n{{- $last := eq (len (slice $.Messages $i)) 1 }}\n{{- if or (eq .Role \"user\") (eq .Role \"system\") }}<start_of_turn>user\n{{ .Content }}<end_of_turn>\n{{ if $last }}<start_of_turn>model\n{{ end }}\n{{- else if eq .Role \"assistant\" }}<start_of_turn>model\n{{ .Content }}{{ if not $last }}<end_of_turn>\n{{ end }}\n{{- end }}\n{{- end }}\"\"\"\n'''\n\nOLLAMA_TEMPLATES[\"gemma-3n\"] = gemma3n_ollama\nOLLAMA_TEMPLATES[\"gemma3n\"] = gemma3n_ollama\n\n# =========================================== GPT-OSS\n\n# Ollama from https://ollama.com/library/gpt-oss:latest/blobs/fa6710a93d78\ngptoss_ollama = '''\nFROM {__FILE_LOCATION__}\nTEMPLATE \"\"\"<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.\nKnowledge cutoff: 2024-06\nCurrent date: {{ currentDate }}\n{{- if and .IsThinkSet .Think (ne .ThinkLevel \"\") }}\n\nReasoning: {{ .ThinkLevel }}\n{{- else if or (not .IsThinkSet) (and .IsThinkSet .Think) }}\n\nReasoning: medium\n{{- end }}\n\n{{- $hasNonBuiltinTools := false }}\n{{- if .Tools -}}\n{{- $hasBrowserSearch := false }}\n{{- $hasBrowserOpen := false }}\n{{- $hasBrowserFind := false }}\n{{- $hasPython := false }}\n  {{- range .Tools }}\n    {{- if eq .Function.Name \"browser.search\" -}}{{- $hasBrowserSearch = true -}}\n    {{- else if eq .Function.Name \"browser.open\" -}}{{- $hasBrowserOpen = true -}}\n    {{- else if eq .Function.Name \"browser.find\" -}}{{- $hasBrowserFind = true -}}\n    {{- else if eq .Function.Name \"python\" -}}{{- $hasPython = true -}}\n    {{- else }}{{ $hasNonBuiltinTools = true -}}\n    {{- end }}\n  {{- end }}\n{{- if or $hasBrowserSearch $hasBrowserOpen $hasBrowserFind $hasPython }}\n\n# Tools\n{{- if or $hasBrowserSearch $hasBrowserOpen $hasBrowserFind }}\n\n## browser\n\n// Tool for browsing.\n// The `cursor` appears in brackets before each browsing display: `[{cursor}]`.\n// Cite information from the tool using the following format:\n// `【{cursor}†L{line_start}(-L{line_end})?】`, for example: `【6†L9-L11】` or `【8†L3】`.\n// Do not quote more than 10 words directly from the tool output.\n// sources=web (default: web)\nnamespace browser {\n{{- if $hasBrowserSearch }}\n\n// Searches for information related to `query` and displays `topn` results.\ntype search = (_: {\nquery: string,\ntopn?: number, // default: 10\nsource?: string,\n}) => any;\n{{- end }}\n{{- if $hasBrowserOpen }}\n\n// Opens the link `id` from the page indicated by `cursor` starting at line number `loc`, showing `num_lines` lines.\n// Valid link ids are displayed with the formatting: `【{id}†.*】`.\n// If `cursor` is not provided, the most recent page is implied.\n// If `id` is a string, it is treated as a fully qualified URL associated with `source`.\n// If `loc` is not provided, the viewport will be positioned at the beginning of the document or centered on the most relevant passage, if available.\n// Use this function without `id` to scroll to a new location of an opened page.\ntype open = (_: {\nid?: number | string, // default: -1\ncursor?: number, // default: -1\nloc?: number, // default: -1\nnum_lines?: number, // default: -1\nview_source?: boolean, // default: false\nsource?: string,\n}) => any;\n{{- end }}\n{{- if $hasBrowserFind }}\n\n// Finds exact matches of `pattern` in the current page, or the page given by `cursor`.\ntype find = (_: {\npattern: string,\ncursor?: number, // default: -1\n}) => any;\n{{- end }}\n\n} // namespace browser\n{{- end }}{{/* end if has browser tools */}}\n{{- if $hasPython }}\n\n## python\n\nUse this tool to execute Python code in your chain of thought. The code will not be shown to the user. This tool should be used for internal reasoning, but not for code that is intended to be visible to the user (e.g. when creating plots, tables, or files).\n\nWhen you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 120.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is UNKNOWN. Depends on the cluster.\n{{- end }}{{/* end if hasPython */}}\n{{- end }}{{/* end if has any built-in tools */}}\n{{- end }}{{/* end if .Tools */}}\n\n# Valid channels: analysis, commentary, final. Channel must be included for every message.{{ if $hasNonBuiltinTools }}\nCalls to these tools must go to the commentary channel: 'functions'.\n{{- end -}}<|end|>{{/* end of system */ -}}\n{{- if or $hasNonBuiltinTools .System -}}\n<|start|>developer<|message|>{{- if $hasNonBuiltinTools }}# Tools\n\n## functions\n\nnamespace functions {\n{{- range .Tools }}\n{{- if not (or (eq .Function.Name \"browser.search\") (eq .Function.Name \"browser.open\") (eq .Function.Name \"browser.find\") (eq .Function.Name \"python\")) }}\n{{if .Function.Description }}\n// {{ .Function.Description }}\n{{- end }}\n{{- if and .Function.Parameters.Properties (gt (len .Function.Parameters.Properties) 0) }}\ntype {{ .Function.Name }} = (_: {\n{{- range $name, $prop := .Function.Parameters.Properties }}\n{{- if $prop.Description }}\n  // {{ $prop.Description }}\n{{- end }}\n  {{ $name }}: {{ if gt (len $prop.Type) 1 }}{{ range $i, $t := $prop.Type }}{{ if $i }} | {{ end }}{{ $t }}{{ end }}{{ else }}{{ index $prop.Type 0 }}{{ end }},\n{{- end }}\n}) => any;\n{{- else }}\ntype {{ .Function.Name }} = () => any;\n{{- end }}\n{{- end }}{{/* end if not browser tool */}}\n{{- end }}{{/* end of range .Tools */}}\n\n} // namespace functions\n{{- end }}{{/* end if hasNonBuiltinTools */}}\n{{- if .System}}\n\n# Instructions\n\n{{ .System }}\n{{- end -}}\n<|end|>\n{{- end -}}\n{{- /* Find the index of the last user message */ -}}\n{{- $lastUserIdx := -1 }}\n{{- $prefillingContent := false }}\n{{- $prefillingThinkingOnly := false }}\n{{- range $i, $msg := .Messages }}\n  {{- $last := eq (len (slice $.Messages $i)) 1 -}}\n  {{- if eq $msg.Role \"user\" }}\n    {{- $lastUserIdx = $i }}\n  {{- end -}}\n  {{- if and $last (eq $msg.Role \"assistant\") (gt (len $msg.Content) 0) }}\n    {{- $prefillingContent = true }}\n  {{- else if and $last (eq $msg.Role \"assistant\") (gt (len $msg.Thinking) 0) }}\n    {{- $prefillingThinkingOnly = true }}\n  {{- end }}\n{{- end -}}\n{{- /* Now render messages */ -}}\n{{- range $i, $msg := .Messages }}\n  {{- $last := eq (len (slice $.Messages $i)) 1 -}}\n  {{- if (ne $msg.Role \"system\") -}}\n    {{- if eq $msg.Role \"tool\" -}}\n      {{- if or (eq $msg.ToolName \"python\") (eq $msg.ToolName \"browser.search\") (eq $msg.ToolName \"browser.open\") (eq $msg.ToolName \"browser.find\") -}}\n        <|start|>{{ $msg.ToolName }} to=assistant<|message|>{{ $msg.Content }}<|end|>\n      {{- else -}}\n        <|start|>functions.{{ $msg.ToolName }} to=assistant<|message|>{{ $msg.Content }}<|end|>\n      {{- end -}}\n    {{- else if eq $msg.Role \"assistant\" -}}\n      {{- if and $msg.Thinking (gt $i $lastUserIdx) -}}{{- /* Show thinking only after last user message */ -}}\n      <|start|>assistant<|channel|>analysis<|message|>{{ $msg.Thinking }}{{- if not $prefillingThinkingOnly -}}<|end|>{{- end -}}\n      {{- end -}}\n      {{- if gt (len $msg.Content) 0 -}}\n        <|start|>assistant<|channel|>final<|message|>{{ $msg.Content }}{{- if not $prefillingContent -}}<|end|>{{- end -}}\n      {{- end -}}\n      {{- if gt (len $msg.ToolCalls) 0 -}}\n        {{- range $j, $toolCall := $msg.ToolCalls -}}\n          {{- $isBuiltin := or (eq $toolCall.Function.Name \"python\") (eq $toolCall.Function.Name \"browser.search\") (eq $toolCall.Function.Name \"browser.open\") (eq $toolCall.Function.Name \"browser.find\") -}}\n          <|start|>assistant<|channel|>{{ if $isBuiltin }}analysis{{ else }}commentary{{ end }} to={{ if not $isBuiltin}}functions.{{end}}{{ $toolCall.Function.Name }} <|constrain|>json<|message|>{{ $toolCall.Function.Arguments }}<|call|>\n        {{- end -}}\n      {{- end -}}\n    {{- else if eq $msg.Role \"user\" -}}\n      <|start|>{{ $msg.Role }}<|message|>{{ $msg.Content }}<|end|>\n    {{- end }}\n  {{- else }}\n  {{- end }}\n{{- end -}}\n{{- if not (or $prefillingContent $prefillingThinkingOnly) -}}\n<|start|>assistant\n{{- end -}}\"\"\"\nPARAMETER temperature 1.0\nPARAMETER top_k 0\nPARAMETER top_p 1.0\n'''\n\nOLLAMA_TEMPLATES[\"gpt-oss\"] = gptoss_ollama\nOLLAMA_TEMPLATES[\"gptoss\"] = gptoss_ollama\n\n\n# =========================================== Qwen3\n\n# Ollama from https://ollama.com/library/qwen3/blobs/53e4ea15e8f5\nqwen3_ollama = '''\nFROM {__FILE_LOCATION__}\nTEMPLATE \"\"\"\n{{- $lastUserIdx := -1 -}}\n{{- range $idx, $msg := .Messages -}}\n{{- if eq $msg.Role \"user\" }}{{ $lastUserIdx = $idx }}{{ end -}}\n{{- end }}\n{{- if or .System .Tools }}<|im_start|>system\n{{ if .System }}\n{{ .System }}\n{{- end }}\n{{- if .Tools }}\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{{- range .Tools }}\n{\"type\": \"function\", \"function\": {{ .Function }}}\n{{- end }}\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call>\n{{- end -}}\n<|im_end|>\n{{ end }}\n{{- range $i, $_ := .Messages }}\n{{- $last := eq (len (slice $.Messages $i)) 1 -}}\n{{- if eq .Role \"user\" }}<|im_start|>user\n{{ .Content }}<|im_end|>\n{{ else if eq .Role \"assistant\" }}<|im_start|>assistant\n{{ if (and $.IsThinkSet (and .Thinking (or $last (gt $i $lastUserIdx)))) -}}\n<think>{{ .Thinking }}</think>\n{{ end -}}\n{{ if .Content }}{{ .Content }}\n{{- else if .ToolCalls }}<tool_call>\n{{ range .ToolCalls }}{\"name\": \"{{ .Function.Name }}\", \"arguments\": {{ .Function.Arguments }}}\n{{ end }}</tool_call>\n{{- end }}{{ if not $last }}<|im_end|>\n{{ end }}\n{{- else if eq .Role \"tool\" }}<|im_start|>user\n<tool_response>\n{{ .Content }}\n</tool_response><|im_end|>\n{{ end }}\n{{- if and (ne .Role \"assistant\") $last }}<|im_start|>assistant\n{{ end }}\n{{- end }}\n\"\"\"\n'''\n\nOLLAMA_TEMPLATES[\"qwen3-instruct\"] = qwen3_ollama\nOLLAMA_TEMPLATES[\"qwen3-thinking\"] = qwen3_ollama\n\n\n# =========================================== Starling-LM\n\n\n# Ollama from https://ollama.com/library/starling-lm:7b/blobs/4b21bfc435b4\nstarling_ollama = '''\nFROM {__FILE_LOCATION__}\nTEMPLATE \"\"\"{{ if .System }}GPT4 Correct System: {{ .System }}<|end_of_turn|>\n{{ end }}{{ if .Prompt }}GPT4 Correct User: {{ .Prompt }}<|end_of_turn|>\n{{ end }}GPT4 Correct Assistant: {{ .Response }}<|end_of_turn|>\"\"\"\nPARAMETER stop \"<|end_of_turn|>\"\nPARAMETER stop \"GPT4 Correct User:\"\nPARAMETER stop \"GPT4 Correct Assistant:\"\nPARAMETER stop \"GPT4 Correct System:\"\nPARAMETER temperature 1.5\nPARAMETER min_p 0.1\n'''\n\nOLLAMA_TEMPLATES[\"starling\"] = starling_ollama\n\n\n# =========================================== Yi-chat\n\n\n# Ollama from https://ollama.com/library/yi:34b-chat/blobs/62fbfd9ed093\nyi_chat_ollama = '''\nFROM {__FILE_LOCATION__}\nTEMPLATE \"\"\"{{ if .System }}<|im_start|>system\n{{ .System }}<|im_end|>\n{{ end }}{{ if .Prompt }}<|im_start|>user\n{{ .Prompt }}<|im_end|>\n{{ end }}<|im_start|>assistant\n{{ .Response }}<|im_end|>\"\"\"\n'''\n\nOLLAMA_TEMPLATES[\"yi-chat\"] = yi_chat_ollama\n\n# =========================================== Granite\n\n# Ollama from https://ollama.com/library/granite3.2:latest/blobs/3e7ca51acd6e\ngranite_32_ollama = '''\nFROM {__FILE_LOCATION__}\nTEMPLATE \"\"\"{{- /*\n\n------ MESSAGE PARSING ------\n\n*/}}\n{{- /*\nDeclare the prompt structure variables to be filled in from messages\n*/}}\n{{- $system := \"\" }}\n{{- $documents := \"\" }}\n{{- $documentCounter := 0 }}\n{{- $thinking := false }}\n{{- $citations := false }}\n{{- $hallucinations := false }}\n{{- $length := \"\" }}\n\n{{- /*\nLoop over messages and look for a user-provided system message and documents\n*/ -}}\n{{- range .Messages }}\n\n    {{- /* User defined system prompt(s) */}}\n    {{- if (eq .Role \"system\")}}\n        {{- if (ne $system \"\") }}\n            {{- $system = print $system \" \" }}\n        {{- end}}\n        {{- $system = print $system .Content }}\n    {{- end}}\n\n    {{- /*\n    NOTE: Since Ollama collates consecutive roles, for control and documents, we\n        work around this by allowing the role to contain a qualifier after the\n        role string.\n    */ -}}\n\n    {{- /* Role specified thinking */ -}}\n    {{- if (and (ge (len .Role) 7) (eq (slice .Role 0 7) \"control\")) }}\n        {{- if (eq .Content \"thinking\")}}{{- $thinking = true }}{{- end}}\n        {{- if (eq .Content \"citations\")}}{{- $citations = true }}{{- end}}\n        {{- if (eq .Content \"hallucinations\")}}{{- $hallucinations = true }}{{- end}}\n        {{- if (and (ge (len .Content) 7) (eq (slice .Content 0 7) \"length \"))}}\n            {{- $length = print ` {\"length\": \"` (slice .Content 7) `\"}` }}\n        {{- end}}\n    {{- end}}\n\n    {{- /* Role specified document */ -}}\n    {{- if (and (ge (len .Role) 8) (eq (slice .Role 0 8) \"document\")) }}\n        {{- if (ne $documentCounter 0)}}\n            {{- $documents = print $documents \" \"}}\n        {{- end}}\n        {{- $identifier := $documentCounter}}\n        {{- if (ge (len .Role) 9) }}\n            {{- $identifier = (slice .Role 8)}}\n        {{- end}}\n        {{- $documents = print $documents \"Document \" $identifier \"\" .Content}}\n        {{- $documentCounter = len (printf \"a%*s\" $documentCounter \"\")}}\n    {{- end}}\n{{- end}}\n\n{{- /*\nIf no user message provided, build the default system message\n*/ -}}\n{{- if eq $system \"\" }}\n    {{- $system = \"Knowledge Cutoff Date: April 2024.You are Granite, developed by IBM.\"}}\n\n    {{- /* Add Tools prompt */}}\n    {{- if .Tools }}\n        {{- $system = print $system \" You are a helpful AI assistant with access to the following tools. When a tool is required to answer the user's query, respond with <|tool_call|> followed by a JSON list of tools used. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request.\" }}\n    {{- end}}\n\n    {{- /* Add documents prompt */}}\n    {{- if $documents }}\n        {{- if .Tools }}\n            {{- $system = print $system \" \"}}\n        {{- else }}\n            {{- $system = print $system \" \"}}\n        {{- end}}\n        {{- $system = print $system \"Write the response to the user's input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data.\" }}\n        {{- if $citations}}\n            {{- $system = print $system \" In your response, use the symbols <co> and </co> to indicate when a fact comes from a document in the search result, e.g <co>0</co> for a fact from document 0. Afterwards, list all the citations with their corresponding documents in an ordered list.\"}}\n        {{- end}}\n        {{- if $hallucinations}}\n            {{- $system = print $system \"Finally, after the response is written, include a numbered list of sentences from the response that are potentially hallucinated and not based in the documents.\"}}\n        {{- end}}\n    {{- end}}\n\n    {{- /* Prompt without tools or documents */}}\n    {{- if (and (not .Tools) (not $documents)) }}\n        {{- $system = print $system \" You are a helpful AI assistant.\"}}\n        {{- if $thinking}}\n            {{- $system = print $system \"Respond to every user query in a comprehensive and detailed way. You can write down your thought process before responding. Write your thoughts after 'Here is my thought process:' and write your response after 'Here is my response:' for each user query.\"}}\n        {{- end}}\n    {{- end}}\n\n    {{- /* Add thinking prompt if no tools or documents */}}\n    {{- if (and $thinking (not .Tools) (not $documents)) }}\n        {{- $system = print $system \" You are a helpful AI assistant.Respond to every user query in a comprehensive and detailed way. You can write down your thought process before responding. Write your thoughts after 'Here is my thought process:' and write your response after 'Here is my response:' for each user query.\"}}\n    {{- end}}\n\n{{- end}}\n{{- /*\n\n------ TEMPLATE EXPANSION ------\n\n*/}}\n{{- /* System Prompt */ -}}\n<|start_of_role|>system<|end_of_role|>{{- $system }}<|end_of_text|>\n\n{{- /* Tools */ -}}\n{{- if .Tools }}\n<|start_of_role|>tools<|end_of_role|>[\n{{- range $index, $_ := .Tools }}\n{{ . }}\n{{- if and (ne (len (slice $.Tools $index)) 1) (gt (len $.Tools) 1) }},\n{{- end}}\n{{- end }}\n]\n{{- end}}\n\n{{- /* Documents */ -}}\n{{- if $documents }}\n<|start_of_role|>documents<|end_of_role|>\n{{ $documents }}<|end_of_text|>\n{{- end}}\n\n{{- /* Standard Messages */}}\n{{- range $index, $_ := .Messages }}\n{{- if (and\n    (ne .Role \"system\")\n    (or (lt (len .Role) 7) (ne (slice .Role 0 7) \"control\"))\n    (or (lt (len .Role) 8) (ne (slice .Role 0 8) \"document\"))\n)}}\n<|start_of_role|>\n{{- if eq .Role \"tool\" }}tool_response\n{{- else }}{{ .Role }}\n{{- end }}<|end_of_role|>\n{{- if .Content }}{{ .Content }}\n{{- else if .ToolCalls }}<|tool_call|>\n{{- range .ToolCalls }}{\"name\": \"{{ .Function.Name }}\", \"arguments\": {{ .Function.Arguments }}}\n{{- end }}\n{{- end }}\n{{- if eq (len (slice $.Messages $index)) 1 }}\n{{- if eq .Role \"assistant\" }}\n{{- else }}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>\n{{- end -}}\n{{- else }}<|end_of_text|>\n{{- end }}\n{{- end }}\n{{- end }}\n\"\"\"\n'''\n\n# granite-3.2-vision https://ollama.com/library/granite3.2-vision:latest/blobs/579046ba1157\ngranite_32_vision_ollama = '''\nFROM {__FILE_LOCATION__}\nTEMPLATE \"\"\"{{- /* Tools */ -}}\n{{- if .Tools -}}\n<|start_of_role|>available_tools<|end_of_role|>\n{{- range $index, $_ := .Tools }}\n{{- $last := eq (len (slice $.Tools $index)) 1 }}\n{{ . }}\n{{- if not $last }}\n{{ end}}\n{{- end -}}\n<|end_of_text|>\n{{ end }}\n\n{{- /* System Prompt */ -}}\n{{- if and (gt (len .Messages) 0) (eq (index .Messages 0).Role \"system\") -}}\n<|system|>\n{{(index .Messages 0).Content}}\n{{- else -}}\n<|system|>\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n{{- end }}\n\n{{- /*Main message loop*/ -}}\n{{- range $index, $_ := .Messages }}\n{{- $last := eq (len (slice $.Messages $index)) 1 }}\n{{- if eq .Role \"system\" }}\n\n{{- else if eq .Role \"user\" }}\n<|user|>\n{{.Content}}\n\n{{- else if eq .Role \"assistant\" }}\n<|assistant|>\n{{- if .Content }}\n{{.Content}}\n<|end_of_text|>\n{{ end }}\n\n{{- else if eq .Role \"assistant_tool_call\" }}\n<|start_of_role|>assistant<|end_of_role|><|tool_call|>{{.Content}}<|end_of_text|>\n\n{{- else if eq .Role \"tool_response\" }}\n<|start_of_role|>tool_response<|end_of_role|>{{.Content}}<|end_of_text|>\n{{- end }}\n\n{{- /* Add generation prompt */ -}}\n{{ if $last }}\n{{- if eq .Role \"assistant\" }}\n{{- else }}\n<|assistant|>\n{{- end }}\n{{- end }}\n{{- end }}\"\"\"\nPARAMETER num_ctx 16384\nPARAMETER temperature 0\nSYSTEM \"\"\"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\"\"\"\n'''\n\nOLLAMA_TEMPLATES[\"granite-32\"] = granite_32_ollama\nOLLAMA_TEMPLATES[\"granite-32-vision\"] = granite_32_vision_ollama\n\n\nOLLAMA_TEMPLATE_TO_MODEL_MAPPER = {\n    \"phi-3.5\": (\n        \"unsloth/Phi-3.5-mini-instruct-bnb-4bit\",\n        \"unsloth/Phi-3.5-mini-instruct\",\n        \"microsoft/Phi-3.5-mini-instruct\",\n    ),\n    \"phi-3\": (\n        \"unsloth/Phi-3-mini-4k-instruct-bnb-4bit\",\n        \"unsloth/Phi-3-mini-4k-instruct\",\n        \"microsoft/Phi-3-mini-4k-instruct\",\n        \"unsloth/Phi-3-medium-4k-instruct-bnb-4bit\",\n        \"unsloth/Phi-3-medium-4k-instruct\",\n        \"microsoft/Phi-3-medium-4k-instruct\",\n        \"unsloth/Phi-3-mini-4k-instruct-v0-bnb-4bit\",\n        \"unsloth/Phi-3-mini-4k-instruct-v0\",\n    ),\n    \"phi-4\": (\n        \"unsloth/phi-4-unsloth-bnb-4bit\",\n        \"unsloth/phi-4\",\n        \"microsoft/phi-4\",\n        \"unsloth/phi-4-bnb-4bit\",\n    ),\n    \"phi-4-reasoning\": (\n        \"unsloth/phi-4-reasoning-unsloth-bnb-4bit\",\n        \"unsloth/phi-4-reasoning\",\n        \"microsoft/Phi-4-reasoning\",\n        \"unsloth/phi-4-reasoning-bnb-4bit\",\n        \"unsloth/phi-4-reasoning-plus-unsloth-bnb-4bit\",\n        \"unsloth/phi-4-reasoning-plus\",\n        \"microsoft/Phi-4-reasoning-plus\",\n        \"unsloth/phi-4-reasoning-plus-bnb-4bit\",\n    ),\n    \"phi-4-mini\": (\n        \"unsloth/Phi-4-mini-instruct-unsloth-bnb-4bit\",\n        \"unsloth/Phi-4-mini-instruct\",\n        \"microsoft/Phi-4-mini-instruct\",\n        \"unsloth/Phi-4-mini-instruct-bnb-4bit\",\n    ),\n    \"phi-4-mini-reasoning\": (\n        \"unsloth/phi-4-mini-reasoning-unsloth-bnb-4bit\",\n        \"unsloth/phi-4-mini-reasoning\",\n        \"microsoft/Phi-4-mini-reasoning\",\n        \"unsloth/phi-4-mini-reasoning-bnb-4bit\",\n    ),\n    \"mistral\": (\n        \"unsloth/mistral-7b-instruct-v0.1-bnb-4bit\",\n        \"unsloth/mistral-7b-instruct-v0.1\",\n        \"mistralai/Mistral-7B-Instruct-v0.1\",\n        \"unsloth/mistral-7b-instruct-v0.2-bnb-4bit\",\n        \"unsloth/mistral-7b-instruct-v0.2\",\n        \"mistralai/Mistral-7B-Instruct-v0.2\",\n    ),\n    \"mistral-v03\": (\n        \"unsloth/mistral-7b-instruct-v0.3-bnb-4bit\",\n        \"unsloth/mistral-7b-instruct-v0.3\",\n        \"mistralai/Mistral-7B-Instruct-v0.3\",\n        \"unsloth/Mistral-Large-Instruct-2407-bnb-4bit\",\n        \"mistralai/Mistral-Large-Instruct-2407\",\n    ),\n    \"mistral-small\": (\n        \"unsloth/Mistral-Small-Instruct-2409-bnb-4bit\",\n        \"unsloth/Mistral-Small-Instruct-2409\",\n        \"mistralai/Mistral-Small-Instruct-2409\",\n        \"unsloth/Mistral-Small-24B-Instruct-2501-unsloth-bnb-4bit\",\n        \"unsloth/Mistral-Small-24B-Instruct-2501\",\n        \"mistralai/Mistral-Small-24B-Instruct-2501\",\n        \"unsloth/Mistral-Small-24B-Instruct-2501-bnb-4bit\",\n    ),\n    \"mistral-small-31\": (\n        \"unsloth/Mistral-Small-3.1-24B-Instruct-2503-unsloth-bnb-4bit\",\n        \"unsloth/Mistral-Small-3.1-24B-Instruct-2503\",\n        \"mistralai/Mistral-Small-3.1-24B-Instruct-2503\",\n        \"unsloth/Mistral-Small-3.1-24B-Instruct-2503-bnb-4bit\",\n    ),\n    \"mistral-small-32\": (\n        \"unsloth/Mistral-Small-3.2-24B-Instruct-2506-unsloth-bnb-4bit\",\n        \"unsloth/Mistral-Small-3.2-24B-Instruct-2506\",\n        \"mistralai/Mistral-Small-3.2-24B-Instruct-2506\",\n        \"unsloth/Mistral-Small-3.2-24B-Instruct-2506-bnb-4bit\",\n    ),\n    \"mixtral\": (\n        \"unsloth/Mixtral-8x7B-Instruct-v0.1-unsloth-bnb-4bit\",\n        \"unsloth/Mixtral-8x7B-Instruct-v0.1\",\n        \"mistralai/Mixtral-8x7B-Instruct-v0.1\",\n        \"unsloth/Mixtral-8x7B-Instruct-v0.1-bnb-4bit\",\n    ),\n    \"mistral-nemo\": (\n        \"unsloth/Mistral-Nemo-Instruct-2407-bnb-4bit\",\n        \"unsloth/Mistral-Nemo-Instruct-2407\",\n        \"mistralai/Mistral-Nemo-Instruct-2407\",\n    ),\n    \"codestral\": (\n        \"mistralai/Codestral-22B-v0.1\",\n        \"mistral-community/Codestral-22B-v0.1\",\n    ),\n    \"devstral\": (\n        \"unsloth/Devstral-Small-2505-unsloth-bnb-4bit\",\n        \"unsloth/Devstral-Small-2505\",\n        \"mistralai/Devstral-Small-2505\",\n        \"unsloth/Devstral-Small-2505-bnb-4bit\",\n        \"unsloth/Devstral-Small-2507-unsloth-bnb-4bit\",\n        \"unsloth/Devstral-Small-2507\",\n        \"mistralai/Devstral-Small-2507\",\n        \"unsloth/Devstral-Small-2507-bnb-4bit\",\n    ),\n    \"magistral\": (\n        \"unsloth/Magistral-Small-2506-unsloth-bnb-4bit\",\n        \"unsloth/Magistral-Small-2506\",\n        \"mistralai/Magistral-Small-2506\",\n        \"unsloth/Magistral-Small-2506-bnb-4bit\",\n        \"unsloth/Magistral-Small-2507-unsloth-bnb-4bit\",\n        \"unsloth/Magistral-Small-2507\",\n        \"mistralai/Magistral-Small-2507\",\n        \"unsloth/Magistral-Small-2507-bnb-4bit\",\n        \"unsloth/Magistral-Small-2509-unsloth-bnb-4bit\",\n        \"unsloth/Magistral-Small-2509\",\n        \"mistralai/Magistral-Small-2509\",\n        \"unsloth/Magistral-Small-2509-bnb-4bit\",\n    ),\n    \"tinyllama\": (\n        \"unsloth/tinyllama-chat-bnb-4bit\",\n        \"unsloth/tinyllama-chat\",\n        \"TinyLlama/TinyLlama-1.1B-Chat-v1.0\",\n    ),\n    \"llama\": (\n        \"unsloth/llama-2-7b-bnb-4bit\",\n        \"unsloth/llama-2-7b\",\n        \"meta-llama/Llama-2-7b-hf\",\n        \"unsloth/llama-2-13b-bnb-4bit\",\n        \"unsloth/llama-2-13b\",\n        \"meta-llama/Llama-2-13b-hf\",\n        \"unsloth/llama-2-7b-chat-bnb-4bit\",\n        \"unsloth/llama-2-7b-chat\",\n        \"meta-llama/Llama-2-7b-chat-hf\",\n    ),\n    \"llama3\": (\n        \"unsloth/llama-3-8b-Instruct-bnb-4bit\",\n        \"unsloth/llama-3-8b-Instruct\",\n        \"meta-llama/Meta-Llama-3-8B-Instruct\",\n        \"unsloth/llama-3-70b-Instruct-bnb-4bit\",\n        \"meta-llama/Meta-Llama-3-70B-Instruct\",\n    ),\n    \"llama-3.1\": (\n        \"unsloth/Meta-Llama-3.1-8B-Instruct-unsloth-bnb-4bit\",\n        \"unsloth/Meta-Llama-3.1-8B-Instruct\",\n        \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n        \"unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit\",\n        \"unsloth/Llama-3.1-8B-Instruct-unsloth-bnb-4bit\",\n        \"unsloth/Llama-3.1-8B-Instruct\",\n        \"meta-llama/Llama-3.1-8B-Instruct\",\n        \"unsloth/Llama-3.1-8B-Instruct-bnb-4bit\",\n        \"unsloth/Meta-Llama-3.1-405B-Instruct-bnb-4bit\",\n        \"meta-llama/Meta-Llama-3.1-405B-Instruct\",\n        \"unsloth/Meta-Llama-3.1-70B-Instruct-bnb-4bit\",\n        \"unsloth/Meta-Llama-3.1-70B-Instruct\",\n        \"meta-llama/Meta-Llama-3.1-70B-Instruct\",\n        \"unsloth/Hermes-3-Llama-3.1-8B-bnb-4bit\",\n        \"unsloth/Hermes-3-Llama-3.1-8B\",\n        \"NousResearch/Hermes-3-Llama-3.1-8B\",\n        \"unsloth/Hermes-3-Llama-3.1-70B-bnb-4bit\",\n        \"unsloth/Hermes-3-Llama-3.1-70B\",\n        \"NousResearch/Hermes-3-Llama-3.1-70B\",\n        \"unsloth/Hermes-3-Llama-3.1-405B-bnb-4bit\",\n        \"NousResearch/Hermes-3-Llama-3.1-405B\",\n        \"unsloth/Llama-3.1-Tulu-3-8B-bnb-4bit\",\n        \"unsloth/Llama-3.1-Tulu-3-8B\",\n        \"allenai/Llama-3.1-Tulu-3-8B\",\n        \"unsloth/Llama-3.1-Tulu-3-70B-bnb-4bit\",\n        \"unsloth/Llama-3.1-Tulu-3-70B\",\n        \"allenai/Llama-3.1-Tulu-3-70B\",\n    ),\n    \"llama-31-storm\": (\n        \"unsloth/Llama-3.1-Storm-8B-bnb-4bit\",\n        \"unsloth/Llama-3.1-Storm-8B\",\n        \"akjindal53244/Llama-3.1-Storm-8B\",\n    ),\n    \"llama-31-nemotron\": (\n        \"unsloth/Llama-3.1-Nemotron-70B-Instruct-bnb-4bit\",\n        \"unsloth/Llama-3.1-Nemotron-70B-Instruct\",\n        \"nvidia/Llama-3.1-Nemotron-70B-Instruct-HF\",\n    ),\n    \"llama-3.2\": (\n        \"unsloth/Llama-3.2-1B-Instruct-unsloth-bnb-4bit\",\n        \"unsloth/Llama-3.2-1B-Instruct\",\n        \"meta-llama/Llama-3.2-1B-Instruct\",\n        \"unsloth/Llama-3.2-1B-Instruct-bnb-4bit\",\n        \"unsloth/Llama-3.2-3B-Instruct-unsloth-bnb-4bit\",\n        \"unsloth/Llama-3.2-3B-Instruct\",\n        \"meta-llama/Llama-3.2-3B-Instruct\",\n        \"unsloth/Llama-3.2-3B-Instruct-bnb-4bit\",\n    ),\n    \"llama-32-vision\": (\n        \"unsloth/Llama-3.2-11B-Vision-Instruct-unsloth-bnb-4bit\",\n        \"unsloth/Llama-3.2-11B-Vision-Instruct\",\n        \"meta-llama/Llama-3.2-11B-Vision-Instruct\",\n        \"unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit\",\n        \"unsloth/Llama-3.2-90B-Vision-Instruct-bnb-4bit\",\n        \"unsloth/Llama-3.2-90B-Vision-Instruct\",\n        \"meta-llama/Llama-3.2-90B-Vision-Instruct\",\n    ),\n    \"llama-3.3\": (\n        \"unsloth/Llama-3.3-70B-Instruct-bnb-4bit\",\n        \"unsloth/Llama-3.3-70B-Instruct\",\n        \"meta-llama/Llama-3.3-70B-Instruct\",\n    ),\n    \"gemma\": (\n        \"unsloth/gemma-7b-it-bnb-4bit\",\n        \"unsloth/gemma-7b-it\",\n        \"google/gemma-7b-it\",\n        \"google/gemma-2b-it\",\n        \"unsloth/gemma-1.1-2b-it-bnb-4bit\",\n        \"unsloth/gemma-1.1-2b-it\",\n        \"google/gemma-1.1-2b-it\",\n        \"unsloth/gemma-1.1-7b-it-bnb-4bit\",\n        \"unsloth/gemma-1.1-7b-it\",\n        \"google/gemma-1.1-7b-it\",\n    ),\n    \"gemma2\": (\n        \"unsloth/gemma-2-9b-it-bnb-4bit\",\n        \"unsloth/gemma-2-9b-it\",\n        \"google/gemma-2-9b-it\",\n        \"unsloth/gemma-2-27b-it-bnb-4bit\",\n        \"unsloth/gemma-2-27b-it\",\n        \"google/gemma-2-27b-it\",\n        \"unsloth/gemma-2-2b-it-bnb-4bit\",\n        \"unsloth/gemma-2-2b-it\",\n        \"google/gemma-2-2b-it\",\n    ),\n    \"gemma-3\": (\n        \"unsloth/gemma-3-1b-it-unsloth-bnb-4bit\",\n        \"unsloth/gemma-3-1b-it\",\n        \"google/gemma-3-1b-it\",\n        \"unsloth/gemma-3-1b-it-bnb-4bit\",\n        \"unsloth/gemma-3-4b-it-unsloth-bnb-4bit\",\n        \"unsloth/gemma-3-4b-it\",\n        \"google/gemma-3-4b-it\",\n        \"unsloth/gemma-3-4b-it-bnb-4bit\",\n        \"unsloth/gemma-3-12b-it-unsloth-bnb-4bit\",\n        \"unsloth/gemma-3-12b-it\",\n        \"google/gemma-3-12b-it\",\n        \"unsloth/gemma-3-12b-it-bnb-4bit\",\n        \"unsloth/gemma-3-27b-it-unsloth-bnb-4bit\",\n        \"unsloth/gemma-3-27b-it\",\n        \"google/gemma-3-27b-it\",\n        \"unsloth/gemma-3-27b-it-bnb-4bit\",\n        \"unsloth/medgemma-4b-it-unsloth-bnb-4bit\",\n        \"unsloth/medgemma-4b-it\",\n        \"google/medgemma-4b-it\",\n        \"unsloth/medgemma-4b-it-bnb-4bit\",\n        \"unsloth/medgemma-27b-text-it-unsloth-bnb-4bit\",\n        \"unsloth/medgemma-27b-text-it\",\n        \"google/medgemma-27b-text-it\",\n        \"unsloth/medgemma-27b-text-it-bnb-4bit\",\n    ),\n    \"gemma3n\": (\n        \"unsloth/gemma-3n-E4B-it-unsloth-bnb-4bit\",\n        \"unsloth/gemma-3n-E4B-it\",\n        \"google/gemma-3n-E4B-it\",\n        \"unsloth/gemma-3n-E4B-it-unsloth-bnb-4bit\",\n        \"unsloth/gemma-3n-E2B-it-unsloth-bnb-4bit\",\n        \"unsloth/gemma-3n-E2B-it\",\n        \"google/gemma-3n-E2B-it\",\n        \"unsloth/gemma-3n-E2B-it-unsloth-bnb-4bit\",\n    ),\n    \"gemma3-270m\": (\n        \"unsloth/gemma-3-270m-it-unsloth-bnb-4bit\",\n        \"unsloth/gemma-3-270m-it\",\n        \"google/gemma-3-270m-it\",\n        \"unsloth/gemma-3-270m-it-bnb-4bit\",\n    ),\n    \"qwen-25\": (\n        \"unsloth/Qwen2.5-0.5B-Instruct-unsloth-bnb-4bit\",\n        \"unsloth/Qwen2.5-0.5B-Instruct\",\n        \"Qwen/Qwen2.5-0.5B-Instruct\",\n        \"unsloth/Qwen2.5-0.5B-Instruct-bnb-4bit\",\n        \"unsloth/Qwen2.5-1.5B-Instruct-unsloth-bnb-4bit\",\n        \"unsloth/Qwen2.5-1.5B-Instruct\",\n        \"Qwen/Qwen2.5-1.5B-Instruct\",\n        \"unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit\",\n        \"unsloth/Qwen2.5-3B-Instruct-unsloth-bnb-4bit\",\n        \"unsloth/Qwen2.5-3B-Instruct\",\n        \"Qwen/Qwen2.5-3B-Instruct\",\n        \"unsloth/Qwen2.5-3B-Instruct-bnb-4bit\",\n        \"unsloth/Qwen2.5-7B-Instruct-unsloth-bnb-4bit\",\n        \"unsloth/Qwen2.5-7B-Instruct\",\n        \"Qwen/Qwen2.5-7B-Instruct\",\n        \"unsloth/Qwen2.5-7B-Instruct-bnb-4bit\",\n        \"unsloth/Qwen2.5-14B-Instruct-unsloth-bnb-4bit\",\n        \"unsloth/Qwen2.5-14B-Instruct\",\n        \"Qwen/Qwen2.5-14B-Instruct\",\n        \"unsloth/Qwen2.5-14B-Instruct-bnb-4bit\",\n        \"unsloth/Qwen2.5-32B-Instruct-bnb-4bit\",\n        \"unsloth/Qwen2.5-32B-Instruct\",\n        \"Qwen/Qwen2.5-32B-Instruct\",\n        \"unsloth/Qwen2.5-72B-Instruct-bnb-4bit\",\n        \"unsloth/Qwen2.5-72B-Instruct\",\n        \"Qwen/Qwen2.5-72B-Instruct\",\n        \"unsloth/Qwen2.5-Math-1.5B-Instruct-bnb-4bit\",\n        \"unsloth/Qwen2.5-Math-1.5B-Instruct\",\n        \"Qwen/Qwen2.5-Math-1.5B-Instruct\",\n        \"unsloth/Qwen2.5-Math-7B-Instruct-bnb-4bit\",\n        \"unsloth/Qwen2.5-Math-7B-Instruct\",\n        \"Qwen/Qwen2.5-Math-7B-Instruct\",\n        \"unsloth/Qwen2.5-Math-72B-Instruct-bnb-4bit\",\n        \"unsloth/Qwen2.5-Math-72B-Instruct\",\n        \"Qwen/Qwen2.5-Math-72B-Instruct\",\n    ),\n    \"qwen-25-coder\": (\n        \"unsloth/Qwen2.5-Coder-0.5B-Instruct-bnb-4bit\",\n        \"unsloth/Qwen2.5-Coder-0.5B-Instruct\",\n        \"Qwen/Qwen2.5-Coder-0.5B-Instruct\",\n        \"unsloth/Qwen2.5-Coder-1.5B-Instruct-bnb-4bit\",\n        \"unsloth/Qwen2.5-Coder-1.5B-Instruct\",\n        \"Qwen/Qwen2.5-Coder-1.5B-Instruct\",\n        \"unsloth/Qwen2.5-Coder-3B-Instruct-bnb-4bit\",\n        \"unsloth/Qwen2.5-Coder-3B-Instruct\",\n        \"Qwen/Qwen2.5-Coder-3B-Instruct\",\n        \"unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit\",\n        \"unsloth/Qwen2.5-Coder-7B-Instruct\",\n        \"Qwen/Qwen2.5-Coder-7B-Instruct\",\n        \"unsloth/Qwen2.5-Coder-14B-Instruct-bnb-4bit\",\n        \"unsloth/Qwen2.5-Coder-14B-Instruct\",\n        \"Qwen/Qwen2.5-Coder-14B-Instruct\",\n        \"unsloth/Qwen2.5-Coder-32B-Instruct-bnb-4bit\",\n        \"unsloth/Qwen2.5-Coder-32B-Instruct\",\n        \"Qwen/Qwen2.5-Coder-32B-Instruct\",\n    ),\n    \"qwen-25-vl\": (\n        \"unsloth/Qwen2.5-VL-3B-Instruct-unsloth-bnb-4bit\",\n        \"unsloth/Qwen2.5-VL-3B-Instruct\",\n        \"Qwen/Qwen2.5-VL-3B-Instruct\",\n        \"unsloth/Qwen2.5-VL-3B-Instruct-bnb-4bit\",\n        \"unsloth/Qwen2.5-VL-7B-Instruct-unsloth-bnb-4bit\",\n        \"unsloth/Qwen2.5-VL-7B-Instruct\",\n        \"Qwen/Qwen2.5-VL-7B-Instruct\",\n        \"unsloth/Qwen2.5-VL-7B-Instruct-bnb-4bit\",\n        \"unsloth/Qwen2.5-VL-32B-Instruct-unsloth-bnb-4bit\",\n        \"unsloth/Qwen2.5-VL-32B-Instruct\",\n        \"Qwen/Qwen2.5-VL-32B-Instruct\",\n        \"unsloth/Qwen2.5-VL-32B-Instruct-bnb-4bit\",\n        \"unsloth/Qwen2.5-VL-72B-Instruct-unsloth-bnb-4bit\",\n        \"unsloth/Qwen2.5-VL-72B-Instruct\",\n        \"Qwen/Qwen2.5-VL-72B-Instruct\",\n        \"unsloth/Qwen2.5-VL-72B-Instruct-bnb-4bit\",\n    ),\n    \"openthinker\": (\n        \"unsloth/OpenThinker-7B-unsloth-bnb-4bit\",\n        \"unsloth/OpenThinker-7B\",\n        \"open-thoughts/OpenThinker-7B\",\n        \"unsloth/OpenThinker-7B-bnb-4bit\",\n    ),\n    \"qwen-2\": (\n        \"unsloth/Qwen2-0.5B-Instruct-bnb-4bit\",\n        \"unsloth/Qwen2-0.5B-Instruct\",\n        \"Qwen/Qwen2-0.5B-Instruct\",\n        \"unsloth/Qwen2-1.5B-Instruct-bnb-4bit\",\n        \"unsloth/Qwen2-1.5B-Instruct\",\n        \"Qwen/Qwen2-1.5B-Instruct\",\n        \"unsloth/Qwen2-7B-Instruct-bnb-4bit\",\n        \"unsloth/Qwen2-7B-Instruct\",\n        \"Qwen/Qwen2-7B-Instruct\",\n        \"unsloth/Qwen2-70B-Instruct-bnb-4bit\",\n        \"Qwen/Qwen2-70B-Instruct\",\n    ),\n    \"qwen3\": (\n        \"unsloth/Qwen3-0.6B-unsloth-bnb-4bit\",\n        \"unsloth/Qwen3-0.6B\",\n        \"Qwen/Qwen3-0.6B\",\n        \"unsloth/Qwen3-0.6B-bnb-4bit\",\n        \"unsloth/Qwen3-1.7B-unsloth-bnb-4bit\",\n        \"unsloth/Qwen3-1.7B\",\n        \"Qwen/Qwen3-1.7B\",\n        \"unsloth/Qwen3-1.7B-bnb-4bit\",\n        \"unsloth/Qwen3-4B-unsloth-bnb-4bit\",\n        \"unsloth/Qwen3-4B\",\n        \"Qwen/Qwen3-4B\",\n        \"unsloth/Qwen3-4B-bnb-4bit\",\n        \"unsloth/Qwen3-8B-unsloth-bnb-4bit\",\n        \"unsloth/Qwen3-8B\",\n        \"Qwen/Qwen3-8B\",\n        \"unsloth/Qwen3-8B-bnb-4bit\",\n        \"unsloth/Qwen3-14B-unsloth-bnb-4bit\",\n        \"unsloth/Qwen3-14B\",\n        \"Qwen/Qwen3-14B\",\n        \"unsloth/Qwen3-14B-bnb-4bit\",\n        \"unsloth/Qwen3-32B-unsloth-bnb-4bit\",\n        \"unsloth/Qwen3-32B\",\n        \"Qwen/Qwen3-32B\",\n        \"unsloth/Qwen3-32B-bnb-4bit\",\n        \"unsloth/Qwen3-30B-A3B-unsloth-bnb-4bit\",\n        \"unsloth/Qwen3-30B-A3B\",\n        \"Qwen/Qwen3-30B-A3B\",\n        \"unsloth/Qwen3-30B-A3B-bnb-4bit\",\n    ),\n    \"qwen3-instruct\": (\n        \"unsloth/Qwen3-4B-Instruct-2507-unsloth-bnb-4bit\",\n        \"unsloth/Qwen3-4B-Instruct-2507\",\n        \"Qwen/Qwen3-4B-Instruct-2507\",\n        \"unsloth/Qwen3-4B-Instruct-2507-bnb-4bit\",\n        \"unsloth/Qwen3-30B-A3B-Instruct-2507\",\n        \"Qwen/Qwen3-30B-A3B-Instruct-2507\",\n        \"unsloth/Qwen3-Coder-30B-A3B-Instruct\",\n        \"Qwen/Qwen3-Coder-30B-A3B-Instruct\",\n        \"unsloth/Qwen3-4B-Instruct-2507-unsloth-bnb-4bit\",\n        \"unsloth/Qwen3-4B-Instruct-2507\",\n        \"Qwen/Qwen3-4B-Instruct-2507\",\n        \"unsloth/Qwen3-4B-Instruct-2507-bnb-4bit\",\n    ),\n    \"qwen3-thinking\": (\n        \"unsloth/QwQ-32B-Preview-bnb-4bit\",\n        \"unsloth/QwQ-32B-Preview\",\n        \"Qwen/QwQ-32B-Preview\",\n        \"unsloth/QwQ-32B-unsloth-bnb-4bit\",\n        \"unsloth/QwQ-32B\",\n        \"Qwen/QwQ-32B\",\n        \"unsloth/QwQ-32B-bnb-4bit\",\n        \"unsloth/Qwen3-4B-Thinking-2507-unsloth-bnb-4bit\",\n        \"unsloth/Qwen3-4B-Thinking-2507\",\n        \"Qwen/Qwen3-4B-Thinking-2507\",\n        \"unsloth/Qwen3-4B-Thinking-2507-bnb-4bit\",\n        \"unsloth/Qwen3-30B-A3B-Thinking-2507\",\n        \"Qwen/Qwen3-30B-A3B-Thinking-2507\",\n    ),\n    \"zephyr\": (\n        \"unsloth/zephyr-sft-bnb-4bit\",\n        \"unsloth/zephyr-sft\",\n        \"HuggingFaceH4/mistral-7b-sft-beta\",\n    ),\n    \"chatml\": (\n        \"unsloth/Hermes-2-Pro-Mistral-7B-bnb-4bit\",\n        \"unsloth/Hermes-2-Pro-Mistral-7B\",\n        \"NousResearch/Hermes-2-Pro-Mistral-7B\",\n        \"unsloth/OpenHermes-2.5-Mistral-7B-bnb-4bit\",\n        \"unsloth/OpenHermes-2.5-Mistral-7B\",\n        \"teknium/OpenHermes-2.5-Mistral-7B\",\n    ),\n    \"gpt-oss\": (\n        \"unsloth/gpt-oss-20b-unsloth-bnb-4bit\",\n        \"unsloth/gpt-oss-20b\",\n        \"openai/gpt-oss-20b\",\n        \"unsloth/gpt-oss-20b-unsloth-bnb-4bit\",\n        \"unsloth/gpt-oss-120b-unsloth-bnb-4bit\",\n        \"unsloth/gpt-oss-120b\",\n        \"openai/gpt-oss-120b\",\n        \"unsloth/gpt-oss-120b-unsloth-bnb-4bit\",\n    ),\n    \"starling\": (\n        \"unsloth/Starling-LM-7B-beta-bnb-4bit\",\n        \"unsloth/Starling-LM-7B-beta\",\n        \"Nexusflow/Starling-LM-7B-beta\",\n    ),\n    \"yi-chat\": (\n        \"unsloth/yi-34b-chat-bnb-4bit\",\n        \"01-ai/Yi-6B-Chat\",\n        \"01-ai/Yi-34B-Chat\",\n    ),\n    \"granite-32\": (\n        \"unsloth/granite-3.2-2b-instruct-unsloth-bnb-4bit\",\n        \"unsloth/granite-3.2-2b-instruct\",\n        \"ibm-granite/granite-3.2-2b-instruct\",\n        \"unsloth/granite-3.2-2b-instruct-bnb-4bit\",\n        \"unsloth/granite-3.2-8b-instruct-unsloth-bnb-4bit\",\n        \"unsloth/granite-3.2-8b-instruct\",\n        \"ibm-granite/granite-3.2-8b-instruct\",\n        \"unsloth/granite-3.2-8b-instruct-bnb-4bit\",\n    ),\n    \"granite-32-vision\": (\n        \"unsloth/granite-vision-3.2-2b-unsloth-bnb-4bit\",\n        \"unsloth/granite-vision-3.2-2b\",\n        \"ibm-granite/granite-vision-3.2-2b\",\n        \"unsloth/granite-vision-3.2-2b-bnb-4bit\",\n    ),\n}\n\nMODEL_TO_OLLAMA_TEMPLATE_MAPPER = {}\n\nfor key, values in OLLAMA_TEMPLATE_TO_MODEL_MAPPER.items():\n    for value in values:\n        MODEL_TO_OLLAMA_TEMPLATE_MAPPER[value] = key\n\n    # Get lowercased\n    lowered_key = key.lower()\n    for value in values:\n        MODEL_TO_OLLAMA_TEMPLATE_MAPPER[value.lower()] = lowered_key\n"
  },
  {
    "path": "unsloth/registry/REGISTRY.md",
    "content": "## Model Registry\n\n### Structure\n```\nunsloth\n    -registry\n        __init__.py\n        registry.py\n        _llama.py\n        _mistral.py\n        _phi.py\n        ...\n```\n\nEach model is registered in a separate file within the `registry` module (e.g. `registry/_llama.py`).\n\nWithin each model registration file, a high-level `ModelMeta` is created for each model version, with the following structure:\n```python\n@dataclass\nclass ModelMeta:\n    org: str\n    base_name: str\n    model_version: str\n    model_info_cls: type[ModelInfo]\n    model_sizes: list[str] = field(default_factory=list)\n    instruct_tags: list[str] = field(default_factory=list)\n    quant_types: list[QuantType] | dict[str, list[QuantType]] = field(default_factory=list)\n    is_multimodal: bool = False\n```\n\nEach model then instantiates a global `ModelMeta` for its specific model version, defining how the model path (e.g. `unsloth/Llama-3.1-8B-Instruct`) is constructed since each model type has a different naming convention.\n```python\nLlamaMeta_3_1 = ModelMeta(\n    org=\"meta-llama\",\n    base_name=\"Llama\",\n    instruct_tags=[None, \"Instruct\"],\n    model_version=\"3.1\",\n    model_sizes=[\"8\"],\n    model_info_cls=LlamaModelInfo,\n    is_multimodal=False,\n    quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],\n)\n```\n\n`LlamaModelInfo` is a subclass of `ModelInfo` that defines the model path for each model size and quant type.\n```python\nclass LlamaModelInfo(ModelInfo):\n    @classmethod\n    def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag):\n        key = f\"{base_name}-{version}-{size}B\"\n        return super().construct_model_name(base_name, version, size, quant_type, instruct_tag, key)\n```\n\nOnce these constructs are defined, the model is registered by writing a register_xx_models function.\n```python\ndef register_llama_3_1_models(include_original_model: bool = False):\n    global _IS_LLAMA_3_1_REGISTERED\n    if _IS_LLAMA_3_1_REGISTERED:\n        return\n    _register_models(LlamaMeta_3_1, include_original_model=include_original_model)\n    _IS_LLAMA_3_1_REGISTERED = True\n```\n\n`_register_models` is a helper function that registers the model with the registry.  The global `_IS_XX_REGISTERED` is used to prevent duplicate registration.\n\nOnce a model is registered, registry.registry.MODEL_REGISTRY is updated with the model info and can be searched with `registry.search_models`.\n\n### Tests\n\nThe `tests/test_model_registry.py` file contains tests for the model registry.\n\nAlso, each model registration file is an executable module that checks that all registered models are available on `huggingface_hub`.\n```python\npython unsloth.registry._llama.py\n```\n\nPrints the following (abridged) output:\n```bash\n✓ unsloth/Llama-3.1-8B\n✓ unsloth/Llama-3.1-8B-bnb-4bit\n✓ unsloth/Llama-3.1-8B-unsloth-bnb-4bit\n✓ meta-llama/Llama-3.1-8B\n✓ unsloth/Llama-3.1-8B-Instruct\n✓ unsloth/Llama-3.1-8B-Instruct-bnb-4bit\n✓ unsloth/Llama-3.1-8B-Instruct-unsloth-bnb-4bit\n✓ meta-llama/Llama-3.1-8B-Instruct\n✓ unsloth/Llama-3.2-1B\n✓ unsloth/Llama-3.2-1B-bnb-4bit\n✓ unsloth/Llama-3.2-1B-unsloth-bnb-4bit\n✓ meta-llama/Llama-3.2-1B\n...\n```\n\n### TODO\n- Model Collections\n    - [x] Gemma3\n    - [ ] Llama3.1\n    - [x] Llama3.2\n    - [x] MistralSmall\n    - [x] Qwen2.5\n    - [x] Qwen2.5-VL\n    - [ ] Qwen2.5 Coder\n    - [x] QwenQwQ-32B\n    - [x] Deepseek v3\n    - [x] Deepseek R1\n    - [x] Phi-4\n    - [ ] Unsloth 4-bit Dynamic Quants\n    - [ ] Vision/multimodal models\n- Sync model uploads with registry\n- Add utility methods for tracking model stats"
  },
  {
    "path": "unsloth/registry/__init__.py",
    "content": "from ._deepseek import register_deepseek_models as _register_deepseek_models\nfrom ._gemma import register_gemma_models as _register_gemma_models\nfrom ._llama import register_llama_models as _register_llama_models\nfrom ._mistral import register_mistral_models as _register_mistral_models\nfrom ._phi import register_phi_models as _register_phi_models\nfrom ._qwen import register_qwen_models as _register_qwen_models\nfrom .registry import MODEL_REGISTRY, ModelInfo, QuantType\n\n_ARE_MODELS_REGISTERED = False\n\n\ndef register_models():\n    global _ARE_MODELS_REGISTERED\n\n    if _ARE_MODELS_REGISTERED:\n        return\n    _register_deepseek_models()\n    _register_gemma_models()\n    _register_llama_models()\n    _register_mistral_models()\n    _register_phi_models()\n    _register_qwen_models()\n\n    _ARE_MODELS_REGISTERED = True\n\n\ndef search_models(\n    org: str = None,\n    base_name: str = None,\n    version: str = None,\n    size: str = None,\n    quant_types: list[QuantType] = None,\n    search_pattern: str = None,\n) -> list[ModelInfo]:\n    \"\"\"\n    Get model info from the registry.\n\n    See registry.ModelInfo for more fields.\n\n    If search_pattern is provided, the full model path will be matched against the pattern, where the model path is the model_id on huggingface hub.\n\n    \"\"\"\n    if not _ARE_MODELS_REGISTERED:\n        register_models()\n\n    model_infos = MODEL_REGISTRY.values()\n    if org:\n        model_infos = [\n            model_info for model_info in model_infos if model_info.org == org\n        ]\n    if base_name:\n        model_infos = [\n            model_info\n            for model_info in model_infos\n            if model_info.base_name == base_name\n        ]\n    if version:\n        model_infos = [\n            model_info for model_info in model_infos if model_info.version == version\n        ]\n    if size:\n        model_infos = [\n            model_info for model_info in model_infos if model_info.size == size\n        ]\n    if quant_types:\n        model_infos = [\n            model_info\n            for model_info in model_infos\n            if any(model_info.quant_type == quant_type for quant_type in quant_types)\n        ]\n    if search_pattern:\n        model_infos = [\n            model_info\n            for model_info in model_infos\n            if search_pattern in model_info.model_path\n        ]\n\n    return model_infos\n"
  },
  {
    "path": "unsloth/registry/_deepseek.py",
    "content": "from unsloth.registry.registry import ModelInfo, ModelMeta, QuantType, _register_models\n\n_IS_DEEPSEEK_V3_REGISTERED = False\n_IS_DEEPSEEK_V3_0324_REGISTERED = False\n_IS_DEEPSEEK_R1_REGISTERED = False\n_IS_DEEPSEEK_R1_ZERO_REGISTERED = False\n_IS_DEEPSEEK_R1_DISTILL_LLAMA_REGISTERED = False\n_IS_DEEPSEEK_R1_DISTILL_QWEN_REGISTERED = False\n\n\nclass DeepseekV3ModelInfo(ModelInfo):\n    @classmethod\n    def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag):\n        key = f\"{base_name}-V{version}\"\n        return super().construct_model_name(\n            base_name, version, size, quant_type, instruct_tag, key\n        )\n\n\nclass DeepseekR1ModelInfo(ModelInfo):\n    @classmethod\n    def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag):\n        key = f\"{base_name}-{version}\" if version else base_name\n        if size:\n            key = f\"{key}-{size}B\"\n        return super().construct_model_name(\n            base_name, version, size, quant_type, instruct_tag, key\n        )\n\n\n# Deepseek V3 Model Meta\nDeepseekV3Meta = ModelMeta(\n    org = \"deepseek-ai\",\n    base_name = \"DeepSeek\",\n    instruct_tags = [None],\n    model_version = \"3\",\n    model_sizes = [\"\"],\n    model_info_cls = DeepseekV3ModelInfo,\n    is_multimodal = False,\n    quant_types = [QuantType.NONE, QuantType.BF16],\n)\n\nDeepseekV3_0324Meta = ModelMeta(\n    org = \"deepseek-ai\",\n    base_name = \"DeepSeek\",\n    instruct_tags = [None],\n    model_version = \"3-0324\",\n    model_sizes = [\"\"],\n    model_info_cls = DeepseekV3ModelInfo,\n    is_multimodal = False,\n    quant_types = [QuantType.NONE, QuantType.GGUF],\n)\n\nDeepseekR1Meta = ModelMeta(\n    org = \"deepseek-ai\",\n    base_name = \"DeepSeek-R1\",\n    instruct_tags = [None],\n    model_version = \"\",\n    model_sizes = [\"\"],\n    model_info_cls = DeepseekR1ModelInfo,\n    is_multimodal = False,\n    quant_types = [QuantType.NONE, QuantType.BF16, QuantType.GGUF],\n)\n\nDeepseekR1ZeroMeta = ModelMeta(\n    org = \"deepseek-ai\",\n    base_name = \"DeepSeek-R1\",\n    instruct_tags = [None],\n    model_version = \"Zero\",\n    model_sizes = [\"\"],\n    model_info_cls = DeepseekR1ModelInfo,\n    is_multimodal = False,\n    quant_types = [QuantType.NONE, QuantType.GGUF],\n)\n\nDeepseekR1DistillLlamaMeta = ModelMeta(\n    org = \"deepseek-ai\",\n    base_name = \"DeepSeek-R1-Distill\",\n    instruct_tags = [None],\n    model_version = \"Llama\",\n    model_sizes = [\"8\", \"70\"],\n    model_info_cls = DeepseekR1ModelInfo,\n    is_multimodal = False,\n    quant_types = {\"8\": [QuantType.UNSLOTH, QuantType.GGUF], \"70\": [QuantType.GGUF]},\n)\n\n# Deepseek R1 Distill Qwen Model Meta\nDeepseekR1DistillQwenMeta = ModelMeta(\n    org = \"deepseek-ai\",\n    base_name = \"DeepSeek-R1-Distill\",\n    instruct_tags = [None],\n    model_version = \"Qwen\",\n    model_sizes = [\"1.5\", \"7\", \"14\", \"32\"],\n    model_info_cls = DeepseekR1ModelInfo,\n    is_multimodal = False,\n    quant_types = {\n        \"1.5\": [QuantType.UNSLOTH, QuantType.BNB, QuantType.GGUF],\n        \"7\": [QuantType.UNSLOTH, QuantType.BNB],\n        \"14\": [QuantType.UNSLOTH, QuantType.BNB, QuantType.GGUF],\n        \"32\": [QuantType.GGUF, QuantType.BNB],\n    },\n)\n\n\ndef register_deepseek_v3_models(include_original_model: bool = False):\n    global _IS_DEEPSEEK_V3_REGISTERED\n    if _IS_DEEPSEEK_V3_REGISTERED:\n        return\n    _register_models(DeepseekV3Meta, include_original_model = include_original_model)\n    _IS_DEEPSEEK_V3_REGISTERED = True\n\n\ndef register_deepseek_v3_0324_models(include_original_model: bool = False):\n    global _IS_DEEPSEEK_V3_0324_REGISTERED\n    if _IS_DEEPSEEK_V3_0324_REGISTERED:\n        return\n    _register_models(DeepseekV3_0324Meta, include_original_model = include_original_model)\n    _IS_DEEPSEEK_V3_0324_REGISTERED = True\n\n\ndef register_deepseek_r1_models(include_original_model: bool = False):\n    global _IS_DEEPSEEK_R1_REGISTERED\n    if _IS_DEEPSEEK_R1_REGISTERED:\n        return\n    _register_models(DeepseekR1Meta, include_original_model = include_original_model)\n    _IS_DEEPSEEK_R1_REGISTERED = True\n\n\ndef register_deepseek_r1_zero_models(include_original_model: bool = False):\n    global _IS_DEEPSEEK_R1_ZERO_REGISTERED\n    if _IS_DEEPSEEK_R1_ZERO_REGISTERED:\n        return\n    _register_models(DeepseekR1ZeroMeta, include_original_model = include_original_model)\n    _IS_DEEPSEEK_R1_ZERO_REGISTERED = True\n\n\ndef register_deepseek_r1_distill_llama_models(include_original_model: bool = False):\n    global _IS_DEEPSEEK_R1_DISTILL_LLAMA_REGISTERED\n    if _IS_DEEPSEEK_R1_DISTILL_LLAMA_REGISTERED:\n        return\n    _register_models(\n        DeepseekR1DistillLlamaMeta, include_original_model = include_original_model\n    )\n    _IS_DEEPSEEK_R1_DISTILL_LLAMA_REGISTERED = True\n\n\ndef register_deepseek_r1_distill_qwen_models(include_original_model: bool = False):\n    global _IS_DEEPSEEK_R1_DISTILL_QWEN_REGISTERED\n    if _IS_DEEPSEEK_R1_DISTILL_QWEN_REGISTERED:\n        return\n    _register_models(\n        DeepseekR1DistillQwenMeta, include_original_model = include_original_model\n    )\n    _IS_DEEPSEEK_R1_DISTILL_QWEN_REGISTERED = True\n\n\ndef register_deepseek_models(include_original_model: bool = False):\n    register_deepseek_v3_models(include_original_model = include_original_model)\n    register_deepseek_v3_0324_models(include_original_model = include_original_model)\n    register_deepseek_r1_models(include_original_model = include_original_model)\n    register_deepseek_r1_zero_models(include_original_model = include_original_model)\n    register_deepseek_r1_distill_llama_models(\n        include_original_model = include_original_model\n    )\n    register_deepseek_r1_distill_qwen_models(\n        include_original_model = include_original_model\n    )\n\n\ndef _list_deepseek_r1_distill_models():\n    from unsloth.utils.hf_hub import ModelInfo as HfModelInfo\n    from unsloth.utils.hf_hub import list_models\n\n    models: list[HfModelInfo] = list_models(\n        author = \"unsloth\", search = \"Distill\", limit = 1000\n    )\n    distill_models = []\n    for model in models:\n        model_id = model.id\n        model_name = model_id.split(\"/\")[-1]\n        # parse out only the version\n        version = model_name.removeprefix(\"DeepSeek-R1-Distill-\")\n        distill_models.append(version)\n\n    return distill_models\n\n\nregister_deepseek_models(include_original_model = True)\n\nif __name__ == \"__main__\":\n    from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info\n\n    MODEL_REGISTRY.clear()\n\n    register_deepseek_models(include_original_model = True)\n\n    for model_id, model_info in MODEL_REGISTRY.items():\n        model_info = _check_model_info(model_id)\n        if model_info is None:\n            print(f\"\\u2718 {model_id}\")\n        else:\n            print(f\"\\u2713 {model_id}\")\n    # distill_models = _list_deepseek_r1_distill_models()\n    # for model in sorted(distill_models):\n    #     if \"qwen\" in model.lower():\n    #         print(model)\n"
  },
  {
    "path": "unsloth/registry/_gemma.py",
    "content": "from unsloth.registry.registry import ModelInfo, ModelMeta, QuantType, _register_models\n\n_IS_GEMMA_3_BASE_REGISTERED = False\n_IS_GEMMA_3_INSTRUCT_REGISTERED = False\n\n\nclass GemmaModelInfo(ModelInfo):\n    @classmethod\n    def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag):\n        key = f\"{base_name}-{version}-{size}B\"\n        return super().construct_model_name(\n            base_name, version, size, quant_type, instruct_tag, key\n        )\n\n\n# Gemma3 Base Model Meta\nGemmaMeta3Base = ModelMeta(\n    org = \"google\",\n    base_name = \"gemma\",\n    instruct_tags = [\"pt\"],  # pt = base\n    model_version = \"3\",\n    model_sizes = [\"1\", \"4\", \"12\", \"27\"],\n    model_info_cls = GemmaModelInfo,\n    is_multimodal = True,\n    quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],\n)\n\n# Gemma3 Instruct Model Meta\nGemmaMeta3Instruct = ModelMeta(\n    org = \"google\",\n    base_name = \"gemma\",\n    instruct_tags = [\"it\"],  # it = instruction tuned\n    model_version = \"3\",\n    model_sizes = [\"1\", \"4\", \"12\", \"27\"],\n    model_info_cls = GemmaModelInfo,\n    is_multimodal = True,\n    quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF],\n)\n\n\ndef register_gemma_3_base_models(include_original_model: bool = False):\n    global _IS_GEMMA_3_BASE_REGISTERED\n    if _IS_GEMMA_3_BASE_REGISTERED:\n        return\n    _register_models(GemmaMeta3Base, include_original_model = include_original_model)\n    _IS_GEMMA_3_BASE_REGISTERED = True\n\n\ndef register_gemma_3_instruct_models(include_original_model: bool = False):\n    global _IS_GEMMA_3_INSTRUCT_REGISTERED\n    if _IS_GEMMA_3_INSTRUCT_REGISTERED:\n        return\n    _register_models(GemmaMeta3Instruct, include_original_model = include_original_model)\n    _IS_GEMMA_3_INSTRUCT_REGISTERED = True\n\n\ndef register_gemma_models(include_original_model: bool = False):\n    register_gemma_3_base_models(include_original_model = include_original_model)\n    register_gemma_3_instruct_models(include_original_model = include_original_model)\n\n\nif __name__ == \"__main__\":\n    from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info\n\n    MODEL_REGISTRY.clear()\n\n    register_gemma_models(include_original_model = True)\n\n    for model_id, model_info in MODEL_REGISTRY.items():\n        model_info = _check_model_info(model_id)\n        if model_info is None:\n            print(f\"\\u2718 {model_id}\")\n        else:\n            print(f\"\\u2713 {model_id}\")\n"
  },
  {
    "path": "unsloth/registry/_llama.py",
    "content": "from unsloth.registry.registry import ModelInfo, ModelMeta, QuantType, _register_models\n\n_IS_LLAMA_3_1_REGISTERED = False\n_IS_LLAMA_3_2_REGISTERED = False\n_IS_LLAMA_3_2_VISION_REGISTERED = False\n\n\nclass LlamaModelInfo(ModelInfo):\n    @classmethod\n    def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag):\n        key = f\"{base_name}-{version}-{size}B\"\n        return super().construct_model_name(\n            base_name, version, size, quant_type, instruct_tag, key\n        )\n\n\nclass LlamaVisionModelInfo(ModelInfo):\n    @classmethod\n    def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag):\n        key = f\"{base_name}-{version}-{size}B-Vision\"\n        return super().construct_model_name(\n            base_name, version, size, quant_type, instruct_tag, key\n        )\n\n\n# Llama 3.1\nLlamaMeta_3_1 = ModelMeta(\n    org = \"meta-llama\",\n    base_name = \"Llama\",\n    instruct_tags = [None, \"Instruct\"],\n    model_version = \"3.1\",\n    model_sizes = [\"8\"],\n    model_info_cls = LlamaModelInfo,\n    is_multimodal = False,\n    quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],\n)\n\n# Llama 3.2 Base Models\nLlamaMeta_3_2_Base = ModelMeta(\n    org = \"meta-llama\",\n    base_name = \"Llama\",\n    instruct_tags = [None],\n    model_version = \"3.2\",\n    model_sizes = [\"1\", \"3\"],\n    model_info_cls = LlamaModelInfo,\n    is_multimodal = False,\n    quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],\n)\n\n# Llama 3.2 Instruction Tuned Models\nLlamaMeta_3_2_Instruct = ModelMeta(\n    org = \"meta-llama\",\n    base_name = \"Llama\",\n    instruct_tags = [\"Instruct\"],\n    model_version = \"3.2\",\n    model_sizes = [\"1\", \"3\"],\n    model_info_cls = LlamaModelInfo,\n    is_multimodal = False,\n    quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF],\n)\n\n# Llama 3.2 Vision\nLlamaMeta_3_2_Vision = ModelMeta(\n    org = \"meta-llama\",\n    base_name = \"Llama\",\n    instruct_tags = [None, \"Instruct\"],\n    model_version = \"3.2\",\n    model_sizes = [\"11\", \"90\"],\n    model_info_cls = LlamaVisionModelInfo,\n    is_multimodal = True,\n    quant_types = {\n        \"11\": [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],\n        \"90\": [QuantType.NONE],\n    },\n)\n\n\ndef register_llama_3_1_models(include_original_model: bool = False):\n    global _IS_LLAMA_3_1_REGISTERED\n    if _IS_LLAMA_3_1_REGISTERED:\n        return\n    _register_models(LlamaMeta_3_1, include_original_model = include_original_model)\n    _IS_LLAMA_3_1_REGISTERED = True\n\n\ndef register_llama_3_2_models(include_original_model: bool = False):\n    global _IS_LLAMA_3_2_REGISTERED\n    if _IS_LLAMA_3_2_REGISTERED:\n        return\n    _register_models(LlamaMeta_3_2_Base, include_original_model = include_original_model)\n    _register_models(\n        LlamaMeta_3_2_Instruct, include_original_model = include_original_model\n    )\n    _IS_LLAMA_3_2_REGISTERED = True\n\n\ndef register_llama_3_2_vision_models(include_original_model: bool = False):\n    global _IS_LLAMA_3_2_VISION_REGISTERED\n    if _IS_LLAMA_3_2_VISION_REGISTERED:\n        return\n    _register_models(\n        LlamaMeta_3_2_Vision, include_original_model = include_original_model\n    )\n    _IS_LLAMA_3_2_VISION_REGISTERED = True\n\n\ndef register_llama_models(include_original_model: bool = False):\n    register_llama_3_1_models(include_original_model = include_original_model)\n    register_llama_3_2_models(include_original_model = include_original_model)\n    register_llama_3_2_vision_models(include_original_model = include_original_model)\n\n\nif __name__ == \"__main__\":\n    from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info\n\n    MODEL_REGISTRY.clear()\n\n    register_llama_models(include_original_model = True)\n\n    for model_id, model_info in MODEL_REGISTRY.items():\n        model_info = _check_model_info(model_id)\n        if model_info is None:\n            print(f\"\\u2718 {model_id}\")\n        else:\n            print(f\"\\u2713 {model_id}\")\n"
  },
  {
    "path": "unsloth/registry/_mistral.py",
    "content": "import copy\n\nfrom unsloth.registry.registry import ModelInfo, ModelMeta, QuantType, _register_models\n\n_IS_MISTRAL_SMALL_REGISTERED = False\n\n_MISTRAL_SMALL_03_25_VERSION = \"2503\"\n_MISTRAL_SMALL_01_25_VERSION = \"2501\"\n_MISTRAL_SMALL_09_24_VERSION = \"2409\"  # Not uploaded to unsloth\n\n\nclass MistralSmallModelInfo(ModelInfo):\n    @classmethod\n    def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag):\n        if version == _MISTRAL_SMALL_03_25_VERSION:\n            key = f\"{base_name}-3.1-{size}B-{instruct_tag}\"\n        else:\n            key = f\"{base_name}-{size}B-{instruct_tag}\"\n        key += f\"-{version}\"\n        key = cls.append_quant_type(key, quant_type)\n\n        return key\n\n\nMistralSmall_2503_Base_Meta = ModelMeta(\n    org = \"mistralai\",\n    base_name = \"Mistral-Small\",\n    instruct_tags = [\"Base\"],\n    model_version = _MISTRAL_SMALL_03_25_VERSION,\n    model_sizes = [\"24\"],\n    model_info_cls = MistralSmallModelInfo,\n    is_multimodal = False,\n    quant_types = [QuantType.NONE, QuantType.UNSLOTH, QuantType.BNB],\n)\n\nMistralSmall_2503_Instruct_Meta = copy.deepcopy(MistralSmall_2503_Base_Meta)\nMistralSmall_2503_Instruct_Meta.instruct_tags = [\"Instruct\"]\nMistralSmall_2503_Instruct_Meta.quant_types = [\n    QuantType.NONE,\n    QuantType.UNSLOTH,\n    QuantType.BNB,\n    QuantType.GGUF,\n]\n\nMistralSmall_2501_Base_Meta = copy.deepcopy(MistralSmall_2503_Base_Meta)\nMistralSmall_2501_Base_Meta.model_version = _MISTRAL_SMALL_01_25_VERSION\n\nMistralSmall_2501_Instruct_Meta = copy.deepcopy(MistralSmall_2503_Instruct_Meta)\nMistralSmall_2501_Instruct_Meta.model_version = _MISTRAL_SMALL_01_25_VERSION\n\n\ndef register_mistral_small_models(include_original_model: bool = False):\n    global _IS_MISTRAL_SMALL_REGISTERED\n    if _IS_MISTRAL_SMALL_REGISTERED:\n        return\n    _register_models(\n        MistralSmall_2503_Base_Meta, include_original_model = include_original_model\n    )\n    _register_models(\n        MistralSmall_2503_Instruct_Meta, include_original_model = include_original_model\n    )\n    _register_models(\n        MistralSmall_2501_Base_Meta, include_original_model = include_original_model\n    )\n    _register_models(\n        MistralSmall_2501_Instruct_Meta, include_original_model = include_original_model\n    )\n\n    _IS_MISTRAL_SMALL_REGISTERED = True\n\n\ndef register_mistral_models(include_original_model: bool = False):\n    register_mistral_small_models(include_original_model = include_original_model)\n\n\nif __name__ == \"__main__\":\n    from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info\n\n    MODEL_REGISTRY.clear()\n\n    register_mistral_models(include_original_model = True)\n\n    for model_id, model_info in MODEL_REGISTRY.items():\n        model_info = _check_model_info(model_id)\n        if model_info is None:\n            print(f\"\\u2718 {model_id}\")\n        else:\n            print(f\"\\u2713 {model_id}\")\n"
  },
  {
    "path": "unsloth/registry/_phi.py",
    "content": "from unsloth.registry.registry import ModelInfo, ModelMeta, QuantType, _register_models\n\n_IS_PHI_4_REGISTERED = False\n_IS_PHI_4_INSTRUCT_REGISTERED = False\n\n\nclass PhiModelInfo(ModelInfo):\n    @classmethod\n    def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag):\n        key = f\"{base_name}-{version}\"\n        return super().construct_model_name(\n            base_name, version, size, quant_type, instruct_tag, key\n        )\n\n\n# Phi Model Meta\nPhiMeta4 = ModelMeta(\n    org = \"microsoft\",\n    base_name = \"phi\",\n    instruct_tags = [None],\n    model_version = \"4\",\n    model_sizes = [\"1\"],  # Assuming only one size\n    model_info_cls = PhiModelInfo,\n    is_multimodal = False,\n    quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],\n)\n\n# Phi Instruct Model Meta\nPhiInstructMeta4 = ModelMeta(\n    org = \"microsoft\",\n    base_name = \"phi\",\n    instruct_tags = [\"mini-instruct\"],\n    model_version = \"4\",\n    model_sizes = [\"1\"],  # Assuming only one size\n    model_info_cls = PhiModelInfo,\n    is_multimodal = False,\n    quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF],\n)\n\n\ndef register_phi_4_models(include_original_model: bool = False):\n    global _IS_PHI_4_REGISTERED\n    if _IS_PHI_4_REGISTERED:\n        return\n    _register_models(PhiMeta4, include_original_model = include_original_model)\n    _IS_PHI_4_REGISTERED = True\n\n\ndef register_phi_4_instruct_models(include_original_model: bool = False):\n    global _IS_PHI_4_INSTRUCT_REGISTERED\n    if _IS_PHI_4_INSTRUCT_REGISTERED:\n        return\n    _register_models(PhiInstructMeta4, include_original_model = include_original_model)\n    _IS_PHI_4_INSTRUCT_REGISTERED = True\n\n\ndef register_phi_models(include_original_model: bool = False):\n    register_phi_4_models(include_original_model = include_original_model)\n    register_phi_4_instruct_models(include_original_model = include_original_model)\n\n\nif __name__ == \"__main__\":\n    from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info\n\n    MODEL_REGISTRY.clear()\n\n    register_phi_models(include_original_model = True)\n\n    for model_id, model_info in MODEL_REGISTRY.items():\n        model_info = _check_model_info(model_id)\n        if model_info is None:\n            print(f\"\\u2718 {model_id}\")\n        else:\n            print(f\"\\u2713 {model_id}\")\n"
  },
  {
    "path": "unsloth/registry/_qwen.py",
    "content": "from unsloth.registry.registry import ModelInfo, ModelMeta, QuantType, _register_models\n\n_IS_QWEN_2_5_REGISTERED = False\n_IS_QWEN_2_5_VL_REGISTERED = False\n_IS_QWEN_QWQ_REGISTERED = False\n\n\nclass QwenModelInfo(ModelInfo):\n    @classmethod\n    def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag):\n        key = f\"{base_name}{version}-{size}B\"\n        return super().construct_model_name(\n            base_name, version, size, quant_type, instruct_tag, key\n        )\n\n\nclass QwenVLModelInfo(ModelInfo):\n    @classmethod\n    def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag):\n        key = f\"{base_name}{version}-VL-{size}B\"\n        return super().construct_model_name(\n            base_name, version, size, quant_type, instruct_tag, key\n        )\n\n\nclass QwenQwQModelInfo(ModelInfo):\n    @classmethod\n    def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag):\n        key = f\"{base_name}-{size}B\"\n        return super().construct_model_name(\n            base_name, version, size, quant_type, instruct_tag, key\n        )\n\n\nclass QwenQVQPreviewModelInfo(ModelInfo):\n    @classmethod\n    def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag):\n        key = f\"{base_name}-{size}B-Preview\"\n        return super().construct_model_name(\n            base_name, version, size, quant_type, instruct_tag, key\n        )\n\n\n# Qwen2.5 Model Meta\nQwen_2_5_Meta = ModelMeta(\n    org = \"Qwen\",\n    base_name = \"Qwen\",\n    instruct_tags = [None, \"Instruct\"],\n    model_version = \"2.5\",\n    model_sizes = [\"3\", \"7\"],\n    model_info_cls = QwenModelInfo,\n    is_multimodal = False,\n    quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],\n)\n\n# Qwen2.5 VL Model Meta\nQwen_2_5_VLMeta = ModelMeta(\n    org = \"Qwen\",\n    base_name = \"Qwen\",\n    instruct_tags = [\"Instruct\"],  # No base, only instruction tuned\n    model_version = \"2.5\",\n    model_sizes = [\"3\", \"7\", \"32\", \"72\"],\n    model_info_cls = QwenVLModelInfo,\n    is_multimodal = True,\n    quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],\n)\n\n# Qwen QwQ Model Meta\nQwenQwQMeta = ModelMeta(\n    org = \"Qwen\",\n    base_name = \"QwQ\",\n    instruct_tags = [None],\n    model_version = \"\",\n    model_sizes = [\"32\"],\n    model_info_cls = QwenQwQModelInfo,\n    is_multimodal = False,\n    quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF],\n)\n\n# Qwen QVQ Preview Model Meta\nQwenQVQPreviewMeta = ModelMeta(\n    org = \"Qwen\",\n    base_name = \"QVQ\",\n    instruct_tags = [None],\n    model_version = \"\",\n    model_sizes = [\"72\"],\n    model_info_cls = QwenQVQPreviewModelInfo,\n    is_multimodal = True,\n    quant_types = [QuantType.NONE, QuantType.BNB],\n)\n\n\ndef register_qwen_2_5_models(include_original_model: bool = False):\n    global _IS_QWEN_2_5_REGISTERED\n    if _IS_QWEN_2_5_REGISTERED:\n        return\n    _register_models(Qwen_2_5_Meta, include_original_model = include_original_model)\n    _IS_QWEN_2_5_REGISTERED = True\n\n\ndef register_qwen_2_5_vl_models(include_original_model: bool = False):\n    global _IS_QWEN_2_5_VL_REGISTERED\n    if _IS_QWEN_2_5_VL_REGISTERED:\n        return\n    _register_models(Qwen_2_5_VLMeta, include_original_model = include_original_model)\n    _IS_QWEN_2_5_VL_REGISTERED = True\n\n\ndef register_qwen_qwq_models(include_original_model: bool = False):\n    global _IS_QWEN_QWQ_REGISTERED\n    if _IS_QWEN_QWQ_REGISTERED:\n        return\n    _register_models(QwenQwQMeta, include_original_model = include_original_model)\n    _register_models(QwenQVQPreviewMeta, include_original_model = include_original_model)\n    _IS_QWEN_QWQ_REGISTERED = True\n\n\ndef register_qwen_models(include_original_model: bool = False):\n    register_qwen_2_5_models(include_original_model = include_original_model)\n    register_qwen_2_5_vl_models(include_original_model = include_original_model)\n    register_qwen_qwq_models(include_original_model = include_original_model)\n\n\nif __name__ == \"__main__\":\n    from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info\n\n    MODEL_REGISTRY.clear()\n\n    register_qwen_models(include_original_model = True)\n\n    for model_id, model_info in MODEL_REGISTRY.items():\n        model_info = _check_model_info(model_id)\n        if model_info is None:\n            print(f\"\\u2718 {model_id}\")\n        else:\n            print(f\"\\u2713 {model_id}\")\n"
  },
  {
    "path": "unsloth/registry/registry.py",
    "content": "import warnings\nfrom dataclasses import dataclass, field\nfrom enum import Enum\n\n\nclass QuantType(Enum):\n    BNB = \"bnb\"\n    UNSLOTH = \"unsloth\"  # dynamic 4-bit quantization\n    GGUF = \"GGUF\"\n    NONE = \"none\"\n    BF16 = \"bf16\"  # only for Deepseek V3\n\n\n# Tags for Hugging Face model paths\nBNB_QUANTIZED_TAG = \"bnb-4bit\"\nUNSLOTH_DYNAMIC_QUANT_TAG = \"unsloth\" + \"-\" + BNB_QUANTIZED_TAG\nGGUF_TAG = \"GGUF\"\nBF16_TAG = \"bf16\"\n\nQUANT_TAG_MAP = {\n    QuantType.BNB: BNB_QUANTIZED_TAG,\n    QuantType.UNSLOTH: UNSLOTH_DYNAMIC_QUANT_TAG,\n    QuantType.GGUF: GGUF_TAG,\n    QuantType.NONE: None,\n    QuantType.BF16: BF16_TAG,\n}\n\n\n# NOTE: models registered with org=\"unsloth\" and QUANT_TYPE.NONE are aliases of QUANT_TYPE.UNSLOTH\n@dataclass\nclass ModelInfo:\n    org: str\n    base_name: str\n    version: str\n    size: int\n    name: str = None  # full model name, constructed from base_name, version, and size unless provided\n    is_multimodal: bool = False\n    instruct_tag: str = None\n    quant_type: QuantType = None\n    description: str = None\n\n    def __post_init__(self):\n        self.name = self.name or self.construct_model_name(\n            self.base_name,\n            self.version,\n            self.size,\n            self.quant_type,\n            self.instruct_tag,\n        )\n\n    @staticmethod\n    def append_instruct_tag(key: str, instruct_tag: str = None):\n        if instruct_tag:\n            key = \"-\".join([key, instruct_tag])\n        return key\n\n    @staticmethod\n    def append_quant_type(key: str, quant_type: QuantType = None):\n        if quant_type != QuantType.NONE:\n            key = \"-\".join([key, QUANT_TAG_MAP[quant_type]])\n        return key\n\n    @classmethod\n    def construct_model_name(\n        cls, base_name, version, size, quant_type, instruct_tag, key = \"\"\n    ):\n        key = cls.append_instruct_tag(key, instruct_tag)\n        key = cls.append_quant_type(key, quant_type)\n        return key\n\n    @property\n    def model_path(\n        self,\n    ) -> str:\n        return f\"{self.org}/{self.name}\"\n\n\n@dataclass\nclass ModelMeta:\n    org: str\n    base_name: str\n    model_version: str\n    model_info_cls: type[ModelInfo]\n    model_sizes: list[str] = field(default_factory = list)\n    instruct_tags: list[str] = field(default_factory = list)\n    quant_types: list[QuantType] | dict[str, list[QuantType]] = field(\n        default_factory = list\n    )\n    is_multimodal: bool = False\n\n\nMODEL_REGISTRY: dict[str, ModelInfo] = {}\n\n\ndef register_model(\n    model_info_cls: ModelInfo,\n    org: str,\n    base_name: str,\n    version: str,\n    size: int,\n    instruct_tag: str = None,\n    quant_type: QuantType = None,\n    is_multimodal: bool = False,\n    name: str = None,\n):\n    name = name or model_info_cls.construct_model_name(\n        base_name = base_name,\n        version = version,\n        size = size,\n        quant_type = quant_type,\n        instruct_tag = instruct_tag,\n    )\n    key = f\"{org}/{name}\"\n\n    if key in MODEL_REGISTRY:\n        raise ValueError(\n            f\"Model {key} already registered, current keys: {MODEL_REGISTRY.keys()}\"\n        )\n\n    MODEL_REGISTRY[key] = model_info_cls(\n        org = org,\n        base_name = base_name,\n        version = version,\n        size = size,\n        is_multimodal = is_multimodal,\n        instruct_tag = instruct_tag,\n        quant_type = quant_type,\n        name = name,\n    )\n\n\ndef _check_model_info(model_id: str, properties: list[str] = [\"lastModified\"]):\n    from huggingface_hub import HfApi\n    from huggingface_hub import ModelInfo as HfModelInfo\n    from huggingface_hub.utils import RepositoryNotFoundError\n\n    api = HfApi()\n\n    try:\n        model_info: HfModelInfo = api.model_info(model_id, expand = properties)\n    except Exception as e:\n        if isinstance(e, RepositoryNotFoundError):\n            warnings.warn(f\"{model_id} not found on Hugging Face\")\n            model_info = None\n        else:\n            raise e\n    return model_info\n\n\ndef _register_models(model_meta: ModelMeta, include_original_model: bool = False):\n    org = model_meta.org\n    base_name = model_meta.base_name\n    instruct_tags = model_meta.instruct_tags\n    model_version = model_meta.model_version\n    model_sizes = model_meta.model_sizes\n    is_multimodal = model_meta.is_multimodal\n    quant_types = model_meta.quant_types\n    model_info_cls = model_meta.model_info_cls\n\n    for size in model_sizes:\n        for instruct_tag in instruct_tags:\n            # Handle quant types per model size\n            if isinstance(quant_types, dict):\n                _quant_types = quant_types[size]\n            else:\n                _quant_types = quant_types\n            for quant_type in _quant_types:\n                # NOTE: models registered with org=\"unsloth\" and QUANT_TYPE.NONE are aliases of QUANT_TYPE.UNSLOTH\n                _org = \"unsloth\"  # unsloth models -- these are all quantized versions of the original model\n                register_model(\n                    model_info_cls = model_info_cls,\n                    org = _org,\n                    base_name = base_name,\n                    version = model_version,\n                    size = size,\n                    instruct_tag = instruct_tag,\n                    quant_type = quant_type,\n                    is_multimodal = is_multimodal,\n                )\n            # include original model from releasing organization\n            if include_original_model:\n                register_model(\n                    model_info_cls = model_info_cls,\n                    org = org,\n                    base_name = base_name,\n                    version = model_version,\n                    size = size,\n                    instruct_tag = instruct_tag,\n                    quant_type = QuantType.NONE,\n                    is_multimodal = is_multimodal,\n                )\n"
  },
  {
    "path": "unsloth/save.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom unsloth_zoo.utils import Version\nfrom importlib.metadata import version as importlib_version\nfrom unsloth_zoo.hf_utils import dtype_from_config, HAS_TORCH_DTYPE\nfrom unsloth_zoo.llama_cpp import (\n    convert_to_gguf,\n    quantize_gguf,\n    use_local_gguf,\n    install_llama_cpp,\n    check_llama_cpp,\n    _download_convert_hf_to_gguf,\n)\n\n# H4: Defensive imports -- these were added in unsloth-zoo PR #526\n# and may not exist on older versions\ntry:\n    from unsloth_zoo.llama_cpp import LLAMA_CPP_DEFAULT_DIR, IS_WINDOWS\nexcept ImportError:\n    import sys\n\n    IS_WINDOWS = sys.platform == \"win32\"\n    LLAMA_CPP_DEFAULT_DIR = \"llama.cpp\"\nfrom bitsandbytes.nn import Linear4bit as Bnb_Linear4bit\nfrom peft.tuners.lora import Linear4bit as Peft_Linear4bit\nfrom peft.tuners.lora import Linear as Peft_Linear\nfrom typing import Optional, Callable, Union, List\nimport sys\nimport requests\nimport torch\nimport os\nimport shutil\nimport pickle\nimport gc\nfrom transformers.models.llama.modeling_llama import logger\nfrom .kernels import fast_dequantize, QUANT_STATE, get_lora_parameters_bias\nimport subprocess\nimport psutil\nimport re\nfrom transformers.models.llama.modeling_llama import logger\nfrom .tokenizer_utils import fix_sentencepiece_gguf\nfrom .models.loader_utils import get_model_name\nfrom .models._utils import _convert_torchao_model\nfrom .ollama_template_mappers import OLLAMA_TEMPLATES, MODEL_TO_OLLAMA_TEMPLATE_MAPPER\nfrom transformers import ProcessorMixin\nfrom huggingface_hub import HfApi\n\ntry:\n    from huggingface_hub import get_token\nexcept:\n    try:\n        from huggingface_hub.utils import get_token\n    except:\n        # For older versions of huggingface_hub\n        from huggingface_hub.utils._token import get_token\nfrom pathlib import Path\nfrom peft import PeftModelForCausalLM, PeftModel\n\n__all__ = [\n    \"print_quantization_methods\",\n    \"unsloth_save_model\",\n    \"save_to_gguf\",\n    \"patch_saving_functions\",\n    \"create_huggingface_repo\",\n]\n\n# llama.cpp specific targets - all takes 90s. Below takes 60s\nLLAMA_CPP_TARGETS = [\n    \"llama-quantize\",\n    \"llama-cli\",\n    \"llama-server\",\n]\n\n# Check environments\nkeynames = \"\\n\" + \"\\n\".join(os.environ.keys())\nIS_COLAB_ENVIRONMENT = \"\\nCOLAB_\" in keynames\nIS_KAGGLE_ENVIRONMENT = \"\\nKAGGLE_\" in keynames\nKAGGLE_TMP = \"/tmp\"\ndel keynames\n\n# Weights\nLLAMA_WEIGHTS = (\n    \"self_attn.q_proj\",\n    \"self_attn.k_proj\",\n    \"self_attn.v_proj\",\n    \"self_attn.o_proj\",\n    \"mlp.gate_proj\",\n    \"mlp.up_proj\",\n    \"mlp.down_proj\",\n)\nLLAMA_LAYERNORMS = (\n    \"input_layernorm\",\n    \"post_attention_layernorm\",\n    \"pre_feedforward_layernorm\",\n    \"post_feedforward_layernorm\",\n    \"self_attn.q_norm\",\n    \"self_attn.k_norm\",\n)\n\n# https://github.com/ggerganov/llama.cpp/blob/master/examples/quantize/quantize.cpp#L19\n# From https://mlabonne.github.io/blog/posts/Quantize_Llama_2_models_using_ggml.html\nALLOWED_QUANTS = {\n    \"not_quantized\": \"Recommended. Fast conversion. Slow inference, big files.\",\n    \"fast_quantized\": \"Recommended. Fast conversion. OK inference, OK file size.\",\n    \"quantized\": \"Recommended. Slow conversion. Fast inference, small files.\",\n    \"f32\": \"Not recommended. Retains 100% accuracy, but super slow and memory hungry.\",\n    \"bf16\": \"Bfloat16 - Fastest conversion + retains 100% accuracy. Slow and memory hungry.\",\n    \"f16\": \"Float16  - Fastest conversion + retains 100% accuracy. Slow and memory hungry.\",\n    \"q8_0\": \"Fast conversion. High resource use, but generally acceptable.\",\n    \"q4_k_m\": \"Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K\",\n    \"q5_k_m\": \"Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K\",\n    \"q2_k\": \"Uses Q4_K for the attention.vw and feed_forward.w2 tensors, Q2_K for the other tensors.\",\n    \"q3_k_l\": \"Uses Q5_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K\",\n    \"q3_k_m\": \"Uses Q4_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K\",\n    \"q3_k_s\": \"Uses Q3_K for all tensors\",\n    \"q4_0\": \"Original quant method, 4-bit.\",\n    \"q4_1\": \"Higher accuracy than q4_0 but not as high as q5_0. However has quicker inference than q5 models.\",\n    \"q4_k_s\": \"Uses Q4_K for all tensors\",\n    \"q4_k\": \"alias for q4_k_m\",\n    \"q5_k\": \"alias for q5_k_m\",\n    \"q5_0\": \"Higher accuracy, higher resource usage and slower inference.\",\n    \"q5_1\": \"Even higher accuracy, resource usage and slower inference.\",\n    \"q5_k_s\": \"Uses Q5_K for all tensors\",\n    \"q6_k\": \"Uses Q8_K for all tensors\",\n    # \"iq2_xxs\" : \"2.06 bpw quantization\", # Not supported sadly\n    # \"iq2_xs\"  : \"2.31 bpw quantization\",\n    # \"iq3_xxs\" : \"3.06 bpw quantization\",\n    \"q3_k_xs\": \"3-bit extra small quantization\",\n}\n\n\ndef has_curl():\n    return shutil.which(\"curl\") is not None\n\n\nCURL_FLAG = \"-DLLAMA_CURL=ON\" if has_curl() else \"-DLLAMA_CURL=OFF\"\n\n\ndef print_quantization_methods():\n    for key, value in ALLOWED_QUANTS.items():\n        print(f'\"{key}\"  ==> {value}')\n\n\ndef check_if_sentencepiece_model(\n    model, temporary_location = \"_unsloth_sentencepiece_temp\"\n):\n    if not hasattr(model, \"_saved_temp_tokenizer\"):\n        return False\n\n    temp_tokenizer = model._saved_temp_tokenizer\n    sentencepiece_model = False\n    file_location = os.path.join(temporary_location, temp_tokenizer.name_or_path)\n    created_folder = False\n    if not os.path.exists(file_location):\n        created_folder = True\n        os.makedirs(file_location)\n    temp_tokenizer.save_pretrained(file_location)\n    if os.path.isfile(f\"{file_location}/tokenizer.model\"):\n        sentencepiece_model = True\n    if created_folder:\n        shutil.rmtree(file_location, ignore_errors = True)\n    return sentencepiece_model\n\n\ndef _free_cached_model(model):\n    from huggingface_hub import scan_cache_dir\n\n    cached_repos = list(scan_cache_dir().repos)\n\n    # Go through every cached repo, and delete the one that matches the model we want to save.\n    # Can save 4GB of disk space - useful for Kaggle systems.\n    for cached_repo in cached_repos:\n        if cached_repo.repo_id == model.config._name_or_path:\n            remove_cache_commit = list(cached_repo.revisions)[0].commit_hash\n            delete_strategy = scan_cache_dir().delete_revisions(\n                remove_cache_commit,\n            )\n\n            logger.warning_once(\n                \"Unsloth: Will remove a cached repo with size \"\n                + delete_strategy.expected_freed_size_str,\n            )\n\n            delete_strategy.execute()\n\n\ndef _merge_lora(layer, name):\n    bias = getattr(layer, \"bias\", None)\n    if isinstance(layer, (Bnb_Linear4bit, Peft_Linear4bit, Peft_Linear)):\n        # Is LoRA so we need to merge!\n        W, quant_state, A, B, s, bias = get_lora_parameters_bias(layer)\n        if quant_state is not None:\n            dtype = (\n                quant_state.dtype if type(quant_state) is not list else quant_state[2]\n            )\n            W = fast_dequantize(W, quant_state)\n        else:\n            dtype = W.dtype\n        W = W.to(torch.float32).t()\n        # W = W.t()\n\n        if A is not None:\n            # sAB = (A.t().to(torch.float32) @ (s * B.t().to(torch.float32)))\n            # W += sAB\n            W.addmm_(A.t().to(torch.float32), B.t().to(torch.float32), alpha = s)\n            # W.addmm_(A.t().to(W.dtype), B.t().to(W.dtype), alpha = s)\n            # if not torch.isfinite(W).all():\n            maximum_element = torch.max(W.min().abs(), W.max())\n            if not torch.isfinite(maximum_element).item():\n                raise ValueError(\n                    f\"Unsloth: Merge failed.\\n{name} has some elements = infinity.\"\n                )\n        W = W.t().to(dtype)\n    else:\n        W = layer.weight\n    return W, bias\n\n\ndef fast_save_pickle(shard, name):\n    # Use this if # CPUs is <= 2\n    print(f\"Unsloth: Saving {name}...\")\n    torch.save(\n        shard,\n        name,\n        # HIGHEST_PROTOCOL seems to not work with Pytorch!\n        # pickle_module   = pickle,\n        # pickle_protocol = pickle.HIGHEST_PROTOCOL,\n    )\n    return\n\n\n@torch.inference_mode\ndef unsloth_save_model(\n    model,\n    tokenizer,\n    save_directory: Union[str, os.PathLike],\n    save_method: str = \"lora\",  # [\"lora\", \"merged_16bit\", \"merged_4bit\"]\n    push_to_hub: bool = False,\n    token: Optional[Union[str, bool]] = None,\n    is_main_process: bool = True,\n    state_dict: Optional[dict] = None,\n    save_function: Callable = torch.save,\n    max_shard_size: Union[int, str] = \"5GB\",\n    safe_serialization: bool = True,\n    variant: Optional[str] = None,\n    save_peft_format: bool = True,\n    # Push to hub\n    use_temp_dir: Optional[bool] = None,\n    commit_message: Optional[str] = \"Trained with Unsloth\",\n    private: Optional[bool] = None,\n    create_pr: bool = False,\n    revision: str = None,\n    commit_description: str = \"Upload model trained with Unsloth 2x faster\",\n    tags: List[str] = None,\n    # Our functions\n    temporary_location: str = \"_unsloth_temporary_saved_buffers\",\n    maximum_memory_usage: float = 0.9,\n    datasets: Optional[List[str]] = None,\n):\n    if token is None:\n        token = get_token()\n\n    if commit_message is None:\n        commit_message = \"\"\n    if \"Unsloth\" not in commit_message:\n        commit_message += \" (Trained with Unsloth)\"\n    commit_message = commit_message.lstrip()\n\n    if commit_description is None:\n        commit_description = \"Upload model trained with Unsloth 2x faster\"\n    elif \"Unsloth 2x faster\" not in commit_description:\n        commit_description += \" (Trained with Unsloth 2x faster)\"\n\n    if save_method == \"merged_4bit\":\n        raise RuntimeError(\n            \"Unsloth: Merging into 4bit will cause your model to lose accuracy if you plan\\n\"\n            \"to merge to GGUF or others later on. I suggest you to do this as a final step\\n\"\n            \"if you're planning to do multiple saves.\\n\"\n            \"If you are certain, change `save_method` to `merged_4bit_forced`.\"\n        )\n    elif save_method == \"merged_4bit_forced\":\n        save_method = \"merged_4bit\"\n\n    save_pretrained_settings = dict(locals())\n    for deletion in (\n        \"model\",\n        \"tokenizer\",\n        \"save_method\",\n        \"temporary_location\",\n        \"maximum_memory_usage\",\n        \"datasets\",\n    ):\n        del save_pretrained_settings[deletion]\n\n    # First check for a token!\n    if push_to_hub:\n        from huggingface_hub import whoami\n\n        try:\n            username = whoami(token = token)[\"name\"]\n        except:\n            raise RuntimeError(\n                \"Unsloth: Please supply a token!\\n\"\n                \"Go to https://huggingface.co/settings/tokens\"\n            )\n\n    assert maximum_memory_usage > 0 and maximum_memory_usage <= 0.95\n\n    # Clean memory up first\n    for _ in range(3):\n        torch.cuda.empty_cache()\n        gc.collect()\n\n    save_method = save_method.lower().replace(\" \", \"_\")\n    if (\n        save_method != \"lora\"\n        and save_method != \"merged_16bit\"\n        and save_method != \"merged_4bit\"\n    ):\n        raise RuntimeError(\n            \"Unsloth: You must select one of 3 options when saving models:\\n\"\n            '\"lora\"         ==> This is the fastest and easiet. Just saves LoRA modules.\\n'\n            '\"merged_16bit\" ==> This merges LoRA weights and saves to float16. Needed for llama.cpp / GGUF.\\n'\n            '\"merged_4bit\"  ==> This merges LoRA weights and saves to 4bit. Useful for DPO / inference.'\n        )\n\n    if save_method == \"merged_4bit\":\n        print(\"Unsloth: Merging 4bit and LoRA weights to 4bit...\")\n        print(\"This might take 5 minutes...\")\n\n        # Counteract no LoRA adapters!\n        if hasattr(model, \"merge_and_unload\"):\n            model = model.merge_and_unload()\n        print(\"Done.\")\n\n    if tags is not None:\n        assert isinstance(tags, (list, tuple))\n        tags = list(tags) + [\n            \"unsloth\",\n        ]\n    else:\n        tags = [\n            \"unsloth\",\n        ]\n    save_pretrained_settings[\"tags\"] = tags\n\n    if ((save_method == \"lora\") or (save_method == \"merged_4bit\")) and push_to_hub:\n        if token is None:\n            raise RuntimeError(\n                \"Unsloth: Pushing to HF requires a token. Pass `token = 'hf_....'`\\n\"\n                \"Go to https://huggingface.co/settings/tokens.\"\n            )\n\n        if save_method == \"lora\":\n            print(\"Unsloth: Saving LoRA adapters. Please wait...\")\n        elif save_method == \"merged_4bit\":\n            print(\"Unsloth: Saving 4bit Bitsandbytes model. Please wait...\")\n\n        # Update model tag\n        _ = upload_to_huggingface(\n            model,\n            save_directory,\n            token,\n            \"finetuned\",\n            \"trl\",\n            file_location = None,\n            old_username = None,\n            private = private,\n            datasets = datasets,\n        )\n\n        getattr(model, \"original_push_to_hub\", model.push_to_hub)(\n            repo_id = save_directory,\n            use_temp_dir = use_temp_dir,\n            commit_message = commit_message,\n            private = private,\n            token = token,\n            max_shard_size = max_shard_size,\n            create_pr = create_pr,\n            safe_serialization = safe_serialization,\n            revision = revision,\n            commit_description = commit_description,\n            tags = tags,\n        )\n        if tokenizer is not None:\n            # Set padding side to left for inference\n            old_padding_side = tokenizer.padding_side\n            tokenizer.padding_side = \"left\"\n\n            getattr(tokenizer, \"original_push_to_hub\", tokenizer.push_to_hub)(\n                repo_id = save_directory,\n                use_temp_dir = use_temp_dir,\n                commit_message = commit_message,\n                private = private,\n                token = token,\n                max_shard_size = max_shard_size,\n                create_pr = create_pr,\n                safe_serialization = safe_serialization,\n                revision = revision,\n                commit_description = commit_description,\n                tags = tags,\n            )\n\n            # Revert back padding side\n            tokenizer.padding_side = old_padding_side\n\n        if hasattr(model, \"config\"):\n            print(\n                f\"Saved {save_method} model to https://huggingface.co/\" + save_directory\n            )\n        return save_directory, None\n\n    # Tokenizer has different saving arguments\n    tokenizer_save_settings = {\n        \"save_directory\": save_pretrained_settings[\"save_directory\"],\n        \"legacy_format\": None,\n        \"filename_prefix\": None,\n        \"push_to_hub\": save_pretrained_settings[\"push_to_hub\"],\n        \"private\": save_pretrained_settings[\"private\"],\n        \"token\": save_pretrained_settings[\"token\"],\n    }\n\n    # Check if PEFT Model or not - if yes, 3 levels. If not 2 levels.\n    from peft import PeftModelForCausalLM\n\n    if isinstance(model, PeftModelForCausalLM):\n        internal_model = model.model\n    else:\n        internal_model = model\n\n    # Cannot be converted properly!\n    if (\n        (save_method == \"merged_4bit\")\n        or (save_method == \"lora\")\n        or (not hasattr(model, \"model\") or not hasattr(internal_model.model, \"layers\"))\n    ):\n        # Do general saving\n        # Edit save_pretrained_settings\n        # [TODO] _create_repo has errors due to **kwargs getting accepted\n        # commit_description does not seem to work?\n        what_to_delete = (\n            (\n                \"use_temp_dir\",\n                \"commit_message\",\n                \"create_pr\",\n                \"revision\",\n                \"commit_description\",\n                \"tags\",\n            )\n            if save_pretrained_settings[\"push_to_hub\"] is False\n            else (\n                \"use_temp_dir\",\n                \"create_pr\",\n                \"revision\",\n                \"tags\",\n                \"commit_description\",\n            )\n        )\n        for deletion in what_to_delete:\n            del save_pretrained_settings[deletion]\n        if hasattr(model, \"add_model_tags\"):\n            model.add_model_tags(\n                [\n                    \"unsloth\",\n                ]\n            )\n\n        # Update model tag\n        if push_to_hub:\n            _ = upload_to_huggingface(\n                model,\n                save_pretrained_settings[\"save_directory\"],\n                token,\n                \"finetuned\",\n                \"trl\",\n                file_location = None,\n                old_username = None,\n                private = private,\n                datasets = datasets,\n            )\n\n        if tokenizer is not None:\n            print(\"Unsloth: Saving tokenizer...\", end = \"\")\n\n            # Set padding side to left for inference\n            old_padding_side = tokenizer.padding_side\n            tokenizer.padding_side = \"left\"\n\n            tokenizer.save_pretrained(**tokenizer_save_settings)\n\n            # Revert back padding side\n            tokenizer.padding_side = old_padding_side\n\n            print(\" Done.\")\n        else:\n            print()\n\n        print(\"Unsloth: Saving model...\", end = \"\")\n        if save_method != \"lora\":\n            print(\" This might take 10 minutes for Llama-7b...\", end = \"\")\n\n        # [TODO] Is this correct?\n        if save_method == \"lora\":\n            save_pretrained_settings[\"selected_adapters\"] = None\n\n        model.save_pretrained(**save_pretrained_settings)\n\n        if push_to_hub and hasattr(model, \"config\"):\n            print(\n                \"Saved to https://huggingface.co/\"\n                + save_pretrained_settings[\"save_directory\"]\n            )\n\n        print(\" Done.\")\n        return save_directory, None\n\n    # If push_to_hub, we must remove the .../ part of a repo\n    username = None\n    if push_to_hub and \"/\" in save_directory:\n        # +1 solves absolute path issues\n        new_save_directory = save_directory\n        username = new_save_directory[: new_save_directory.find(\"/\")]\n        new_save_directory = new_save_directory[new_save_directory.find(\"/\") + 1 :]\n        if IS_KAGGLE_ENVIRONMENT:\n            new_save_directory = os.path.join(\n                KAGGLE_TMP, new_save_directory[new_save_directory.find(\"/\") + 1 :]\n            )\n            logger.warning_once(\n                \"Unsloth: You are pushing to hub in Kaggle environment.\\n\"\n                f\"To save memory, we shall move {save_directory} to {new_save_directory}\"\n            )\n        else:\n            logger.warning_once(\n                f\"Unsloth: You are pushing to hub, but you passed your HF username = {username}.\\n\"\n                f\"We shall truncate {save_directory} to {new_save_directory}\"\n            )\n\n        save_pretrained_settings[\"save_directory\"] = new_save_directory\n        tokenizer_save_settings[\"save_directory\"] = new_save_directory\n        save_directory = new_save_directory\n\n    print(\"Unsloth: Merging 4bit and LoRA weights to 16bit...\")\n\n    # Determine max RAM usage minus sharding\n    max_ram = psutil.virtual_memory().available\n    sharded_ram_usage = 5 * 1024 * 1024 * 1024\n    if type(max_shard_size) is str:\n        gb_found = re.match(\n            r\"([0-9]{1,})[\\s]{0,}GB\", max_shard_size, flags = re.IGNORECASE\n        )\n        mb_found = re.match(\n            r\"([0-9]{1,})[\\s]{0,}MB\", max_shard_size, flags = re.IGNORECASE\n        )\n        if gb_found:\n            sharded_ram_usage = int(gb_found.group(1)) * 1024 * 1024 * 1024\n        elif mb_found:\n            sharded_ram_usage = int(mb_found.group(1)) * 1024 * 1024\n    elif type(max_shard_size) is int:\n        sharded_ram_usage = max_shard_size\n\n    # Switch to our fast saving modules if it's a slow PC!\n    n_cpus = psutil.cpu_count(logical = False)\n    if n_cpus is None:\n        n_cpus = psutil.cpu_count()\n    if n_cpus is None:\n        n_cpus = 1\n\n    if safe_serialization is None:\n        safe_serialization = True\n        save_pretrained_settings[\"safe_serialization\"] = safe_serialization\n\n    elif safe_serialization and (n_cpus <= 2):\n        logger.warning_once(\n            f\"Unsloth: You have {n_cpus} CPUs. Using `safe_serialization` is 10x slower.\\n\"\n            f\"We shall switch to Pytorch saving, which might take 3 minutes and not 30 minutes.\\n\"\n            f\"To force `safe_serialization`, set it to `None` instead.\",\n        )\n        safe_serialization = False\n        save_function = fast_save_pickle\n        save_pretrained_settings[\"safe_serialization\"] = safe_serialization\n        save_pretrained_settings[\"save_function\"] = save_function\n\n    # Only safe_serialization uses more RAM\n    if safe_serialization:\n        max_ram -= sharded_ram_usage\n    else:\n        max_ram -= sharded_ram_usage * 0.25  # Uses much less\n\n    max_ram = int(max(0, max_ram) * maximum_memory_usage)\n    print(\n        f\"Unsloth: Will use up to \"\n        f\"{round(max_ram/1024/1024/1024, 2)} out of \"\n        f\"{round(psutil.virtual_memory().total/1024/1024/1024, 2)} RAM for saving.\"\n    )\n\n    # Move temporary_location to /tmp in Kaggle\n    if IS_KAGGLE_ENVIRONMENT:\n        temporary_location = os.path.join(KAGGLE_TMP, temporary_location)\n\n    # Max directory for disk saving\n    if not os.path.exists(temporary_location):\n        os.makedirs(temporary_location)\n\n    # Check if Kaggle or Colab, since only 20GB of Disk space allowed.\n    if IS_KAGGLE_ENVIRONMENT or IS_COLAB_ENVIRONMENT:\n        # We free up 4GB of space\n        logger.warning_once(\n            \"Unsloth: Kaggle/Colab has limited disk space. We need to delete the downloaded\\n\"\n            \"model which will save 4-16GB of disk space, allowing you to save on Kaggle/Colab.\"\n        )\n        _free_cached_model(internal_model)\n\n    # HF also uses a OrderedDict\n    from collections import OrderedDict\n\n    state_dict = OrderedDict()\n\n    torch_dtype = dtype_from_config(internal_model.config)\n    if type(torch_dtype) is str:\n        if torch_dtype == \"float16\":\n            torch_dtype = torch.float16\n        elif torch_dtype == \"bfloat16\":\n            torch_dtype = torch.bfloat16\n\n    # Check modules to save float32 dtype\n    state_dict[\"model.embed_tokens.weight\"] = (\n        internal_model.model.embed_tokens.weight.data.to(torch_dtype)\n    )\n\n    max_vram = int(\n        torch.cuda.get_device_properties(0).total_memory * maximum_memory_usage\n    )\n\n    print(\"Unsloth: Saving model... This might take 5 minutes ...\")\n\n    from tqdm import tqdm as ProgressBar\n\n    for j, layer in enumerate(ProgressBar(internal_model.model.layers)):\n        for item in LLAMA_WEIGHTS:\n            proj = eval(f\"layer.{item}\")\n            name = f\"model.layers.{j}.{item}.weight\"\n            W, bias = _merge_lora(proj, name)\n\n            # Bias term\n            if bias is not None:\n                state_dict[f\"model.layers.{j}.{item}.bias\"] = bias\n\n            if (torch.cuda.memory_allocated() + W.nbytes) < max_vram:\n                # Save to GPU memory\n                state_dict[name] = W\n            # [TODO] Saving to RAM seems to leak memory???\n            # elif (max_ram - W.nbytes) > 0:\n            #     # Save to CPU memory\n            #     logger.warning_once(f\"We will save to RAM and not VRAM now.\")\n            #     state_dict[name] = W.to(\"cpu\", non_blocking = True, copy = True)\n            #     max_ram = max(max_ram - W.nbytes, 0)\n            else:\n                # Save to Disk\n                logger.warning_once(\"\\nWe will save to Disk and not RAM now.\")\n                filename = os.path.join(temporary_location, f\"{name}.pt\")\n                torch.save(\n                    W,\n                    filename,\n                    pickle_module = pickle,\n                    pickle_protocol = pickle.HIGHEST_PROTOCOL,\n                )\n                # weights_only = True weirdly fails?\n                state_dict[name] = torch.load(\n                    filename, map_location = \"cpu\", mmap = True, weights_only = False\n                )\n        for item in LLAMA_LAYERNORMS:\n            try:\n                # Skip for Gemma 2\n                state_dict[f\"model.layers.{j}.{item}.weight\"] = eval(\n                    f\"layer.{item}.weight.data\"\n                )\n            except:\n                continue\n\n    state_dict[\"model.norm.weight\"] = internal_model.model.norm.weight.data\n    # Check for modules_to_save float32 dtype\n\n    # Check for tied weights\n    if (\n        internal_model.model.embed_tokens.weight.data_ptr()\n        != internal_model.lm_head.weight.data_ptr()\n    ):\n        state_dict[\"lm_head.weight\"] = internal_model.lm_head.weight.data.to(\n            torch_dtype\n        )\n\n    # All tensors MUST be type torch.Tensor and not torch.nn.parameter.Parameter\n    for key, value in state_dict.items():\n        if hasattr(value, \"data\"):\n            state_dict[key] = value = value.data\n        if type(value) is not torch.Tensor:\n            logger.warning_once(f\"Unsloth: {key} is not a Tensor but a {type(value)}.\")\n\n    # Edit save_pretrained_settings\n    # [TODO] _create_repo has errors due to **kwargs getting accepted\n    save_pretrained_settings[\"state_dict\"] = state_dict\n\n    # commit_description does not seem to work?\n    what_to_delete = (\n        (\n            \"use_temp_dir\",\n            \"commit_message\",\n            \"create_pr\",\n            \"revision\",\n            \"commit_description\",\n            \"tags\",\n        )\n        if not push_to_hub\n        else (\n            \"use_temp_dir\",\n            \"create_pr\",\n            \"revision\",\n            \"tags\",\n            \"commit_description\",\n        )\n    )\n    for deletion in what_to_delete:\n        del save_pretrained_settings[deletion]\n    if hasattr(model, \"add_model_tags\"):\n        model.add_model_tags(\n            [\n                \"unsloth\",\n            ]\n        )\n\n    # Update model tag\n    if push_to_hub:\n        _ = upload_to_huggingface(\n            model,\n            save_pretrained_settings[\"save_directory\"],\n            token,\n            \"finetuned\",\n            \"trl\",\n            file_location = None,\n            old_username = username,\n            private = private,\n            datasets = datasets,\n        )\n\n    # First check if we're pushing to an organization!\n    save_directory = save_pretrained_settings[\"save_directory\"]\n\n    if save_pretrained_settings[\"push_to_hub\"]:\n        new_save_directory, new_username = _determine_username(\n            save_directory, username, token\n        )\n\n        if token is not None:\n            from huggingface_hub import whoami\n\n            actual_username = whoami(token = token)[\"name\"]\n        else:\n            actual_username = username\n\n    # Check if pushing to an organization\n    if save_pretrained_settings[\"push_to_hub\"] and (username != actual_username):\n        print(f\"Unsloth: Saving to organization with address {new_save_directory}\")\n        # We upload everything at the end!\n        tokenizer_save_settings[\"push_to_hub\"] = False\n        tokenizer_save_settings[\"save_directory\"] = new_save_directory\n\n    # Save tokenizer\n    if tokenizer is not None:\n        print(\"Unsloth: Saving tokenizer...\", end = \"\")\n\n        # Set padding side to left for inference\n        old_padding_side = tokenizer.padding_side\n        tokenizer.padding_side = \"left\"\n\n        tokenizer.save_pretrained(**tokenizer_save_settings)\n\n        # Revert back padding side\n        tokenizer.padding_side = old_padding_side\n\n        print(\" Done.\")\n    else:\n        print()\n\n    # Since merged, edit quantization_config\n    old_config = model.config\n    new_config = model.config.to_dict()\n    if \"quantization_config\" in new_config:\n        del new_config[\"quantization_config\"]\n    original_model = model\n    new_config = type(model.config).from_dict(new_config)\n    while hasattr(original_model, \"model\"):\n        original_model = original_model.model\n        original_model.config = new_config\n    model.config = new_config\n\n    # Save!\n    # [TODO] --> is this correct?\n    # save_pretrained_settings[\"selected_adapters\"] = None\n\n    # Check if pushing to an organization\n    if save_pretrained_settings[\"push_to_hub\"] and (username != actual_username):\n        print(f\"Unsloth: Saving to organization with address {new_save_directory}\")\n        # Pushing to organization!\n        # Sadly .save_pretrained doesn't work :(\n        # We first save it via .save_pretrained, then upload manually!\n        save_pretrained_settings[\"save_directory\"] = new_save_directory\n        save_pretrained_settings[\"push_to_hub\"] = False\n        internal_model.save_pretrained(**save_pretrained_settings)\n\n        # Now manually go through each file and upload them manually!\n        filenames = os.listdir(new_save_directory)\n\n        hf_api = HfApi(token = save_pretrained_settings[\"token\"])\n\n        print(\"Unsloth: Uploading all files... Please wait...\")\n        hf_api.upload_folder(\n            folder_path = new_save_directory,\n            path_in_repo = \".\",\n            repo_id = new_save_directory,\n            repo_type = \"model\",\n            commit_message = \"(Trained with Unsloth)\",\n            ignore_patterns = \"*.md\",\n        )\n    else:\n        internal_model.save_pretrained(**save_pretrained_settings)\n\n    # Revert config back\n    original_model = model\n    while hasattr(original_model, \"model\"):\n        original_model = original_model.model\n        original_model.config = old_config\n    model.config = old_config\n    print(\"Done.\")\n\n    if push_to_hub and hasattr(model, \"config\"):\n        print(\n            f\"Saved merged model to https://huggingface.co/{username}/{save_directory.lstrip('/').split('/')[-1]}\"\n        )\n\n    save_pretrained_settings[\"state_dict\"] = None\n\n    for j, (key, value) in enumerate(state_dict.items()):\n        state_dict[key] = None\n        if j % 10 == 0:\n            torch.cuda.empty_cache()\n            gc.collect()\n    state_dict = None\n    del state_dict\n    torch.cuda.empty_cache()\n    gc.collect()\n\n    # Remove temporary location\n    import shutil\n\n    shutil.rmtree(temporary_location, ignore_errors = True)\n\n    for _ in range(3):\n        torch.cuda.empty_cache()\n        gc.collect()\n    return save_directory, username\n\n\ndef install_llama_cpp_clone_non_blocking():\n    full_command = [\n        \"git\",\n        \"clone\",\n        \"--recursive\",\n        \"https://github.com/ggerganov/llama.cpp\",\n    ]\n    run_installer = subprocess.Popen(\n        full_command, stdout = subprocess.DEVNULL, stderr = subprocess.STDOUT\n    )\n    return run_installer\n\n\ndef install_llama_cpp_make_non_blocking():\n    # https://github.com/ggerganov/llama.cpp/issues/7062\n    # Weirdly GPU conversion for GGUF breaks??\n    # env = { **os.environ, \"LLAMA_CUDA\": \"1\", }\n    # Force make clean\n    check = os.system(\"make clean -C llama.cpp\")\n    IS_CMAKE = False\n    if check == 0:\n        # Uses old MAKE\n        n_jobs = max(int((psutil.cpu_count() or 1) * 1.5), 1)\n        full_command = [\"make\", \"all\", \"-j\" + str(n_jobs), \"-C\", \"llama.cpp\"]\n        IS_CMAKE = False\n    else:\n        # Uses new CMAKE\n        n_jobs = max(int(psutil.cpu_count() or 1), 1)  # Use less CPUs since 1.5x faster\n        check = os.system(\n            f\"cmake llama.cpp -B llama.cpp/build -DBUILD_SHARED_LIBS=OFF -DGGML_CUDA=OFF {CURL_FLAG}\"\n        )\n\n        if check != 0:\n            raise RuntimeError(\n                f\"*** Unsloth: Failed compiling llama.cpp using os.system(...) with error {check}. Please report this ASAP!\"\n            )\n        # f\"cmake --build llama.cpp/build --config Release -j{psutil.cpu_count()*2} --clean-first --target {' '.join(LLAMA_CPP_TARGETS)}\",\n        full_command = [\n            \"cmake\",\n            \"--build\",\n            \"llama.cpp/build\",\n            \"--config\",\n            \"Release\",\n            \"-j\" + str(n_jobs),\n            \"--clean-first\",\n            \"--target\",\n        ] + LLAMA_CPP_TARGETS\n        IS_CMAKE = True\n    # https://github.com/ggerganov/llama.cpp/issues/7062\n    # Weirdly GPU conversion for GGUF breaks??\n    # run_installer = subprocess.Popen(full_command, env = env, stdout = subprocess.DEVNULL, stderr = subprocess.STDOUT)\n    run_installer = subprocess.Popen(\n        full_command, stdout = subprocess.DEVNULL, stderr = subprocess.STDOUT\n    )\n    return run_installer, IS_CMAKE\n\n\ndef install_python_non_blocking(packages = []):\n    full_command = [\"pip\", \"install\"] + packages\n    run_installer = subprocess.Popen(\n        full_command, stdout = subprocess.DEVNULL, stderr = subprocess.STDOUT\n    )\n    return run_installer\n\n\ndef try_execute(commands, force_complete = False):\n    for command in commands:\n        with subprocess.Popen(\n            command,\n            shell = True,\n            stdout = subprocess.PIPE,\n            stderr = subprocess.STDOUT,\n            bufsize = 1,\n        ) as sp:\n            for line in sp.stdout:\n                line = line.decode(\"utf-8\", errors = \"replace\")\n                if \"undefined reference\" in line:\n                    raise RuntimeError(\n                        f\"*** Unsloth: Failed compiling llama.cpp with {line}. Please report this ASAP!\"\n                    )\n                elif \"deprecated\" in line:\n                    return \"CMAKE\"\n                elif \"Unknown argument\" in line:\n                    raise RuntimeError(\n                        f\"*** Unsloth: Failed compiling llama.cpp with {line}. Please report this ASAP!\"\n                    )\n                elif \"***\" in line:\n                    raise RuntimeError(\n                        f\"*** Unsloth: Failed compiling llama.cpp with {line}. Please report this ASAP!\"\n                    )\n                print(line, flush = True, end = \"\")\n            if force_complete and sp.returncode is not None and sp.returncode != 0:\n                raise subprocess.CalledProcessError(sp.returncode, sp.args)\n    return None\n\n\ndef install_llama_cpp_old(version = -10):\n    # Download the 10th latest release since the latest might be broken!\n    # FALLBACK mechanism\n    releases = subprocess.check_output(\n        [\"git\", \"ls-remote\", \"--tags\", \"https://github.com/ggerganov/llama.cpp.git\"]\n    )\n    releases = releases.decode(\"utf-8\").replace(\"\\t\", \" \").split(\"\\n\")\n    for i, x in enumerate(releases):\n        if \"refs/tags/b\" not in x:\n            break\n    releases = releases[:i]\n    latest = releases[-1]\n    version = releases[version].split(\" \")[0]\n\n    # Check if the llama.cpp exists\n    if os.path.exists(\"llama.cpp\"):\n        print(\n            \"**[WARNING]** You have a llama.cpp directory which is broken.\\n\"\n            \"Unsloth will DELETE the broken directory and install a new one.\\n\"\n            \"Press CTRL + C / cancel this if this is wrong. We shall wait 30 seconds.\\n\"\n        )\n        import time\n\n        for i in range(30):\n            print(f\"**[WARNING]** Deleting llama.cpp directory... {30-i} seconds left.\")\n            time.sleep(1)\n        import shutil\n\n        shutil.rmtree(\"llama.cpp\", ignore_errors = True)\n\n    # Clone a specific commit\n    # Also don't use the GPU!\n    commands = [\n        \"git clone --recursive https://github.com/ggerganov/llama.cpp\",\n        f\"cd llama.cpp && git reset --hard {version} && git clean -df\",\n    ]\n    try_execute(commands)\n\n    # Try using MAKE\n    commands = [\n        \"make clean -C llama.cpp\",\n        f\"make all -j{(psutil.cpu_count() or 1)*2} -C llama.cpp\",\n    ]\n    if try_execute(commands) == \"CMAKE\":\n        # Instead use CMAKE\n        commands = [\n            f\"cmake llama.cpp -B llama.cpp/build -DBUILD_SHARED_LIBS=OFF -DGGML_CUDA=OFF {CURL_FLAG}\",\n            f\"cmake --build llama.cpp/build --config Release -j{(psutil.cpu_count() or 1)*2} --clean-first --target {' '.join(LLAMA_CPP_TARGETS)}\",\n            \"cp llama.cpp/build/bin/llama-* llama.cpp\",\n            \"rm -rf llama.cpp/build\",\n        ]\n\n        try_execute(commands)\n\n    # Check if successful\n    if not (\n        os.path.exists(\"llama.cpp/llama-quantize.exe\")\n        or os.path.exists(\"llama.cpp/llama-quantize\")\n        or os.path.exists(\"llama.cpp/quantize.exe\")\n        or os.path.exists(\"llama.cpp/quantize\")\n        or os.path.exists(\"llama.cpp/build/bin/llama-quantize\")\n        or os.path.exists(\"llama.cpp/build/bin/quantize\")\n    ):\n        raise RuntimeError(\n            \"Unsloth: The file 'llama.cpp/llama-quantize' or `llama.cpp/quantize` does not exist.\\n\"\n            \"We've also double checked the building directory under 'llama.cpp/build/bin/'.\\n\"\n            \"But we expect this file to exist! Check if the file exists under llama.cpp and investigate the building process of llama.cpp (make/cmake)!\"\n        )\n\n\ndef install_llama_cpp_blocking(use_cuda = False):\n    # https://github.com/ggerganov/llama.cpp/issues/7062\n    # Weirdly GPU conversion for GGUF breaks??\n    # use_cuda = \"LLAMA_CUDA=1\" if use_cuda else \"\"\n\n    commands = [\n        \"git clone --recursive https://github.com/ggerganov/llama.cpp\",\n        \"pip install gguf protobuf\",\n    ]\n    if os.path.exists(\"llama.cpp\"):\n        return\n    try_execute(commands)\n\n    commands = [\n        \"make clean -C llama.cpp\",\n        # https://github.com/ggerganov/llama.cpp/issues/7062\n        # Weirdly GPU conversion for GGUF breaks??\n        # f\"{use_cuda} make all -j{(psutil.cpu_count() or 1)*2} -C llama.cpp\",\n        f\"make all -j{(psutil.cpu_count() or 1)*2} -C llama.cpp\",\n    ]\n    if try_execute(commands) == \"CMAKE\":\n        # Instead use CMAKE\n        commands = [\n            f\"cmake llama.cpp -B llama.cpp/build -DBUILD_SHARED_LIBS=OFF -DGGML_CUDA=OFF {CURL_FLAG}\",\n            f\"cmake --build llama.cpp/build --config Release -j{(psutil.cpu_count() or 1)*2} --clean-first --target {' '.join(LLAMA_CPP_TARGETS)}\",\n            \"cp llama.cpp/build/bin/llama-* llama.cpp\",\n            \"rm -rf llama.cpp/build\",\n        ]\n        try_execute(commands)\n\n\ndef get_executable(executables):\n    # Get system locations (System Path).split(system separator)\n    system_directories = os.environ.get(\"PATH\").split(os.pathsep)\n\n    for directory in system_directories:\n        for executable in executables:\n            path = os.path.join(directory, executable)\n            # Check if the executable exists and is executable\n            if os.path.exists(path) and os.access(path, os.X_OK):\n                return path\n    return None\n\n\ndef save_to_gguf(\n    model_name: str,\n    model_type: str,\n    model_dtype: str,\n    is_sentencepiece: bool = False,\n    model_directory: str = \"unsloth_finetuned_model\",\n    quantization_method = \"fast_quantized\",  # Can be a list of options! [\"q4_k_m\", \"q8_0\", \"q5_k_m\"]\n    first_conversion: str = None,\n    is_vlm: bool = False,\n    is_gpt_oss: bool = False,\n):\n    \"\"\"\n    Orchestrates the complete GGUF conversion process.\n    Handles installation, conversion, and quantization.\n    \"\"\"\n    # print_output True only if UNSLOTH_ENABLE_LOGGING=1\n    if os.environ.get(\"UNSLOTH_ENABLE_LOGGING\", \"0\") == \"1\":\n        print_output = True\n    else:\n        print_output = False\n\n    # Validate model dtype\n    assert model_dtype == \"float16\" or model_dtype == \"bfloat16\"\n    model_dtype = \"f16\" if model_dtype == \"float16\" else \"bf16\"\n\n    # Convert quantization_method to list\n    if isinstance(quantization_method, list):\n        pass\n    elif isinstance(quantization_method, str):\n        quantization_method = [\n            quantization_method,\n        ]\n    elif isinstance(quantization_method, tuple):\n        quantization_method = list(quantization_method)\n    else:\n        raise TypeError(\n            \"Unsloth: quantization_method can only be a string or a list of strings\"\n        )\n\n    # Check if bfloat16 is supported\n    if model_dtype == \"bf16\" and not torch.cuda.is_bf16_supported():\n        logger.warning(\n            \"Unsloth: Cannot convert to bf16 GGUF since your computer doesn't support it.\\n\"\n            \"We shall switch instead to f16.\"\n        )\n        model_dtype = \"f16\"\n\n    # Check first_conversion as well\n    if first_conversion is None:\n        first_conversion = model_dtype\n\n    # Check I quants\n    for quant_method in quantization_method:\n        if quant_method.startswith(\"iq2\"):\n            raise RuntimeError(\n                \"Unsloth: Currently iq2 type quantizations aren't supported yet - sorry!\"\n            )\n\n    # Map quant methods\n    new_quantization_methods = []\n    for quant_method in quantization_method:\n        if quant_method == \"not_quantized\":\n            quant_method = model_dtype\n        elif quant_method == \"fast_quantized\":\n            quant_method = \"q8_0\"\n        elif quant_method == \"quantized\":\n            quant_method = \"q4_k_m\"\n        elif quant_method is None:\n            quant_method = \"q8_0\"\n\n        # Check if wrong method\n        if quant_method not in ALLOWED_QUANTS.keys():\n            error = f\"Unsloth: Quant method = [{quant_method}] not supported. Choose from below:\\n\"\n            for key, value in ALLOWED_QUANTS.items():\n                error += f\"[{key}] => {value}\\n\"\n            raise RuntimeError(error)\n\n        new_quantization_methods.append(quant_method)\n    quantization_method = new_quantization_methods\n\n    # Determine optimal first_conversion\n    if is_gpt_oss:\n        print(\"Unsloth: GPT-OSS model detected - using special conversion settings\")\n        first_conversion = \"None\"  # No quantization for GPT-OSS\n        # Only keep one conversion method since GPT-OSS doesn't quantize\n        quantization_method = [\"None\"]\n    else:\n        if first_conversion is None:\n            # Check if q8_0 is the ONLY quantization method requested\n            if len(quantization_method) == 1 and quantization_method[0] == \"q8_0\":\n                first_conversion = \"None\"  # Let llama-quantize do the direct conversion\n            else:\n                # For all other cases, choose the highest precision format\n                # that can be requantized to all requested formats\n                strength = 0\n                for quant_method in quantization_method:\n                    if quant_method == \"f32\":\n                        strength = max(strength, 3)\n                    elif quant_method == \"f16\":\n                        strength = max(strength, 2)\n                    elif quant_method == \"bf16\":\n                        strength = max(strength, 1)\n                    # Note: we don't set strength for q8_0 here since we handle it above\n\n                if strength >= 3:\n                    first_conversion = \"f32\"\n                elif strength >= 2:\n                    first_conversion = \"f16\"\n                elif strength >= 1:\n                    first_conversion = \"bf16\"\n                else:\n                    first_conversion = \"bf16\"  # requantizing from q8_0 disallowed in new llama.cpp default to bf16.\n\n    # Check bfloat16 support again for first_conversion\n    if first_conversion == \"bf16\" and not torch.cuda.is_bf16_supported():\n        logger.warning(\"Unsloth: Switching bf16 to f16 due to hardware limitations\")\n        first_conversion = \"f16\"\n\n    first_conversion_dtype = \"\" if first_conversion == \"None\" else first_conversion\n    # Print conversion info\n    print_info = (\n        f\"==((====))==  Unsloth: Conversion from HF to GGUF information\\n\"\n        f\"   {chr(92)}{chr(92)}   /|    [0] Installing llama.cpp might take 3 minutes.\\n\"\n        f\"O^O/ {chr(92)}_/ {chr(92)}    [1] Converting HF to GGUF {first_conversion_dtype} might take 3 minutes.\\n\"\n        f\"{chr(92)}        /    [2] Converting GGUF {first_conversion_dtype} to {quantization_method} might take 10 minutes each.\\n\"\n        f' \"-____-\"     In total, you will have to wait at least 16 minutes.\\n'\n    )\n    print(print_info)\n\n    # Step 1: Ensure llama.cpp is installed\n    try:\n        quantizer_location, converter_location = check_llama_cpp()\n        print(\"Unsloth: llama.cpp found in the system. Skipping installation.\")\n    except:\n        print(\"Unsloth: Installing llama.cpp. This might take 3 minutes...\")\n        if IS_KAGGLE_ENVIRONMENT:\n            # Kaggle: no CUDA support due to environment limitations\n            quantizer_location, converter_location = install_llama_cpp(\n                gpu_support = False, print_output = print_output\n            )\n        else:\n            quantizer_location, converter_location = install_llama_cpp(\n                gpu_support = False,  # GGUF conversion doesn't need CUDA\n                print_output = print_output,\n            )\n\n    # Step 2: Download and patch converter script\n    print(\"Unsloth: Preparing converter script...\")\n    with use_local_gguf():\n        converter_path, supported_text_archs, supported_vision_archs = (\n            _download_convert_hf_to_gguf()\n        )\n\n        # Step 3: Initial GGUF conversion\n        print(\n            f\"Unsloth: [1] Converting model into {first_conversion_dtype} GGUF format.\"\n        )\n        print(f\"This might take 3 minutes...\")\n\n        initial_files, is_vlm_update = convert_to_gguf(\n            model_name = model_name,\n            input_folder = model_directory,\n            model_dtype = model_dtype,\n            quantization_type = first_conversion,\n            converter_location = converter_path,\n            supported_text_archs = supported_text_archs,\n            supported_vision_archs = supported_vision_archs,\n            is_vlm = is_vlm,\n            is_gpt_oss = is_gpt_oss,\n            max_shard_size = \"50GB\",\n            print_output = print_output,\n        )\n    # update is_vlm switch\n    is_vlm = is_vlm_update\n    # Check conversion success\n    for file in initial_files:\n        if not os.path.exists(file):\n            if IS_KAGGLE_ENVIRONMENT:\n                raise RuntimeError(\n                    f\"Unsloth: Conversion failed for {file}\\n\"\n                    \"You are in a Kaggle environment with limited disk space (20GB).\\n\"\n                    \"Try saving to /tmp for more space or use a smaller model.\\n\"\n                    \"Alternatively, save the 16bit model first, then convert manually.\"\n                )\n            else:\n                raise RuntimeError(\n                    f\"Unsloth: Conversion failed for {file}\\n\"\n                    \"Please check disk space and try again.\"\n                )\n\n    # Move initial GGUF files into a dedicated _gguf directory\n    gguf_directory = f\"{model_directory}_gguf\"\n    os.makedirs(gguf_directory, exist_ok = True)\n    moved_files = []\n    for fpath in initial_files:\n        dst = os.path.join(gguf_directory, os.path.basename(fpath))\n        shutil.move(fpath, dst)\n        moved_files.append(dst)\n    initial_files = moved_files\n\n    print(f\"Unsloth: Initial conversion completed! Files: {initial_files}\")\n\n    # Step 4: Additional quantizations using llama-quantize\n    all_saved_locations = initial_files.copy()\n\n    # Get CPU count for quantization\n    n_cpus = psutil.cpu_count()\n    if n_cpus is None:\n        n_cpus = 1\n    n_cpus *= 2\n\n    if not is_gpt_oss:\n        base_gguf = initial_files[0]\n        quants_created = False\n        for quant_method in quantization_method:\n            if quant_method != first_conversion:\n                print(\n                    f\"Unsloth: [2] Converting GGUF {first_conversion_dtype} into {quant_method}. This might take 10 minutes...\"\n                )\n                output_location = os.path.join(\n                    gguf_directory, f\"{model_name}.{quant_method.upper()}.gguf\"\n                )\n                try:\n                    # Use the quantize_gguf function we created\n                    quantized_file = quantize_gguf(\n                        input_gguf = base_gguf,\n                        output_gguf = output_location,\n                        quant_type = quant_method,\n                        quantizer_location = quantizer_location,\n                        print_output = print_output,\n                    )\n                    all_saved_locations.append(quantized_file)\n                    quants_created = True\n                except Exception as e:\n                    if IS_KAGGLE_ENVIRONMENT:\n                        raise RuntimeError(\n                            f\"Unsloth: Quantization failed for {output_location}\\n\"\n                            \"You are in a Kaggle environment, which might be the reason this is failing.\\n\"\n                            \"Kaggle only provides 20GB of disk space in the working directory.\\n\"\n                            \"Merging to 16bit for 7b models use 16GB of space.\\n\"\n                            \"This means using `model.{save_pretrained/push_to_hub}_merged` works, but\\n\"\n                            \"`model.{save_pretrained/push_to_hub}_gguf will use too much disk space.\\n\"\n                            \"You can try saving it to the `/tmp` directory for larger disk space.\\n\"\n                            \"I suggest you to save the 16bit model first, then use manual llama.cpp conversion.\\n\"\n                            f\"Error: {e}\"\n                        )\n                    else:\n                        if IS_WINDOWS:\n                            build_instructions = (\n                                f'cd \"{LLAMA_CPP_DEFAULT_DIR}\"\\n'\n                                f\"cmake -S . -B build -DBUILD_SHARED_LIBS=OFF\\n\"\n                                f\"cmake --build build --config Release\"\n                            )\n                        else:\n                            build_instructions = f'cd \"{LLAMA_CPP_DEFAULT_DIR}\" && make clean && make all -j'\n\n                        raise RuntimeError(\n                            f\"Unsloth: Quantization failed for {output_location}\\n\"\n                            \"You might have to compile llama.cpp yourself, then run this again.\\n\"\n                            \"You do not need to close this Python program. Run the following commands in a new terminal:\\n\"\n                            f'git clone --recursive https://github.com/ggerganov/llama.cpp \"{LLAMA_CPP_DEFAULT_DIR}\"\\n'\n                            f\"{build_instructions}\\n\"\n                            \"Once that's done, redo the quantization.\\n\"\n                            f\"Error: {e}\"\n                        )\n        print(\"Unsloth: Model files cleanup...\")\n        if quants_created:\n            all_saved_locations.remove(base_gguf)\n            Path(base_gguf).unlink(missing_ok = True)\n\n            # flip the list to get [text_model, mmproj] order. for text models stays the same.\n            all_saved_locations.reverse()\n    else:\n        print(\"Unsloth: GPT-OSS model - skipping additional quantizations\")\n\n    if is_gpt_oss:\n        want_full_precision = True\n    else:\n        want_full_precision = first_conversion in frozenset(quantization_method)\n\n    print(f\"Unsloth: All GGUF conversions completed successfully!\")\n    print(f\"Generated files: {all_saved_locations}\")\n\n    return all_saved_locations, want_full_precision, is_vlm\n\n\ndef unsloth_save_pretrained_merged(\n    self,\n    save_directory: Union[str, os.PathLike],\n    tokenizer = None,\n    save_method: str = \"merged_16bit\",  # [\"lora\", \"merged_16bit\", \"merged_4bit\"]\n    push_to_hub: bool = False,\n    token: Optional[Union[str, bool]] = None,\n    is_main_process: bool = True,\n    state_dict: Optional[dict] = None,\n    save_function: Callable = torch.save,\n    max_shard_size: Union[int, str] = \"5GB\",\n    safe_serialization: bool = True,\n    variant: Optional[str] = None,\n    save_peft_format: bool = True,\n    tags: List[str] = None,\n    temporary_location: str = \"_unsloth_temporary_saved_buffers\",\n    maximum_memory_usage: float = 0.75,\n    datasets: Optional[List[str]] = None,\n):\n    \"\"\"\n    Same as .save_pretrained(...) except 4bit weights are auto\n    converted to float16 with as few overhead as possible.\n\n    Choose for `save_method` to be either:\n    1. `16bit`: Merge LoRA into float16 weights. Useful for GGUF / llama.cpp.\n    2.  `4bit`: Merge LoRA into int4 weights. Useful for DPO / HF inference.\n    3.  `lora`: Save LoRA adapters with no merging. Useful for HF inference.\n    \"\"\"\n    if tokenizer is None:\n        logger.warning_once(\n            \"Unsloth: You're not saving a tokenizer as well?\\n\"\n            \"You can do it separately via `tokenizer.save_pretrained(...)`\"\n        )\n\n    arguments = dict(locals())\n    arguments[\"model\"] = self\n    del arguments[\"self\"]\n    unsloth_save_model(**arguments)\n    for _ in range(3):\n        gc.collect()\n\n\ndef unsloth_push_to_hub_merged(\n    self,\n    repo_id: str,\n    tokenizer = None,\n    save_method: str = \"merged_16bit\",  # [\"lora\", \"merged_16bit\", \"merged_4bit\"]\n    use_temp_dir: Optional[bool] = None,\n    commit_message: Optional[str] = \"Trained with Unsloth\",\n    private: Optional[bool] = None,\n    token: Union[bool, str, None] = None,\n    max_shard_size: Union[int, str, None] = \"5GB\",\n    create_pr: bool = False,\n    safe_serialization: bool = True,\n    revision: str = None,\n    commit_description: str = \"Upload model trained with Unsloth 2x faster\",\n    tags: Optional[List[str]] = None,\n    temporary_location: str = \"_unsloth_temporary_saved_buffers\",\n    maximum_memory_usage: float = 0.75,\n    datasets: Optional[List[str]] = None,\n):\n    \"\"\"\n    Same as .push_to_hub(...) except 4bit weights are auto\n    converted to float16 with as few overhead as possible.\n\n    Choose for `save_method` to be either:\n    1. `16bit`: Merge LoRA into float16 weights. Useful for GGUF / llama.cpp.\n    2.  `4bit`: Merge LoRA into int4 weights. Useful for DPO / HF inference.\n    3.  `lora`: Save LoRA adapters with no merging. Useful for HF inference.\n    \"\"\"\n    if tokenizer is None:\n        logger.warning_once(\n            \"Unsloth: You're not saving a tokenizer as well?\\n\"\n            \"You can do it separately via `tokenizer.push_to_hub(...)`\"\n        )\n\n    arguments = dict(locals())\n    arguments[\"model\"] = self\n    arguments[\"save_directory\"] = repo_id\n    arguments[\"push_to_hub\"] = True\n    del arguments[\"self\"]\n    del arguments[\"repo_id\"]\n    unsloth_save_model(**arguments)\n    for _ in range(3):\n        gc.collect()\n\n\nMODEL_CARD = \"\"\"---\nbase_model: {base_model}\ntags:\n- text-generation-inference\n- transformers\n- unsloth\n- {model_type}\n- {extra}\nlicense: apache-2.0\nlanguage:\n- en\n---\n\n# Uploaded {method} model\n\n- **Developed by:** {username}\n- **License:** apache-2.0\n- **Finetuned from model :** {base_model}\n\nThis {model_type} model was trained 2x faster with [Unsloth](https://github.com/unslothai/unsloth)\n\n[<img src=\"https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20made%20with%20love.png\" width=\"200\"/>](https://github.com/unslothai/unsloth)\n\"\"\"\n\n\ndef _determine_username(save_directory, old_username, token):\n    username = \"\"\n    save_directory = save_directory.lstrip(\"./\")\n    if \"/\" not in save_directory:\n        from huggingface_hub import whoami\n\n        try:\n            username = whoami(token = token)[\"name\"]\n            if type(old_username) is str and username != old_username:\n                username = old_username\n            save_directory = f\"{username}/{save_directory}\"\n        except:\n            raise RuntimeError(\n                f\"Unsloth: {save_directory} is not a Huggingface directory.\"\n            )\n    else:\n        username = save_directory.split(\"/\")[0]\n    return save_directory, username\n\n\ndef create_huggingface_repo(\n    model,\n    save_directory,\n    token = None,\n    private = False,\n    datasets = None,\n):\n    if token is None:\n        token = get_token()\n    save_directory, username = _determine_username(save_directory, None, token)\n\n    from huggingface_hub import create_repo\n\n    try:\n        create_repo(\n            repo_id = save_directory,\n            token = token,\n            repo_type = \"model\",\n            exist_ok = False,\n            private = private,\n        )\n\n        # Create model card\n        from huggingface_hub import ModelCard\n\n        content = MODEL_CARD.format(\n            username = username,\n            base_model = model.config._name_or_path,\n            model_type = model.config.model_type,\n            method = \"\",\n            extra = \"unsloth\",\n        )\n        card = ModelCard(content)\n        if datasets:\n            card.data.datasets = datasets\n        card.push_to_hub(save_directory, token = token)\n    except:\n        # Repo already exists — update datasets metadata separately\n        if datasets:\n            try:\n                from huggingface_hub import metadata_update\n\n                metadata_update(\n                    save_directory, {\"datasets\": datasets}, overwrite = True, token = token\n                )\n            except Exception as e:\n                logger.warning_once(\n                    f\"Unsloth: Could not update datasets metadata for {save_directory}: {e}\"\n                )\n    hf_api = HfApi(token = token)\n    return save_directory, hf_api\n\n\ndef upload_to_huggingface(\n    model,\n    save_directory,\n    token,\n    method,\n    extra = \"\",\n    file_location = None,\n    old_username = None,\n    private = None,\n    create_config = True,\n    datasets = None,\n):\n    save_directory, username = _determine_username(save_directory, old_username, token)\n\n    from huggingface_hub import create_repo\n\n    try:\n        create_repo(\n            repo_id = save_directory,\n            token = token,\n            repo_type = \"model\",\n            exist_ok = False,\n            private = private,\n        )\n\n        # Create model card\n        from huggingface_hub import ModelCard\n\n        content = MODEL_CARD.format(\n            username = username,\n            base_model = model.config._name_or_path,\n            model_type = model.config.model_type,\n            method = \"\",\n            extra = extra,\n        )\n        card = ModelCard(content)\n        if datasets:\n            card.data.datasets = datasets\n        card.push_to_hub(save_directory, token = token)\n    except:\n        # Repo already exists — update datasets metadata separately\n        if datasets:\n            try:\n                from huggingface_hub import metadata_update\n\n                metadata_update(\n                    save_directory, {\"datasets\": datasets}, overwrite = True, token = token\n                )\n            except Exception as e:\n                logger.warning_once(\n                    f\"Unsloth: Could not update datasets metadata for {save_directory}: {e}\"\n                )\n\n    if file_location is not None:\n        # Now upload file\n        hf_api = HfApi(token = token)\n\n        if \"/\" in file_location:\n            uploaded_location = file_location[file_location.rfind(\"/\") + 1 :]\n        else:\n            uploaded_location = file_location\n\n        # find ftevent file from tensorboard and upload it\n        import glob\n\n        ftevent_files = glob.glob(\"*out.tfevents*\", recursive = True)\n        if len(ftevent_files) > 0:\n            print(\n                \"Unsloth: Uploading tensorboard files... Please wait...\",\n                file_location + \"*out.tfevents*\",\n            )\n            for ftevent_file in ftevent_files:\n                hf_api.upload_file(\n                    path_or_fileobj = ftevent_file,\n                    path_in_repo = ftevent_file.replace(file_location, \"\"),\n                    repo_id = save_directory,\n                    repo_type = \"model\",\n                    commit_message = \"(Trained with Unsloth)\",\n                )\n\n        hf_api.upload_file(\n            path_or_fileobj = file_location,\n            path_in_repo = uploaded_location,\n            repo_id = save_directory,\n            repo_type = \"model\",\n            commit_message = \"(Trained with Unsloth)\",\n        )\n\n        # We also upload a config.json file\n        if create_config:\n            import json\n\n            with open(\"_temporary_unsloth_config.json\", \"w\", encoding = \"utf-8\") as file:\n                json.dump({\"model_type\": model.config.model_type}, file, indent = 4)\n            hf_api.upload_file(\n                path_or_fileobj = \"_temporary_unsloth_config.json\",\n                path_in_repo = \"config.json\",\n                repo_id = save_directory,\n                repo_type = \"model\",\n                commit_message = \"(Trained with Unsloth)\",\n            )\n            os.remove(\"_temporary_unsloth_config.json\")\n    return username\n\n\ndef fix_tokenizer_bos_token(tokenizer):\n    # Check if BOS added already, then warn\n    fix_bos_token = False\n    chat_template = getattr(tokenizer, \"chat_template\", None)\n\n    if tokenizer(\"A\").input_ids[0] == getattr(tokenizer, \"bos_token_id\", None):\n        if chat_template is not None and (\n            tokenizer.bos_token in chat_template\n            or \"{bos_token}\" in chat_template.replace(\" \", \"\")\n            or \"{bos_token+\" in chat_template.replace(\" \", \"\")\n        ):\n            fix_bos_token = True\n            logger.warning(\n                \"Unsloth: ##### The current model auto adds a BOS token.\\n\"\n                \"Unsloth: ##### Your chat template has a BOS token. We shall remove it temporarily.\"\n            )\n\n            # Remove {{bos_token}}\n            new_chat_template = re.sub(\n                r\"\\{[\\s]{0,}\\{[\\s]{0,}bos\\_token[\\s]{0,}\\}[\\s]{0,}\\}\", \"\", chat_template\n            )\n            # Remove {{bos_token +\n            new_chat_template = re.sub(\n                r\"\\{[\\s]{0,}\\{[\\s]{0,}bos\\_token[\\s]{0,}\\+[\\s]{0,}\",\n                \"\",\n                new_chat_template,\n            )\n\n            tokenizer.chat_template = new_chat_template\n\n    return fix_bos_token, chat_template\n\n\ndef create_ollama_modelfile(tokenizer, base_model_name, model_location):\n    \"\"\"\n    Creates an Ollama Modelfile.\n    Use ollama.create(model = \"new_ollama_model\", modelfile = modelfile)\n    \"\"\"\n    ollama_template_name = MODEL_TO_OLLAMA_TEMPLATE_MAPPER.get(base_model_name)\n    if not ollama_template_name:\n        print(\n            f\"Unsloth: No Ollama template mapping found for model '{base_model_name}'. Skipping Ollama Modelfile\"\n        )\n        return None\n    ollama_modelfile = OLLAMA_TEMPLATES.get(ollama_template_name)\n    if not ollama_modelfile:\n        print(\n            f\"Unsloth: No Ollama template mapping found for model '{base_model_name}'. Skipping Ollama Modelfile\"\n        )\n        return None\n    tokenizer._ollama_modelfile = (\n        ollama_modelfile  # This comes from the unpacking above\n    )\n    modelfile = ollama_modelfile\n\n    FILE_LOCATION_REPLACER = \"⚫@✅#🦥__FILE_LOCATION__⚡@🦥#⛵\"\n    EOS_TOKEN_REPLACER = \"⚫@✅#🦥__EOS_TOKEN__⚡@🦥#⛵\"\n    LEFT_BRACKET_REPLACER = \"⚫@✅#🦥\"\n    RIGHT_BRACKET_REPLACER = \"⚡@🦥#⛵\"\n\n    # Fixes https://github.com/unslothai/unsloth/issues/1087\n    # We must convert all {'s and }'s but keep {__FILE_LOCATION__} intact\n    modelfile = (\n        modelfile.replace(\"{__FILE_LOCATION__}\", FILE_LOCATION_REPLACER)\n        .replace(\"{__EOS_TOKEN__}\", EOS_TOKEN_REPLACER)\n        .replace(\"{\", LEFT_BRACKET_REPLACER)\n        .replace(\"}\", RIGHT_BRACKET_REPLACER)\n    )\n\n    # Revert {__FILE_LOCATION__} back\n    modelfile = modelfile.replace(\n        FILE_LOCATION_REPLACER, \"{__FILE_LOCATION__}\"\n    ).replace(EOS_TOKEN_REPLACER, \"{__EOS_TOKEN__}\")\n\n    if \"__EOS_TOKEN__\" in modelfile:\n        modelfile = modelfile.format(\n            __FILE_LOCATION__ = model_location,\n            __EOS_TOKEN__ = tokenizer.eos_token,\n        )\n    else:\n        modelfile = modelfile.format(\n            __FILE_LOCATION__ = model_location,\n        )\n\n    modelfile = modelfile.replace(\"⚫@✅#🦥\", \"{\").replace(\"⚡@🦥#⛵\", \"}\").rstrip()\n\n    return modelfile\n\n\ndef create_ollama_model(username: str, model_name: str, tag: str, modelfile_path: str):\n    try:\n        init_check = subprocess.run(\n            [\"curl\", \"http://localhost:11434\"],\n            capture_output = True,\n            text = True,\n            timeout = 3,\n        )\n        if init_check.returncode == 0:\n            print(init_check.stdout.strip())\n        else:\n            print(\"Ollama Server is not Running\")\n    except subprocess.TimeoutExpired:\n        return \"Ollama Request Timeout\"\n\n    process = subprocess.Popen(\n        [\n            \"ollama\",\n            \"create\",\n            f\"{username}/{model_name}:{tag}\",\n            \"-f\",\n            f\"{modelfile_path}\",\n        ],\n        stdout = subprocess.PIPE,\n        stderr = subprocess.STDOUT,\n        text = True,\n        bufsize = 1,\n        universal_newlines = True,\n    )\n\n    for line in iter(process.stdout.readline, \"\"):\n        print(line, end = \"\")\n        sys.stdout.flush()\n\n    return_code = process.wait()\n\n    if return_code != 0:\n        print(f\"\\nMODEL CREATED FAILED WITH RETURN CODE {return_code}\")\n    else:\n        print(\"\\nMODEL CREATED SUCCESSFULLY\")\n\n\ndef push_to_ollama_hub(username: str, model_name: str, tag: str):\n    try:\n        init_check = subprocess.run(\n            [\"curl\", \"http://localhost:11434\"],\n            capture_output = True,\n            text = True,\n            timeout = 3,\n        )\n        if init_check.returncode == 0:\n            print(init_check.stdout.strip())\n        else:\n            print(\"Ollama Server is not Running\")\n    except subprocess.TimeoutExpired:\n        return \"Ollama Request Timeout\"\n\n    process = subprocess.Popen(\n        [\"ollama\", \"push\", f\"{username}/{model_name}:{tag}\"],\n        stdout = subprocess.PIPE,\n        stderr = subprocess.STDOUT,\n        text = True,\n        bufsize = 1,\n        universal_newlines = True,\n    )\n\n    for line in iter(process.stdout.readline, \"\"):\n        print(line, end = \"\")\n        sys.stdout.flush()\n\n    return_code = process.wait()\n\n    if return_code != 0:\n        print(f\"\\nMODEL PUBLISHED FAILED WITH RETURN CODE {return_code}\")\n    else:\n        print(\"\\nMODEL PUBLISHED SUCCESSFULLY\")\n\n\ndef push_to_ollama(tokenizer, gguf_location, username: str, model_name: str, tag: str):\n    model_file = create_ollama_modelfile(\n        tokenizer = tokenizer, gguf_location = gguf_location\n    )\n\n    with open(f\"Modelfile_{model_name}\", \"w\", encoding = \"utf-8\") as f:\n        f.write(model_file)\n        f.close()\n\n    create_ollama_model(\n        username = username,\n        model_name = model_name,\n        tag = tag,\n        modelfile_path = f\"Modelfile_{model_name}\",\n    )\n\n    push_to_ollama_hub(username = username, model_name = model_name, tag = tag)\n\n    print(\"Successfully pushed to ollama\")\n\n\ndef unsloth_save_pretrained_gguf(\n    self,\n    save_directory: Union[str, os.PathLike],\n    tokenizer = None,\n    quantization_method = \"fast_quantized\",\n    first_conversion: str = None,\n    push_to_hub: bool = False,\n    token: Optional[Union[str, bool]] = None,\n    private: Optional[bool] = None,\n    is_main_process: bool = True,\n    state_dict: Optional[dict] = None,\n    save_function: Callable = torch.save,\n    max_shard_size: Union[int, str] = \"5GB\",\n    safe_serialization: bool = True,\n    variant: Optional[str] = None,\n    save_peft_format: bool = True,\n    tags: List[str] = None,\n    temporary_location: str = \"_unsloth_temporary_saved_buffers\",\n    maximum_memory_usage: float = 0.85,\n):\n    \"\"\"\n    Same as .save_pretrained(...) except 4bit weights are auto\n    converted to float16 then converted to GGUF / llama.cpp format.\n\n    Choose for `quantization_method` to be:\n    \"not_quantized\"  : \"Recommended. Fast conversion. Slow inference, big files.\",\n    \"fast_quantized\" : \"Recommended. Fast conversion. OK inference, OK file size.\",\n    \"quantized\"      : \"Recommended. Slow conversion. Fast inference, small files.\",\n    \"f32\"     : \"Not recommended. Retains 100% accuracy, but super slow and memory hungry.\",\n    \"f16\"     : \"Fastest conversion + retains 100% accuracy. Slow and memory hungry.\",\n    \"q8_0\"    : \"Fast conversion. High resource use, but generally acceptable.\",\n    \"q4_k_m\"  : \"Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K\",\n    \"q5_k_m\"  : \"Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K\",\n    \"q2_k\"    : \"Uses Q4_K for the attention.vw and feed_forward.w2 tensors, Q2_K for the other tensors.\",\n    \"q3_k_l\"  : \"Uses Q5_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K\",\n    \"q3_k_m\"  : \"Uses Q4_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K\",\n    \"q3_k_s\"  : \"Uses Q3_K for all tensors\",\n    \"q4_0\"    : \"Original quant method, 4-bit.\",\n    \"q4_1\"    : \"Higher accuracy than q4_0 but not as high as q5_0. However has quicker inference than q5 models.\",\n    \"q4_k_s\"  : \"Uses Q4_K for all tensors\",\n    \"q4_k\"    : \"alias for q4_k_m\",\n    \"q5_k\"    : \"alias for q5_k_m\",\n    \"q5_0\"    : \"Higher accuracy, higher resource usage and slower inference.\",\n    \"q5_1\"    : \"Even higher accuracy, resource usage and slower inference.\",\n    \"q5_k_s\"  : \"Uses Q5_K for all tensors\",\n    \"q6_k\"    : \"Uses Q8_K for all tensors\",\n    \"iq2_xxs\" : \"2.06 bpw quantization\",\n    \"iq2_xs\"  : \"2.31 bpw quantization\",\n    \"iq3_xxs\" : \"3.06 bpw quantization\",\n    \"q3_k_xs\" : \"3-bit extra small quantization\",\n    \"\"\"\n    if tokenizer is None:\n        raise ValueError(\"Unsloth: Saving to GGUF must have a tokenizer.\")\n\n    try:\n        base_model_name = get_model_name(self.config._name_or_path, load_in_4bit = False)\n        model_name = base_model_name.split(\"/\")[-1]\n    except:\n        base_model_name = self.config._name_or_path\n        model_name = base_model_name.split(\"/\")[-1]\n\n    # Check if push_to_hub is requested\n    if push_to_hub:\n        raise ValueError(\n            \"Unsloth: Please use .push_to_hub_gguf() instead of .save_pretrained_gguf() with push_to_hub=True\"\n        )\n\n    # Step 1: Check if this is a VLM (Vision-Language Model) and check if gpt-oss\n    is_vlm = False\n    if hasattr(self, \"config\") and hasattr(self.config, \"architectures\"):\n        is_vlm = any(\n            x.endswith((\"ForConditionalGeneration\", \"ForVisionText2Text\"))\n            for x in self.config.architectures\n        )\n        is_vlm = is_vlm or hasattr(self.config, \"vision_config\")\n\n    is_processor = is_vlm and isinstance(tokenizer, ProcessorMixin)\n\n    is_gpt_oss = (\n        True\n        if (\n            hasattr(self.config, \"architectures\")\n            and self.config.architectures == \"GptOssForCausalLM\"\n        )\n        or (\n            hasattr(self.config, \"model_type\")\n            and self.config.model_type in [\"gpt-oss\", \"gpt_oss\"]\n        )\n        else False\n    )\n    # Step 2: Prepare arguments for model saving\n    arguments = dict(locals())\n    arguments[\"model\"] = self\n    arguments[\"tokenizer\"] = tokenizer\n    arguments[\"push_to_hub\"] = False  # We handle upload ourselves\n    # GPT-OSS needs mxfp4 save method\n    if is_gpt_oss:\n        if quantization_method is not None:\n            _qm = (\n                quantization_method\n                if isinstance(quantization_method, (list, tuple))\n                else [quantization_method]\n            )\n            _ignored = [q for q in _qm if str(q).lower() != \"mxfp4\"]\n            if _ignored:\n                logger.warning_once(\n                    f\"Unsloth: GPT-OSS does not support GGUF quantization \"\n                    f\"(requested: {', '.join(str(q) for q in _ignored)}). \"\n                    f\"Overriding to MXFP4 format. \"\n                    f\"Pass quantization_method=None to suppress this warning.\"\n                )\n        arguments[\"save_method\"] = \"mxfp4\"\n    else:\n        arguments[\"save_method\"] = \"merged_16bit\"\n    del arguments[\"self\"]\n    del arguments[\"quantization_method\"]\n    del arguments[\"first_conversion\"]\n    del arguments[\"is_vlm\"]\n    del arguments[\"is_gpt_oss\"]\n    del arguments[\"model_name\"]\n    del arguments[\"base_model_name\"]\n    del arguments[\"is_processor\"]\n\n    # Step 3: Fix tokenizer BOS token if needed\n    if is_processor:\n        fix_bos_token, old_chat_template = fix_tokenizer_bos_token(tokenizer.tokenizer)\n    else:\n        fix_bos_token, old_chat_template = fix_tokenizer_bos_token(tokenizer)\n\n    # Step 4: Save/merge model to 16-bit format\n    print(\n        f'Unsloth: Merging model weights to {\"mxfp4\" if is_gpt_oss else \"16-bit\"} format...'\n    )\n    try:\n        # Call unsloth_generic_save directly (it's in the same file)\n        unsloth_generic_save(**arguments)\n\n    except Exception as e:\n        raise RuntimeError(f\"Failed to save/merge model: {e}\")\n\n    if is_processor:\n        tokenizer = tokenizer.tokenizer\n\n    # Use old chat template if the bos is removed\n    if fix_bos_token:\n        tokenizer.chat_template = old_chat_template\n\n    # Step 6: Clean up memory\n    for _ in range(3):\n        import gc\n\n        gc.collect()\n        if torch.cuda.is_available():\n            torch.cuda.empty_cache()\n\n    # Step 7: Get model dtype and type\n    try:\n        model_dtype = dtype_from_config(self.config)\n        model_type = self.config.model_type\n        if type(model_dtype) is str:\n            assert model_dtype == \"float16\" or model_dtype == \"bfloat16\"\n        elif model_dtype == torch.float16:\n            model_dtype = \"float16\"\n        elif model_dtype == torch.bfloat16:\n            model_dtype = \"bfloat16\"\n        else:\n            raise TypeError(\"Unsloth: Model dtype can only be float16 or bfloat16\")\n    except Exception as e:\n        # Fallback if dtype_from_config fails\n        print(f\"Unsloth: Could not determine dtype ({e}), defaulting to float16\")\n        model_dtype = \"float16\"\n\n    # Step 8: Convert to GGUF format\n    print(\"Unsloth: Converting to GGUF format...\")\n\n    # Convert quantization_method to list if string\n    # Use old style quantization_method\n    quantization_methods = []\n    if quantization_method is not None:\n        # Convert quantization_method to list\n        if isinstance(quantization_method, list):\n            pass\n        elif isinstance(quantization_method, str):\n            quantization_method = [\n                quantization_method,\n            ]\n        elif isinstance(quantization_method, tuple):\n            quantization_method = list(quantization_method)\n        else:\n            raise TypeError(\n                \"Unsloth: quantization_method can only be a string or a list of strings\"\n            )\n        for i, quant_method in enumerate(quantization_method):\n            quant_method = quant_method.lower()\n            if quant_method == \"not_quantized\":\n                quant_method = \"f16\"\n            elif quant_method == \"fast_quantized\":\n                quant_method = \"q8_0\"\n            elif quant_method == \"quantized\":\n                quant_method = \"q4_k_m\"\n            elif quant_method is None:\n                quant_method = \"q8_0\"\n            quantization_methods.append(quant_method.lower())\n\n    try:\n        all_file_locations, want_full_precision, is_vlm_update = save_to_gguf(\n            model_name = model_name,\n            model_type = model_type,\n            model_dtype = model_dtype,\n            is_sentencepiece = False,\n            model_directory = save_directory,\n            quantization_method = quantization_methods,\n            first_conversion = first_conversion,\n            is_vlm = is_vlm,  # Pass VLM flag\n            is_gpt_oss = is_gpt_oss,  # Pass gpt_oss Flag\n        )\n    except Exception as e:\n        if IS_KAGGLE_ENVIRONMENT:\n            raise RuntimeError(\n                f\"Unsloth: GGUF conversion failed in Kaggle environment.\\n\"\n                f\"This is likely due to the 20GB disk space limit.\\n\"\n                f\"Try saving to /tmp directory or use a smaller model.\\n\"\n                f\"Error: {e}\"\n            )\n        else:\n            raise RuntimeError(f\"Unsloth: GGUF conversion failed: {e}\")\n\n    # Step 9: Create Ollama modelfile\n    gguf_directory = f\"{save_directory}_gguf\"\n    modelfile_location = None\n    ollama_success = False\n    if all_file_locations:\n        try:\n            if is_vlm_update:\n                modelfile = create_ollama_modelfile(tokenizer, base_model_name, \".\")\n            else:\n                modelfile = create_ollama_modelfile(\n                    tokenizer,\n                    base_model_name,\n                    os.path.basename(all_file_locations[0]),\n                )\n            if modelfile is not None:\n                modelfile_location = os.path.join(gguf_directory, \"Modelfile\")\n                with open(modelfile_location, \"w\", encoding = \"utf-8\") as file:\n                    file.write(modelfile)\n                ollama_success = True\n        except Exception as e:\n            print(f\"Warning: Could not create Ollama modelfile: {e}\")\n\n    # Step 10: Show BOS token warning if applicable\n    if fix_bos_token:\n        logger.warning(\n            \"Unsloth: ##### The current model auto adds a BOS token.\\n\"\n            \"Unsloth: ##### We removed it in GGUF's chat template for you.\"\n        )\n\n    _exe = \".exe\" if IS_WINDOWS else \"\"\n    if IS_WINDOWS:\n        _bin_dir = os.path.join(LLAMA_CPP_DEFAULT_DIR, \"build\", \"bin\", \"Release\")\n    else:\n        _bin_dir = LLAMA_CPP_DEFAULT_DIR\n\n    if is_vlm_update:\n        print(\"\\n\")\n        print(\n            f\"Unsloth: example usage for Multimodal LLMs: {os.path.join(_bin_dir, 'llama-mtmd-cli' + _exe)} -m {all_file_locations[0]} --mmproj {all_file_locations[-1]}\"\n        )\n        print(\"Unsloth: load image inside llama.cpp runner: /image test_image.jpg\")\n        print(\"Unsloth: Prompt model to describe the image\")\n    else:\n        print(\n            f'Unsloth: example usage for text only LLMs: {os.path.join(_bin_dir, \"llama-cli\" + _exe)} --model {all_file_locations[0]} -p \"why is the sky blue?\"'\n        )\n\n    if ollama_success:\n        print(f\"Unsloth: Saved Ollama Modelfile to {modelfile_location}\")\n        print(\n            f\"Unsloth: convert model to ollama format by running - ollama create model_name -f {modelfile_location}\"\n        )\n\n    # Return a dict with all needed info for push_to_hub\n    return {\n        \"save_directory\": save_directory,\n        \"gguf_directory\": gguf_directory,\n        \"gguf_files\": all_file_locations,\n        \"modelfile_location\": modelfile_location,\n        \"want_full_precision\": want_full_precision,\n        \"is_vlm\": is_vlm_update,\n        \"fix_bos_token\": fix_bos_token,\n    }\n\n\ndef unsloth_push_to_hub_gguf(\n    self,\n    repo_id: str,\n    tokenizer = None,\n    quantization_method = \"fast_quantized\",\n    first_conversion: str = None,\n    use_temp_dir: Optional[bool] = None,\n    commit_message: Optional[str] = \"Trained with Unsloth\",\n    private: Optional[bool] = None,\n    token: Union[bool, str, None] = None,\n    max_shard_size: Union[int, str, None] = \"5GB\",\n    create_pr: bool = False,\n    safe_serialization: bool = True,\n    revision: str = None,\n    commit_description: str = \"Upload model trained with Unsloth 2x faster\",\n    tags: Optional[List[str]] = None,\n    temporary_location: str = \"_unsloth_temporary_saved_buffers\",\n    maximum_memory_usage: float = 0.85,\n    datasets: Optional[List[str]] = None,\n):\n    \"\"\"\n    Same as .push_to_hub(...) except 4bit weights are auto\n    converted to float16 then converted to GGUF / llama.cpp format.\n\n    Choose for `quantization_method` to be:\n    \"not_quantized\"  : \"Recommended. Fast conversion. Slow inference, big files.\",\n    \"fast_quantized\" : \"Recommended. Fast conversion. OK inference, OK file size.\",\n    \"quantized\"      : \"Recommended. Slow conversion. Fast inference, small files.\",\n    \"f32\"     : \"Not recommended. Retains 100% accuracy, but super slow and memory hungry.\",\n    \"f16\"     : \"Fastest conversion + retains 100% accuracy. Slow and memory hungry.\",\n    \"q8_0\"    : \"Fast conversion. High resource use, but generally acceptable.\",\n    \"q4_k_m\"  : \"Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K\",\n    \"q5_k_m\"  : \"Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K\",\n    \"q2_k\"    : \"Uses Q4_K for the attention.vw and feed_forward.w2 tensors, Q2_K for the other tensors.\",\n    \"q3_k_l\"  : \"Uses Q5_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K\",\n    \"q3_k_m\"  : \"Uses Q4_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K\",\n    \"q3_k_s\"  : \"Uses Q3_K for all tensors\",\n    \"q4_0\"    : \"Original quant method, 4-bit.\",\n    \"q4_1\"    : \"Higher accuracy than q4_0 but not as high as q5_0. However has quicker inference than q5 models.\",\n    \"q4_k_s\"  : \"Uses Q4_K for all tensors\",\n    \"q5_0\"    : \"Higher accuracy, higher resource usage and slower inference.\",\n    \"q5_1\"    : \"Even higher accuracy, resource usage and slower inference.\",\n    \"q5_k_s\"  : \"Uses Q5_K for all tensors\",\n    \"q6_k\"    : \"Uses Q8_K for all tensors\",\n    \"\"\"\n    if tokenizer is None:\n        raise ValueError(\"Unsloth: Saving to GGUF must have a tokenizer.\")\n\n    # Step 1: Determine save directory\n    model_name = repo_id.split(\"/\")[-1] if \"/\" in repo_id else repo_id\n\n    if use_temp_dir or use_temp_dir is None:\n        import tempfile\n\n        temp_dir = tempfile.mkdtemp(prefix = \"unsloth_gguf_\")\n        save_directory = temp_dir\n        cleanup_temp = True\n    else:\n        save_directory = model_name  # Use model name, not repo_id\n        cleanup_temp = False\n\n    # Step 2: Call save_pretrained_gguf to do the conversion\n    print(f\"Unsloth: Converting model to GGUF format...\")\n\n    try:\n        # Call save_pretrained_gguf - it returns all the info we need\n        result = unsloth_save_pretrained_gguf(\n            self = self,\n            save_directory = save_directory,\n            tokenizer = tokenizer,\n            quantization_method = quantization_method,\n            first_conversion = first_conversion,\n            push_to_hub = False,  # Never push from here\n            token = None,  # Don't need token for local save\n            max_shard_size = max_shard_size,\n            safe_serialization = safe_serialization,\n            temporary_location = temporary_location,\n            maximum_memory_usage = maximum_memory_usage,\n        )\n\n        # Extract results\n        all_file_locations = result[\"gguf_files\"]\n        modelfile_location = result[\"modelfile_location\"]\n        want_full_precision = result[\"want_full_precision\"]\n        is_vlm = result[\"is_vlm\"]\n        fix_bos_token = result[\"fix_bos_token\"]\n        actual_save_directory = result[\"save_directory\"]\n\n    except Exception as e:\n        if cleanup_temp:\n            import shutil\n\n            for d in [save_directory, f\"{save_directory}_gguf\"]:\n                try:\n                    shutil.rmtree(d)\n                except:\n                    pass\n        raise RuntimeError(f\"Failed to convert model to GGUF: {e}\")\n\n    # Step 3: Upload to HuggingFace Hub\n    print(\"Unsloth: Uploading GGUF to Huggingface Hub...\")\n\n    try:\n        from huggingface_hub import HfApi\n\n        api = HfApi(token = token)\n\n        # Get full repo id\n        if \"/\" not in repo_id:\n            username = api.whoami()[\"name\"]\n            full_repo_id = f\"{username}/{repo_id}\"\n        else:\n            full_repo_id = repo_id\n\n        # Create repo\n        api.create_repo(\n            repo_id = full_repo_id,\n            repo_type = \"model\",\n            private = private,\n            exist_ok = True,\n        )\n\n        # Upload GGUF files\n        for file_location in all_file_locations:\n            original_name = os.path.basename(file_location)\n            # Replace temp directory name with proper model name\n            if cleanup_temp and \"unsloth_gguf_\" in original_name:\n                # Extract the quantization part (e.g., \".Q8_0.gguf\" or \".Q8_0-mmproj.gguf\")\n                quant_suffix = (\n                    original_name.split(\".\", 1)[1]\n                    if \".\" in original_name\n                    else original_name\n                )\n                proper_name = f\"{model_name}.{quant_suffix}\"\n            else:\n                proper_name = original_name.replace(\n                    os.path.basename(save_directory), model_name\n                )\n\n            print(f\"Uploading {proper_name}...\")\n\n            api.upload_file(\n                path_or_fileobj = file_location,\n                path_in_repo = proper_name,\n                repo_id = full_repo_id,\n                repo_type = \"model\",\n                commit_message = commit_message,\n                commit_description = commit_description,\n                create_pr = create_pr,\n                revision = revision,\n            )\n\n        # Upload config.json if exists\n        config_path = os.path.join(actual_save_directory, \"config.json\")\n        if os.path.exists(config_path):\n            print(\"Uploading config.json...\")\n            api.upload_file(\n                path_or_fileobj = config_path,\n                path_in_repo = \"config.json\",\n                repo_id = full_repo_id,\n                repo_type = \"model\",\n                commit_message = f\"{commit_message} - config\",\n                create_pr = create_pr,\n                revision = revision,\n            )\n\n        # Upload Modelfile if exists\n        if modelfile_location and os.path.exists(modelfile_location):\n            print(\"Uploading Ollama Modelfile...\")\n            api.upload_file(\n                path_or_fileobj = modelfile_location,\n                path_in_repo = \"Modelfile\",\n                repo_id = full_repo_id,\n                repo_type = \"model\",\n                commit_message = f\"{commit_message} - Ollama Modelfile\",\n                create_pr = create_pr,\n                revision = revision,\n            )\n\n        # Create and upload README\n        readme_content = f\"\"\"---\ntags:\n- gguf\n- llama.cpp\n- unsloth\n{\"- vision-language-model\" if is_vlm else \"\"}\n---\n\n# {repo_id.split(\"/\")[-1]} : GGUF\n\nThis model was finetuned and converted to GGUF format using [Unsloth](https://github.com/unslothai/unsloth).\n\n**Example usage**:\n- For text only LLMs:    `llama-cli -hf {repo_id} --jinja`\n- For multimodal models: `llama-mtmd-cli -hf {repo_id} --jinja`\n\n## Available Model files:\n\"\"\"\n        for file in all_file_locations:\n            # Fix filename in README too\n            original_name = os.path.basename(file)\n            if cleanup_temp and \"unsloth_gguf_\" in original_name:\n                quant_suffix = (\n                    original_name.split(\".\", 1)[1]\n                    if \".\" in original_name\n                    else original_name\n                )\n                proper_name = f\"{model_name}.{quant_suffix}\"\n            else:\n                proper_name = original_name.replace(\n                    os.path.basename(save_directory), model_name\n                )\n            readme_content += f\"- `{proper_name}`\\n\"\n\n        # Special note for VLM with Modelfile\n        if is_vlm and modelfile_location:\n            readme_content += \"\\n## ⚠️ Ollama Note for Vision Models\\n\"\n            readme_content += \"**Important:** Ollama currently does not support separate mmproj files for vision models.\\n\\n\"\n            readme_content += \"To create an Ollama model from this vision model:\\n\"\n            readme_content += \"1. Place the `Modelfile` in the same directory as the finetuned bf16 merged model\\n\"\n            readme_content += \"3. Run: `ollama create model_name -f ./Modelfile`\\n\"\n            readme_content += \"   (Replace `model_name` with your desired name)\\n\\n\"\n            readme_content += (\n                \"This will create a unified bf16 model that Ollama can use.\\n\"\n            )\n        elif modelfile_location:\n            readme_content += \"\\n## Ollama\\n\"\n            readme_content += \"An Ollama Modelfile is included for easy deployment.\\n\"\n\n        if fix_bos_token:\n            readme_content += \"\\n## Note\\n\"\n            readme_content += (\n                \"The model's BOS token behavior was adjusted for GGUF compatibility.\\n\"\n            )\n\n        readme_content += (\n            \"This was trained 2x faster with [Unsloth](https://github.com/unslothai/unsloth)\\n\"\n            '[<img src=\"https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20made%20with%20love.png\" width=\"200\"/>](https://github.com/unslothai/unsloth)\\n'\n        )\n\n        readme_path = os.path.join(actual_save_directory, \"README.md\")\n        with open(readme_path, \"w\") as f:\n            f.write(readme_content)\n\n        api.upload_file(\n            path_or_fileobj = readme_path,\n            path_in_repo = \"README.md\",\n            repo_id = full_repo_id,\n            repo_type = \"model\",\n            commit_message = \"Add README\",\n            create_pr = create_pr,\n            revision = revision,\n        )\n\n        print(\n            f\"Unsloth: Successfully uploaded GGUF to https://huggingface.co/{full_repo_id}\"\n        )\n\n        # Add tags\n        if tags is None:\n            tags = []\n        tags.extend([\"gguf\", \"llama-cpp\", \"unsloth\"])\n        if is_vlm:\n            tags.append(\"vision-language-model\")\n\n        try:\n            api.add_tags(\n                repo_id = full_repo_id,\n                tags = tags,\n                repo_type = \"model\",\n            )\n        except:\n            pass\n\n        if datasets:\n            try:\n                from huggingface_hub import metadata_update\n\n                metadata_update(\n                    full_repo_id, {\"datasets\": datasets}, overwrite = True, token = token\n                )\n            except Exception as e:\n                logger.warning_once(\n                    f\"Unsloth: Could not update datasets metadata for {full_repo_id}: {e}\"\n                )\n\n    except Exception as e:\n        raise RuntimeError(f\"Failed to upload to Hugging Face Hub: {e}\")\n\n    finally:\n        # Clean up temporary directory\n        if cleanup_temp:\n            print(\"Unsloth: Cleaning up temporary files...\")\n            import shutil\n\n            for d in [save_directory, f\"{save_directory}_gguf\"]:\n                if os.path.exists(d):\n                    try:\n                        shutil.rmtree(d)\n                    except:\n                        pass\n\n    return full_repo_id\n\n\n# Corrected function to save LoRA to a custom directory\ndef save_lora_to_custom_dir(model, tokenizer, save_directory):\n    # Create the custom directory if it doesn't exist\n    os.makedirs(save_directory, exist_ok = True)\n\n    # Call the unsloth_save_model function with the custom directory\n    unsloth_save_model(\n        model,\n        tokenizer,\n        save_directory = save_directory,\n        save_method = \"lora\",\n        push_to_hub = False,\n    )\n\n\n# Corrected method within the model class to convert LoRA to GGML and push to Hugging Face Hub\ndef unsloth_convert_lora_to_ggml_and_push_to_hub(\n    self,\n    tokenizer,\n    repo_id: str,\n    use_temp_dir: Optional[bool] = None,\n    commit_message: Optional[str] = \"Converted LoRA to GGML with Unsloth\",\n    private: Optional[bool] = None,\n    token: Union[bool, str, None] = None,\n    create_pr: bool = False,\n    revision: str = None,\n    commit_description: str = \"Convert LoRA to GGML format using Unsloth\",\n    temporary_location: str = \"_unsloth_temporary_saved_buffers\",\n    maximum_memory_usage: float = 0.85,\n):\n    if not os.path.exists(\"llama.cpp\"):\n        if IS_KAGGLE_ENVIRONMENT:\n            python_install = install_python_non_blocking([\"protobuf\"])\n            python_install.wait()\n            install_llama_cpp_blocking(use_cuda = False)\n            makefile = None\n        else:\n            git_clone = install_llama_cpp_clone_non_blocking()\n            python_install = install_python_non_blocking([\"protobuf\"])\n            git_clone.wait()\n            makefile = install_llama_cpp_make_non_blocking()\n            python_install.wait()\n    else:\n        makefile = None\n\n    for _ in range(3):\n        gc.collect()\n\n    lora_directory_push = \"lora-to-ggml-push\"\n    save_lora_to_custom_dir(self, tokenizer, lora_directory_push)\n\n    model_type = self.config.model_type\n    output_file = os.path.join(lora_directory_push, \"ggml-adapter-model.bin\")\n\n    print(\n        f\"Unsloth: Converting auto-saved LoRA adapters at {lora_directory_push} to GGML format.\"\n    )\n    print(f\"The output file will be {output_file}\")\n\n    command = f\"python3 llama.cpp/convert-lora-to-ggml.py {lora_directory_push} {output_file} llama\"\n\n    try:\n        with subprocess.Popen(\n            command,\n            shell = True,\n            stdout = subprocess.PIPE,\n            stderr = subprocess.PIPE,\n            bufsize = 1,\n            universal_newlines = True,\n        ) as sp:\n            for line in sp.stdout:\n                print(line, end = \"\", flush = True)\n            for line in sp.stderr:\n                print(line, end = \"\", flush = True)\n            sp.wait()\n            if sp.returncode != 0:\n                raise subprocess.CalledProcessError(sp.returncode, command)\n    except subprocess.CalledProcessError as e:\n        print(f\"Error: Conversion failed with return code {e.returncode}\")\n        return\n\n    print(f\"Unsloth: Conversion completed! Output file: {output_file}\")\n\n    print(\"Unsloth: Uploading GGML file to Hugging Face Hub...\")\n    username = upload_to_huggingface(\n        self,\n        repo_id,\n        token,\n        \"GGML converted LoRA\",\n        \"ggml\",\n        output_file,\n        None,\n        private,\n    )\n    link = f\"{repo_id.lstrip('/')}\"\n    print(\"Unsloth: Done.\")\n    print(f\"Converted LoRA to GGML and uploaded to https://huggingface.co/{link}\")\n    print(\n        \"\\nThis GGML making function was made by Maheswar. Ping him @Maheswar on the Unsloth Discord or on HuggingFace (@mahiatlinux) if you like this!\"\n    )\n\n\ndef unsloth_convert_lora_to_ggml_and_save_locally(\n    self,\n    save_directory: str,  # Added parameter for the folder name\n    tokenizer,\n    temporary_location: str = \"_unsloth_temporary_saved_buffers\",\n    maximum_memory_usage: float = 0.85,\n):\n    if not os.path.exists(\"llama.cpp\"):\n        if IS_KAGGLE_ENVIRONMENT:\n            python_install = install_python_non_blocking([\"protobuf\"])\n            python_install.wait()\n            install_llama_cpp_blocking(use_cuda = False)\n            makefile = None\n        else:\n            git_clone = install_llama_cpp_clone_non_blocking()\n            python_install = install_python_non_blocking([\"protobuf\"])\n            git_clone.wait()\n            makefile = install_llama_cpp_make_non_blocking()\n            python_install.wait()\n    else:\n        makefile = None\n\n    for _ in range(3):\n        gc.collect()\n\n    # Use the provided save_directory for local saving\n    save_lora_to_custom_dir(self, tokenizer, save_directory)\n\n    model_type = self.config.model_type\n    output_file = os.path.join(save_directory, \"ggml-adapter-model.bin\")\n\n    print(\n        f\"Unsloth: Converting auto-saved LoRA adapters at {save_directory} to GGML format.\"\n    )\n    print(f\"The output file will be {output_file}\")\n\n    command = f\"python3 llama.cpp/convert-lora-to-ggml.py {save_directory} {output_file} llama\"\n\n    try:\n        with subprocess.Popen(\n            command,\n            shell = True,\n            stdout = subprocess.PIPE,\n            stderr = subprocess.PIPE,\n            bufsize = 1,\n            universal_newlines = True,\n        ) as sp:\n            for line in sp.stdout:\n                print(line, end = \"\", flush = True)\n            for line in sp.stderr:\n                print(line, end = \"\", flush = True)\n            sp.wait()\n            if sp.returncode != 0:\n                raise subprocess.CalledProcessError(sp.returncode, command)\n    except subprocess.CalledProcessError as e:\n        print(f\"Error: Conversion failed with return code {e.returncode}\")\n        return\n    print(\"Unsloth: Done.\")\n    print(f\"Unsloth: Conversion completed! Output file: {output_file}\")\n    print(\n        \"\\nThis GGML making function was made by Maheswar. Ping him @Maheswar on the Unsloth Discord or on HuggingFace (@mahiatlinux) if you like this!\"\n    )\n\n\nfrom .models.loader_utils import get_model_name\nfrom unsloth_zoo.saving_utils import (\n    merge_and_overwrite_lora,\n    prepare_saving,\n)\nfrom unsloth_zoo.llama_cpp import (\n    install_llama_cpp,\n    convert_to_gguf as _convert_to_gguf,\n)\n\n\n@torch.inference_mode\ndef save_to_gguf_generic(\n    model,\n    save_directory,\n    tokenizer,\n    quantization_method = None,\n    quantization_type = \"Q8_0\",\n    repo_id = None,\n    token = None,\n):\n    if token is None and repo_id is not None:\n        token = get_token()\n    if repo_id is not None and token is None:\n        raise RuntimeError(\"Unsloth: Please specify a token for uploading!\")\n\n    if not os.path.exists(os.path.join(\"llama.cpp\", \"unsloth_convert_hf_to_gguf.py\")):\n        install_llama_cpp(just_clone_repo = True)\n\n    # Use old style quantization_method\n    new_quantization_methods = []\n    if quantization_method is not None:\n        # Convert quantization_method to list\n        if isinstance(quantization_method, list):\n            pass\n        elif isinstance(quantization_method, str):\n            quantization_method = [\n                quantization_method,\n            ]\n        elif isinstance(quantization_method, tuple):\n            quantization_method = list(quantization_method)\n        else:\n            raise TypeError(\n                \"Unsloth: quantization_method can only be a string or a list of strings\"\n            )\n        for i, quant_method in enumerate(quantization_method):\n            quant_method = quant_method.lower()\n            if quant_method == \"not_quantized\":\n                quant_method = \"f16\"\n            elif quant_method == \"fast_quantized\":\n                quant_method = \"q8_0\"\n            elif quant_method == \"quantized\":\n                quant_method = \"q4_k_m\"\n            elif quant_method is None:\n                quant_method = \"q8_0\"\n            new_quantization_methods.append(quant_method.lower())\n    else:\n        new_quantization_methods.append(quantization_type.lower())\n    # Check if wrong method\n    for quant_method in new_quantization_methods:\n        if quant_method not in ALLOWED_QUANTS.keys():\n            error = f\"Unsloth: Quant method = [{quant_method}] not supported. Choose from below:\\n\"\n            for key, value in ALLOWED_QUANTS.items():\n                error += f\"[{key}] => {value}\\n\"\n            raise RuntimeError(error)\n\n    # Go through all types and save individually - somewhat inefficient\n    # since we save F16 / BF16 multiple times\n    for quantization_type in new_quantization_methods:\n        metadata = _convert_to_gguf(\n            save_directory,\n            print_output = True,\n            quantization_type = quantization_type,\n        )\n        if repo_id is not None:\n            prepare_saving(\n                model,\n                repo_id,\n                push_to_hub = True,\n                max_shard_size = \"50GB\",\n                private = True,\n                token = token,\n            )\n\n            from huggingface_hub import HfApi\n\n            api = HfApi(token = token)\n            api.upload_folder(\n                folder_path = save_directory,\n                repo_id = repo_id,\n                repo_type = \"model\",\n                allow_patterns = [\"*.gguf\"],\n            )\n    return metadata\n\n\n@torch.inference_mode\ndef unsloth_generic_save(\n    model,\n    tokenizer,\n    save_directory: Union[str, os.PathLike] = \"unsloth_finetuned_merge\",\n    save_method: str = \"lora\",  # [\"lora\", \"merged_16bit\", \"merged_4bit\"]\n    push_to_hub: bool = False,\n    token: Optional[Union[str, bool]] = None,\n    is_main_process: bool = True,\n    state_dict: Optional[dict] = None,\n    save_function: Callable = torch.save,\n    max_shard_size: Union[int, str] = \"5GB\",\n    safe_serialization: bool = True,\n    variant: Optional[str] = None,\n    save_peft_format: bool = True,\n    # Push to hub\n    use_temp_dir: Optional[bool] = None,\n    commit_message: Optional[str] = \"Trained with Unsloth\",\n    private: Optional[bool] = None,\n    create_pr: bool = False,\n    revision: str = None,\n    commit_description: str = \"Upload model trained with Unsloth 2x faster\",\n    tags: List[str] = None,\n    # Our functions\n    temporary_location: str = \"_unsloth_temporary_saved_buffers\",\n    maximum_memory_usage: float = 0.9,\n    datasets: Optional[List[str]] = None,\n):\n    if token is None and push_to_hub:\n        token = get_token()\n\n    if save_method == \"merged_4bit\":\n        raise RuntimeError(\n            \"Unsloth: Merging into 4bit will cause your model to lose accuracy if you plan\\n\"\n            \"to merge to GGUF or others later on. I suggest you to do this as a final step\\n\"\n            \"if you're planning to do multiple saves.\\n\"\n            \"If you are certain, change `save_method` to `merged_4bit_forced`.\"\n        )\n    elif save_method == \"merged_4bit_forced\":\n        save_method = \"merged_4bit\"\n\n    merge_and_overwrite_lora(\n        get_model_name,\n        model = model,\n        tokenizer = tokenizer,\n        save_directory = save_directory,\n        push_to_hub = push_to_hub,\n        private = private,\n        token = token,\n        save_method = save_method,\n        output_dtype = None,\n        low_disk_space_usage = True,\n        use_temp_file = False,\n    )\n\n    if push_to_hub and datasets:\n        try:\n            from huggingface_hub import metadata_update\n\n            save_dir, _ = _determine_username(save_directory, None, token)\n            metadata_update(\n                save_dir, {\"datasets\": datasets}, overwrite = True, token = token\n            )\n        except Exception as e:\n            logger.warning_once(\n                f\"Unsloth: Could not update datasets metadata for {save_directory}: {e}\"\n            )\n\n    return\n\n\ndef unsloth_generic_save_pretrained_merged(\n    self,\n    save_directory: Union[str, os.PathLike],\n    tokenizer = None,\n    save_method: str = \"merged_16bit\",  # [\"lora\", \"merged_16bit\", \"merged_4bit\"]\n    push_to_hub: bool = False,\n    token: Optional[Union[str, bool]] = None,\n    is_main_process: bool = True,\n    state_dict: Optional[dict] = None,\n    save_function: Callable = torch.save,\n    max_shard_size: Union[int, str] = \"5GB\",\n    safe_serialization: bool = True,\n    variant: Optional[str] = None,\n    save_peft_format: bool = True,\n    tags: List[str] = None,\n    temporary_location: str = \"_unsloth_temporary_saved_buffers\",\n    maximum_memory_usage: float = 0.75,\n    datasets: Optional[List[str]] = None,\n):\n    \"\"\"\n    Same as .push_to_hub(...) except 4bit weights are auto\n    converted to float16 with as few overhead as possible.\n\n    Choose for `save_method` to be either:\n    1. `16bit`: Merge LoRA into float16 weights. Useful for GGUF / llama.cpp.\n    2.  `4bit`: Merge LoRA into int4 weights. Useful for DPO / HF inference.\n    3.  `lora`: Save LoRA adapters with no merging. Useful for HF inference.\n    \"\"\"\n    if tokenizer is None:\n        logger.warning_once(\n            \"Unsloth: You're not saving a tokenizer as well?\\n\"\n            \"You can do it separately via `tokenizer.save_pretrained(...)`\"\n        )\n\n    arguments = dict(locals())\n    arguments[\"model\"] = self\n    del arguments[\"self\"]\n    unsloth_generic_save(**arguments)\n    for _ in range(3):\n        gc.collect()\n\n\ndef unsloth_generic_push_to_hub_merged(\n    self,\n    repo_id: str,\n    tokenizer = None,\n    save_method: str = \"merged_16bit\",  # [\"lora\", \"merged_16bit\", \"merged_4bit\"]\n    use_temp_dir: Optional[bool] = None,\n    commit_message: Optional[str] = \"Trained with Unsloth\",\n    private: Optional[bool] = None,\n    token: Union[bool, str, None] = None,\n    max_shard_size: Union[int, str, None] = \"5GB\",\n    create_pr: bool = False,\n    safe_serialization: bool = True,\n    revision: str = None,\n    commit_description: str = \"Upload model trained with Unsloth 2x faster\",\n    tags: Optional[List[str]] = None,\n    temporary_location: str = \"_unsloth_temporary_saved_buffers\",\n    maximum_memory_usage: float = 0.75,\n    datasets: Optional[List[str]] = None,\n):\n    \"\"\"\n    Same as .push_to_hub(...) except 4bit weights are auto\n    converted to float16 with as few overhead as possible.\n\n    Choose for `save_method` to be either:\n    1. `16bit`: Merge LoRA into float16 weights. Useful for GGUF / llama.cpp.\n    2.  `4bit`: Merge LoRA into int4 weights. Useful for DPO / HF inference.\n    3.  `lora`: Save LoRA adapters with no merging. Useful for HF inference.\n    \"\"\"\n    if tokenizer is None:\n        logger.warning_once(\n            \"Unsloth: You're not saving a tokenizer as well?\\n\"\n            \"You can do it separately via `tokenizer.push_to_hub(...)`\"\n        )\n\n    arguments = dict(locals())\n    arguments[\"model\"] = self\n    arguments[\"save_directory\"] = repo_id\n    arguments[\"push_to_hub\"] = True\n    del arguments[\"self\"]\n    del arguments[\"repo_id\"]\n    unsloth_generic_save(**arguments)\n    for _ in range(3):\n        gc.collect()\n\n\ndef _unsloth_save_torchao_with_attached_config(\n    model,\n    save_directory: Union[str, os.PathLike],\n    tokenizer,\n    push_to_hub: bool = False,\n    token: Optional[Union[str, bool]] = None,\n):\n    \"\"\"Save a QAT-trained model by converting fake-quantized weights to real quantized weights.\"\"\"\n    # Convert QAT fake-quantized weights to real quantized weights\n    _convert_torchao_model(model)\n    # PEFT models also might come here, so parse it\n    if isinstance(model, PeftModelForCausalLM):\n        _unsloth_save_torchao_with_given_config(\n            model = model,\n            save_directory = save_directory,\n            tokenizer = tokenizer,\n            torchao_config = model.config.quantization_config,\n            push_to_hub = push_to_hub,\n            token = token,\n        )\n        return\n\n    # TorchAO does not support safe_serialization reliably\n    safe_serialization = False\n\n    if push_to_hub:\n        model.push_to_hub(\n            save_directory, safe_serialization = safe_serialization, token = token\n        )\n        tokenizer.push_to_hub(save_directory, token = token)\n    else:\n        model.save_pretrained(save_directory, safe_serialization = safe_serialization)\n        tokenizer.save_pretrained(save_directory)\n\n\ndef _unsloth_save_torchao_with_given_config(\n    model,\n    save_directory: Union[str, os.PathLike],\n    tokenizer,\n    torchao_config,\n    push_to_hub: bool = False,\n    token: Optional[Union[str, bool]] = None,\n):\n    \"\"\"Quantizes the model with torchao and saves a torchao quantized checkpoint\n\n    Args\n      `save_directory`: local folder path or huggingface hub ID when `push_to_hub` is set to True, e.g. `my_model`\n      `torchao_config` (TorchAOBaseConfig): configuration for torchao quantization, full list: https://docs.pytorch.org/ao/main/api_ref_quantization.html#inference-apis-for-quantize\n      `push_to_hub` (bool): whether to push the checkpoint to huggingface hub or save locally\n    \"\"\"\n\n    if push_to_hub:\n        assert token is not None, \"Unsloth: Please specify a token for uploading!\"\n\n    assert (\n        torchao_config is not None\n    ), \"Unsloth: Please specify a torchao_config for post-training quantization!\"\n\n    # first merge the lora weights\n    arguments = dict(locals())\n    arguments[\"push_to_hub\"] = False  # We save ourselves\n    arguments[\"save_method\"] = \"merged_16bit\"  # Must be 16bit\n    del arguments[\"torchao_config\"]\n\n    if not isinstance(model, PeftModelForCausalLM) and not isinstance(model, PeftModel):\n        model.save_pretrained(save_directory)\n        tokenizer.save_pretrained(save_directory)\n    else:\n        unsloth_generic_save(**arguments)\n\n    for _ in range(3):\n        gc.collect()\n\n    from transformers import (\n        AutoModelForCausalLM,\n        AutoTokenizer,\n        TorchAoConfig,\n        AutoModelForImageTextToText,\n        AutoProcessor,\n    )\n    from torchao import quantize_\n\n    if isinstance(torchao_config, TorchAoConfig):\n        quantization_config = torchao_config\n    else:\n        quantization_config = TorchAoConfig(quant_type = torchao_config)\n\n    # Determine if this is a VLM\n    is_vlm = False\n    if hasattr(model, \"config\") and hasattr(model.config, \"architectures\"):\n        is_vlm = any(\n            x.endswith((\"ForConditionalGeneration\", \"ForVisionText2Text\"))\n            for x in model.config.architectures\n        )\n        is_vlm = is_vlm or hasattr(model.config, \"vision_config\")\n    auto_model = AutoModelForImageTextToText if is_vlm else AutoModelForCausalLM\n    auto_processor = AutoProcessor if is_vlm else AutoTokenizer\n\n    tokenizer = auto_processor.from_pretrained(save_directory)\n\n    # TorchAO must only use bfloat16 for loading (float16 fails)\n    if HAS_TORCH_DTYPE:\n        kwargs = {\"torch_dtype\": torch.bfloat16}\n    else:\n        kwargs = {\"dtype\": torch.bfloat16}\n\n    # Reload with quantization applied\n    quantized_model = auto_model.from_pretrained(\n        save_directory,\n        device_map = \"auto\",\n        quantization_config = quantization_config,\n        **kwargs,\n    )\n\n    torchao_save_directory = save_directory + \"-torchao\"\n\n    # TorchAO does not support safe_serialization right now 0.14.0 seems broken!\n    safe_serialization = Version(importlib_version(\"torchao\")) > Version(\"0.14.0\")\n    safe_serialization = False\n\n    if push_to_hub:\n        quantized_model.push_to_hub(\n            torchao_save_directory, safe_serialization = safe_serialization, token = token\n        )\n        tokenizer.push_to_hub(torchao_save_directory, token = token)\n    else:\n        quantized_model.save_pretrained(\n            torchao_save_directory, safe_serialization = safe_serialization\n        )\n        tokenizer.save_pretrained(torchao_save_directory)\n\n    # Clean up the intermediate unquantized model\n    if os.path.exists(save_directory):\n        try:\n            shutil.rmtree(save_directory)\n        except:\n            pass\n\n\ndef unsloth_save_pretrained_torchao(\n    self,\n    save_directory: Union[str, os.PathLike],\n    tokenizer = None,\n    torchao_config = None,\n    push_to_hub: bool = False,\n    token: Optional[Union[str, bool]] = None,\n):\n    \"\"\"Saves a torchao quantized model checkpoint.\n\n    This function handles two mutually exclusive workflows:\n\n    1. **QAT (Quantization-Aware Training)**: If the model was trained with `qat_scheme`\n       parameter, do NOT pass `torchao_config`. The function will convert the QAT\n       fake-quantized weights to real quantized weights and save directly.\n\n    2. **PTQ (Post-Training Quantization)**: If you want to apply quantization to a\n       regular model, pass a `torchao_config`. The model must NOT have been trained\n       with `qat_scheme`.\n\n    Args:\n      `save_directory`: local folder path or huggingface hub ID when `push_to_hub` is True\n      `tokenizer`: the tokenizer to save alongside the model\n      `torchao_config` (TorchAOBaseConfig): configuration for torchao quantization.\n          Required for PTQ, must be None for QAT models.\n          Options: https://docs.pytorch.org/ao/main/api_ref_quantization.html#inference-apis-for-quantize\n      `push_to_hub` (bool): whether to push to huggingface hub or save locally\n      `token`: HuggingFace token for pushing to hub\n    \"\"\"\n    if token is None and push_to_hub:\n        token = get_token()\n\n    has_qat_config = (\n        hasattr(self, \"_torchao_config\") and self._torchao_config is not None\n    )\n\n    if torchao_config is not None:\n        # PTQ path: user provided a config, model must NOT have QAT config unless PEFT\n        assert not has_qat_config, (\n            \"Unsloth: You passed `torchao_config` but this model was trained with `qat_scheme`. \"\n            \"For QAT models, do not pass `torchao_config` - the quantization config is already \"\n            \"attached to the model from training.\"\n        )\n        _unsloth_save_torchao_with_given_config(\n            model = self,\n            save_directory = save_directory,\n            tokenizer = tokenizer,\n            torchao_config = torchao_config,\n            push_to_hub = push_to_hub,\n            token = token,\n        )\n    else:\n        # QAT path: no config provided, model must have QAT config\n        assert has_qat_config, (\n            \"Unsloth: No `torchao_config` provided and model was not trained with `qat_scheme`. \"\n            \"Either train with `qat_scheme` parameter, or provide a `torchao_config` for \"\n            \"post-training quantization.\"\n        )\n        _unsloth_save_torchao_with_attached_config(\n            model = self,\n            save_directory = save_directory,\n            tokenizer = tokenizer,\n            push_to_hub = push_to_hub,\n            token = token,\n        )\n\n    for _ in range(3):\n        gc.collect()\n\n\ndef not_implemented_save(*args, **kwargs):\n    raise NotImplementedError(\n        \"Unsloth: Sorry GGUF is currently not supported for vision models!\"\n    )\n\n\ndef patch_saving_functions(model, vision = False):\n    import inspect\n    import types\n    from typing import Callable, Optional, Union, List\n\n    # And now re add our saving methods!\n    if model.push_to_hub.__name__ == \"unsloth_push_to_hub\":\n        original_push_to_hub = model.original_push_to_hub\n    else:\n        original_push_to_hub = model.push_to_hub\n\n    signature = str(inspect.signature(original_push_to_hub)).replace(\"NoneType\", \"None\")\n    signature = signature[1:]\n    signature = re.sub(\"<function save at .+?>\", \"torch.save\", signature)\n    docs = original_push_to_hub.__doc__.encode(\"utf-8\").decode(\"utf-8\")\n\n    push_to_hub_text = f'''def unsloth_push_to_hub(self, {signature}:\n    \"\"\"\n    {docs}\n    \"\"\"\n    arguments = dict(locals())\n    del arguments[\"self\"]\n    if \"tags\" in arguments and arguments[\"tags\"] is not None:\n        assert(isinstance(arguments[\"tags\"], (list, tuple)))\n        arguments[\"tags\"] = list(arguments[\"tags\"]) + [\"unsloth\",]\n    elif \"tags\" in arguments:\n        arguments[\"tags\"] = [\"unsloth\",]\n    elif hasattr(self, \"add_model_tags\"):\n        self.add_model_tags([\"unsloth\",])\n\n    if \"commit_message\" in arguments:\n        commit_message = arguments[\"commit_message\"]\n        if commit_message is not None:\n            if not commit_message.endswith(\" \"): commit_message += \" \"\n            if \"Unsloth\" not in commit_message:\n                commit_message += \"(Trained with Unsloth)\"\n        else:\n            commit_message = \"Upload model trained with Unsloth\"\n        arguments[\"commit_message\"] = commit_message\n\n    if \"commit_description\" in arguments:\n        commit_description = arguments[\"commit_description\"]\n        if commit_description is not None:\n            if not commit_description.endswith(\" \"): commit_description += \" \"\n            if \"Unsloth\" not in commit_description:\n                commit_description += \"(Trained with Unsloth 2x faster)\"\n        else:\n            commit_description = \"Upload model trained with Unsloth 2x faster\"\n        arguments[\"commit_description\"] = commit_description\n\n    # Update model tag\n    if hasattr(self, \"config\"):\n        _ = upload_to_huggingface(\n            self, arguments[\"repo_id\"], arguments[\"token\"],\n            \"finetuned\", \"trl\", file_location = None,\n            old_username = None, private = arguments[\"private\"],\n        )\n    pass\n\n    try:\n        self.original_push_to_hub(**arguments)\n    except:\n        del arguments[\"tags\"]\n        self.original_push_to_hub(**arguments)\n    pass\n\n    if hasattr(self, \"config\"):\n        print(\"Saved model to https://huggingface.co/\" + arguments[\"repo_id\"])\n    pass\n    '''\n    exec(push_to_hub_text, globals())\n\n    original_model = model\n    while True:\n        # Check if push_to_hub exists before accessing its __name__\n        if (\n            hasattr(original_model, \"push_to_hub\")\n            and original_model.push_to_hub.__name__ != \"unsloth_push_to_hub\"\n        ):\n            original_model.original_push_to_hub = original_model.push_to_hub\n            original_model.push_to_hub = types.MethodType(\n                unsloth_push_to_hub, original_model\n            )\n            if hasattr(original_model, \"add_model_tags\"):\n                original_model.add_model_tags(\n                    [\n                        \"unsloth\",\n                    ]\n                )\n\n        if hasattr(original_model, \"model\"):\n            original_model = original_model.model\n        else:\n            break\n\n    # Add saving methods to top level model\n    if not vision:\n        if hasattr(model, \"config\"):\n            # Counteract tokenizers\n            model.push_to_hub_merged = types.MethodType(\n                unsloth_generic_push_to_hub_merged, model\n            )\n            model.save_pretrained_merged = types.MethodType(\n                unsloth_generic_save_pretrained_merged, model\n            )\n            model.push_to_hub_gguf = types.MethodType(unsloth_push_to_hub_gguf, model)\n            model.save_pretrained_gguf = types.MethodType(\n                unsloth_save_pretrained_gguf, model\n            )\n            model.save_pretrained_torchao = types.MethodType(\n                unsloth_save_pretrained_torchao, model\n            )\n            model.push_to_hub_ggml = types.MethodType(\n                unsloth_convert_lora_to_ggml_and_push_to_hub, model\n            )\n            model.save_pretrained_ggml = types.MethodType(\n                unsloth_convert_lora_to_ggml_and_save_locally, model\n            )\n    else:\n        # Vision only 1 option\n        model.push_to_hub_merged = types.MethodType(\n            unsloth_generic_push_to_hub_merged, model\n        )\n        model.save_pretrained_merged = types.MethodType(\n            unsloth_generic_save_pretrained_merged, model\n        )\n        model.push_to_hub_gguf = types.MethodType(unsloth_push_to_hub_gguf, model)\n        model.save_pretrained_gguf = types.MethodType(\n            unsloth_save_pretrained_gguf, model\n        )\n        model.save_pretrained_torchao = types.MethodType(\n            unsloth_save_pretrained_torchao, model\n        )\n    return model\n"
  },
  {
    "path": "unsloth/tokenizer_utils.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom transformers import AutoTokenizer\nfrom transformers.convert_slow_tokenizer import convert_slow_tokenizer\nfrom transformers import PreTrainedTokenizerFast\nimport re\nimport os\nfrom transformers.models.llama.modeling_llama import logger\nfrom peft import PeftModelForCausalLM\nimport torch\nimport itertools\nimport collections\nimport numpy as np\nimport gc\nimport subprocess\nimport psutil\n\nfrom unsloth_zoo.tokenizer_utils import (\n    mean_of_trained_tokens,\n    add_new_tokens,\n    fix_untrained_tokens,\n)\nfrom unsloth_zoo.training_utils import (\n    fix_zero_training_loss,\n)\n\n__all__ = [\n    \"load_correct_tokenizer\",\n    \"fix_sentencepiece_tokenizer\",\n    \"check_tokenizer\",\n    \"add_new_tokens\",\n    \"fix_sentencepiece_gguf\",\n]\n\n\nIGNORED_TOKENIZER_CHECKING = frozenset(\n    (\n        \"CodeLlamaTokenizerFast\",\n        \"CodeLlamaTokenizer\",\n    )\n)\n\n\nIGNORED_TOKENIZER_NAMES = [\n    # Qwen Coder did not train on tool calling. Math did!\n    \"unsloth/Qwen2.5-Coder-1.5B-Instruct\",\n    \"unsloth/Qwen2.5-Coder-7B-Instruct\",\n]\nIGNORED_TOKENIZER_NAMES = frozenset(\n    [x.lower() for x in IGNORED_TOKENIZER_NAMES]\n    + [x.lower() + \"-bnb-4bit\" for x in IGNORED_TOKENIZER_NAMES]\n)\nos.environ[\"UNSLOTH_IGNORED_TOKENIZER_NAMES\"] = \"\\n\".join(IGNORED_TOKENIZER_NAMES)\n\n# Check environments\nkeynames = \"\\n\" + \"\\n\".join(os.environ.keys())\nIS_COLAB_ENVIRONMENT = \"\\nCOLAB_\" in keynames\nIS_KAGGLE_ENVIRONMENT = \"\\nKAGGLE_\" in keynames\nKAGGLE_TMP = \"/tmp\"\ndel keynames\n\n\ndef try_fix_tokenizer(tokenizer, prepend = True):\n    if hasattr(tokenizer, \"_tokenizer\"):\n        converted_tokenizer = tokenizer._tokenizer\n    else:\n        converted_tokenizer = convert_slow_tokenizer(tokenizer)\n\n    tokenizer_string = converted_tokenizer.to_str()\n\n    # Llama does _apple. Sometimes this is wrong!!\n    prepend_text = '{\"type\":\"Prepend\",\"prepend\":\"▁\"},'\n    if not prepend and prepend_text in tokenizer_string:\n        tokenizer_string = tokenizer_string.replace(prepend_text, \"\", 1)\n\n    dir_names = dir(tokenizer)\n    # Get eos_token, bos_token etc\n    token_names = [x for x in dir_names if x.endswith(\"_token\") and x.count(\"_\") == 1]\n\n    for token_name in token_names:\n        token = getattr(tokenizer, token_name, None)\n        if token is None:\n            continue\n        token_id = getattr(tokenizer, token_name + \"_id\", None)\n\n        # Locate the token's id mapping in the string\n        find_text = f'\"id\":{token_id},\"content\":\"'\n        start = tokenizer_string.find(find_text) + len(find_text)\n        if start == -1:\n            continue\n        end = tokenizer_string.find('\",', start)\n\n        bad_token = tokenizer_string[start:end]\n        # Check if token is the actual same one - if not, edit it\n        if bad_token != token:\n            bad_text = f'{find_text}{bad_token}\",'\n            good_text = f'{find_text}{token}\",'\n            tokenizer_string = tokenizer_string.replace(bad_text, good_text, 1)\n\n            # And replace vocab section\n            bad_text = f'\"{bad_token}\":{token_id},'\n            good_text = f'\"{token}\":{token_id},'\n            tokenizer_string = tokenizer_string.replace(bad_text, good_text, 1)\n\n    fixed_tokenizer = converted_tokenizer.from_str(tokenizer_string)\n    return fixed_tokenizer\n\n\ndef get_sorted_dict(dictionary):\n    sorted_keys = sorted(dictionary.values())\n    inverted_dictionary = {value: key for key, value in dictionary.items()}\n\n    sorted_dictionary = {}\n    for key in sorted_keys:\n        value = inverted_dictionary[key]\n        sorted_dictionary[value] = key\n    return sorted_dictionary\n\n\ndef convert_to_fast_tokenizer(\n    slow_tokenizer,\n    temporary_location = \"_unsloth_sentencepiece_temp\",\n):\n    is_fast = getattr(slow_tokenizer, \"is_fast\", False)\n    if is_fast:\n        return slow_tokenizer\n\n    try:\n        tokenizer_name = slow_tokenizer.__class__.__name__\n        lowered_tokenizer_name = tokenizer_name.lower()\n        if lowered_tokenizer_name.endswith(\"tokenizer\"):\n            class_name = lowered_tokenizer_name[: -len(\"tokenizer\")]\n            FastTokenizer = eval(\n                f'__import__(f\"transformers.models.{class_name}\").{tokenizer_name}Fast'\n            )\n        else:\n            FastTokenizer = PreTrainedTokenizerFast\n    except:\n        FastTokenizer = PreTrainedTokenizerFast\n\n    # Get all arguments (bos_token, etc)\n    docs = FastTokenizer.__doc__\n    docs = docs[docs.find(\"Args:\") :]\n    args = re.findall(r\"\\n[\\s]+([^\\s]{1,}) \\(\", docs, flags = re.MULTILINE)\n    args = [x for x in args if not x.endswith(\"_file\")]\n\n    # Also some missing maybe!\n    docs = PreTrainedTokenizerFast.__doc__\n    docs = docs[docs.find(\"Args:\") :]\n    args2 = re.findall(r\"\\n[\\s]+([^\\s]{1,}) \\(\", docs, flags = re.MULTILINE)\n    args2 = [x for x in args2 if not x.endswith(\"_file\")]\n    args = list(set(args + args2))\n\n    kwargs = {}\n    for arg in args:\n        kwargs[arg] = getattr(slow_tokenizer, arg, None)\n    kwargs[\"tokenizer_object\"] = try_fix_tokenizer(slow_tokenizer, prepend = True)\n    fast_tokenizer = FastTokenizer(**kwargs)\n\n    # Check if they're similar!\n    sorted_slow_tokenizer = get_sorted_dict(slow_tokenizer.get_vocab())\n    sorted_fast_tokenizer = get_sorted_dict(fast_tokenizer.get_vocab())\n\n    check_vocab = sorted_slow_tokenizer == sorted_fast_tokenizer\n    check_special = (\n        slow_tokenizer.all_special_tokens == fast_tokenizer.all_special_tokens\n    )\n\n    # Failure so return slow_tokenizer\n    if not check_vocab or not check_special:\n        return slow_tokenizer\n\n    # Now confirm if they match\n    if not assert_same_tokenization(slow_tokenizer, fast_tokenizer):\n        # Maybe remove prepending of __apple?\n        kwargs[\"tokenizer_object\"] = try_fix_tokenizer(slow_tokenizer, prepend = False)\n        fast_tokenizer = FastTokenizer(**kwargs)\n        if not assert_same_tokenization(slow_tokenizer, fast_tokenizer):\n            # Failure :(\n            return slow_tokenizer\n\n    # Also tokenizer.model is missing!\n    name = slow_tokenizer.name_or_path.replace(\"/\", \"_\")\n    if not os.path.exists(temporary_location):\n        os.makedirs(temporary_location)\n    new_location = f\"{temporary_location}/{name}\"\n    slow_tokenizer.save_pretrained(new_location)\n    fast_tokenizer.save_pretrained(new_location)\n\n    # Now load it!\n    fast_tokenizer = AutoTokenizer.from_pretrained(new_location)\n    if assert_same_tokenization(slow_tokenizer, fast_tokenizer):\n        return fast_tokenizer\n    return slow_tokenizer\n\n\n# Check Mistral chat template without BOS / EOS\nmistral_template = (\n    \"{% if messages[0]['role'] == 'system' %}\"\n    \"{% if messages[1]['role'] == 'user' %}\"\n    \"{{ '[INST] ' + messages[0]['content'] + ' ' + messages[1]['content'] + ' [/INST]' }}\"\n    \"{% set loop_messages = messages[2:] %}\"\n    \"{% else %}\"\n    \"{{ '[INST] ' + messages[0]['content'] + ' [/INST]' }}\"\n    \"{% set loop_messages = messages[1:] %}\"\n    \"{% endif %}\"\n    \"{% else %}\"\n    \"{% set loop_messages = messages %}\"\n    \"{% endif %}\"\n    \"{% for message in loop_messages %}\"\n    \"{% if message['role'] == 'user' %}\"\n    \"{{ '[INST] ' + message['content'] + ' [/INST]' }}\"\n    \"{% elif message['role'] == 'assistant' %}\"\n    \"{{ message['content'] }}\"\n    \"{% else %}\"\n    \"{{ raise_exception('Only user and assistant roles are supported!') }}\"\n    \"{% endif %}\"\n    \"{% endfor %}\"\n)\n\n# Check Llama chat template without BOS / EOS\nllama_template = (\n    \"{% if messages[0]['role'] == 'system' %}\"\n    \"{% if messages[1]['role'] == 'user' %}\"\n    \"{{ '[INST] <<SYS>>\\n' + messages[0]['content'] + '\\n<</SYS>>\\n\\n' + messages[1]['content'] + ' [/INST]' }}\"\n    \"{% set loop_messages = messages[2:] %}\"\n    \"{% else %}\"\n    \"{{ '[INST] ' + messages[0]['content'] + ' [/INST]' }}\"\n    \"{% set loop_messages = messages[1:] %}\"\n    \"{% endif %}\"\n    \"{% else %}\"\n    \"{% set loop_messages = messages %}\"\n    \"{% endif %}\"\n    \"{% for message in loop_messages %}\"\n    \"{% if message['role'] == 'user' %}\"\n    \"{{ '[INST] ' + message['content'].strip() + ' [/INST]' }}\"\n    \"{% elif message['role'] == 'assistant' %}\"\n    \"{{ ' ' + message['content'].strip() + ' ' }}\"\n    \"{% else %}\"\n    \"{{ raise_exception('Only user and assistant roles are supported!') }}\"\n    \"{% endif %}\"\n    \"{% endfor %}\"\n)\n\n\ndef assert_same_tokenization(slow_tokenizer, fast_tokenizer):\n    # Get eos_token, bos_token etc\n    if not hasattr(slow_tokenizer, \"all_special_tokens\"):\n        return True\n    dir_names = dir(slow_tokenizer)\n    special_tokens = list(\n        filter(\n            None,\n            (\n                getattr(slow_tokenizer, x)\n                for x in dir_names\n                if x.endswith(\"_token\") and x.count(\"_\") == 1\n            ),\n        )\n    )\n    all_special_tokens = list(set(special_tokens + slow_tokenizer.all_special_tokens))\n\n    # Remove replacement char for false positive\n    replacement_char = b\"\\xc3\\xaf\\xc2\\xbf\\xc2\\xbd\".decode(\"utf-8\")\n    all_special_tokens = [x for x in all_special_tokens if x != replacement_char]\n\n    # Check if chat template is enabled!\n    check_chat_template1 = True\n    check_chat_template2 = True\n    check_chat_template3 = True\n\n    \"\"\"\n    Weirdly Mistral tokenizers are actually correct??\n    Ie below will actually load mistral v1 and v3 incorrectly!\n\n    slow_chat_template = getattr(slow_tokenizer, \"chat_template\", None)\n    fast_chat_template = getattr(fast_tokenizer, \"chat_template\", None)\n    messages = [\n        {\"role\": \"user\", \"content\": \" What is 2+2? \"},\n        {\"role\": \"assistant\", \"content\": \" It's 4. \"},\n    ]\n    # Check the tokenizer's own chat template\n    if slow_chat_template is not None and fast_chat_template is not None:\n        check_chat_template1 = \\\n            slow_tokenizer.apply_chat_template(messages) == \\\n            fast_tokenizer.apply_chat_template(messages)\n    pass\n\n    # Check Mistral chat template without BOS / EOS\n    slow_tokenizer.chat_template = mistral_template\n    fast_tokenizer.chat_template = mistral_template\n    check_chat_template2 = \\\n        slow_tokenizer.apply_chat_template(messages) == \\\n        fast_tokenizer.apply_chat_template(messages)\n    pass\n\n    # Check Llama chat template without BOS / EOS\n    slow_tokenizer.chat_template = llama_template\n    fast_tokenizer.chat_template = llama_template\n    check_chat_template3 = \\\n        slow_tokenizer.apply_chat_template(messages) == \\\n        fast_tokenizer.apply_chat_template(messages)\n    pass\n\n    # Combine them all and revert chat templates\n    slow_tokenizer.chat_template = slow_chat_template\n    fast_tokenizer.chat_template = fast_chat_template\n    \"\"\"\n    check_chat_template = (\n        check_chat_template1 and check_chat_template2 and check_chat_template3\n    )\n\n    # Try special tokens\n    try:\n        string = (\n            \"\\n\".join(all_special_tokens)\n            + \"A quick brown fox jumps over the lazy dog!!\\n\\nHi</s>\\n\\n\"\n            + \"\".join(all_special_tokens)\n        )\n        check_special_tokens = (\n            slow_tokenizer(string).input_ids == fast_tokenizer(string).input_ids\n        )\n\n        return check_chat_template and check_special_tokens\n    except:\n        # For eg see https://github.com/unslothai/unsloth/issues/292\n        # Sometimes tokenizer has weird tokens, causing a combined tokenization to fail.\n        # [TODO] We temporarily disable this for CodeLlama tokenizers\n        if slow_tokenizer.__repr__().split(\"(\", 1)[0] in IGNORED_TOKENIZER_CHECKING:\n            return check_chat_template\n        else:\n            return False\n\n\ndef fix_sentencepiece_tokenizer(\n    old_tokenizer,\n    new_tokenizer,\n    token_mapping,\n    temporary_location = \"_unsloth_sentencepiece_temp\",\n):\n    # From https://github.com/google/sentencepiece/issues/121\n    # We need to manually edit the sentencepiece tokenizer!\n    try:\n        from transformers.convert_slow_tokenizer import import_protobuf\n\n        sentencepiece_model_pb2 = import_protobuf()\n    except Exception as e:\n        try:\n            import google.protobuf\n            from unsloth_zoo.utils import Version\n\n            protobuf_version = Version(google.protobuf.__version__)\n            if protobuf_version > Version(\"3.20.3\"):\n                raise RuntimeError(\n                    f\"Unsloth: Your protobuf version = {protobuf_version} is too new.\\n\"\n                    f\"Please downgrade via `pip install --force-reinstall protobuf==3.20.3`\"\n                )\n        except:\n            # This will only work for older SentencePiece versions <= 3.20.3\n            from transformers.utils import sentencepiece_model_pb2\n\n    if not os.path.exists(temporary_location):\n        os.makedirs(temporary_location)\n\n    # Check if tokenizer.model exists\n    if not os.path.isfile(f\"{temporary_location}/tokenizer.model\"):\n        return new_tokenizer\n\n    # First save the old tokenizer\n    old_tokenizer.save_pretrained(temporary_location)\n\n    tokenizer_file = sentencepiece_model_pb2.ModelProto()\n    tokenizer_file.ParseFromString(\n        open(f\"{temporary_location}/tokenizer.model\", \"rb\").read()\n    )\n\n    # Now save the new tokenizer\n    new_tokenizer.save_pretrained(temporary_location)\n\n    # Now correct the old tokenizer's .model file\n    for old_token, new_token in token_mapping.items():\n        ids = old_tokenizer([old_token], add_special_tokens = False).input_ids\n        ids = ids[0]\n        if len(ids) != 1:\n            # Skip this token!\n            print(\n                f\"Skip mapping {old_token} to {new_token} since {new_token} is already in the tokenizer!\"\n            )\n            continue\n        ids = ids[0]\n        # [TODO] Hack for Starling - try except\n        try:\n            tokenizer_piece = tokenizer_file.pieces[ids]\n        except:\n            continue\n        assert tokenizer_piece.piece == old_token\n        tokenizer_piece.piece = new_token\n\n    # And now write it\n    with open(f\"{temporary_location}/tokenizer.model\", \"wb\") as file:\n        file.write(tokenizer_file.SerializeToString())\n\n    # And load it!\n    from transformers import AutoTokenizer\n\n    tokenizer = AutoTokenizer.from_pretrained(\n        temporary_location,\n        eos_token = new_tokenizer.eos_token,\n        pad_token = new_tokenizer.pad_token,\n    )\n    return tokenizer\n\n\ndef fix_sentencepiece_gguf(saved_location):\n    \"\"\"\n    Fixes sentencepiece tokenizers which did not extend the vocabulary with\n    user defined tokens.\n    Inspiration from https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py\n    \"\"\"\n    from copy import deepcopy\n    from transformers.utils import sentencepiece_model_pb2\n    import json\n    from enum import IntEnum\n\n    class SentencePieceTokenTypes(IntEnum):\n        NORMAL = 1\n        UNKNOWN = 2\n        CONTROL = 3\n        USER_DEFINED = 4\n        UNUSED = 5\n        BYTE = 6\n\n    # Load tokenizer.model\n    tokenizer_file = sentencepiece_model_pb2.ModelProto()\n    if not os.path.isfile(f\"{saved_location}/tokenizer.model\"):\n        return\n    tokenizer_file.ParseFromString(\n        open(f\"{saved_location}/tokenizer.model\", \"rb\").read()\n    )\n    sentence_piece_size = len(tokenizer_file.pieces)\n\n    # Load added_tokens_json\n    if not os.path.isfile(f\"{saved_location}/added_tokens.json\"):\n        return\n    with open(f\"{saved_location}/added_tokens.json\", \"r\", encoding = \"utf-8\") as file:\n        added_tokens_json = json.load(file)\n    if len(added_tokens_json) == 0:\n        return\n\n    added_tokens_json = dict(\n        sorted(added_tokens_json.items(), key = lambda item: item[1])\n    )\n    new_size = sentence_piece_size + len(added_tokens_json)\n\n    # Confirm added_tokens_json is correct\n    added_tokens_ids = np.array(list(added_tokens_json.values()))\n    diff = np.diff(added_tokens_ids)\n    if diff.min() != 1 or diff.max() != 1:\n        return\n    if added_tokens_ids.min() != sentence_piece_size:\n        return\n\n    # Edit sentence piece tokens with added_tokens_json\n    logger.warning(\n        f\"Unsloth: Extending {saved_location}/tokenizer.model with added_tokens.json.\\n\"\n        f\"Originally tokenizer.model is of size ({sentence_piece_size}).\\n\"\n        f\"But we need to extend to sentencepiece vocab size ({new_size}).\"\n    )\n    new_tokens = deepcopy(tokenizer_file.pieces[-len(added_tokens_ids) :])\n    for new_token, added_token in zip(new_tokens, added_tokens_json.keys()):\n        new_token.piece = added_token.encode(\"utf-8\")\n        new_token.score = -1000.0\n        new_token.type = SentencePieceTokenTypes.USER_DEFINED\n\n    tokenizer_file.pieces.extend(new_tokens)\n\n    with open(f\"{saved_location}/tokenizer.model\", \"wb\") as file:\n        file.write(tokenizer_file.SerializeToString())\n\n    # Add padding tokens\n    # actual_vocab_size = model.config.vocab_size\n    # padding = actual_vocab_size - len(tokenizer_file.pieces)\n    return\n\n\ndef _load_correct_tokenizer(\n    tokenizer_name,\n    model_max_length = None,\n    padding_side = \"right\",\n    token = None,\n    trust_remote_code = False,\n    cache_dir = \"huggingface_tokenizers_cache\",\n    fix_tokenizer = True,\n):\n    if IS_COLAB_ENVIRONMENT:\n        cache_dir = cache_dir\n    elif IS_KAGGLE_ENVIRONMENT:\n        # /tmp of Kaggle seems has a 80GB limit!\n        # Let's utilize them\n        cache_dir = os.path.join(KAGGLE_TMP, cache_dir)\n    else:\n        cache_dir = None\n\n    # Try loading the slow tokenizer. If it fails, then try Fast only\n    # Mainly to solve Deepseek models with no tokenizer.model file\n    slow_tokenizer = None\n    try:\n        slow_tokenizer = AutoTokenizer.from_pretrained(\n            tokenizer_name,\n            model_max_length = model_max_length,\n            padding_side = padding_side,\n            token = token,\n            trust_remote_code = trust_remote_code,\n            # Cannot just use use_fast = False as per https://twitter.com/danielhanchen/status/1789659394302718373\n            use_fast = False,\n            legacy = False,\n            from_slow = True,\n            cache_dir = cache_dir,\n        )\n    except:\n        slow_tokenizer = None\n        # print(\n        #     f\"Unsloth: {tokenizer_name} has no tokenizer.model file.\\n\"\\\n        #     \"Just informing you about this - this is not a critical error.\"\n        # )\n    # Unsure why this occurs!\n    if type(slow_tokenizer) is bool:\n        slow_tokenizer = None\n\n    fast_tokenizer = AutoTokenizer.from_pretrained(\n        tokenizer_name,\n        model_max_length = model_max_length,\n        padding_side = padding_side,\n        token = token,\n        trust_remote_code = trust_remote_code,\n        cache_dir = cache_dir,\n    )\n\n    if not fix_tokenizer or tokenizer_name in IGNORED_TOKENIZER_NAMES:\n        return fast_tokenizer\n    # Ignore Mistral ones - they're a bit weird to handle!\n    elif \"mistral\" in tokenizer_name.lower():\n        return fast_tokenizer\n    # Ignore Phi-4 ones as well\n    elif \"phi-4\" in tokenizer_name.lower():\n        return fast_tokenizer\n    elif slow_tokenizer is not None:\n        if hasattr(fast_tokenizer, \"add_bos_token\") and hasattr(\n            slow_tokenizer, \"add_bos_token\"\n        ):\n            fast_tokenizer.add_bos_token = slow_tokenizer.add_bos_token\n        if hasattr(fast_tokenizer, \"add_eos_token\") and hasattr(\n            slow_tokenizer, \"add_eos_token\"\n        ):\n            fast_tokenizer.add_eos_token = slow_tokenizer.add_eos_token\n\n        # Confirm if slow and fast are equivalent!\n        if assert_same_tokenization(slow_tokenizer, fast_tokenizer):\n            return fast_tokenizer\n        else:\n            logger.warning(\n                f\"Unsloth: Will load {tokenizer_name} as a legacy tokenizer.\"\n            )\n            return convert_to_fast_tokenizer(slow_tokenizer)\n        pass\n    else:\n        return fast_tokenizer\n\n\ndef load_correct_tokenizer(\n    tokenizer_name,\n    model_max_length = None,\n    padding_side = \"right\",\n    token = None,\n    trust_remote_code = False,\n    cache_dir = \"huggingface_tokenizers_cache\",\n    fix_tokenizer = True,\n):\n    tokenizer = _load_correct_tokenizer(\n        tokenizer_name = tokenizer_name,\n        model_max_length = model_max_length,\n        padding_side = padding_side,\n        token = token,\n        trust_remote_code = trust_remote_code,\n        cache_dir = cache_dir,\n        fix_tokenizer = fix_tokenizer,\n    )\n\n    ### 1. Fixup tokenizer's chat_template\n    old_chat_template = getattr(tokenizer, \"chat_template\", None)\n\n    # Ignore mistral type models since they don't have an add_generation_prompt\n    if any(\n        s in str(getattr(tokenizer, \"name_or_path\", \"\")).lower()\n        for s in [\"mistral\", \"qwen3guard\"]\n    ):\n        chat_template = old_chat_template\n\n    # Also check Llama-2 old style models\n    elif (\n        old_chat_template is not None\n        and \"[/INST]\" in old_chat_template\n        and \"[INST]\" in old_chat_template\n        and \"bos_token\" in old_chat_template\n        and \"eos_token\" in old_chat_template\n    ):\n        chat_template = old_chat_template\n\n    else:\n        chat_template = fix_chat_template(tokenizer)\n        if old_chat_template is not None and chat_template is None:\n            raise RuntimeError(\n                \"Unsloth: Fixing chat template failed - please file a report immediately!\"\n            )\n        pass\n\n    tokenizer.chat_template = chat_template\n    return tokenizer\n\n\ndef _find_end_position(template, endfor, endif):\n    where_endfor = template.find(endfor)\n    where_endif = template.find(endif)\n    if where_endfor == where_endif == -1:\n        return None\n    elif where_endfor > where_endif:\n        return endfor\n    else:\n        return endif\n\n\ndef _fix_chat_template(chat_template):\n    endfor = \"{% endfor %}\"\n    endif = \"{% endif %}\"\n    chosen_end = _find_end_position(chat_template, endfor, endif)\n    if chosen_end is None:\n        endfor = \"{%- endfor %}\"\n        endif = \"{%- endif %}\"\n        chosen_end = _find_end_position(chat_template, endfor, endif)\n    if chosen_end is None:\n        return chat_template\n\n    where = chat_template.find(chosen_end)\n\n    after_endfor = chat_template[where + len(chosen_end) :]\n\n    dash = \"-\" if chosen_end.startswith(\"{%-\") else \"\"\n\n    if (\n        \"{%\" + dash + \" if\" not in after_endfor\n        and \"{%\" + dash + \" set \" not in after_endfor\n        and after_endfor.startswith(\"{{\")\n        and after_endfor.endswith(\"}}\")\n        and after_endfor.count(\"{{\") == 1\n        and after_endfor.count(\"}}\") == 1\n    ):\n        after_endfor = (\n            \"{%\" + dash + \" if add_generation_prompt %}\" + after_endfor + endif\n        )\n\n        chat_template = chat_template[: where + len(chosen_end)] + after_endfor\n    return chat_template\n\n\ndef fix_chat_template(tokenizer):\n    chat_template = getattr(tokenizer, \"chat_template\", None)\n    if chat_template is None:\n        return None\n\n    ### 1. Check if add_generation_prompt works\n    # Check for ShareGPT style first\n    is_sharegpt = None\n    try:\n        messages = [\n            {\"role\": \"user\", \"content\": \"Who are you?\"},\n        ]\n        tokenizer.apply_chat_template(\n            messages, add_generation_prompt = False, tokenize = False\n        )\n        is_sharegpt = False\n    except:\n        try:\n            messages = [\n                {\"from\": \"human\", \"value\": \"Who are you?\"},\n            ]\n            tokenizer.apply_chat_template(\n                messages, add_generation_prompt = False, tokenize = False\n            )\n            is_sharegpt = True\n        except:\n            is_sharegpt = None\n\n    # Not ShareGPT or HF style - just return\n    if is_sharegpt is None:\n        return chat_template\n\n    # Tokenize\n    messages = [\n        {\"role\": \"user\", \"content\": \"Who are you?\"}\n        if not is_sharegpt\n        else {\"from\": \"human\", \"value\": \"Who are you?\"}\n    ]\n    no = tokenizer.apply_chat_template(\n        messages, add_generation_prompt = False, tokenize = False\n    )\n    yes = tokenizer.apply_chat_template(\n        messages, add_generation_prompt = True, tokenize = False\n    )\n\n    if no == yes:\n        # SAME?! That's not good! We check for add_generation_prompt\n        if (\n            \"{% if add_generation_prompt %}\" not in chat_template\n            and \"{%- if add_generation_prompt %}\" not in chat_template\n        ):\n            # Try fixing it by adding it\n            new_chat_template = _fix_chat_template(chat_template)\n            if (\n                \"{% if add_generation_prompt %}\" not in new_chat_template\n                and \"{%- if add_generation_prompt %}\" not in new_chat_template\n            ):\n                raise RuntimeError(\n                    f\"Unsloth: The tokenizer `{tokenizer.name_or_path}`\\n\"\n                    \"does not have a {% if add_generation_prompt %} for generation purposes.\\n\"\n                    f\"Please file a bug report to the maintainers of `{tokenizer.name_or_path}` - thanks!\"\n                )\n            else:\n                logger.warning_once(\n                    \"Unsloth: We successfully patched the tokenizer to add a {% if add_generation_prompt %} to the chat_template.\\n\"\n                    f\"This is not a bug, but please notify the maintainers of `{tokenizer.name_or_path}` - thanks!\"\n                )\n                chat_template = new_chat_template\n        else:\n            raise RuntimeError(\n                f\"Unsloth: The tokenizer `{tokenizer.name_or_path}`\\n\"\n                \"has a {% if add_generation_prompt %} for generation purposes, but wasn't provided correctly.\\n\"\n                \"Please file a bug report immediately - thanks!\"\n            )\n    return chat_template\n\n\ndef check_tokenizer(\n    model,\n    tokenizer,\n    model_name = \"unsloth/llama-2-7b-bnb-4bit\",\n    model_max_length = 4096,\n    padding_side = \"right\",\n    token = None,\n    _reload = True,\n):\n    # Checks tokenizer for out of bounds ids.\n    # Mainly a fix for https://huggingface.co/berkeley-nest/Starling-LM-7B-alpha\n    # where <sep> had token id=32002.\n    # See https://huggingface.co/berkeley-nest/Starling-LM-7B-alpha/discussions/25\n    # Seems like the Fast tokenizer in Rust breaks things!\n\n    # We ignore some of them!\n    if tokenizer.__repr__().split(\"(\", 1)[0] in IGNORED_TOKENIZER_CHECKING:\n        return tokenizer\n\n    max_embedding_size = model.model.embed_tokens.weight.shape[0]\n    added_tokens_fast = tokenizer.added_tokens_decoder\n    added_tokens_fast = {\n        index: str(value) for index, value in added_tokens_fast.items()\n    }\n    sorted_keys = sorted(added_tokens_fast)\n    added_tokens_fast = {key: added_tokens_fast[key] for key in sorted_keys}\n\n    for j, index in enumerate(added_tokens_fast.keys()):\n        if index >= max_embedding_size:\n            bad_indices = list(added_tokens_fast.keys())[j:]\n            bad_tokens = list(added_tokens_fast.values())[j:]\n            if not _reload:\n                # Try removing the token\n                added_tokens = [str(x) for x in tokenizer.added_tokens_decoder.values()]\n                special_tokens = tokenizer.special_tokens_map\n                import itertools\n\n                special_tokens = frozenset(\n                    itertools.chain.from_iterable(\n                        [x] if type(x) is str else x for x in special_tokens.values()\n                    )\n                )\n                can_be_removed1 = [x for x in bad_tokens if x not in special_tokens]\n                can_be_removed2 = [\n                    x\n                    for x in can_be_removed1\n                    if x in tokenizer._added_tokens_encoder.keys()\n                ]\n\n                # Check of extra tokens can in fact we removed!\n                can_be_removed = (len(can_be_removed1) == len(bad_tokens)) and (\n                    len(can_be_removed2) == len(bad_tokens)\n                )\n\n                # Check if sep_token or other generic types\n                remove_generic = False\n                try_mapper = []\n                if not can_be_removed:\n                    names = dir(tokenizer)\n                    names = (\n                        x for x in names if x.endswith(\"_token\") and x.count(\"_\") == 1\n                    )\n                    generic_tokens = [(x, getattr(tokenizer, x, None)) for x in names]\n\n                    try_removal = []\n                    for token in bad_tokens:\n                        for name_token, check_token in generic_tokens:\n                            if check_token == token:\n                                try_removal.append(token)\n                                try_mapper.append(name_token)\n\n                    # Recheck!\n                    can_be_removed = len(try_removal) == len(bad_tokens)\n                    if can_be_removed:\n                        remove_generic = True\n                    can_be_removed1 = bad_tokens\n\n                if can_be_removed:\n                    # Yes it can be fixed!\n                    for j, bad_token in enumerate(can_be_removed1):\n                        remove_id = tokenizer._added_tokens_encoder[bad_token]\n                        del tokenizer._added_tokens_decoder[remove_id]\n                        del tokenizer._added_tokens_encoder[bad_token]\n\n                        if remove_generic and (try_removal[j] == bad_token):\n                            # Remove sep token for example\n                            setattr(tokenizer, try_mapper[j], None)\n                            setattr(tokenizer, try_mapper[j] + \"_id\", None)\n                    # Confirm 1 more time!\n                    if max(tokenizer.added_tokens_decoder.keys()) < max_embedding_size:\n                        logger.warning_once(\n                            f\"Unsloth loaded a broken tokenizer `{model_name}`, but managed to repair it!\\n\"\n                            f\"Tokens {bad_tokens} with ids {bad_indices} exceeds the max vocab size of {max_embedding_size}.\\n\"\n                            \"We removed these bad tokens. If you think this is incorrect, fix your tokenizer first.\"\n                        )\n                        return convert_to_fast_tokenizer(tokenizer)\n\n                # :( Failure\n                raise RuntimeError(\n                    f\"Unsloth tried to load `{model_name}`, but cannot succeed.\\n\"\n                    f\"Tokens {bad_tokens} with ids {bad_indices} exceeds the max vocab size of {max_embedding_size}.\\n\"\n                    f\"Fix your tokenizer since it'll perform out of bounds memory accesses.\"\n                )\n\n            if IS_COLAB_ENVIRONMENT or IS_KAGGLE_ENVIRONMENT:\n                cache_dir = \"huggingface_tokenizers_cache\"\n            else:\n                cache_dir = None\n\n            # Sometimes slow tokenizer does not work like Deepseek\n            try:\n                # Try slow tokenizer which can fix things!\n                tokenizer = AutoTokenizer.from_pretrained(\n                    model_name,\n                    model_max_length = model_max_length,\n                    padding_side = padding_side,\n                    token = token,\n                    # Cannot just use use_fast = False as per https://twitter.com/danielhanchen/status/1789659394302718373\n                    use_fast = False,\n                    legacy = False,\n                    from_slow = True,\n                    cache_dir = cache_dir,\n                )\n                return check_tokenizer(\n                    model = model,\n                    tokenizer = tokenizer,\n                    model_name = model_name,\n                    model_max_length = model_max_length,\n                    padding_side = padding_side,\n                    token = token,\n                    _reload = False,\n                )\n                break\n            except:\n                # Tokenizer has out of bounds issues and we can't\n                # load the slow tokenizer version :(\n                logger.warning_once(\n                    \"Unsloth: Tokenizer is most likely buggy, and Unsloth failed to repair it.\\n\"\n                    \"It will still work, but beware of out of bounds memory accesses.\\n\"\n                    \"Please file an issue on the model owner's repo about this issue.\"\n                )\n                return tokenizer\n    return convert_to_fast_tokenizer(tokenizer)\n\n\nimport inspect\nfrom inspect import getsource\nimport trl\nimport trl.trainer.sft_trainer\nfrom trl.trainer.sft_trainer import *\nfrom transformers.trainer import *\n\ntry:\n    from trl.trainer.sft_trainer import neftune_post_forward_hook\nexcept:\n\n    def neftune_post_forward_hook(module, input, output):\n        \"\"\"\n        Implements the NEFTune forward pass for the model using forward hooks. Note this works only for\n        torch.nn.Embedding layers. This method is slightly adapted from the original source code\n        that can be found here: https://github.com/neelsjain/NEFTune\n\n        Simply add it to your model as follows:\n        ```python\n        model = ...\n        model.embed_tokens.neftune_noise_alpha = 0.1\n        model.embed_tokens.register_forward_hook(neftune_post_forward_hook)\n        ```\n\n        Args:\n            module (`torch.nn.Module`):\n                The embedding module where the hook is attached. Note that you need to set\n                `module.neftune_noise_alpha` to the desired noise alpha value.\n            input (`torch.Tensor`):\n                The input tensor to the model.\n            output (`torch.Tensor`):\n                The output tensor of the model (i.e. the embeddings).\n        \"\"\"\n        if module.training:\n            dims = torch.tensor(output.size(1) * output.size(2))\n            mag_norm = module.neftune_noise_alpha / torch.sqrt(dims)\n            output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm)\n        return output\n\n\ndef patch_sft_trainer_tokenizer():\n    \"\"\"\n    Patches the trainer with changes\n    \"\"\"\n    try:\n        sft_trainer = eval(f\"trl.trainer.sft_trainer.SFTTrainer\")\n    except:\n        return\n    all_imports = dir(trl.trainer.sft_trainer)\n\n    for (\n        function_name,\n        replacer,\n    ) in (\n        # (\"_prepare_non_packed_dataloader\", \"def tokenize(element):\",),\n        (\n            \"_prepare_non_packed_dataloader\",\n            None,\n        ),\n        (\n            \"_prepare_dataset\",\n            None,\n        ),\n        # (\"_prepare_packed_dataloader\", \"if dataset_text_field is not None\",),\n    ):\n        if not hasattr(sft_trainer, function_name):\n            continue\n\n        function = getsource(eval(f\"sft_trainer.{function_name}\"))\n        where = function.find(\"def\")\n        function = function.split(\"\\n\")\n        function = \"\\n\".join(x[where:] for x in function)\n\n        check_text = (\n            \"\\n\"\n            \"if 'tokenizer'          not in locals(): tokenizer = processing_class\\n\"\n            \"if 'formatting_func'    not in locals(): raise RuntimeError('Unsloth: Please file a bug report - `formatting_func` does not exist!')\\n\"\n            \"if 'dataset_text_field' not in locals() and 'args' in locals(): dataset_text_field = args.dataset_text_field\\n\"\n            \"if 'dataset_text_field' not in locals(): raise RuntimeError('Unsloth: Please file a bug report - `dataset_text_field` does not exist!')\\n\"\n            \"test_text = dataset[0][dataset_text_field] if (formatting_func is None and dataset_text_field is not None) else formatting_func(dataset[0])[0]\\n\"\n            \"chat_template = getattr(tokenizer, 'chat_template', None)\\n\"\n            \"chat_template = '' if chat_template is None else chat_template\\n\"\n            \"has_bos_token_already = (test_text.startswith(tokenizer.bos_token) or tokenizer.bos_token in chat_template) \"\n            \"if getattr(tokenizer, 'bos_token', None) is not None else False\\n\"\n            \"if 'add_special_tokens' not in locals() and has_bos_token_already:\\n\"\n            \"    from functools import partial\\n\"\n            \"    tokenizer = partial(tokenizer, add_special_tokens = False)\\n\"\n            \"    processing_class = tokenizer\\n\"\n            \"else:\\n\"\n            \"    add_special_tokens = False if has_bos_token_already else add_special_tokens\\n\\n\"\n        )\n\n        check_text = check_text.split(\"\\n\")\n        check_text = \"\\n\".join(\" \" * where + x for x in check_text)\n        check_text = check_text.rstrip() + \"\\n\"\n\n        if replacer is None:\n            # .*? matches first match. .+? matches final match.\n            replacer = re.findall(\n                f\"def {function_name}\" + r\"\\(.*?\\).*?\\:\\n\",\n                function,\n                flags = re.MULTILINE | re.DOTALL,\n            )\n            if len(replacer) == 0:\n                continue\n            replacer = replacer[0]\n            function = function.replace(replacer, replacer + check_text)\n        else:\n            function = function.replace(replacer, check_text + replacer)\n\n        x = [x for x in all_imports if x in function]\n        try:\n            exec(f\"from trl.trainer.sft_trainer import ({','.join(x)})\", locals())\n        except ImportError:\n            for _item in x:\n                try:\n                    exec(f\"from trl.trainer.sft_trainer import {_item}\", locals())\n                except ImportError:\n                    pass\n        exec(function, locals(), globals())\n        exec(\n            f\"trl.trainer.sft_trainer.SFTTrainer.{function_name} = {function_name}\",\n            globals(),\n        )\n\n    # Patch train with fix_untrained_tokens\n    for path_to_trainer in (\n        \"sft_trainer.SFTTrainer\",\n        \"dpo_trainer.DPOTrainer\",\n        \"kto_trainer.KTOTrainer\",\n    ):\n        function_name, replacer = \"train\", \"if resume_from_checkpoint is False:\"\n        try:\n            function = getsource(eval(f\"trl.trainer.{path_to_trainer}.{function_name}\"))\n        except Exception:\n            continue\n        where = function.find(\"def\")\n        function = function.split(\"\\n\")\n        function = \"\\n\".join(x[where:] for x in function)\n\n        check_text = (\n            \"\\n\"\n            \"import subprocess, re, gc, numpy as np\\n\"\n            \"a = np.array([0,])\\n\"\n            \"try:\\n\"\n            \"    a = subprocess.check_output('nvidia-smi --query-gpu=memory.used --format=csv', shell = True)\\n\"\n            \"    a = re.findall(rb'([\\\\d]{1,})[\\\\s]{1,}M', a)\\n\"\n            \"    a = np.array([int(x.decode('utf-8'))/1024 for x in a])\\n\"\n            \"except:\\n\"\n            \"    if not torch.cuda.is_available():\\n\"\n            \"        raise RuntimeError('Unsloth: We do not support AMD / Intel machines yet - it is a work in progress!')\\n\"\n            \"if ((a - PRE_CHECK) >= 1).sum() > 1:\\n\"\n            \"    raise RuntimeError('Unsloth currently does not support multi GPU setups - but we are working on it!')\\n\"\n            \"for _ in range(3):\\n\"\n            \"    gc.collect()\\n\"\n            \"    torch.cuda.empty_cache()\\n\"\n            \"pass\\n\"\n            \"\\n\"\n            \"tokenizer = self.processing_class if hasattr(self, 'processing_class') else self.tokenizer\\n\"\n            \"fix_untrained_tokens(self.model, tokenizer, self.train_dataset, IGNORED_TOKENIZER_NAMES, eps = 1e-16)\\n\\n\"\n            \"fix_zero_training_loss(self.model, tokenizer, self.train_dataset)\\n\\n\"\n        )\n\n        # Warn on gradient accumulation steps if it's used\n        check_text += (\n            \"\\n\"\n            \"try:\\n\"\n            \"    gradient_accumulation_steps = self.args.gradient_accumulation_steps\\n\"\n            \"    if type(gradient_accumulation_steps) is int and gradient_accumulation_steps > 1:\\n\"\n            \"        from transformers import __version__ as transformers_version\\n\"\n            \"        from packaging.version import Version\\n\"\n            \"        if Version(transformers_version) <= Version('4.45.2'):\\n\"\n            \"            print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\\\\n'\\\\\\n\"\n            \"                  '`pip install --upgrade --no-cache-dir --no-deps unsloth transformers git+https://github.com/huggingface/trl.git`')\\n\"\n            \"except:\\n\"\n            \"    pass\\n\"\n            \"\\n\\n\"\n        )\n\n        # Add NEFTune since it doesn't seem to work?? We need to manually inject it\n        check_text += (\n            \"\\n\"\n            \"if hasattr(self, 'neftune_hook_handle'):\\n\"\n            \"    self.neftune_hook_handle.remove()\\n\"\n            \"    if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle\\n\"\n            \"\\n\"\n            \"if getattr(self, 'neftune_noise_alpha', None) is not None:\\n\"\n            \"    self.model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha\\n\"\n            \"    self.neftune_hook_handle = self.model.get_input_embeddings().register_forward_hook(neftune_post_forward_hook)\\n\"\n            \"pass\\n\"\n            \"\\n\"\n        )\n\n        # Also DPO weirdly tokenizes non numeric columns? Delete them!\n        check_text += (\n            \"\\n\"\n            \"if hasattr(self.train_dataset, 'column_names'):\\n\"\n            \"    column_names = set(self.train_dataset.column_names)\\n\"\n            \"    check = ['chosen', 'rejected', 'prompt', 'chosen_input_ids', 'chosen_attention_mask',\\n\"\n            \"        'chosen_labels', 'rejected_input_ids', 'rejected_attention_mask', 'rejected_labels',\\n\"\n            \"        'prompt_input_ids', 'prompt_attention_mask']\\n\"\n            \"    if all(x in column_names for x in check):\\n\"\n            \"        self.train_dataset = self.train_dataset.remove_columns(['chosen', 'rejected', 'prompt'])\\n\"\n            \"    del check, column_names\\n\"\n            \"\\n\"\n        )\n\n        check_text = check_text.split(\"\\n\")\n        check_text = \"\\n\".join(\" \" * where + x for x in check_text)\n\n        function = function.replace(replacer, check_text + replacer)\n        exec(function, globals())\n\n        exec(\n            f\"trl.trainer.{path_to_trainer}.{function_name} = {function_name}\",\n            globals(),\n        )\n\n\n# Finally patch TRL tokenizer things -> moved to RL\n# patch_sft_trainer_tokenizer()\n"
  },
  {
    "path": "unsloth/trainer.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\nimport psutil\nimport warnings\nfrom dataclasses import dataclass, field\nfrom typing import Optional\nfrom functools import wraps\n\nimport trl\nimport inspect\nfrom trl import SFTTrainer\nfrom . import is_bfloat16_supported\nfrom unsloth.utils import (\n    configure_padding_free,\n    configure_sample_packing,\n    enable_padding_free_metadata,\n    enable_sample_packing,\n)\nfrom unsloth_zoo.training_utils import (\n    unsloth_train as _unsloth_train,\n)\nfrom unsloth_zoo.vision_utils import (\n    UnslothVisionDataCollator,\n)\nfrom unsloth_zoo.hf_utils import get_transformers_model_type\nfrom unsloth_zoo.utils import Version\nimport dataclasses\n\n__all__ = [\n    \"UnslothTrainingArguments\",\n    \"UnslothTrainer\",\n    \"unsloth_train\",\n    \"_patch_trl_trainer\",\n    \"UnslothVisionDataCollator\",\n]\n\nlogger = logging.getLogger(__name__)\n\n_AUTO_PADDING_FREE_ENV_DISABLED = os.environ.get(\n    \"UNSLOTH_DISABLE_AUTO_PADDING_FREE\", \"\"\n).strip().lower() in {\"1\", \"true\", \"yes\", \"on\"}\n\nPADDING_FREE_BLOCKLIST = {\n    \"gemma2\",  # - gemma2:  Uses slow_attention_softcapping which has torch.compile issues\n    \"gpt_oss\",  # - gpt_oss: Uses Flex Attention which doesn't handle padding_free correctly\n}\n\n\ndef _should_pack(config) -> bool:\n    if config is None or not getattr(config, \"packing\", False):\n        return False\n    return not getattr(config, \"_unsloth_disable_auto_packing\", False)\n\n\ndef _should_auto_padding_free(config) -> bool:\n    if (\n        config is None\n        or _AUTO_PADDING_FREE_ENV_DISABLED\n        or getattr(config, \"packing\", False)\n    ):\n        return False\n    return getattr(config, \"padding_free\", None) is None\n\n\ndef _disable_sample_packing(config):\n    if config is None:\n        return\n    for attr, value in ((\"packing\", False), (\"padding_free\", False)):\n        if hasattr(config, attr):\n            setattr(config, attr, value)\n    if hasattr(config, \"remove_unused_columns\"):\n        setattr(config, \"remove_unused_columns\", True)\n    setattr(config, \"_unsloth_disable_auto_packing\", True)\n\n\n_AUTO_PACK_SKIP_MESSAGES = (\n    \"packing is not supported\",\n    \"padding-free training\",\n    \"passing a custom data collator\",\n)\n\n\ndef _should_skip_auto_packing_error(exc: Exception) -> bool:\n    message = str(exc).lower()\n    return any(msg in message for msg in _AUTO_PACK_SKIP_MESSAGES)\n\n\n# Unsloth gradient accumulation fix:\nfrom transformers import __version__ as transformers_version, ProcessorMixin\n\nif Version(transformers_version) > Version(\"4.45.2\"):\n\n    def unsloth_train(trainer, *args, **kwargs):\n        return trainer.train(*args, **kwargs)\n\nelse:\n\n    def unsloth_train(trainer, *args, **kwargs):\n        if len(args) != 0 or len(kwargs) != 0:\n            raise RuntimeError(\n                \"Unsloth: Our custom gradient accumulation fixed trainer does not support other arguments.\\n\"\n                \"If you want to use our fix inside of HF, please update `transformers` to the latest version via:\\n\"\n                \"`pip uninstall transformers -y && pip install --upgrade --no-cache-dir transformers`\"\n            )\n        print(\n            \"Unsloth: Using our custom gradient accumulation fixed trainer, which is not feature complete.\\n\"\n            \"If you want to use our fix inside of HF, please update `transformers` to the latest version via:\\n\"\n            \"`pip uninstall transformers -y && pip install --upgrade --no-cache-dir transformers`\"\n        )\n        return _unsloth_train(trainer)\n\n\ntry:\n    from trl import SFTConfig as TrainingArguments\nexcept:\n    from transformers import TrainingArguments\n\n\nclass UnslothTrainingArguments(TrainingArguments):\n    def __init__(self, embedding_learning_rate: float = None, *args, **kwargs):\n        embedding_learning_rate = embedding_learning_rate\n        super().__init__(*args, **kwargs)\n\n\ndef _create_unsloth_optimizer(\n    model,\n    optimizer_cls,\n    optimizer_kwargs,\n    embedding_lr = 5e-5,\n):\n    lr = optimizer_kwargs[\"lr\"]\n    weight_decay = optimizer_kwargs.get(\"weight_decay\", 0.0)\n\n    param_groups = {\n        \"non_embeddings\": {},\n        \"embeddings\": {},\n    }\n\n    for name, param in model.named_parameters():\n        if not param.requires_grad:\n            continue\n        if name.endswith(\"modules_to_save.default.weight\"):\n            partial_name = name[: -len(\".modules_to_save.default.weight\")]\n            partial_name = partial_name[partial_name.rfind(\".\") + 1 :]\n            print(\n                f\"Unsloth: Setting lr = {embedding_lr:.2e} instead of {lr:.2e} for {partial_name}.\"\n            )\n            param_groups[\"embeddings\"][name] = param\n        else:\n            param_groups[\"non_embeddings\"][name] = param\n\n    optimizer_grouped_parameters = [\n        {\n            \"params\": list(param_groups[\"non_embeddings\"].values()),\n            \"weight_decay\": weight_decay,\n            \"lr\": lr,\n        },\n        {\n            \"params\": list(param_groups[\"embeddings\"].values()),\n            \"weight_decay\": weight_decay,\n            \"lr\": embedding_lr,\n        },\n    ]\n    optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)\n    return optimizer\n\n\nclass UnslothTrainer(SFTTrainer):\n    def create_optimizer(self):\n        embedding_learning_rate = getattr(self.args, \"embedding_learning_rate\", None)\n        if embedding_learning_rate is None:\n            return super().create_optimizer()\n\n        if self.optimizer is None:\n            optimizer_cls, optimizer_kwargs = SFTTrainer.get_optimizer_cls_and_kwargs(\n                self.args\n            )\n            self.optimizer = _create_unsloth_optimizer(\n                self.model,\n                optimizer_cls,\n                optimizer_kwargs,\n                embedding_learning_rate,\n            )\n        return self.optimizer\n\n\n# From `trl>=0.13.0`, they changed how to pass several params to the trainer\n# We need to patch to make the transition smooth\ndef _resolve_trainer_params(trainer_class, init_fn):\n    \"\"\"Resolve the real named parameters for a trainer __init__.\n\n    Some TRL trainers (e.g., ORPOTrainer in TRL 0.27.1) are thin wrappers\n    with only ``def __init__(self, *args, **kwargs)``.  For those, walk the\n    MRO and return the first parent class that has real named parameters.\n    \"\"\"\n    params = inspect.signature(init_fn).parameters\n    named = {\n        k\n        for k, v in params.items()\n        if v.kind\n        in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY)\n        and k != \"self\"\n    }\n    if named:\n        return set(params.keys())\n\n    # Thin wrapper detected - walk MRO for real signature\n    for cls in trainer_class.__mro__[1:]:\n        if cls is object:\n            continue\n        parent_init = cls.__dict__.get(\"__init__\")\n        if parent_init is None:\n            continue\n        try:\n            parent_params = inspect.signature(parent_init).parameters\n            parent_named = {\n                k\n                for k, v in parent_params.items()\n                if v.kind\n                in (\n                    inspect.Parameter.POSITIONAL_OR_KEYWORD,\n                    inspect.Parameter.KEYWORD_ONLY,\n                )\n                and k != \"self\"\n            }\n            if parent_named:\n                return set(parent_params.keys())\n        except (ValueError, TypeError):\n            continue\n    return set(params.keys())\n\n\ndef _backwards_compatible_trainer(trainer_class, config_class):\n    original_init = trainer_class.__init__\n\n    @wraps(original_init)\n    def new_init(self, *args, **kwargs):\n        # All Trainer tokenizer are now called processing_class\n        trainer_params = _resolve_trainer_params(trainer_class, original_init)\n\n        if \"processing_class\" in trainer_params and \"tokenizer\" in kwargs:\n            kwargs[\"processing_class\"] = kwargs.pop(\"tokenizer\")\n\n        if (\"args\" in kwargs) and (Version(trl) >= Version(\"0.13.0.dev0\")):\n            training_args = kwargs.pop(\"args\", None)\n\n            # Get parameters that Trainer.__init__ actually expects\n            trainer_params.remove(\"self\")\n            trainer_params.remove(\"args\")\n\n            # Get fields that should be passed to Config init\n            config_fields = {\n                field.name: field\n                for field in dataclasses.fields(config_class)\n                if field.init\n            }\n\n            # Create config dict with valid fields from training_args\n            config_dict = {\n                name: getattr(training_args, name)\n                for name in config_fields\n                if hasattr(training_args, name)\n            }\n\n            # Get parameters that exist in Config but not in TrainingArguments\n            from transformers import TrainingArguments\n\n            moved_params = set(inspect.signature(config_class).parameters.keys()) - set(\n                inspect.signature(TrainingArguments).parameters.keys()\n            )\n\n            # Separate kwargs into trainer kwargs and config kwargs\n            trainer_kwargs = {}\n            additional_config_kwargs = {}\n\n            for key, value in kwargs.items():\n                if key in trainer_params:\n                    trainer_kwargs[key] = value\n                elif key in moved_params or key in config_fields:\n                    additional_config_kwargs[key] = value\n                else:\n                    additional_config_kwargs[key] = value\n\n            # Update config_dict with additional kwargs\n            config_dict.update(additional_config_kwargs)\n\n            # Create Config with all the collected parameters\n            # Reinitialising config class with parameters (that were none initially but populated on first init)\n            # causes the 2nd init to fail as there are mutual exclusive checks on pairs of parameters.\n            # Refer: https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_config.py#L499-L502 for example\n            # So we only create config class if the previous init was not TrainingArguments\n            if not isinstance(training_args, TrainingArguments):\n                config = config_class(**config_dict)\n            else:\n                config = training_args\n\n            # Reconstruct kwargs for Trainer\n            kwargs = trainer_kwargs\n            kwargs[\"args\"] = config\n        original_init(self, *args, **kwargs)\n\n    return new_init\n\n\ndef _patch_sft_trainer_auto_packing(trl_module):\n    sft_trainer = getattr(trl_module, \"SFTTrainer\", None)\n    if sft_trainer is None:\n        return\n    if getattr(sft_trainer, \"_unsloth_auto_packing_wrapped\", False):\n        return\n\n    original_init = sft_trainer.__init__\n\n    @wraps(original_init)\n    def new_init(self, *args, **kwargs):\n        config_arg = None\n        if len(args) >= 2:\n            config_arg = args[1]\n        else:\n            config_arg = kwargs.get(\"args\")\n\n        # Check if model type is unsupported for padding_free\n        model = kwargs.get(\"model\")\n        is_unsupported_model = False\n        is_vlm = False\n        if model is not None:\n            model_config = getattr(model, \"config\", None)\n            if model_config is not None:\n                model_types = get_transformers_model_type(model_config)\n                # Blocklist: models that don't work correctly with padding_free\n                is_unsupported_model = any(\n                    x in PADDING_FREE_BLOCKLIST for x in model_types\n                )\n\n                # Check if VLM\n                architectures = getattr(model_config, \"architectures\", None)\n                if architectures is None:\n                    architectures = []\n                is_vlm = any(\n                    x.endswith(\"ForConditionalGeneration\") for x in architectures\n                )\n                is_vlm = is_vlm or hasattr(model_config, \"vision_config\")\n\n        processing_class = kwargs.get(\"processing_class\") or kwargs.get(\"tokenizer\")\n        data_collator = kwargs.get(\"data_collator\")\n\n        # We also disable vision language models for padding free collators\n        blocked = (\n            (data_collator is not None)\n            or isinstance(processing_class, ProcessorMixin)\n            or is_vlm\n            or is_unsupported_model\n            or (\n                os.environ.get(\"UNSLOTH_RETURN_LOGITS\", \"0\") == \"1\"\n            )  # Disable padding free on forced logits\n        )\n        requested_pack = bool(getattr(config_arg, \"packing\", False))\n        if blocked:\n            if hasattr(config_arg, \"packing\"):\n                setattr(config_arg, \"packing\", False)\n            if hasattr(config_arg, \"padding_free\"):\n                setattr(config_arg, \"padding_free\", False)\n\n        if blocked and requested_pack:\n            reason = \"custom data collator\"\n            if data_collator is None and isinstance(processing_class, ProcessorMixin):\n                reason = \"processor-based model\"\n            elif is_vlm:\n                reason = \"vision-language model\"\n            elif is_unsupported_model:\n                reason = f\"unsupported model type(s): {', '.join(model_types)}\"\n            message = \"Unsloth: Sample packing skipped \" f\"({reason} detected).\"\n            print(message)\n\n        packing_active = False\n        if _should_pack(config_arg) and not blocked:\n            configure_sample_packing(config_arg)\n            packing_active = True\n            logger.info(\"Unsloth: Sample packing enabled for SFTTrainer instance.\")\n\n        # Resolve padding_free: None (default) = auto-enable unless env-disabled or packing\n        auto_padding_free_active = False\n        padding_free_requested = getattr(config_arg, \"padding_free\", None) is True\n        if not blocked:\n            if padding_free_requested:\n                configure_padding_free(config_arg)\n            elif _should_auto_padding_free(config_arg):\n                configure_padding_free(config_arg)\n                auto_padding_free_active = True\n                logger.info(\n                    \"Unsloth: Padding-free batching auto-enabled for SFTTrainer instance.\"\n                )\n\n        try:\n            original_init(self, *args, **kwargs)\n        except ValueError as exc:\n            if packing_active and _should_skip_auto_packing_error(exc):\n                logger.info(\n                    \"Unsloth: Auto sample packing failed because trainer reported an incompatible setup (%s).\",\n                    exc,\n                )\n                _disable_sample_packing(config_arg)\n                packing_active = False\n                original_init(self, *args, **kwargs)\n            else:\n                raise\n\n        trainer_args = getattr(self, \"args\", None)\n        trainer_packing = bool(trainer_args and getattr(trainer_args, \"packing\", False))\n        trainer_padding_free = bool(\n            trainer_args and getattr(trainer_args, \"padding_free\", False)\n        )\n\n        if blocked and trainer_args is not None:\n            # Mirror the block on the trainer args to avoid re-enabling later\n            setattr(trainer_args, \"packing\", False)\n            setattr(trainer_args, \"padding_free\", False)\n\n        if (\n            not blocked\n            and trainer_packing\n            and (packing_active or _should_pack(trainer_args))\n        ):\n            enable_sample_packing(self.model, self)\n            print(\n                \"🦥 Unsloth: Packing enabled - training is >2x faster and uses less VRAM!\"\n            )\n        elif not blocked and trainer_padding_free:\n            enable_padding_free_metadata(self.model, self)\n            message = (\n                \"🦥 Unsloth: Padding-free auto-enabled, enabling faster training.\"\n                if auto_padding_free_active\n                else \"🦥 Unsloth: Padding-free enabled, enabling faster training.\"\n            )\n            print(message)\n\n    sft_trainer.__init__ = new_init\n    sft_trainer._unsloth_auto_packing_wrapped = True\n\n\ndef _patch_trl_trainer():\n    import trl\n\n    if hasattr(trl, \"__UNSLOTH_BACKWARDS_COMPATIBLE__\"):\n        return\n    if Version(trl) <= Version(\"0.11.0\"):\n        return\n\n    import trl.trainer\n\n    trl_classes = dir(trl.trainer)\n    trl_trainers = set(\n        x[: -len(\"Trainer\")] for x in trl_classes if x.endswith(\"Trainer\")\n    )\n    trl_configs = set(x[: -len(\"Config\")] for x in trl_classes if x.endswith(\"Config\"))\n    trl_classes = list(trl_trainers & trl_configs)\n\n    for x in trl_classes:\n        try:\n            exec(\n                f\"trl.{x}Trainer.__init__ = _backwards_compatible_trainer(trl.{x}Trainer, trl.{x}Config)\",\n                globals(),\n            )\n        except:\n            continue\n\n    _patch_sft_trainer_auto_packing(trl)\n\n    trl.__UNSLOTH_BACKWARDS_COMPATIBLE__ = True\n"
  },
  {
    "path": "unsloth/utils/__init__.py",
    "content": "# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.\n#\n# This program is free software: you can redistribute it and/or modify\n# it under the terms of the GNU Lesser General Public License as published by\n# the Free Software Foundation, either version 3 of the License, or\n# (at your option) any later version.\n#\n# This program is distributed in the hope that it will be useful,\n# but WITHOUT ANY WARRANTY; without even the implied warranty of\n# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the\n# GNU General Public License for more details.\n#\n# You should have received a copy of the GNU Lesser General Public License\n# along with this program.  If not, see <https://www.gnu.org/licenses/>.\n\nfrom .packing import (\n    configure_padding_free,\n    configure_sample_packing,\n    enable_padding_free_metadata,\n    enable_sample_packing,\n    mark_allow_overlength,\n)\nfrom .attention_dispatch import (\n    AttentionConfig,\n    AttentionContext,\n    FLASH_DENSE,\n    FLASH_VARLEN,\n    SDPA,\n    XFORMERS,\n    run_attention,\n    select_attention_backend,\n)\n\n__all__ = [\n    \"configure_sample_packing\",\n    \"configure_padding_free\",\n    \"enable_sample_packing\",\n    \"enable_padding_free_metadata\",\n    \"mark_allow_overlength\",\n    \"AttentionConfig\",\n    \"AttentionContext\",\n    \"FLASH_VARLEN\",\n    \"FLASH_DENSE\",\n    \"XFORMERS\",\n    \"SDPA\",\n    \"run_attention\",\n    \"select_attention_backend\",\n]\n"
  },
  {
    "path": "unsloth/utils/attention_dispatch.py",
    "content": "# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.\n#\n# This program is free software: you can redistribute it and/or modify\n# it under the terms of the GNU Lesser General Public License as published by\n# the Free Software Foundation, either version 3 of the License, or\n# (at your option) any later version.\n#\n# This program is distributed in the hope that it will be useful,\n# but WITHOUT ANY WARRANTY; without even the implied warranty of\n# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the\n# GNU General Public License for more details.\n#\n# You should have received a copy of the GNU Lesser General Public License\n# along with this program.  If not, see <https://www.gnu.org/licenses/>.\n\n\"\"\"Shared helpers for attention backend selection and execution.\"\"\"\n\nfrom __future__ import annotations\n\nfrom dataclasses import dataclass\nfrom typing import Any, Optional, Tuple\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn.functional import scaled_dot_product_attention\n\nfrom ..models._utils import *\nfrom ..utils.packing import (\n    build_sdpa_packed_attention_mask,\n    build_xformers_block_causal_mask,\n)\n\nif HAS_FLASH_ATTENTION:\n    from flash_attn import flash_attn_func, flash_attn_varlen_func\nHAS_XFORMERS = xformers is not None\n\n# xformers kernels (FA3, FA2, cutlass) only support compute capability <= 9.0.\n# Disable xformers on newer GPUs (e.g. RTX 5070 Ti / sm_120) and fall back to SDPA.\nif HAS_XFORMERS and torch.cuda.is_available():\n    _cc = torch.cuda.get_device_capability()\n    if _cc[0] >= 12:\n        HAS_XFORMERS = False\nSDPA_HAS_GQA = \"enable_gqa\" in (scaled_dot_product_attention.__doc__ or \"\")\n\nFLASH_VARLEN = \"flash_varlen\"\nFLASH_DENSE = \"flash_dense\"\nXFORMERS = \"xformers\"\nSDPA = \"sdpa\"\n\n\nXFORMERS_BLOCK_DIAG_CLS = (\n    xformers.attn_bias.BlockDiagonalCausalMask if HAS_XFORMERS else None\n)\n\n\n@dataclass\nclass AttentionConfig:\n    \"\"\"\n    Per-layer attention metadata.\n\n    NOTE(djsaunde): I had originally intended this to be populated once per layer, but\n        we're currently constructing it on every forward pass since it can possibly be\n        invalid from one forward pass to the next (e.g., switching from training to\n        inference). For now, I'm keeping separate from AttentionContext for the sake of\n        better grouping of params.\n    \"\"\"\n\n    backend: str\n    n_kv_heads: int\n    n_groups: int\n    flash_dense_kwargs: Optional[dict[str, Any]] = None\n    flash_varlen_kwargs: Optional[dict[str, Any]] = None\n    sdpa_kwargs: Optional[dict[str, Any]] = None\n    xformers_kwargs: Optional[dict[str, Any]] = None\n\n\n@dataclass\nclass AttentionContext:\n    \"\"\"Per-call info required to run attention.\"\"\"\n\n    bsz: int\n    q_len: int\n    kv_seq_len: int\n    n_heads: int\n    head_dim: int\n    requires_grad: bool\n    seq_info: Optional[Tuple[Tensor, Tensor, int]]\n    attention_mask: Optional[Tensor]\n    causal_mask: Optional[Any]\n    sliding_window: Optional[int] = None\n\n\ndef select_attention_backend(use_varlen: bool = False) -> str:\n    \"\"\"Return attention backend based on availability / priority order.\"\"\"\n\n    if HAS_FLASH_ATTENTION:\n        if use_varlen:\n            return FLASH_VARLEN\n        else:\n            return FLASH_DENSE\n    if HAS_XFORMERS:\n        return XFORMERS\n    return SDPA\n\n\ndef run_attention(\n    *,\n    config: AttentionConfig,\n    context: AttentionContext,\n    Q: Tensor,\n    K: Tensor,\n    V: Tensor,\n) -> Tensor:\n    \"\"\"\n    Run attention using config / context info.\n\n    Backend choice is prioritized for speed: FlashAttention when installed\n    (`flash_varlen` for packed/variable-length inputs with `seq_info`, otherwise dense\n    flash), then xFormers if flash is unavailable, with PyTorch SDPA as the final\n    fallback (e.g., CPU or no fused kernels).\n\n    Varlen flash is preferred when packing metadata is present because it avoids padding\n    and keeps peak memory low. xFormers and SDPA can also handle packed batches (we\n    pass a block-diagonal mask into each).\n    \"\"\"\n\n    backend = config.backend\n    if backend == FLASH_VARLEN and context.seq_info is None:\n        backend = FLASH_DENSE if HAS_FLASH_ATTENTION else SDPA\n\n    # [TODO] Flash attention does not support arbitrary attention masks (only\n    # causal via flag). When a padding mask is present (e.g. left-padded\n    # batched generation), fall back to SDPA which consumes attn_mask.\n    # xFormers also does not thread context.attention_mask through, so the\n    # same fallback applies.\n    if context.attention_mask is not None and backend in (\n        FLASH_DENSE,\n        FLASH_VARLEN,\n        XFORMERS,\n    ):\n        backend = SDPA\n\n    flash_dense_kwargs = config.flash_dense_kwargs or {}\n    flash_varlen_kwargs = config.flash_varlen_kwargs or {}\n    sdpa_kwargs = config.sdpa_kwargs or {}\n    xformers_kwargs = config.xformers_kwargs or {}\n\n    bsz = context.bsz\n    n_heads = context.n_heads\n    q_len = context.q_len\n    head_dim = context.head_dim\n    kv_seq_len = context.kv_seq_len\n    requires_grad = context.requires_grad\n    sliding_window = context.sliding_window\n\n    if backend == FLASH_VARLEN:\n        Q_f = Q.transpose(1, 2).reshape(bsz * q_len, n_heads, head_dim)\n        K_f = K.transpose(1, 2).reshape(bsz * q_len, config.n_kv_heads, head_dim)\n        V_f = V.transpose(1, 2).reshape(bsz * q_len, config.n_kv_heads, head_dim)\n        _, cu_seqlens, max_seqlen = context.seq_info\n        return flash_attn_varlen_func(\n            Q_f,\n            K_f,\n            V_f,\n            cu_seqlens,\n            cu_seqlens,\n            max_seqlen,\n            max_seqlen,\n            **flash_varlen_kwargs,\n        ).view(bsz, q_len, n_heads, head_dim)\n    elif backend == FLASH_DENSE:\n        Q_t = Q.transpose(1, 2)\n        K_t = K.transpose(1, 2)\n        V_t = V.transpose(1, 2)\n        return flash_attn_func(Q_t, K_t, V_t, **flash_dense_kwargs).reshape(\n            bsz, q_len, n_heads, head_dim\n        )\n    elif backend == XFORMERS:\n        attn_bias = build_xformers_block_causal_mask(\n            context.seq_info,\n            sliding_window = sliding_window,\n            base_mask = context.causal_mask,\n        )\n\n        Q_t = Q.transpose(1, 2)\n        K_t = K.transpose(1, 2)\n        V_t = V.transpose(1, 2)\n\n        K_mod = K_t\n        V_mod = V_t\n        Q_mod = Q_t\n\n        if config.n_groups != 1:\n            K_mod = K_t.view(bsz, kv_seq_len, config.n_kv_heads, 1, head_dim)\n            V_mod = V_t.view(bsz, kv_seq_len, config.n_kv_heads, 1, head_dim)\n            K_mod = K_mod.expand(\n                bsz, kv_seq_len, config.n_kv_heads, config.n_groups, head_dim\n            )\n            V_mod = V_mod.expand(\n                bsz, kv_seq_len, config.n_kv_heads, config.n_groups, head_dim\n            )\n\n            if requires_grad:\n                K_mod = K_mod.reshape(bsz, kv_seq_len, n_heads, head_dim)\n                V_mod = V_mod.reshape(bsz, kv_seq_len, n_heads, head_dim)\n            else:\n                Q_mod = Q_t.view(\n                    bsz, q_len, config.n_kv_heads, config.n_groups, head_dim\n                )\n\n        has_block = XFORMERS_BLOCK_DIAG_CLS is not None and isinstance(\n            attn_bias, XFORMERS_BLOCK_DIAG_CLS\n        )\n\n        if config.n_groups != 1 and has_block:\n            if not requires_grad:\n                Q_mod = Q_mod.view(\n                    1, bsz * q_len, config.n_kv_heads, config.n_groups, head_dim\n                )\n                K_mod = K_mod.view(\n                    1, bsz * kv_seq_len, config.n_kv_heads, config.n_groups, head_dim\n                )\n                V_mod = V_mod.view(\n                    1, bsz * kv_seq_len, config.n_kv_heads, config.n_groups, head_dim\n                )\n            else:\n                Q_mod = Q_mod.view(1, bsz * q_len, n_heads, head_dim)\n                K_mod = K_mod.view(1, bsz * kv_seq_len, n_heads, head_dim)\n                V_mod = V_mod.view(1, bsz * kv_seq_len, n_heads, head_dim)\n\n        out = xformers_attention(\n            Q_mod,\n            K_mod,\n            V_mod,\n            attn_bias = attn_bias,\n            **xformers_kwargs,\n        )\n\n        if config.n_groups != 1 and not requires_grad:\n            out = out.view(bsz, q_len, config.n_kv_heads, config.n_groups, head_dim)\n            out = out.reshape(bsz, q_len, n_heads, head_dim)\n        else:\n            out = out.view(bsz, q_len, n_heads, head_dim)\n        return out\n    else:\n        local_mask = context.attention_mask\n        is_causal_local = False\n        if context.seq_info is not None and local_mask is None:\n            local_mask = build_sdpa_packed_attention_mask(\n                context.seq_info,\n                dtype = Q.dtype,\n                device = Q.device,\n                sliding_window = sliding_window,\n            )\n        else:\n            q_len_local = Q.shape[-2]\n            k_len_local = K.shape[-2]\n            # ---- SDPA mask normalization for left padding / 2D masks ----\n            if local_mask is not None and isinstance(local_mask, torch.Tensor):\n                local_mask = local_mask.to(device = Q.device)\n\n                if local_mask.dim() == 2:\n                    # key padding keep mask: (bsz, k_len), 1/True = real token\n                    if local_mask.dtype == torch.bool:\n                        key_keep = local_mask\n                    else:\n                        # tokenizer attention_mask is typically int 0/1\n                        key_keep = local_mask != 0\n\n                    past_len = (\n                        k_len_local - q_len_local\n                    )  # works for prefill (0) and decode\n                    q_pos = torch.arange(\n                        past_len, past_len + q_len_local, device = Q.device\n                    )\n                    k_pos = torch.arange(k_len_local, device = Q.device)\n\n                    causal_keep = (\n                        k_pos[None, :] <= q_pos[:, None]\n                    )  # True = allowed (SDPA)\n                    if sliding_window is not None:\n                        causal_keep &= k_pos[None, :] >= (\n                            q_pos[:, None] - (sliding_window - 1)\n                        )\n\n                    # (bsz, 1, q_len, k_len) boolean keep mask\n                    local_mask = (\n                        causal_keep[None, None, :, :] & key_keep[:, None, None, :]\n                    )\n\n                elif local_mask.dim() == 3:\n                    # (bsz, q_len, k_len) -> (bsz, 1, q_len, k_len)\n                    local_mask = local_mask[:, None, :, :]\n\n                elif local_mask.dim() == 4:\n                    if local_mask.dtype != torch.bool:\n                        # Use boolean keep masks for better SDPA stability.\n                        local_mask = local_mask.eq(0)\n                else:\n                    raise ValueError(\n                        f\"Unsupported SDPA attention_mask rank: {local_mask.dim()}\"\n                    )\n\n                # Avoid NaNs from fully-masked rows (common with left padding).\n                if local_mask.dtype == torch.bool:\n                    no_allowed = ~local_mask.any(\n                        dim = -1, keepdim = True\n                    )  # (bsz,1,q_len,1)\n                    local_mask = local_mask | no_allowed\n\n            is_causal_local = local_mask is None and q_len_local == k_len_local\n\n        kwargs = dict(sdpa_kwargs)\n        kwargs.setdefault(\"attn_mask\", local_mask)\n        kwargs.setdefault(\"is_causal\", is_causal_local)\n\n        use_sdpa_gqa = SDPA_HAS_GQA and config.n_groups != 1\n        if (\n            use_sdpa_gqa\n            and (not requires_grad)\n            and isinstance(local_mask, torch.Tensor)\n            and local_mask.dim() >= 3\n            and local_mask.shape[0] > 1\n        ):\n            # Batched masked inference has shown row-coupled drift with SDPA GQA.\n            # Fall back to explicit KV expansion for deterministic row-wise behavior.\n            use_sdpa_gqa = False\n\n        if use_sdpa_gqa:\n            kwargs.setdefault(\"enable_gqa\", True)\n            out = scaled_dot_product_attention(Q, K, V, **kwargs)\n            return out.transpose(1, 2)\n\n        K_mod = K\n        V_mod = V\n        if config.n_groups != 1:\n            K_mod = K[:, :, None, :, :].expand(\n                bsz, config.n_kv_heads, config.n_groups, kv_seq_len, head_dim\n            )\n            V_mod = V[:, :, None, :, :].expand(\n                bsz, config.n_kv_heads, config.n_groups, kv_seq_len, head_dim\n            )\n            K_mod = K_mod.reshape(bsz, n_heads, kv_seq_len, head_dim)\n            V_mod = V_mod.reshape(bsz, n_heads, kv_seq_len, head_dim)\n\n        out = scaled_dot_product_attention(\n            Q.contiguous(),\n            K_mod.contiguous(),\n            V_mod.contiguous(),\n            **kwargs,\n        )\n        return out.transpose(1, 2).contiguous()\n\n\n__all__ = [\n    \"AttentionConfig\",\n    \"AttentionContext\",\n    \"select_attention_backend\",\n    \"run_attention\",\n]\n"
  },
  {
    "path": "unsloth/utils/hf_hub.py",
    "content": "from huggingface_hub import HfApi, ModelInfo\n\n_HFAPI: HfApi = None\n\nPOPULARITY_PROPERTIES = [\n    \"downloads\",\n    \"downloadsAllTime\",\n    \"trendingScore\",\n    \"likes\",\n]\nTHOUSAND = 1000\nMILLION = 1000000\nBILLION = 1000000000\n\n\ndef formatted_int(value: int) -> str:\n    if value < THOUSAND:\n        return str(value)\n    elif value < MILLION:\n        return f\"{float(value) / 1000:,.1f}K\"\n    elif value < BILLION:\n        return f\"{float(value) / 1000000:,.1f}M\"\n    else:\n        return f\"{float(value) / 1000000000:,.1f}B\"\n\n\ndef get_model_info(\n    model_id: str, properties: list[str] = [\"safetensors\", \"lastModified\"]\n) -> ModelInfo:\n    \"\"\"\n    Get the model info for a specific model.\n\n    properties: list[str] = See https://huggingface.co/docs/huggingface_hub/api-ref/hf_hub/hf_api/model_info\n    Default properties: [\"safetensors\", \"lastModified\"], only retrieves minimal information.\n    Set to None to retrieve the full model information.\n    \"\"\"\n    global _HFAPI\n    if _HFAPI is None:\n        _HFAPI = HfApi()\n    try:\n        model_info: ModelInfo = _HFAPI.model_info(model_id, expand = properties)\n    except Exception as e:\n        print(f\"Error getting model info for {model_id}: {e}\")\n        model_info = None\n    return model_info\n\n\ndef list_models(\n    properties: list[str] = None,\n    full: bool = False,\n    sort: str = \"downloads\",\n    author: str = \"unsloth\",\n    search: str = None,\n    limit: int = 10,\n) -> list[ModelInfo]:\n    \"\"\"\n    Retrieve model information from the Hugging Face Hub.\n\n    properties: list[str] = See https://huggingface.co/docs/huggingface_hub/api-ref/hf_hub/hf_api/list_models\n    full: bool = Whether to retrieve the full model information, if True properties will be ignored.\n    sort: str = The sort order.\n    author: str = The author of the model.\n    search: str = The search query for filtering models.\n\n    \"\"\"\n    global _HFAPI\n    if _HFAPI is None:\n        _HFAPI = HfApi()\n    if full:\n        properties = None\n\n    models: list[ModelInfo] = _HFAPI.list_models(\n        author = author,\n        search = search,\n        sort = sort,\n        limit = limit,\n        expand = properties,\n        full = full,\n    )\n    return models\n"
  },
  {
    "path": "unsloth/utils/packing.py",
    "content": "# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.\n#\n# This program is free software: you can redistribute it and/or modify\n# it under the terms of the GNU Lesser General Public License as published by\n# the Free Software Foundation, either version 3 of the License, or\n# (at your option) any later version.\n#\n# This program is distributed in the hope that it will be useful,\n# but WITHOUT ANY WARRANTY; without even the implied warranty of\n# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the\n# GNU General Public License for more details.\n#\n# You should have received a copy of the GNU Lesser General Public License\n# along with this program.  If not, see <https://www.gnu.org/licenses/>.\n\n\"\"\"Utilities for enabling packed (padding-free) batches across Unsloth.\"\"\"\n\nfrom __future__ import annotations\n\nimport logging\nfrom collections import OrderedDict\nfrom typing import Any, Iterable, Optional, Sequence, Tuple\n\nimport torch\n\ntry:\n    from xformers.ops.fmha.attn_bias import (\n        BlockDiagonalCausalMask as _XFormersBlockMask,\n    )\nexcept Exception:\n    try:\n        from xformers.attn_bias import BlockDiagonalCausalMask as _XFormersBlockMask\n    except Exception:\n        _XFormersBlockMask = None\n\n_XFORMERS_MASK_CACHE_MAXSIZE = 32\n_XFORMERS_MASK_CACHE: OrderedDict[Tuple[Tuple[int, ...], int], Any] = OrderedDict()\n\n# Cache per device for get_packed_info_from_kwargs to avoid repeated D2H sync across layers\n_PACKED_INFO_CACHE: dict = {}\n\n# Cache per device for build_sdpa_packed_attention_mask to avoid repeated D2H sync across layers\n_SDPA_MASK_CACHE: dict = {}\n\n# Cache per device for build_xformers_block_causal_mask to avoid repeated D2H sync across layers\n_XFORMERS_BLOCK_MASK_CACHE: dict = {}\n\n\ndef _window_cache_key(sliding_window: Optional[int]) -> int:\n    if sliding_window is None or sliding_window <= 0:\n        return 0\n    return int(sliding_window)\n\n\ndef _get_cached_block_mask(\n    lengths: Tuple[int, ...],\n    sliding_window: Optional[int],\n):\n    if _XFormersBlockMask is None:\n        return None\n\n    window_key = _window_cache_key(sliding_window)\n    cache_key = (lengths, window_key)\n    cached = _XFORMERS_MASK_CACHE.get(cache_key)\n    if cached is not None:\n        _XFORMERS_MASK_CACHE.move_to_end(cache_key)\n        return cached\n\n    mask = _XFormersBlockMask.from_seqlens(list(lengths))\n    if window_key and mask is not None and hasattr(mask, \"make_local_attention\"):\n        mask = mask.make_local_attention(window_size = window_key)\n\n    _XFORMERS_MASK_CACHE[cache_key] = mask\n    if len(_XFORMERS_MASK_CACHE) > _XFORMERS_MASK_CACHE_MAXSIZE:\n        _XFORMERS_MASK_CACHE.popitem(last = False)\n    return mask\n\n\nclass _TrlPackingWarningFilter(logging.Filter):\n    to_filter = (\n        \"attention implementation is not\",\n        \"kernels-community\",\n    )\n\n    def filter(self, record: logging.LogRecord) -> bool:\n        message = record.getMessage()\n        return not any(substring in message for substring in self.to_filter)\n\n\n_TRL_FILTER_INSTALLED = False\n\n\ndef _ensure_trl_warning_filter():\n    global _TRL_FILTER_INSTALLED\n    if _TRL_FILTER_INSTALLED:\n        return\n    logging.getLogger(\"trl.trainer.sft_trainer\").addFilter(_TrlPackingWarningFilter())\n    _TRL_FILTER_INSTALLED = True\n\n\ndef mark_allow_overlength(module):\n    \"\"\"Mark a module hierarchy so padding-free batches can exceed max_seq_length.\"\"\"\n    if module is None:\n        return\n    if hasattr(module, \"max_seq_length\"):\n        setattr(module, \"_unsloth_allow_packed_overlength\", True)\n    children = getattr(module, \"children\", None)\n    if children is None:\n        return\n    for child in children():\n        mark_allow_overlength(child)\n\n\ndef configure_sample_packing(config):\n    \"\"\"Mutate an ``SFTConfig`` so TRL prepares packed batches.\"\"\"\n    _ensure_trl_warning_filter()\n    setattr(config, \"packing\", True)\n    setattr(config, \"padding_free\", True)\n    setattr(config, \"remove_unused_columns\", False)\n\n\ndef configure_padding_free(config):\n    \"\"\"Mutate an ``SFTConfig`` so TRL enables padding-free batching without packing.\"\"\"\n    _ensure_trl_warning_filter()\n    setattr(config, \"padding_free\", True)\n    setattr(config, \"remove_unused_columns\", False)\n\n\ndef enable_sample_packing(\n    model,\n    trainer,\n    *,\n    sequence_lengths_key: str = \"seq_lengths\",\n) -> None:\n    \"\"\"Enable runtime support for packed batches on an existing trainer.\"\"\"\n    if model is None or trainer is None:\n        raise ValueError(\"model and trainer must not be None\")\n\n    mark_allow_overlength(model)\n\n    if hasattr(trainer, \"args\") and hasattr(trainer.args, \"remove_unused_columns\"):\n        trainer.args.remove_unused_columns = False\n\n    collator = getattr(trainer, \"data_collator\", None)\n    if collator is None or not hasattr(collator, \"torch_call\"):\n        return\n    if getattr(collator, \"_unsloth_packing_wrapped\", False):\n        return\n\n    if hasattr(collator, \"padding_free\"):\n        collator.padding_free = True\n    if hasattr(collator, \"return_position_ids\"):\n        collator.return_position_ids = True\n\n    original_torch_call = collator.torch_call\n\n    def torch_call_with_lengths(examples: Sequence[dict]):\n        batch = original_torch_call(examples)\n        if examples and isinstance(examples[0], dict):\n            seq_lengths: list[int] = []\n            for example in examples:\n                lengths = example.get(sequence_lengths_key)\n                if isinstance(lengths, Iterable):\n                    seq_lengths.extend(int(length) for length in lengths)\n            # Fallback: infer lengths from tokenized inputs when metadata is absent\n            if not seq_lengths:\n                for example in examples:\n                    ids = example.get(\"input_ids\")\n                    if isinstance(ids, Iterable):\n                        seq_lengths.append(len(ids))\n            if seq_lengths:\n                batch[\"packed_seq_lengths\"] = torch.tensor(\n                    seq_lengths, dtype = torch.int32\n                )\n                if \"attention_mask\" in batch:\n                    batch.pop(\"attention_mask\")\n        return batch\n\n    collator.torch_call = torch_call_with_lengths\n    collator._unsloth_packing_wrapped = True\n\n\ndef enable_padding_free_metadata(model, trainer):\n    \"\"\"Inject seq-length metadata when padding-free batching is enabled without packing.\"\"\"\n    collator = getattr(trainer, \"data_collator\", None)\n    if (\n        collator is None\n        or getattr(collator, \"_unsloth_padding_free_lengths_wrapped\", False)\n        or not getattr(collator, \"padding_free\", False)\n    ):\n        return\n\n    mark_allow_overlength(model)\n    if hasattr(collator, \"return_position_ids\"):\n        collator.return_position_ids = True\n    if hasattr(trainer, \"args\") and hasattr(trainer.args, \"remove_unused_columns\"):\n        trainer.args.remove_unused_columns = False\n\n    original_torch_call = collator.torch_call\n\n    def torch_call_with_padding_free_metadata(examples: Sequence[dict]):\n        seq_lengths: list[int] = []\n        if examples and isinstance(examples[0], dict):\n            for example in examples:\n                lengths = example.get(\"seq_lengths\")\n                if lengths is None:\n                    ids = example.get(\"input_ids\")\n                    if ids is None:\n                        continue\n                    lengths = [len(ids)]\n                    example[\"seq_lengths\"] = lengths\n                seq_lengths.extend(lengths)\n\n        batch = original_torch_call(examples)\n        if seq_lengths:\n            batch[\"packed_seq_lengths\"] = torch.tensor(\n                seq_lengths,\n                dtype = torch.int32,\n            )\n        return batch\n\n    collator.torch_call = torch_call_with_padding_free_metadata\n    collator._unsloth_padding_free_lengths_wrapped = True\n\n\ndef get_packed_info_from_kwargs(\n    kwargs: dict,\n    device: torch.device,\n) -> Optional[Tuple[torch.Tensor, torch.Tensor, int]]:\n    \"\"\"Return packed sequence metadata expected by the attention kernels.\"\"\"\n\n    seq_lengths = kwargs.get(\"packed_seq_lengths\")\n    if seq_lengths is None:\n        return None\n\n    entry = _PACKED_INFO_CACHE.get(device)\n    if entry is not None and entry[\"seq_lengths\"] is seq_lengths:\n        return entry[\"result\"]\n\n    lengths = seq_lengths.to(device = device, dtype = torch.int32, non_blocking = True)\n    cu_seqlens = torch.zeros(lengths.numel() + 1, dtype = torch.int32, device = device)\n    torch.cumsum(lengths, dim = 0, dtype = torch.int32, out = cu_seqlens[1:])\n\n    max_seqlen = int(lengths.max().item())\n    result = (lengths, cu_seqlens, max_seqlen)\n    _PACKED_INFO_CACHE[device] = {\"seq_lengths\": seq_lengths, \"result\": result}\n    return result\n\n\ndef build_xformers_block_causal_mask(\n    seq_info: Optional[Tuple[torch.Tensor, torch.Tensor, int]],\n    *,\n    sliding_window: Optional[int] = None,\n    base_mask: Optional[Any] = None,\n):\n    if _XFormersBlockMask is None:\n        return None\n    if seq_info is not None:\n        seq_lengths, _, _ = seq_info\n        # Cache the mask to avoid repeated D2H sync across layers\n        device = seq_lengths.device\n        params = (sliding_window,)\n        entry = _XFORMERS_BLOCK_MASK_CACHE.get(device)\n        if (\n            entry is not None\n            and entry[\"seq_lengths\"] is seq_lengths\n            and entry[\"params\"] == params\n        ):\n            return entry[\"mask\"]\n\n        lengths_tensor = seq_lengths.to(\"cpu\", torch.int32)\n        if lengths_tensor.numel() == 0:\n            return None\n        lengths = tuple(int(x) for x in lengths_tensor.tolist())\n        mask = _get_cached_block_mask(lengths, sliding_window)\n\n        _XFORMERS_BLOCK_MASK_CACHE[device] = {\n            \"seq_lengths\": seq_lengths,\n            \"params\": params,\n            \"mask\": mask,\n        }\n    else:\n        mask = base_mask\n\n        if (\n            sliding_window is not None\n            and sliding_window > 0\n            and mask is not None\n            and hasattr(mask, \"make_local_attention\")\n        ):\n            mask = mask.make_local_attention(window_size = sliding_window)\n    return mask\n\n\ndef build_sdpa_packed_attention_mask(\n    seq_info: Tuple[torch.Tensor, torch.Tensor, int],\n    *,\n    dtype: torch.dtype,\n    device: torch.device,\n    sliding_window: Optional[int] = None,\n) -> torch.Tensor:\n    seq_lengths, _, _ = seq_info\n\n    params = (dtype, sliding_window)\n    entry = _SDPA_MASK_CACHE.get(device)\n    if (\n        entry is not None\n        and entry[\"seq_lengths\"] is seq_lengths\n        and entry[\"params\"] == params\n    ):\n        return entry[\"mask\"]\n\n    total_tokens = int(seq_lengths.sum().item())\n    mask = torch.full(\n        (total_tokens, total_tokens),\n        float(\"-inf\"),\n        dtype = dtype,\n        device = device,\n    )\n    offset = 0\n    for length in seq_lengths.tolist():\n        length = int(length)\n        if length <= 0:\n            continue\n        block = torch.zeros((length, length), dtype = dtype, device = device)\n        upper = torch.triu(\n            torch.ones((length, length), device = device), diagonal = 1\n        ).bool()\n        block = block.masked_fill(upper, float(\"-inf\"))\n        if (\n            sliding_window is not None\n            and sliding_window > 0\n            and length > sliding_window\n        ):\n            idx = torch.arange(length, device = device)\n            dist = idx.unsqueeze(1) - idx.unsqueeze(0)\n            window_mask = dist >= sliding_window\n            block = block.masked_fill(window_mask, float(\"-inf\"))\n        mask[offset : offset + length, offset : offset + length] = block\n        offset += length\n\n    result = mask.unsqueeze(0).unsqueeze(0)\n    _SDPA_MASK_CACHE[device] = {\n        \"seq_lengths\": seq_lengths,\n        \"params\": params,\n        \"mask\": result,\n    }\n    return result\n\n\ndef _normalize_packed_lengths(\n    seq_lengths: Any,\n    *,\n    device: torch.device,\n) -> Optional[torch.Tensor]:\n    if seq_lengths is None:\n        return None\n    if isinstance(seq_lengths, torch.Tensor):\n        lengths = seq_lengths.to(device = device, dtype = torch.int64)\n    else:\n        lengths = torch.tensor(seq_lengths, device = device, dtype = torch.int64)\n    if lengths.ndim != 1:\n        lengths = lengths.reshape(-1)\n    if lengths.numel() == 0:\n        return None\n    return lengths\n\n\ndef mask_packed_sequence_boundaries(\n    shift_labels: torch.Tensor,\n    seq_lengths: Any,\n    *,\n    ignore_index: int = -100,\n) -> bool:\n    \"\"\"Mark final token of every packed sample so CE ignores boundary predictions.\"\"\"\n    lengths = _normalize_packed_lengths(seq_lengths, device = shift_labels.device)\n    if lengths is None:\n        return False\n\n    flat = shift_labels.reshape(-1)\n    total_tokens = flat.shape[0]\n    boundary_positions = torch.cumsum(lengths, dim = 0) - 1\n    valid = boundary_positions < total_tokens\n    if not torch.all(valid):\n        boundary_positions = boundary_positions[valid]\n    if boundary_positions.numel() == 0:\n        return False\n    flat[boundary_positions] = ignore_index\n    return True\n\n\ndef clear_packed_caches():\n    \"\"\"Release cached masks/metadata to free device memory.\"\"\"\n    _PACKED_INFO_CACHE.clear()\n    _SDPA_MASK_CACHE.clear()\n    _XFORMERS_BLOCK_MASK_CACHE.clear()\n\n\n__all__ = [\n    \"configure_sample_packing\",\n    \"configure_padding_free\",\n    \"enable_sample_packing\",\n    \"enable_padding_free_metadata\",\n    \"mark_allow_overlength\",\n    \"get_packed_info_from_kwargs\",\n    \"build_xformers_block_causal_mask\",\n    \"build_sdpa_packed_attention_mask\",\n    \"mask_packed_sequence_boundaries\",\n    \"clear_packed_caches\",\n]\n"
  },
  {
    "path": "unsloth-cli.py",
    "content": "#!/usr/bin/env python3\n\n\"\"\"\n🦥 Starter Script for Fine-Tuning FastLanguageModel with Unsloth\n\nThis script is designed as a starting point for fine-tuning your models using unsloth.\nIt includes configurable options for model loading, PEFT parameters, training arguments, \nand model saving/pushing functionalities.\n\nYou will likely want to customize this script to suit your specific use case \nand requirements.\n\nHere are a few suggestions for customization:\n    - Modify the dataset loading and preprocessing steps to match your data.\n    - Customize the model saving and pushing configurations.\n\nUsage: (most of the options have valid default values this is an extended example for demonstration purposes)\n    python unsloth-cli.py --model_name \"unsloth/llama-3-8b\" --max_seq_length 8192 --dtype None --load_in_4bit \\\n    --r 64 --lora_alpha 32 --lora_dropout 0.1 --bias \"none\" --use_gradient_checkpointing \"unsloth\" \\\n    --random_state 3407 --use_rslora --per_device_train_batch_size 4 --gradient_accumulation_steps 8 \\\n    --warmup_steps 5 --max_steps 400 --learning_rate 2e-6 --logging_steps 1 --optim \"adamw_8bit\" \\\n    --weight_decay 0.005 --lr_scheduler_type \"linear\" --seed 3407 --output_dir \"outputs\" \\\n    --report_to \"tensorboard\" --save_model --save_path \"model\" --quantization_method \"f16\" \\\n    --push_model --hub_path \"hf/model\" --hub_token \"your_hf_token\"\n\nTo see a full list of configurable options, use:\n    python unsloth-cli.py --help\n\nHappy fine-tuning!\n\"\"\"\n\nimport argparse\nimport os\n\n\ndef run(args):\n    from unsloth import FastLanguageModel\n    from datasets import load_dataset\n    from transformers.utils import strtobool\n    from trl import SFTTrainer, SFTConfig\n    from unsloth import is_bfloat16_supported\n    from unsloth.models.loader_utils import prepare_device_map\n    import logging\n    from unsloth import RawTextDataLoader\n\n    logging.getLogger(\"hf-to-gguf\").setLevel(logging.WARNING)\n\n    # Load model and tokenizer\n    device_map, distributed = prepare_device_map()\n    model, tokenizer = FastLanguageModel.from_pretrained(\n        model_name = args.model_name,\n        max_seq_length = args.max_seq_length,\n        dtype = args.dtype,\n        load_in_4bit = args.load_in_4bit,\n        device_map = device_map,\n    )\n\n    # Configure PEFT model\n    model = FastLanguageModel.get_peft_model(\n        model,\n        r = args.r,\n        target_modules = [\n            \"q_proj\",\n            \"k_proj\",\n            \"v_proj\",\n            \"o_proj\",\n            \"gate_proj\",\n            \"up_proj\",\n            \"down_proj\",\n        ],\n        lora_alpha = args.lora_alpha,\n        lora_dropout = args.lora_dropout,\n        bias = args.bias,\n        use_gradient_checkpointing = args.use_gradient_checkpointing,\n        random_state = args.random_state,\n        use_rslora = args.use_rslora,\n        loftq_config = args.loftq_config,\n    )\n\n    alpaca_prompt = \"\"\"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n    ### Instruction:\n    {}\n\n    ### Input:\n    {}\n\n    ### Response:\n    {}\"\"\"\n\n    EOS_TOKEN = tokenizer.eos_token  # Must add EOS_TOKEN\n\n    def formatting_prompts_func(examples):\n        instructions = examples[\"instruction\"]\n        inputs = examples[\"input\"]\n        outputs = examples[\"output\"]\n        texts = []\n        for instruction, input, output in zip(instructions, inputs, outputs):\n            text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN\n            texts.append(text)\n        return {\"text\": texts}\n\n    def load_dataset_smart(args):\n        from transformers.utils import strtobool\n\n        if args.raw_text_file:\n            # Use raw text loader\n            loader = RawTextDataLoader(tokenizer, args.chunk_size, args.stride)\n            dataset = loader.load_from_file(args.raw_text_file)\n        elif args.dataset.endswith((\".txt\", \".md\", \".json\", \".jsonl\")):\n            # Auto-detect local raw text files\n            loader = RawTextDataLoader(tokenizer)\n            dataset = loader.load_from_file(args.dataset)\n        else:\n            # Check for modelscope usage\n            use_modelscope = strtobool(\n                os.environ.get(\"UNSLOTH_USE_MODELSCOPE\", \"False\")\n            )\n            if use_modelscope:\n                from modelscope import MsDataset\n\n                dataset = MsDataset.load(args.dataset, split = \"train\")\n            else:\n                # Existing HuggingFace dataset logic\n                dataset = load_dataset(args.dataset, split = \"train\")\n\n            # Apply formatting for structured datasets\n            dataset = dataset.map(formatting_prompts_func, batched = True)\n        return dataset\n\n    # Load dataset using smart loader\n    dataset = load_dataset_smart(args)\n    print(\"Data is formatted and ready!\")\n\n    # Configure training arguments\n    training_args = SFTConfig(\n        per_device_train_batch_size = args.per_device_train_batch_size,\n        per_device_eval_batch_size = args.per_device_eval_batch_size,\n        gradient_accumulation_steps = args.gradient_accumulation_steps,\n        warmup_steps = args.warmup_steps,\n        max_steps = args.max_steps,\n        learning_rate = args.learning_rate,\n        fp16 = not is_bfloat16_supported(),\n        bf16 = is_bfloat16_supported(),\n        logging_steps = args.logging_steps,\n        optim = args.optim,\n        weight_decay = args.weight_decay,\n        lr_scheduler_type = args.lr_scheduler_type,\n        seed = args.seed,\n        output_dir = args.output_dir,\n        report_to = args.report_to,\n        max_length = args.max_seq_length,\n        dataset_num_proc = 2,\n        ddp_find_unused_parameters = False if distributed else None,\n        packing = args.packing,\n    )\n\n    # Initialize trainer\n    trainer = SFTTrainer(\n        model = model,\n        processing_class = tokenizer,\n        train_dataset = dataset,\n        args = training_args,\n    )\n\n    trainer.train()\n\n    # Save model\n    if args.save_model:\n        # if args.quantization_method is a list, we will save the model for each quantization method\n        if args.save_gguf:\n            if isinstance(args.quantization, list):\n                for quantization_method in args.quantization:\n                    print(\n                        f\"Saving model with quantization method: {quantization_method}\"\n                    )\n                    model.save_pretrained_gguf(\n                        args.save_path,\n                        tokenizer,\n                        quantization_method = quantization_method,\n                    )\n                    if args.push_model:\n                        model.push_to_hub_gguf(\n                            hub_path = args.hub_path,\n                            hub_token = args.hub_token,\n                            quantization_method = quantization_method,\n                        )\n            else:\n                print(f\"Saving model with quantization method: {args.quantization}\")\n                model.save_pretrained_gguf(\n                    args.save_path,\n                    tokenizer,\n                    quantization_method = args.quantization,\n                )\n                if args.push_model:\n                    model.push_to_hub_gguf(\n                        hub_path = args.hub_path,\n                        hub_token = args.hub_token,\n                        quantization_method = args.quantization,\n                    )\n        else:\n            model.save_pretrained_merged(args.save_path, tokenizer, args.save_method)\n            if args.push_model:\n                model.push_to_hub_merged(args.save_path, tokenizer, args.hub_token)\n    else:\n        print(\"Warning: The model is not saved!\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\n        description = \"🦥 Fine-tune your llm faster using unsloth!\"\n    )\n\n    model_group = parser.add_argument_group(\"🤖 Model Options\")\n    model_group.add_argument(\n        \"--model_name\",\n        type = str,\n        default = \"unsloth/llama-3-8b\",\n        help = \"Model name to load\",\n    )\n    model_group.add_argument(\n        \"--max_seq_length\",\n        type = int,\n        default = 2048,\n        help = \"Maximum sequence length, default is 2048. We auto support RoPE Scaling internally!\",\n    )\n    model_group.add_argument(\n        \"--dtype\",\n        type = str,\n        default = None,\n        help = \"Data type for model (None for auto detection)\",\n    )\n    model_group.add_argument(\n        \"--load_in_4bit\",\n        action = \"store_true\",\n        help = \"Use 4bit quantization to reduce memory usage\",\n    )\n    model_group.add_argument(\n        \"--dataset\",\n        type = str,\n        default = \"yahma/alpaca-cleaned\",\n        help = \"Huggingface dataset to use for training\",\n    )\n\n    lora_group = parser.add_argument_group(\n        \"🧠 LoRA Options\",\n        \"These options are used to configure the LoRA model.\",\n    )\n    lora_group.add_argument(\n        \"--r\",\n        type = int,\n        default = 16,\n        help = \"Rank for Lora model, default is 16.  (common values: 8, 16, 32, 64, 128)\",\n    )\n    lora_group.add_argument(\n        \"--lora_alpha\",\n        type = int,\n        default = 16,\n        help = \"LoRA alpha parameter, default is 16. (common values: 8, 16, 32, 64, 128)\",\n    )\n    lora_group.add_argument(\n        \"--lora_dropout\",\n        type = float,\n        default = 0.0,\n        help = \"LoRA dropout rate, default is 0.0 which is optimized.\",\n    )\n    lora_group.add_argument(\n        \"--bias\",\n        type = str,\n        default = \"none\",\n        help = \"Bias setting for LoRA\",\n    )\n    lora_group.add_argument(\n        \"--use_gradient_checkpointing\",\n        type = str,\n        default = \"unsloth\",\n        help = \"Use gradient checkpointing\",\n    )\n    lora_group.add_argument(\n        \"--random_state\",\n        type = int,\n        default = 3407,\n        help = \"Random state for reproducibility, default is 3407.\",\n    )\n    lora_group.add_argument(\n        \"--use_rslora\",\n        action = \"store_true\",\n        help = \"Use rank stabilized LoRA\",\n    )\n    lora_group.add_argument(\n        \"--loftq_config\",\n        type = str,\n        default = None,\n        help = \"Configuration for LoftQ\",\n    )\n\n    training_group = parser.add_argument_group(\"🎓 Training Options\")\n    training_group.add_argument(\n        \"--per_device_train_batch_size\",\n        type = int,\n        default = 2,\n        help = \"Batch size per device during training, default is 2.\",\n    )\n    training_group.add_argument(\n        \"--per_device_eval_batch_size\",\n        type = int,\n        default = 4,\n        help = \"Batch size per device during evaluation, default is 4.\",\n    )\n    training_group.add_argument(\n        \"--gradient_accumulation_steps\",\n        type = int,\n        default = 4,\n        help = \"Number of gradient accumulation steps, default is 4.\",\n    )\n    training_group.add_argument(\n        \"--warmup_steps\",\n        type = int,\n        default = 5,\n        help = \"Number of warmup steps, default is 5.\",\n    )\n    training_group.add_argument(\n        \"--max_steps\",\n        type = int,\n        default = 400,\n        help = \"Maximum number of training steps.\",\n    )\n    training_group.add_argument(\n        \"--learning_rate\",\n        type = float,\n        default = 2e-4,\n        help = \"Learning rate, default is 2e-4.\",\n    )\n    training_group.add_argument(\n        \"--optim\",\n        type = str,\n        default = \"adamw_8bit\",\n        help = \"Optimizer type.\",\n    )\n    training_group.add_argument(\n        \"--weight_decay\",\n        type = float,\n        default = 0.01,\n        help = \"Weight decay, default is 0.01.\",\n    )\n    training_group.add_argument(\n        \"--lr_scheduler_type\",\n        type = str,\n        default = \"linear\",\n        help = \"Learning rate scheduler type, default is 'linear'.\",\n    )\n    training_group.add_argument(\n        \"--seed\",\n        type = int,\n        default = 3407,\n        help = \"Seed for reproducibility, default is 3407.\",\n    )\n    training_group.add_argument(\n        \"--packing\",\n        action = \"store_true\",\n        help = \"Enable padding-free sample packing via TRL's bin packer.\",\n    )\n\n    report_group = parser.add_argument_group(\"📊 Report Options\")\n    report_group.add_argument(\n        \"--report_to\",\n        type = str,\n        default = \"tensorboard\",\n        choices = [\n            \"azure_ml\",\n            \"clearml\",\n            \"codecarbon\",\n            \"comet_ml\",\n            \"dagshub\",\n            \"dvclive\",\n            \"flyte\",\n            \"mlflow\",\n            \"neptune\",\n            \"tensorboard\",\n            \"wandb\",\n            \"all\",\n            \"none\",\n        ],\n        help = (\n            \"The list of integrations to report the results and logs to. Supported platforms are:\\n\\t\\t \"\n            \"'azure_ml', 'clearml', 'codecarbon', 'comet_ml', 'dagshub', 'dvclive', 'flyte', \"\n            \"'mlflow', 'neptune', 'tensorboard', and 'wandb'. Use 'all' to report to all integrations \"\n            \"installed, 'none' for no integrations.\"\n        ),\n    )\n    report_group.add_argument(\n        \"--logging_steps\",\n        type = int,\n        default = 1,\n        help = \"Logging steps, default is 1\",\n    )\n\n    save_group = parser.add_argument_group(\"💾 Save Model Options\")\n    save_group.add_argument(\n        \"--output_dir\",\n        type = str,\n        default = \"outputs\",\n        help = \"Output directory\",\n    )\n    save_group.add_argument(\n        \"--save_model\",\n        action = \"store_true\",\n        help = \"Save the model after training\",\n    )\n    save_group.add_argument(\n        \"--save_method\",\n        type = str,\n        default = \"merged_16bit\",\n        choices = [\"merged_16bit\", \"merged_4bit\", \"lora\"],\n        help = \"Save method for the model, default is 'merged_16bit'\",\n    )\n    save_group.add_argument(\n        \"--save_gguf\",\n        action = \"store_true\",\n        help = \"Convert the model to GGUF after training\",\n    )\n    save_group.add_argument(\n        \"--save_path\",\n        type = str,\n        default = \"model\",\n        help = \"Path to save the model\",\n    )\n    save_group.add_argument(\n        \"--quantization\",\n        type = str,\n        default = \"q8_0\",\n        nargs = \"+\",\n        help = (\n            \"Quantization method for saving the model. common values ('f16', 'q4_k_m', 'q8_0'), \"\n            \"Check our wiki for all quantization methods https://github.com/unslothai/unsloth/wiki#saving-to-gguf\"\n        ),\n    )\n\n    push_group = parser.add_argument_group(\"🚀 Push Model Options\")\n    push_group.add_argument(\n        \"--push_model\",\n        action = \"store_true\",\n        help = \"Push the model to Hugging Face hub after training\",\n    )\n    push_group.add_argument(\n        \"--push_gguf\",\n        action = \"store_true\",\n        help = \"Push the model as GGUF to Hugging Face hub after training\",\n    )\n    push_group.add_argument(\n        \"--hub_path\",\n        type = str,\n        default = \"hf/model\",\n        help = \"Path on Hugging Face hub to push the model\",\n    )\n    push_group.add_argument(\n        \"--hub_token\",\n        type = str,\n        help = \"Token for pushing the model to Hugging Face hub\",\n    )\n\n    parser.add_argument(\n        \"--raw_text_file\", type = str, help = \"Path to raw text file for training\"\n    )\n    parser.add_argument(\n        \"--chunk_size\", type = int, default = 2048, help = \"Size of text chunks for training\"\n    )\n    parser.add_argument(\n        \"--stride\", type = int, default = 512, help = \"Overlap between chunks\"\n    )\n\n    args = parser.parse_args()\n    run(args)\n"
  },
  {
    "path": "unsloth_cli/__init__.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport typer\n\nfrom unsloth_cli.commands.train import train\nfrom unsloth_cli.commands.inference import inference\nfrom unsloth_cli.commands.export import export, list_checkpoints\nfrom unsloth_cli.commands.ui import ui\nfrom unsloth_cli.commands.studio import studio_app\n\napp = typer.Typer(\n    help = \"Command-line interface for Unsloth training, inference, and export.\",\n    context_settings = {\"help_option_names\": [\"-h\", \"--help\"]},\n)\n\napp.command()(train)\napp.command()(inference)\napp.command()(export)\napp.command(\"list-checkpoints\")(list_checkpoints)\napp.command()(ui)\napp.add_typer(studio_app, name = \"studio\", help = \"Unsloth Studio commands.\")\n"
  },
  {
    "path": "unsloth_cli/commands/__init__.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n"
  },
  {
    "path": "unsloth_cli/commands/export.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nfrom pathlib import Path\nfrom typing import Optional\n\nimport typer\n\n\nEXPORT_FORMATS = [\"merged-16bit\", \"merged-4bit\", \"gguf\", \"lora\"]\nGGUF_QUANTS = [\"q4_k_m\", \"q5_k_m\", \"q8_0\", \"f16\"]\n\n\ndef list_checkpoints(\n    outputs_dir: Path = typer.Option(\n        Path(\"./outputs\"), \"--outputs-dir\", help = \"Directory that holds training runs.\"\n    ),\n):\n    \"\"\"List checkpoints detected in the outputs directory.\"\"\"\n    from studio.backend.core.export import ExportBackend\n\n    backend = ExportBackend()\n    checkpoints = backend.scan_checkpoints(outputs_dir = str(outputs_dir))\n    if not checkpoints:\n        typer.echo(\"No checkpoints found.\")\n        raise typer.Exit()\n\n    for model_name, ckpt_list, metadata in checkpoints:\n        typer.echo(f\"\\n{model_name}:\")\n        for display, path, loss in ckpt_list:\n            loss_str = f\" (loss: {loss:.4f})\" if loss is not None else \"\"\n            typer.echo(f\"  {display}{loss_str}: {path}\")\n\n\ndef export(\n    checkpoint: Path = typer.Argument(..., help = \"Path to checkpoint directory.\"),\n    output_dir: Path = typer.Argument(..., help = \"Directory to save exported model.\"),\n    format: str = typer.Option(\n        \"merged-16bit\",\n        \"--format\",\n        \"-f\",\n        help = f\"Export format: {', '.join(EXPORT_FORMATS)}\",\n    ),\n    quantization: str = typer.Option(\n        \"q4_k_m\",\n        \"--quantization\",\n        \"-q\",\n        help = f\"GGUF quantization method: {', '.join(GGUF_QUANTS)}\",\n    ),\n    push_to_hub: bool = typer.Option(\n        False, \"--push-to-hub\", help = \"Push exported model to HuggingFace Hub.\"\n    ),\n    repo_id: Optional[str] = typer.Option(\n        None, \"--repo-id\", help = \"HuggingFace repo ID (username/model-name).\"\n    ),\n    hf_token: Optional[str] = typer.Option(\n        None, \"--hf-token\", envvar = \"HF_TOKEN\", help = \"HuggingFace token.\"\n    ),\n    private: bool = typer.Option(\n        False, \"--private\", help = \"Make the HuggingFace repo private.\"\n    ),\n    max_seq_length: int = typer.Option(2048, \"--max-seq-length\"),\n    load_in_4bit: bool = typer.Option(True, \"--load-in-4bit/--no-load-in-4bit\"),\n):\n    \"\"\"Export a checkpoint to various formats (merged, GGUF, LoRA adapter).\"\"\"\n    if format not in EXPORT_FORMATS:\n        typer.echo(\n            f\"Error: Invalid format '{format}'. Choose from: {', '.join(EXPORT_FORMATS)}\",\n            err = True,\n        )\n        raise typer.Exit(code = 2)\n\n    if push_to_hub and not repo_id:\n        typer.echo(\"Error: --repo-id required when using --push-to-hub\", err = True)\n        raise typer.Exit(code = 2)\n\n    from studio.backend.core.export import ExportBackend\n\n    backend = ExportBackend()\n\n    typer.echo(f\"Loading checkpoint: {checkpoint}\")\n    success, message = backend.load_checkpoint(\n        checkpoint_path = str(checkpoint),\n        max_seq_length = max_seq_length,\n        load_in_4bit = load_in_4bit,\n    )\n    if not success:\n        typer.echo(f\"Error: {message}\", err = True)\n        raise typer.Exit(code = 1)\n    typer.echo(message)\n\n    typer.echo(f\"Exporting as {format}...\")\n    if format == \"merged-16bit\":\n        success, message = backend.export_merged_model(\n            save_directory = str(output_dir),\n            format_type = \"16-bit (FP16)\",\n            push_to_hub = push_to_hub,\n            repo_id = repo_id,\n            hf_token = hf_token,\n            private = private,\n        )\n    elif format == \"merged-4bit\":\n        success, message = backend.export_merged_model(\n            save_directory = str(output_dir),\n            format_type = \"4-bit (FP4)\",\n            push_to_hub = push_to_hub,\n            repo_id = repo_id,\n            hf_token = hf_token,\n            private = private,\n        )\n    elif format == \"gguf\":\n        success, message = backend.export_gguf(\n            save_directory = str(output_dir),\n            quantization_method = quantization.upper(),\n            push_to_hub = push_to_hub,\n            repo_id = repo_id,\n            hf_token = hf_token,\n        )\n    elif format == \"lora\":\n        success, message = backend.export_lora_adapter(\n            save_directory = str(output_dir),\n            push_to_hub = push_to_hub,\n            repo_id = repo_id,\n            hf_token = hf_token,\n            private = private,\n        )\n\n    if not success:\n        typer.echo(f\"Error: {message}\", err = True)\n        raise typer.Exit(code = 1)\n\n    typer.echo(message)\n"
  },
  {
    "path": "unsloth_cli/commands/inference.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport sys\nfrom typing import Optional\n\nimport typer\n\n\ndef inference(\n    model: str = typer.Argument(..., help = \"HF model id or local path.\"),\n    prompt: str = typer.Argument(..., help = \"Prompt to send to the model.\"),\n    hf_token: Optional[str] = typer.Option(\n        None, \"--hf-token\", envvar = \"HF_TOKEN\", help = \"Hugging Face token if needed.\"\n    ),\n    temperature: float = typer.Option(0.7, \"--temperature\"),\n    top_p: float = typer.Option(0.9, \"--top-p\"),\n    top_k: int = typer.Option(40, \"--top-k\"),\n    max_new_tokens: int = typer.Option(256, \"--max-new-tokens\"),\n    repetition_penalty: float = typer.Option(1.1, \"--repetition-penalty\"),\n    system_prompt: str = typer.Option(\n        \"\",\n        \"--system-prompt\",\n        help = \"Optional system prompt to prepend.\",\n    ),\n    max_seq_length: int = typer.Option(2048, \"--max-seq-length\"),\n    load_in_4bit: bool = typer.Option(True, \"--load-in-4bit/--no-load-in-4bit\"),\n):\n    \"\"\"Run a single inference using the specified model.\"\"\"\n    from studio.backend.core import ModelConfig, get_inference_backend\n\n    inference_backend = get_inference_backend()\n    model_config = ModelConfig.from_ui_selection(\n        dropdown_value = model, search_value = None, hf_token = hf_token, is_lora = False\n    )\n    if not model_config:\n        typer.echo(\"Could not resolve model config\", err = True)\n        raise typer.Exit(code = 1)\n\n    if not inference_backend.load_model(\n        config = model_config,\n        max_seq_length = max_seq_length,\n        load_in_4bit = load_in_4bit,\n        hf_token = hf_token,\n    ):\n        typer.echo(\"Model load failed\", err = True)\n        raise typer.Exit(code = 1)\n\n    messages = [{\"role\": \"user\", \"content\": prompt}]\n    stream = inference_backend.generate_chat_response(\n        messages = messages,\n        system_prompt = system_prompt,\n        temperature = temperature,\n        top_p = top_p,\n        top_k = top_k,\n        max_new_tokens = max_new_tokens,\n        repetition_penalty = repetition_penalty,\n    )\n\n    typer.echo(\"Assistant:\", nl = True)\n    previous = \"\"\n    for chunk in stream:\n        delta = chunk[len(previous) :]\n        if delta:\n            sys.stdout.write(delta)\n            sys.stdout.flush()\n        previous = chunk\n    sys.stdout.write(\"\\n\")\n    sys.stdout.flush()\n"
  },
  {
    "path": "unsloth_cli/commands/studio.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport os\nimport platform\nimport subprocess\nimport sys\nimport time\nfrom pathlib import Path\nfrom typing import Optional\nimport typer\n\nstudio_app = typer.Typer(help = \"Unsloth Studio commands.\")\n\nSTUDIO_HOME = Path.home() / \".unsloth\" / \"studio\"\n\n# __file__ is unsloth_cli/commands/studio.py -- two parents up is the package root\n# (either site-packages or the repo root for editable installs).\n_PACKAGE_ROOT = Path(__file__).resolve().parent.parent.parent\n\n\ndef _studio_venv_python() -> Optional[Path]:\n    \"\"\"Return the studio venv Python binary, or None if not set up.\"\"\"\n    if platform.system() == \"Windows\":\n        p = STUDIO_HOME / \".venv\" / \"Scripts\" / \"python.exe\"\n    else:\n        p = STUDIO_HOME / \".venv\" / \"bin\" / \"python\"\n    return p if p.is_file() else None\n\n\ndef _find_run_py() -> Optional[Path]:\n    \"\"\"Find studio/backend/run.py.\n\n    No CWD dependency — works from any directory.\n    Since studio/ is now a proper package (has __init__.py), it lives in\n    site-packages after pip install, right next to unsloth_cli/.\n    \"\"\"\n    # 1. Relative to __file__ (site-packages or editable repo root)\n    run_py = _PACKAGE_ROOT / \"studio\" / \"backend\" / \"run.py\"\n    if run_py.is_file():\n        return run_py\n    # 2. Studio venv's site-packages (Linux + Windows layouts)\n    for pattern in (\n        \"lib/python*/site-packages/studio/backend/run.py\",\n        \"Lib/site-packages/studio/backend/run.py\",\n    ):\n        for match in (STUDIO_HOME / \".venv\").glob(pattern):\n            return match\n    return None\n\n\ndef _find_setup_script() -> Optional[Path]:\n    \"\"\"Find studio/setup.sh or studio/setup.ps1.\n\n    No CWD dependency — works from any directory.\n    \"\"\"\n    name = \"setup.ps1\" if platform.system() == \"Windows\" else \"setup.sh\"\n    # 1. Relative to __file__ (site-packages or editable repo root)\n    s = _PACKAGE_ROOT / \"studio\" / name\n    if s.is_file():\n        return s\n    # 2. Studio venv's site-packages\n    for pattern in (\n        f\"lib/python*/site-packages/studio/{name}\",\n        f\"Lib/site-packages/studio/{name}\",\n    ):\n        for match in (STUDIO_HOME / \".venv\").glob(pattern):\n            return match\n    return None\n\n\n# ── unsloth studio (server) ──────────────────────────────────────────\n\n\n@studio_app.callback(invoke_without_command = True)\ndef studio_default(\n    ctx: typer.Context,\n    port: int = typer.Option(8888, \"--port\", \"-p\"),\n    host: str = typer.Option(\"0.0.0.0\", \"--host\", \"-H\"),\n    frontend: Optional[Path] = typer.Option(None, \"--frontend\", \"-f\"),\n    silent: bool = typer.Option(False, \"--silent\", \"-q\"),\n):\n    \"\"\"Launch the Unsloth Studio server.\"\"\"\n    if ctx.invoked_subcommand is not None:\n        return\n\n    # Always use the studio venv if it exists and we're not already in it\n    studio_venv_dir = STUDIO_HOME / \".venv\"\n    in_studio_venv = sys.prefix.startswith(str(studio_venv_dir))\n\n    if not in_studio_venv:\n        studio_python = _studio_venv_python()\n        run_py = _find_run_py()\n        if studio_python and run_py:\n            if not silent:\n                typer.echo(\"Launching Unsloth Studio... Please wait...\")\n            args = [\n                str(studio_python),\n                str(run_py),\n                \"--host\",\n                host,\n                \"--port\",\n                str(port),\n            ]\n            if frontend:\n                args.extend([\"--frontend\", str(frontend)])\n            if silent:\n                args.append(\"--silent\")\n            # On Windows, os.execvp() spawns a child but the parent lingers,\n            # so Ctrl+C only kills the parent leaving the child orphaned.\n            # Use subprocess.run() on Windows so the parent waits for the child.\n            if sys.platform == \"win32\":\n                import subprocess as _sp\n\n                proc = _sp.Popen(args)\n                try:\n                    rc = proc.wait()\n                except KeyboardInterrupt:\n                    # Child has its own signal handler — let it finish\n                    rc = proc.wait()\n                raise typer.Exit(rc)\n            else:\n                os.execvp(str(studio_python), args)\n        else:\n            typer.echo(\"Studio not set up. Run 'unsloth studio setup' first.\")\n            raise typer.Exit(1)\n\n    from studio.backend.run import run_server\n\n    if not silent:\n        from studio.backend.run import _resolve_external_ip\n\n        display_host = _resolve_external_ip() if host == \"0.0.0.0\" else host\n        typer.echo(f\"Starting Unsloth Studio on http://{display_host}:{port}\")\n\n    run_server(\n        host = host,\n        port = port,\n        frontend_path = frontend,\n        silent = silent,\n    )\n\n    from studio.backend.run import _shutdown_event\n\n    try:\n        if _shutdown_event is not None:\n            # NOTE: Event.wait() without a timeout blocks at the C level\n            # on Linux, preventing Python from delivering SIGINT (Ctrl+C).\n            while not _shutdown_event.is_set():\n                _shutdown_event.wait(timeout = 1)\n        else:\n            while True:\n                time.sleep(1)\n    except KeyboardInterrupt:\n        from studio.backend.run import _graceful_shutdown, _server\n\n        _graceful_shutdown(_server)\n        typer.echo(\"\\nShutting down...\")\n\n\n# ── unsloth studio setup ─────────────────────────────────────────────\n\n\n@studio_app.command()\ndef setup():\n    \"\"\"Run one-time Studio environment setup.\"\"\"\n    script = _find_setup_script()\n    if not script:\n        typer.echo(\"Error: Could not find setup script (setup.sh / setup.ps1).\")\n        raise typer.Exit(1)\n\n    if platform.system() == \"Windows\":\n        result = subprocess.run(\n            [\"powershell\", \"-ExecutionPolicy\", \"Bypass\", \"-File\", str(script)],\n        )\n    else:\n        result = subprocess.run([\"bash\", str(script)])\n\n    if result.returncode != 0:\n        raise typer.Exit(result.returncode)\n\n\n# ── unsloth studio reset-password ────────────────────────────────────\n\n\n@studio_app.command(\"reset-password\")\ndef reset_password():\n    \"\"\"Reset the Studio admin password.\n\n    Deletes the auth database so that a fresh admin account with a new\n    random password is created on the next server start.  The Studio\n    server must be restarted after running this command.\n    \"\"\"\n    auth_dir = STUDIO_HOME / \"auth\"\n    db_file = auth_dir / \"auth.db\"\n    pw_file = auth_dir / \".bootstrap_password\"\n\n    if not db_file.exists():\n        typer.echo(\"No auth database found -- nothing to reset.\")\n        raise typer.Exit(0)\n\n    db_file.unlink(missing_ok = True)\n    pw_file.unlink(missing_ok = True)\n\n    typer.echo(\"Auth database deleted. Restart Unsloth Studio to get a new password.\")\n"
  },
  {
    "path": "unsloth_cli/commands/train.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport time\nfrom pathlib import Path\nfrom typing import Optional\n\nimport typer\n\nfrom unsloth_cli.config import Config, load_config\nfrom unsloth_cli.options import add_options_from_config\n\n\n@add_options_from_config(Config)\ndef train(\n    config: Optional[Path] = typer.Option(\n        None,\n        \"--config\",\n        \"-c\",\n        help = \"Path to YAML/JSON config file. CLI flags override config values.\",\n    ),\n    hf_token: Optional[str] = typer.Option(\n        None, \"--hf-token\", envvar = \"HF_TOKEN\", help = \"Hugging Face token if needed.\"\n    ),\n    wandb_token: Optional[str] = typer.Option(\n        None, \"--wandb-token\", envvar = \"WANDB_API_KEY\", help = \"Weights & Biases API key.\"\n    ),\n    dry_run: bool = typer.Option(\n        False,\n        \"--dry-run\",\n        help = \"Show resolved config and exit without training.\",\n    ),\n    config_overrides: dict = None,\n):\n    \"\"\"Launch training using the existing Unsloth training backend.\"\"\"\n    try:\n        cfg = load_config(config)\n    except FileNotFoundError as e:\n        typer.echo(f\"Error: {e}\", err = True)\n        raise typer.Exit(code = 2)\n\n    cfg.apply_overrides(**config_overrides)\n\n    # CLI/env tokens take precedence over config\n    # Handle case where typer.Option isn't resolved (decorator interaction)\n    from typer.models import OptionInfo\n\n    if isinstance(hf_token, OptionInfo):\n        hf_token = None\n    if isinstance(wandb_token, OptionInfo):\n        wandb_token = None\n    hf_token = hf_token or cfg.logging.hf_token\n    wandb_token = wandb_token or cfg.logging.wandb_token\n\n    if dry_run:\n        import yaml\n\n        data = cfg.model_dump()\n        data[\"training\"][\"output_dir\"] = str(data[\"training\"][\"output_dir\"])\n        typer.echo(yaml.dump(data, default_flow_style = False, sort_keys = False))\n        raise typer.Exit(code = 0)\n\n    if not cfg.model:\n        typer.echo(\"Error: provide --model or set model in --config\", err = True)\n        raise typer.Exit(code = 2)\n\n    if not cfg.data.dataset and not cfg.data.local_dataset:\n        typer.echo(\n            \"Error: provide --dataset or --local-dataset (or via --config)\", err = True\n        )\n        raise typer.Exit(code = 2)\n\n    # Check if the model path is a LoRA adapter (has adapter_config.json)\n    model_path = Path(cfg.model) if cfg.model else None\n    model_is_lora = (\n        model_path\n        and model_path.is_dir()\n        and (model_path / \"adapter_config.json\").exists()\n    )\n    use_lora = cfg.training.training_type.lower() == \"lora\"\n\n    if model_is_lora and not use_lora:\n        typer.echo(\n            \"Error: Cannot do full finetuning on a LoRA adapter. \"\n            \"Use --training-type lora or provide a base model.\",\n            err = True,\n        )\n        raise typer.Exit(code = 2)\n\n    from studio.backend.core.training.trainer import UnslothTrainer\n\n    trainer = UnslothTrainer()\n\n    # Load model (trainer.is_vlm is set after this)\n    if not trainer.load_model(\n        model_name = cfg.model,\n        max_seq_length = cfg.training.max_seq_length,\n        load_in_4bit = cfg.training.load_in_4bit if use_lora else False,\n        hf_token = hf_token,\n    ):\n        typer.echo(\"Model load failed\", err = True)\n        raise typer.Exit(code = 1)\n\n    is_vision = trainer.is_vlm\n\n    if not trainer.prepare_model_for_training(**cfg.model_kwargs(use_lora, is_vision)):\n        typer.echo(\"Model preparation failed\", err = True)\n        raise typer.Exit(code = 1)\n\n    result = trainer.load_and_format_dataset(\n        dataset_source = cfg.data.dataset or \"\",\n        format_type = cfg.data.format_type,\n        local_datasets = cfg.data.local_dataset,\n    )\n    if result is None:\n        typer.echo(\"Dataset load failed\", err = True)\n        raise typer.Exit(code = 1)\n\n    ds, eval_ds = result\n\n    training_kwargs = cfg.training_kwargs()\n    training_kwargs[\"wandb_token\"] = wandb_token  # CLI/env takes precedence\n    started = trainer.start_training(\n        dataset = ds, eval_dataset = eval_ds, **training_kwargs\n    )\n\n    if not started:\n        typer.echo(\"Training failed to start\", err = True)\n        raise typer.Exit(code = 1)\n\n    try:\n        while trainer.training_thread and trainer.training_thread.is_alive():\n            time.sleep(1)\n    except KeyboardInterrupt:\n        typer.echo(\"Stopping training (Ctrl+C detected)...\")\n        trainer.stop_training()\n    finally:\n        if trainer.training_thread:\n            trainer.training_thread.join()\n\n    final = trainer.get_training_progress()\n    if getattr(final, \"error\", None):\n        typer.echo(f\"Training error: {final.error}\", err = True)\n        raise typer.Exit(code = 1)\n"
  },
  {
    "path": "unsloth_cli/commands/ui.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nimport os\nimport sys\nimport time\nfrom pathlib import Path\nfrom typing import Optional\n\nimport typer\n\n\ndef ui(\n    port: int = typer.Option(\n        8888, \"--port\", \"-p\", help = \"Port to run the UI server on.\"\n    ),\n    host: str = typer.Option(\n        \"0.0.0.0\", \"--host\", \"-H\", help = \"Host address to bind to.\"\n    ),\n    frontend: Optional[Path] = typer.Option(\n        None, \"--frontend\", \"-f\", help = \"Path to frontend build directory.\"\n    ),\n    silent: bool = typer.Option(\n        False, \"--silent\", \"-q\", help = \"Suppress startup messages.\"\n    ),\n):\n    \"\"\"Launch the Unsloth web UI backend server (alias for 'unsloth studio').\"\"\"\n    from unsloth_cli.commands.studio import (\n        _studio_venv_python,\n        _find_run_py,\n        STUDIO_HOME,\n    )\n\n    # Re-execute in studio venv if available and not already inside it\n    studio_venv_dir = STUDIO_HOME / \".venv\"\n    in_studio_venv = sys.prefix.startswith(str(studio_venv_dir))\n\n    if not in_studio_venv:\n        studio_python = _studio_venv_python()\n        run_py = _find_run_py()\n        if studio_python and run_py:\n            if not silent:\n                typer.echo(\"Launching Unsloth Studio... Please wait...\")\n            args = [\n                str(studio_python),\n                str(run_py),\n                \"--host\",\n                host,\n                \"--port\",\n                str(port),\n            ]\n            if frontend:\n                args.extend([\"--frontend\", str(frontend)])\n            if silent:\n                args.append(\"--silent\")\n            # On Windows, os.execvp() spawns a child but the parent lingers,\n            # so Ctrl+C only kills the parent leaving the child orphaned.\n            # Use subprocess.run() on Windows so the parent waits for the child.\n            if sys.platform == \"win32\":\n                import subprocess as _sp\n\n                proc = _sp.Popen(args)\n                try:\n                    rc = proc.wait()\n                except KeyboardInterrupt:\n                    # Child has its own signal handler — let it finish\n                    rc = proc.wait()\n                raise typer.Exit(rc)\n            else:\n                os.execvp(str(studio_python), args)\n        else:\n            typer.echo(\"Studio not set up. Run 'unsloth studio setup' first.\")\n            raise typer.Exit(1)\n\n    from studio.backend.run import run_server\n\n    if not silent:\n        from studio.backend.run import _resolve_external_ip\n\n        display_host = _resolve_external_ip() if host == \"0.0.0.0\" else host\n        typer.echo(f\"Starting Unsloth Studio on http://{display_host}:{port}\")\n\n    run_server(\n        host = host,\n        port = port,\n        frontend_path = frontend,\n        silent = silent,\n    )\n\n    from studio.backend.run import _shutdown_event\n\n    try:\n        if _shutdown_event is not None:\n            _shutdown_event.wait()\n        else:\n            while True:\n                time.sleep(1)\n    except KeyboardInterrupt:\n        from studio.backend.run import _graceful_shutdown, _server\n\n        _graceful_shutdown(_server)\n        typer.echo(\"\\nShutting down...\")\n"
  },
  {
    "path": "unsloth_cli/config.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\nfrom pathlib import Path\nfrom typing import Literal, Optional, List\n\nimport yaml\nfrom pydantic import BaseModel, Field\n\n\nclass DataConfig(BaseModel):\n    dataset: Optional[str] = None\n    local_dataset: Optional[List[str]] = None\n    format_type: Literal[\"auto\", \"alpaca\", \"chatml\", \"sharegpt\"] = \"auto\"\n\n\nclass TrainingConfig(BaseModel):\n    training_type: Literal[\"lora\", \"full\"] = \"lora\"\n    max_seq_length: int = 2048\n    load_in_4bit: bool = True\n    output_dir: Path = Path(\"./outputs\")\n    num_epochs: int = 3\n    learning_rate: float = 2e-4\n    batch_size: int = 2\n    gradient_accumulation_steps: int = 4\n    warmup_steps: int = 5\n    max_steps: int = 0\n    save_steps: int = 0\n    weight_decay: float = 0.01\n    random_seed: int = 3407\n    packing: bool = False\n    train_on_completions: bool = False\n    gradient_checkpointing: Literal[\"unsloth\", \"true\", \"none\"] = \"unsloth\"\n\n\nclass LoraConfig(BaseModel):\n    lora_r: int = 64\n    lora_alpha: int = 16\n    lora_dropout: float = 0.0\n    target_modules: str = \"q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj\"\n    vision_all_linear: bool = False\n    use_rslora: bool = False\n    use_loftq: bool = False\n    finetune_vision_layers: bool = True\n    finetune_language_layers: bool = True\n    finetune_attention_modules: bool = True\n    finetune_mlp_modules: bool = True\n\n\nclass LoggingConfig(BaseModel):\n    enable_wandb: bool = False\n    wandb_project: str = \"unsloth-training\"\n    wandb_token: Optional[str] = None\n    enable_tensorboard: bool = False\n    tensorboard_dir: str = \"runs\"\n    hf_token: Optional[str] = None\n\n\nclass Config(BaseModel):\n    model: Optional[str] = None\n    data: DataConfig = Field(default_factory = DataConfig)\n    training: TrainingConfig = Field(default_factory = TrainingConfig)\n    lora: LoraConfig = Field(default_factory = LoraConfig)\n    logging: LoggingConfig = Field(default_factory = LoggingConfig)\n\n    def apply_overrides(self, **kwargs):\n        \"\"\"Apply CLI overrides by matching arg names to config fields.\"\"\"\n        for key, value in kwargs.items():\n            if value is None:\n                continue\n            if hasattr(self, key):\n                setattr(self, key, value)\n            else:\n                for section in (self.data, self.training, self.lora, self.logging):\n                    if hasattr(section, key):\n                        setattr(section, key, value)\n                        break\n\n    def model_kwargs(self, use_lora: bool, is_vision: bool) -> dict:\n        \"\"\"Return kwargs for trainer.prepare_model_for_training().\"\"\"\n        # Determine target modules based on model type\n        if use_lora and is_vision:\n            # Vision models expect a string (e.g., \"all-linear\"); fall back to None to use trainer defaults\n            target_modules = \"all-linear\" if self.lora.vision_all_linear else None\n        else:\n            parsed = [\n                m.strip()\n                for m in str(self.lora.target_modules).split(\",\")\n                if m and m.strip()\n            ]\n            target_modules = parsed or None\n\n        return {\n            \"use_lora\": use_lora,\n            \"finetune_vision_layers\": self.lora.finetune_vision_layers,\n            \"finetune_language_layers\": self.lora.finetune_language_layers,\n            \"finetune_attention_modules\": self.lora.finetune_attention_modules,\n            \"finetune_mlp_modules\": self.lora.finetune_mlp_modules,\n            \"target_modules\": target_modules,\n            \"lora_r\": self.lora.lora_r,\n            \"lora_alpha\": self.lora.lora_alpha,\n            \"lora_dropout\": self.lora.lora_dropout,\n            \"use_gradient_checkpointing\": self.training.gradient_checkpointing,\n            \"use_rslora\": self.lora.use_rslora,\n            \"use_loftq\": self.lora.use_loftq,\n        }\n\n    def training_kwargs(self) -> dict:\n        \"\"\"Return kwargs for trainer.start_training().\"\"\"\n        return {\n            \"output_dir\": str(self.training.output_dir),\n            \"num_epochs\": self.training.num_epochs,\n            \"learning_rate\": self.training.learning_rate,\n            \"batch_size\": self.training.batch_size,\n            \"gradient_accumulation_steps\": self.training.gradient_accumulation_steps,\n            \"warmup_steps\": self.training.warmup_steps,\n            \"max_steps\": self.training.max_steps,\n            \"save_steps\": self.training.save_steps,\n            \"weight_decay\": self.training.weight_decay,\n            \"random_seed\": self.training.random_seed,\n            \"packing\": self.training.packing,\n            \"train_on_completions\": self.training.train_on_completions,\n            \"max_seq_length\": self.training.max_seq_length,\n            \"enable_wandb\": self.logging.enable_wandb,\n            \"wandb_project\": self.logging.wandb_project,\n            \"wandb_token\": self.logging.wandb_token,\n            \"enable_tensorboard\": self.logging.enable_tensorboard,\n            \"tensorboard_dir\": self.logging.tensorboard_dir,\n        }\n\n\ndef load_config(path: Optional[Path]) -> Config:\n    \"\"\"Load config from YAML/JSON file, or return defaults if no path given.\"\"\"\n    if not path:\n        return Config()\n\n    path = Path(path)\n    if not path.exists():\n        raise FileNotFoundError(f\"Config file not found: {path}\")\n\n    text = path.read_text(encoding = \"utf-8\")\n    if path.suffix.lower() in {\".yaml\", \".yml\"}:\n        data = yaml.safe_load(text) or {}\n    else:\n        import json\n\n        data = json.loads(text or \"{}\")\n\n    return Config(**data)\n"
  },
  {
    "path": "unsloth_cli/options.py",
    "content": "# SPDX-License-Identifier: AGPL-3.0-only\n# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0\n\n\"\"\"Generate Typer CLI options from Pydantic models.\"\"\"\n\nimport functools\nimport inspect\nfrom pathlib import Path\nfrom typing import Any, Callable, Optional, get_args, get_origin\n\nimport typer\nfrom pydantic import BaseModel\n\n\ndef _python_name_to_cli_flag(name: str) -> str:\n    \"\"\"Convert python_name to --cli-flag.\"\"\"\n    return \"--\" + name.replace(\"_\", \"-\")\n\n\ndef _unwrap_optional(annotation: Any) -> Any:\n    \"\"\"Unwrap Optional[X] to X.\"\"\"\n    origin = get_origin(annotation)\n    if origin is not None:\n        args = get_args(annotation)\n        if type(None) in args:\n            non_none = [a for a in args if a is not type(None)]\n            if non_none:\n                return non_none[0]\n    return annotation\n\n\ndef _is_bool_field(annotation: Any) -> bool:\n    \"\"\"Check if field is a boolean (including Optional[bool]).\"\"\"\n    return _unwrap_optional(annotation) is bool\n\n\ndef _is_list_type(annotation: Any) -> bool:\n    \"\"\"Check if type is a List.\"\"\"\n    return get_origin(annotation) is list\n\n\ndef _get_python_type(annotation: Any) -> type:\n    \"\"\"Get the Python type for annotation.\"\"\"\n    unwrapped = _unwrap_optional(annotation)\n    if unwrapped in (str, int, float, bool, Path):\n        return unwrapped\n    return str\n\n\ndef _collect_config_fields(config_class: type[BaseModel]) -> list[tuple[str, Any]]:\n    \"\"\"\n    Collect all fields from a config class, flattening nested models. Returns list of\n    (name, field_info) tuples. Raises ValueError on duplicate field names.\n    \"\"\"\n    fields = []\n    seen_names: set[str] = set()\n\n    for name, field_info in config_class.model_fields.items():\n        annotation = field_info.annotation\n        # Skip nested models - recurse into them\n        if isinstance(annotation, type) and issubclass(annotation, BaseModel):\n            for nested_name, nested_field in annotation.model_fields.items():\n                if nested_name in seen_names:\n                    raise ValueError(f\"Duplicate field name '{nested_name}' in config\")\n                seen_names.add(nested_name)\n                fields.append((nested_name, nested_field))\n        else:\n            if name in seen_names:\n                raise ValueError(f\"Duplicate field name '{name}' in config\")\n            seen_names.add(name)\n            fields.append((name, field_info))\n    return fields\n\n\ndef add_options_from_config(config_class: type[BaseModel]) -> Callable:\n    \"\"\"\n    Decorator that adds CLI options for all fields in a Pydantic config model.\n\n    The decorated function should declare a `config_overrides: dict = None` parameter\n    which will receive a dict of all CLI-provided config values.\n    \"\"\"\n    fields = _collect_config_fields(config_class)\n    field_names = {\n        name for name, field_info in fields if not _is_list_type(field_info.annotation)\n    }\n\n    def decorator(func: Callable) -> Callable:\n        sig = inspect.signature(func)\n        original_params = list(sig.parameters.values())\n        original_param_names = {p.name for p in original_params}\n\n        # Build new parameters: config fields first, then original params\n        new_params = []\n\n        for field_name, field_info in fields:\n            # Skip fields already defined in function signature (e.g., with envvar)\n            if field_name in original_param_names:\n                continue\n            annotation = field_info.annotation\n            if _is_list_type(annotation):\n                continue\n\n            flag_name = _python_name_to_cli_flag(field_name)\n            help_text = field_info.description or \"\"\n\n            if _is_bool_field(annotation):\n                default = typer.Option(\n                    None,\n                    f\"{flag_name}/--no-{field_name.replace('_', '-')}\",\n                    help = help_text,\n                )\n                param = inspect.Parameter(\n                    field_name,\n                    inspect.Parameter.POSITIONAL_OR_KEYWORD,\n                    default = default,\n                    annotation = Optional[bool],\n                )\n            else:\n                py_type = _get_python_type(annotation)\n                default = typer.Option(None, flag_name, help = help_text)\n                param = inspect.Parameter(\n                    field_name,\n                    inspect.Parameter.POSITIONAL_OR_KEYWORD,\n                    default = default,\n                    annotation = Optional[py_type],\n                )\n            new_params.append(param)\n\n        # Add original params, excluding config_overrides (will be injected)\n        for param in original_params:\n            if param.name != \"config_overrides\":\n                new_params.append(param)\n\n        new_sig = sig.replace(parameters = new_params)\n\n        @functools.wraps(func)\n        def wrapper(*args, **kwargs):\n            config_overrides = {}\n            for key in list(kwargs.keys()):\n                if key in field_names:\n                    if kwargs[key] is not None:\n                        config_overrides[key] = kwargs[key]\n                    # Only delete if not an explicitly declared parameter\n                    if key not in original_param_names:\n                        del kwargs[key]\n\n            kwargs[\"config_overrides\"] = config_overrides\n            return func(*args, **kwargs)\n\n        wrapper.__signature__ = new_sig\n        return wrapper\n\n    return decorator\n"
  }
]